# Drift vs Advanced Drift

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pmd_beamphysics.wavefront.wavefront import Wavefront
from pmd_beamphysics.wavefront.propagators import (
    drift_wavefront,
    drift_wavefront_advanced,
)
from pmd_beamphysics.wavefront.gaussian import add_gaussian
from pmd_beamphysics.units import Z0, c_light
import numpy as np

import matplotlib.pyplot as plt

# Gaussian pulse 

In [None]:
W0 = Wavefront(
    Ex=np.zeros((101, 101, 1368)),
    dx=5.999999999999999e-06,
    dy=5.999999999999999e-06,
    dz=1.46724e-09,
    wavelength=1.2227e-10,
)
w0 = 50e-6
zR = np.pi * w0**2 / W0.wavelength


W = W0.copy()
add_gaussian(W, z=0, w0=w0, energy=1.2345e-6)

W.plot()

## Checking the total Energy in the W object

In [None]:
W.energy

In [None]:
# W.write_genesis4('field_test_gaussian.h5')

## Propagate to 420 meters

In [None]:
W1 = drift_wavefront(W, 420)
W1.plot()

## Advanced propagator 

In [None]:
W2 = drift_wavefront_advanced(W, 420, Rcurv=200)
W2.plot()

In [None]:
W2.energy

## Check energy

In [None]:
W2.energy

## Phase errors

In [None]:
plt.figure(figsize=(12, 3))
phase1 = np.sum(np.angle(W1.Ex), axis=2)[:, 51]
phase2 = np.sum(np.angle(W2.Ex), axis=2)[:, 51]

plt.plot(W1.xvec, phase1, label="drift")
plt.plot(W2.xvec, phase2, label="advanced drift")
plt.legend()

## Checking Gaussian beam propagation

In [None]:
%%time
Zlist = np.linspace(0, 100, 20)
Wlist = [drift_wavefront(W, z) for z in Zlist]

Wlist2 = [drift_wavefront_advanced(W, z, Rcurv=40) for z in Zlist]

sizes = np.array([w.sigma_x for w in Wlist])
sizes2 = np.array([w.sigma_x for w in Wlist2])

In [None]:
[w.energy for w in Wlist2]

In [None]:
np.sum(np.abs(Wlist2[2].Ex) ** 2) * Wlist2[2].dx * Wlist2[2].dy * Wlist2[2].dz / (
    2 * Z0 * c_light
)

In [None]:
sigma_x0 = W.sigma_x

expected_w = sigma_x0 * np.sqrt(1 + (Zlist / zR) ** 2)

In [None]:
fig, ax = plt.subplots()
ax.plot(Zlist, 1e6 * expected_w, label="expected")
ax.plot(Zlist, 1e6 * sizes, "--", label="drift")
ax.plot(Zlist, 1e6 * sizes2, "--", label="advanced drift")
ax.set_xlabel(r"$z$ (m)")

ax.set_ylabel(r"$\sigma_x$ (µm)")
plt.legend()

# K-space 

In [None]:
Wk = W.to_kspace()
Wk.plot()

In [None]:
sigma_thetax = np.array([float(w.to_kspace().sigma_thetax) for w in Wlist])
sigma_thetax2 = np.array([float(w.to_kspace().sigma_thetax) for w in Wlist2])

In [None]:
fig, ax = plt.subplots()
ax.plot(Zlist, 1e6 * sigma_thetax, "--", label="drift")
ax.plot(Zlist[1:], 1e6 * sigma_thetax2[1:], "o-", label="advanced drift")
ax.set_xlabel(r"$z$ (m)")

ax.set_ylabel(r"$\sigma_\theta$ (µrad)")

ax.set_ylim(0, 1)
plt.legend()