# Non-Linear Transfer Maps

This notebook computes transfer maps via the Lie Operator.

The Lie operator is defined as:
$$                                                                              
  \colon f \colon = \sum^n_{i=1} \left(\frac{\partial f}{\partial x_i} \frac{\partial}{\partial p_i}
                             - \frac{\partial f}{\partial p_i} \frac{\partial}{\partial x_i}
                        \right)                                                   
$$        

In [452]:
import sympy
from sympy import symbols, Function, sqrt, cos, print_latex, diff, Matrix, Rational, Add, re, I
from IPython.display import Markdown, Latex
from functools import partial
from math import factorial

def poisson_bracket(planes, momenta, h1, h2):
    '''
    Returns [h1, h2]
    parameter:
        - planes: a vector of symbols for each dimension, eg. (x y)
        - momenta: a vector of symbols for each momentum, eg (px, py)
        - h1 and h2: whatever the poisson bracket needs to be applied on
    '''
    
    res = 0
    for i in range(len(planes)):
        qi = planes[i]
        pi = momenta[i]

        res += diff(h1, qi) * diff(h2, pi) - diff(h1, pi) * diff(h2, qi)
    return res

## Creating the transfer map

In [646]:
def transfer_map(elements, var, order):
    """
    Returns the result of a map of hamiltonians acting on a variable (like x, px, etc)
    The order is the order of the power series on the exponential e^{:f:}g = g + [f, g] + 1/2! [f, [f, g]] ...
    """

    # first element (g)
    g = var

    # Iterate on each element and use the last ones as g
    for i in range(len(elements)):
        # Current element
        f = elements[len(elements)-1-i]
        
        # Compute each order of the power series of the exponential
        for i in range(order):
            # Compute the power of poisson brackets (like :f:^2 g = [f, [f, g])
            tmp = g
            for _ in range(i+1):
                tmp = pb(f, tmp)

            # Add the the result
            g += Rational(1, factorial(i+1)) * tmp

    return g


def get_terms(expr, term):
    """
    Helper function to get a specific term in an expression
    Example:
        get_terms("ax + b - 5cx**2", "x**2") returns "-5cx**2"
    """
    return Add(*[argi for argi in expr.args if argi.has(term)])


def get_like_expr(expr, multipole, var):
    """
    Returns a "magnet-like" expression of a multipole differentiated by "var"
    Example for x:
        Returns x^3 - xy^2 for an octupole instead of x^4 - 6x^2y^2 + y^4
    """

    expr_like = diff(multipole, var)

    # Get the factors to remove them
    monom_list = sympy.Poly(expr_like, x, y).monoms()
    expr_like = sum([x**a*y**b for a, b in monom_list if a + b > 0]) 

    # Get the terms and sum them
    return sum([get_terms(expr, term) for term in expr_like.args])
    

## Definiting some hamiltonians

In [647]:
# Symbols for coordinates
x, px, y, py = symbols("x p_x y p_y", real=True)

# Create a partial function for it to be easier to use
pb = partial(poisson_bracket, [x, y], [px, py])

# Multipole strength, K1 = dipole
K = {}
for i in range(1, 8):
    K[i] = symbols(f"K_{i}", real=True)

def hamiltonian(n):
    # Returns the normal hamiltonian (no skew)
    return Rational(1, factorial(n)) * K[n] * re((x + I *y)**n)
    
# Each magnet now
dipole = hamiltonian(1)
quadrupole = hamiltonian(2)
sextupole = hamiltonian(3)
octupole = hamiltonian(4)
decapole = hamiltonian(5)
dodecapole = hamiltonian(6)

## Defining elements

In [676]:
# Length of drift
L = symbols("L")

# K3 strength for each sextupole
K3_1 = symbols("K_{3\,h1}", real=True)
K3_2 = symbols("K_{3\,h2}", real=True)

# Hamiltonians
h1 = sextupole.subs(K[3], K3_1)
h2 = sextupole.subs(K[3], K3_2)
d1 = -Rational(1,2) * L * (px**2 + py**2)  # drift

display(h1)
display(h2)
display(d1)

# List of elements to be added together
elements = [h1, d1, h2]


K_{3,h1}*(x**3 - 3*x*y**2)/6

K_{3,h2}*(x**3 - 3*x*y**2)/6

-L*(p_x**2 + p_y**2)/2

## Final Transfer Map

In [677]:
# Applied the transfer map on px
# Order 0 => no effect on any coordinate
display(Markdown("### Map to the order 0:"))
map = transfer_map(elements, px, order=0).expand()
display(map)

### Map to the order 0:

p_x

In [681]:
# Applied the transfer map on px
# Order 1
display(Markdown("### Map to the order 1:"))
map = transfer_map(elements, px, order=1).expand()
display(map)

display(Markdown("#### Octupolar component:"))
display(get_like_expr(map, octupole, x))

### Map to the order 1:

K_{3,h1}*K_{3,h2}*L*x**3/2 + K_{3,h1}*K_{3,h2}*L*x*y**2/2 + K_{3,h1}*x**2/2 - K_{3,h1}*y**2/2 + K_{3,h2}*L*p_x*x - K_{3,h2}*L*p_y*y + K_{3,h2}*x**2/2 - K_{3,h2}*y**2/2 + p_x

#### Octupolar component:

K_{3,h1}*K_{3,h2}*L*x**3/2 + K_{3,h1}*K_{3,h2}*L*x*y**2/2

In [682]:
# Applied the transfer map on px
# Order 2
display(Markdown("### Map to the order 2:"))
map = transfer_map(elements, px, order=2).expand()
display(map)

# Look at the decapolar like terms: d(x^5 - 10x^3y^2 + 5xy^4) / dx
display(Markdown("#### Decapolar component:"))
display(get_like_expr(map, decapole, x))

### Map to the order 2:

K_{3,h1}**2*K_{3,h2}*L**2*x**4/8 - 3*K_{3,h1}**2*K_{3,h2}*L**2*x**2*y**2/4 + K_{3,h1}**2*K_{3,h2}*L**2*y**4/8 + K_{3,h1}*K_{3,h2}*L**2*p_x*x**2/2 - K_{3,h1}*K_{3,h2}*L**2*p_x*y**2/2 + K_{3,h1}*K_{3,h2}*L**2*p_y*x*y + K_{3,h1}*K_{3,h2}*L*x**3/2 + K_{3,h1}*K_{3,h2}*L*x*y**2/2 + K_{3,h1}*x**2/2 - K_{3,h1}*y**2/2 + K_{3,h2}*L**2*p_x**2/2 - K_{3,h2}*L**2*p_y**2/2 + K_{3,h2}*L*p_x*x - K_{3,h2}*L*p_y*y + K_{3,h2}*x**2/2 - K_{3,h2}*y**2/2 + p_x

#### Decapolar component:

K_{3,h1}**2*K_{3,h2}*L**2*x**4/8 - 3*K_{3,h1}**2*K_{3,h2}*L**2*x**2*y**2/4 + K_{3,h1}**2*K_{3,h2}*L**2*y**4/8

In [680]:
# Applied the transfer map on px
# Order 3 => no difference from order 2
display(Markdown("### Map to the order 3:"))
print("Up to order 2 == Up to order 3: ", transfer_map(elements, px, order=2) == transfer_map(elements, px, order=3))

### Map to the order 3:

Up to order 2 == Up to order 3:  True


## Tests

In [649]:
# Manual transfer map of the elements to order 1
e1 = px + pb(h2, px)
e2 = e1 + pb(d1, e1)
e3 = e2 + pb(h1, e2)

# Same as transfer_map(elements, px, order=1)
display(e3.expand()) 

K_{3,h1}*K_{3,h2}*L*x**3/2 + K_{3,h1}*K_{3,h2}*L*x*y**2/2 + K_{3,h1}*x**2/2 - K_{3,h1}*y**2/2 + K_{3,h2}*L*p_x*x - K_{3,h2}*L*p_y*y + K_{3,h2}*x**2/2 - K_{3,h2}*y**2/2 + p_x

In [653]:
pb(sextupole, py)

-K_3*x*y

In [667]:
pb(h1, pb(h1, px))

0

In [664]:
px + pb(h1, px)

K_{3,h1}*(3*x**2 - 3*y**2)/6 + p_x