This code computes the probability of Schwinger pair production by an arbitrary $A_\mu(t,x,y,z)$, assuming only that it vanishes as any of the coordinates goes to infinity, using the method in:

*Momentum correlation in pair production by spacetime dependent fields from scattered wave functions*

arXiv:2509.17770

G. Torgrimsson

# Import

diffrax contains the ODE solver for solving the Dirac equation on a GPU.

In [1]:
# this is needed when running this code on Colab
# %%capture output
# %pip install diffrax

In [2]:
import jax
import jax.numpy as np
from diffrax import diffeqsolve, ODETerm, Dopri5, Tsit5, PIDController, SaveAt

import matplotlib.pyplot as plt
from scipy.integrate import simpson
from scipy.interpolate import make_interp_spline, RegularGridInterpolator
from tqdm import tqdm

In [3]:
precision=64

if precision==32:
    realType=np.float32
    compType=np.complex64

elif precision==64:
    jax.config.update("jax_enable_x64", True)
    realType=np.float64
    compType=np.complex128

In [4]:
# def toMathematica(arr):
#   mathematica_str=str(arr.tolist()).replace('[','{').replace(']','}').replace("e","*10^").replace("j","I").replace("(","").replace(")","")
#   print(mathematica_str)

In [5]:
#tol=1e-7
@jax.jit
def chop(x, tol=1e-10):
    return np.where(np.abs(x) < tol, 0.0, x)

In [6]:
@jax.jit
def simpson_jax(y, dx=1.0):
    n = y.shape[-1]
    if n % 2 == 1:
        # Odd number of samples → pure Simpson's 1/3 rule
        return (dx/3) * (
            y[..., 0]
            + 4 * np.sum(y[..., 1:-1:2], axis=-1)
            + 2 * np.sum(y[..., 2:-2:2], axis=-1)
            + y[..., -1]
        )
    else:
        # Even number of samples → use Simpson's 1/3 on first n-3, 3/8 on last 4
        y_main = y[..., :-3]
        y_tail = y[..., -4:]

        main = (dx/3) * (
            y_main[..., 0]
            + 4 * np.sum(y_main[..., 1:-1:2], axis=-1)
            + 2 * np.sum(y_main[..., 2:-2:2], axis=-1)
            + y_main[..., -1]
        )

        tail = (3*dx/8) * (
            y_tail[..., 0]
            + 3 * y_tail[..., 1]
            + 3 * y_tail[..., 2]
            + y_tail[..., 3]
        )

        return main + tail

In [7]:
from datetime import datetime
import time
from IPython import get_ipython

start_time = time.perf_counter()

In [8]:
def pre_run_cell(info):
    global start_time
    start_time = time.perf_counter()

def post_run_cell(result):
    end_time = time.perf_counter()
    elapsed = end_time - start_time
    print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | elapsed time: {elapsed:.3f} s")

ip = get_ipython()
ip.events.register('pre_run_cell', pre_run_cell)
ip.events.register('post_run_cell', post_run_cell)

2025-11-23 20:01:17 | elapsed time: 0.008 s


# Canonical example

The field and the scattered wave function should be contained in $-x_b<x,y,z<x_b$. $n_x$ is the number of points in the $x$, $y$ and $z$ directions. $n_k$ ($\geq n_x$) is the number of points in the Fourier space.   

$-D_0\psi=(\gamma^0\gamma^k D_k+i\gamma^0)\psi$, where $D_\mu=\partial_\mu+iA_\mu$, so
$-\partial_t\psi=(iA_0+i\beta+\alpha^k\partial_k+iA_k\alpha^k)\psi$

In [9]:
# number of spatial dimensions
xDim=2

2025-11-23 20:01:23 | elapsed time: 0.000 s


In [9]:
xDim=3

2025-11-23 18:35:00 | elapsed time: 0.000 s


In [10]:
# simple gauge

gamma=1
E0=1/3
omega=gamma*E0
kappa=gamma*E0

def A0(t,x,y,z):
  return (E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x**2+y**2+z**2))

def A1(t,x,y,z):
  return 0

def A2(t,x,y,z):
  return 0

def A3(t,x,y,z):
  return 0

2025-11-23 20:01:27 | elapsed time: 0.000 s


In [11]:
xb=25; nx=128; nk=nx
tin=-10; tout=10

2025-11-23 20:01:29 | elapsed time: 0.000 s


In [None]:
# xb=25; nx=201; nk=303
# tin=-12; tout=10

In [None]:
# another gauge for checking gauge invariance

gamma=1
E0=1/3
omega=gamma*E0
kappa=gamma*E0

def A0(t,x,y,z):
  return (1+omega*omega*t)*(E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x**2+y**2+z**2))

def A1(t,x,y,z):
  return (kappa*kappa*x)*(E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x**2+y**2+z**2))

def A2(t,x,y,z):
  return (kappa*kappa*y)*(E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x**2+y**2+z**2))

def A3(t,x,y,z):
  return (kappa*kappa*z)*(E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x**2+y**2+z**2))

2025-11-23 18:58:05 | elapsed time: 0.001 s


# 4D spinors, x basis

In [12]:
#4D spin up and down along x

gDim=4

g0=np.array([[0,0,1,0],[0,0,0,1],[1,0,0,0],[0,1,0,0]])
g3=np.array([[0,0,0,1],[0,0,1,0],[0,-1,0,0],[-1,0,0,0]])
g2=np.array([[0,0,0,-1j],[0,0,1j,0],[0,1j,0,0],[-1j,0,0,0]])
g1=np.array([[0,0,1,0],[0,0,0,-1],[-1,0,0,0],[0,1,0,0]])
id=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])

g=[g0,g1,g2,g3]
eta=np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,-1]])

alpha1=g0@g1; alpha2=g0@g2; alpha3=g0@g3; beta=g0;

def uab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  return (P0*g0[a,b]+P1*g1[a,b]+P2*g2[a,b]+P3*g3[a,b]+id[a,b])/np.sqrt(2*P0*(P0+P1))

def vab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  return (-P0*g0[a,b]-P1*g1[a,b]-P2*g2[a,b]-P3*g3[a,b]+id[a,b])/np.sqrt(2*P0*(P0+P1))

2025-11-23 20:01:43 | elapsed time: 0.633 s


# 4D spinors, z basis

In [None]:
#4D spin up and down along z

gDim=4

g0=np.array([[0,0,1,0],[0,0,0,1],[1,0,0,0],[0,1,0,0]])
g1=np.array([[0,0,0,1],[0,0,1,0],[0,-1,0,0],[-1,0,0,0]])
g2=np.array([[0,0,0,-1j],[0,0,1j,0],[0,1j,0,0],[-1j,0,0,0]])
g3=np.array([[0,0,1,0],[0,0,0,-1],[-1,0,0,0],[0,1,0,0]])
id=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])

g=[g0,g1,g2,g3]
eta=np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,-1]])

alpha1=g0@g1; alpha2=g0@g2; alpha3=g0@g3; beta=g0;

def uab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  return (P0*g0[a,b]+P1*g1[a,b]+P2*g2[a,b]+P3*g3[a,b]+id[a,b])/np.sqrt(2*P0*(P0+P3))

def vab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  return (-P0*g0[a,b]-P1*g1[a,b]-P2*g2[a,b]-P3*g3[a,b]+id[a,b])/np.sqrt(2*P0*(P0+P3))

# 2D spinors

In [None]:
#2D

gDim=2

alpha1=np.array([[0,1],[1,0]])
alpha2=np.array([[0,-1j],[1j,0]])
alpha3=np.array([[0,0],[0,0]])
beta=np.array([[1,0],[0,-1]])
id=np.array([[1,0],[0,1]])

def uab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2)
  if a==0:
    return 0.j+(1+P0)/np.sqrt(2*P0*(1+P0))
  elif a==1:
    return (-P1-1j*P2)/np.sqrt(2*P0*(1+P0))

def vab(P1,P2,P3,a,b):
  P0=np.sqrt(1+P1*P1+P2*P2)
  if a==0:
    return np.where(P0==1.,0.+0.j,(-P1+1j*P2)*np.sqrt((P0-1)/(2*P0*(P1*P1+P2*P2))))
  elif a==1:
    return np.where(P0==1.,1.+0.j,0.j+np.sqrt((P1*P1+P2*P2)/(2*P0*(P0-1))))

2025-11-23 18:51:29 | elapsed time: 0.032 s


# Grid, ODE, projection and integral definitions

In [13]:
x0=-xb; x1=xb; dx=(x1-x0)/(nx-1)
xGrid=np.linspace(x0,x1,nx)
dk=2*np.pi/(nk*dx); kMax=dk*(nk-1)/2
kGrid=2*np.pi*np.fft.fftfreq(nk,dx)
nH=round((nx-1)/2)

psiOut=np.zeros(2*gDim*(nx**xDim),dtype=realType)

if xDim==1:
  X=xGrid; Y=0; Z=0;
  K1=kGrid; K2=0; K3=0;
elif xDim==2:
  X,Y=np.meshgrid(xGrid,xGrid,indexing='ij',sparse=True)
  Z=0
  K1,K2=np.meshgrid(kGrid,kGrid,indexing='ij',sparse=True)
  K3=0
elif xDim==3:
  X,Y,Z=np.meshgrid(xGrid,xGrid,xGrid,indexing='ij',sparse=True)
  K1,K2,K3=np.meshgrid(kGrid,kGrid,kGrid,indexing='ij',sparse=True)

Ushape=[gDim]+xDim*[nx]

pre2D=2/((2*np.pi)**5)

2025-11-23 20:01:54 | elapsed time: 0.401 s


The spatial derivatives, $\partial_i$, are calculated by Fourier transforming, multiplying by $k_i$, and then Fourier transforming back to $x^i$.
dPsi gives the right-hand side of the Dirac equation for the scattered wave, $\partial_t\psi_\text{scat.}=dPsi$. There is a minus sign because we integrate backwards in time.

In [14]:
yShape=[2,gDim]+xDim*[nx]

def makeDer(P1,P2,P3,spin,sign):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)

  if sign==1:
    uOut=[uab(P1,P2,P3,a,spin) for a in range(gDim)]
  elif sign==-1:
    uOut=[vab(P1,P2,P3,a,spin) for a in range(gDim)]

  if xDim==1:
    B=1j*beta-1j*sign*(P2*alpha2+P3*alpha3)
  elif xDim==2:
    B=1j*beta-1j*sign*(P3*alpha3)
  elif xDim==3:
    B=1j*beta

  @jax.jit
  def dPsi(t,y,args):
    y=y.reshape(yShape)
    dy = [np.zeros(xDim*[nx],dtype=compType) for _ in range(gDim)]
    A0ar=chop(A0(t,X,Y,Z))
    A1ar=chop(A1(t,X,Y,Z))
    A2ar=chop(A2(t,X,Y,Z))
    A3ar=chop(A3(t,X,Y,Z))

    for b in range(gDim):
      phi=y[0,b]+1j*y[1,b]
      phi2=phi+uOut[b]*np.exp(-1j*sign*(P0*t+P1*X+P2*Y+P3*Z))
      dy[b]=dy[b]+1j*A0ar*phi2
      der=np.fft.ifftn(phi)
      for a in range(gDim):
        dy[a]=dy[a]+B[a,b]*phi+np.fft.fftn(-1j*(alpha1[a,b]*K1+alpha2[a,b]*K2+alpha3[a,b]*K3)*der)
        dy[a]=dy[a]+1j*(alpha1[a,b]*A1ar+alpha2[a,b]*A2ar+alpha3[a,b]*A3ar)*phi2

    for a in range(gDim):
      dy[a]=-chop(dy[a])

    dy=np.concatenate(dy)
    return np.concatenate([np.real(dy),np.imag(dy)]).ravel()

  return dPsi

2025-11-23 20:01:54 | elapsed time: 0.001 s


$u_r^\dagger({\bf q})$ and $v_r^\dagger(-{\bf q})$ for projecting with $(U_{\rm in}[r,{\bf q}]|$ and $(V_{\rm in}[r,-{\bf q}]|$. $r=1,2$ is a spin index, which is irrelevant for the 2D spinors.

In [15]:
uFFT1=[uab(K1,-K2,K3,a,1)*np.exp(1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]
vFFT1=[vab(-K1,K2,-K3,a,1)*np.exp(-1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]

uFFT2=[uab(K1,-K2,K3,a,2)*np.exp(1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]
vFFT2=[vab(-K1,K2,-K3,a,2)*np.exp(-1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]

2025-11-23 20:01:55 | elapsed time: 1.112 s


Project $\psi_\text{scat.}=\Delta U$ and $\psi_\text{scat.}=\Delta V$ onto in states: $(U_{\rm in}|\Delta U)$, $(V_{\rm in}|\Delta U)$, $(U_{\rm in}|\Delta V)$ and $(V_{\rm in}|\Delta V)$. We can compute these using FFT. First, $f({\bf q}):=(U_{\rm in}[{\bf q}]|\Delta U)=U_{\rm in}^\dagger(r,{\bf q},t_{\rm in},{\bf x}\to 0)\Delta U({\bf q})$. Then, for the next step in the calculation of $N_1$, we express $f({\bf q})=:\int d^3{\bf x}\;e^{i{\bf q}\cdot{\bf x}}f({\bf x})$. U1U etc. below corresponds to $f({\bf x})$.

Solve the Dirac equation, project onto in states, and then perform the integrals in

$$
N_{1\text{vers1}}=\big|{}_m(U_\infty|\Delta V)_n+{}_m(\Delta U|U_\infty)(U_\infty|\Delta V)_n\big|_{t=t_{\rm in}}^2
$$
$$
N_{1\text{vers2}}=\big|{}_m(\Delta U|V_\infty)_n+{}_m(\Delta U|V_\infty)(V_\infty|\Delta V)_n\big|_{t=t_{\rm in}}^2
$$

We compute the integrals over ${\bf q}$ using
$$
\int\frac{d^3{\bf q}}{(2\pi)^3}f^\dagger({\bf q})g({\bf q})=\int d^3{\bf x}f^\dagger({\bf x})g({\bf x})
$$
where $f$ and $g$ are the functions calculated with projectU() and projectV().
We should find $N_{1\text{vers1}}=N_{1\text{vers2}}$ to within the numerical precision.

In [16]:
solver = Dopri5()
controller = PIDController(rtol=1e-5, atol=1e-10)

def doAll(p1,p2,p3,pp1,pp2,pp3):
  global Usol, Vsol, N1vers1, N1vers2, terms1, terms2

  UdPsi=makeDer(p1,p2,p3,uSpin,1)
  term=ODETerm(UdPsi)
  solutionU = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Usolr,Usoli=np.split(solutionU.ys[0],2)
  Usol=Usolr+1j*Usoli
  Usol=Usol.reshape(Ushape)

  VdPsi=makeDer(pp1,pp2,pp3,vSpin,-1)
  term=ODETerm(VdPsi)
  solutionV = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Vsolr,Vsoli=np.split(solutionV.ys[0],2)
  Vsol=Vsolr+1j*Vsoli
  Vsol=Vsol.reshape(Ushape)

  # projectU():

  f=[np.fft.ifftn(Usol[a]) for a in range(gDim)]

  U1U=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1U=np.fft.fftn(U1U)

  U2U=sum(uFFT2[a]*f[a] for a in range(gDim))
  U2U=np.fft.fftn(U2U)

  V1U=sum(vFFT1[a]*f[a] for a in range(gDim))
  V1U=np.fft.fftn(V1U)

  V2U=sum(vFFT2[a]*f[a] for a in range(gDim))
  V2U=np.fft.fftn(V2U)

  # projectV():

  f=[np.fft.ifftn(Vsol[a]) for a in range(gDim)]

  U1V=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1V=np.fft.fftn(U1V)

  U2V=sum(uFFT2[a]*f[a] for a in range(gDim))
  U2V=np.fft.fftn(U2V)

  V1V=sum(vFFT1[a]*f[a] for a in range(gDim))
  V1V=np.fft.fftn(V1V)

  V2V=sum(vFFT2[a]*f[a] for a in range(gDim))
  V2V=np.fft.fftn(V2V)

  UU1U1V=np.conjugate(U1U)*U1V
  UU2U2V=np.conjugate(U2U)*U2V
  UV1V1V=np.conjugate(V1U)*V1V
  UV2V2V=np.conjugate(V2U)*V2V
  U1Vdelta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U1V
  U2Vdelta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U2V
  V1Udelta=np.exp(1j*(pp1*X+pp2*Y+pp3*Z))*np.conjugate(V1U)
  V2Udelta=np.exp(1j*(pp1*X+pp2*Y+pp3*Z))*np.conjugate(V2U)

  for i in range(xDim):
    UU1U1V=simpson_jax(UU1U1V,dx)
    UU2U2V=simpson_jax(UU2U2V,dx)
    U1Vdelta=simpson_jax(U1Vdelta,dx)
    U2Vdelta=simpson_jax(U2Vdelta,dx)

    UV1V1V=simpson_jax(UV1V1V,dx)
    UV2V2V=simpson_jax(UV2V2V,dx)
    V1Udelta=simpson_jax(V1Udelta,dx)
    V2Udelta=simpson_jax(V2Udelta,dx)

  terms1=[U1Vdelta,U2Vdelta,UU1U1V,UU2U2V]
  terms2=[V1Udelta,V2Udelta,UV1V1V,UV2V2V]

  if uSpin==1:
    delta1=U1Vdelta
  elif uSpin==2:
    delta1=U2Vdelta

  if vSpin==1:
    delta2=V1Udelta
  elif vSpin==2:
    delta2=V2Udelta

  if gDim==4:
    secondSpin=1
  else:
    secondSpin=0

  N1vers1=np.abs(delta1+UU1U1V+secondSpin*UU2U2V)**2
  N1vers2=np.abs(delta2+UV1V1V+secondSpin*UV2V2V)**2

2025-11-23 20:01:55 | elapsed time: 0.002 s


In [17]:
@jax.jit
def doAllB(p1,p2,p3,pp1,pp2,pp3):

  UdPsi=makeDer(p1,p2,p3,uSpin,1)
  term=ODETerm(UdPsi)
  solutionU = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Usolr,Usoli=np.split(solutionU.ys[0],2)
  Usol=Usolr+1j*Usoli
  Usol=Usol.reshape(Ushape)

  VdPsi=makeDer(pp1,pp2,pp3,vSpin,-1)
  term=ODETerm(VdPsi)
  solutionV = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Vsolr,Vsoli=np.split(solutionV.ys[0],2)
  Vsol=Vsolr+1j*Vsoli
  Vsol=Vsol.reshape(Ushape)

  # projectU():

  f=[np.fft.ifftn(Usol[a]) for a in range(gDim)]

  U1U=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1U=np.fft.fftn(U1U)

  if gDim==4:
    U2U=sum(uFFT2[a]*f[a] for a in range(gDim))
    U2U=np.fft.fftn(U2U)

  # projectV():

  f=[np.fft.ifftn(Vsol[a]) for a in range(gDim)]

  U1V=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1V=np.fft.fftn(U1V)

  if gDim==4:
    U2V=sum(uFFT2[a]*f[a] for a in range(gDim))
    U2V=np.fft.fftn(U2V)

  if uSpin==1:
    delta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U1V
  elif uSpin==2:
    delta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U2V

  if gDim==4:
    UUUV=np.conjugate(U1U)*U1V+np.conjugate(U2U)*U2V
  elif gDim==2:
    UUUV=np.conjugate(U1U)*U1V

  for i in range(xDim):
    UUUV=simpson_jax(UUUV,dx)
    delta=simpson_jax(delta,dx)


  return np.abs(delta+UUUV)**2

2025-11-23 20:01:55 | elapsed time: 0.002 s


# Check definitions

In [18]:
P1=.35; P2=0; P3=0; delta1=0; delta2=0; delta3=0;

p1=-P1+delta1/2; p2=-P2+delta2/2; p3=-P3+delta3/2;
pp1=P1+delta1/2; pp2=P2+delta2/2; pp3=P3+delta3/2;
uSpin=1; vSpin=1;

2025-11-23 20:01:59 | elapsed time: 0.000 s


In [None]:
[kappa*dx,dk,kMax,nx,nk,xDim,gDim]

[0.13123359580052493, 0.1246819584393449, 7.9173043608984015, 128, 128, 1, 2]

2025-11-23 18:23:09 | elapsed time: 0.003 s


In [None]:
doAll(p1,p2,p3,pp1,pp2,pp3)

2025-11-23 18:23:13 | elapsed time: 3.047 s


In [None]:
[N1vers1,N1vers2,N1vers1/N1vers2-1]

[Array(0.00024684, dtype=float64),
 Array(0.00024685, dtype=float64),
 Array(-5.17398875e-05, dtype=float64)]

2025-11-23 18:23:13 | elapsed time: 0.070 s


In [None]:
doAllB(p1,p2,p3,pp1,pp2,pp3)

Array(0.00024684, dtype=float64)

2025-11-23 18:23:28 | elapsed time: 1.489 s


In [None]:
[kappa*dx,dk,kMax,nx,nk,xDim,gDim]

[0.13123359580052493, 0.1246819584393449, 7.9173043608984015, 128, 128, 2, 2]

2025-11-23 18:51:38 | elapsed time: 0.003 s


In [None]:
doAll(p1,p2,p3,pp1,pp2,pp3)

2025-11-23 18:51:46 | elapsed time: 3.821 s


In [None]:
[N1vers1,N1vers2,N1vers1/N1vers2-1]

[Array(0.00219057, dtype=float64),
 Array(0.00219047, dtype=float64),
 Array(4.70527193e-05, dtype=float64)]

2025-11-23 18:51:48 | elapsed time: 0.074 s


In [None]:
doAllB(p1,p2,p3,pp1,pp2,pp3)

Array(0.00219057, dtype=float64)

2025-11-23 18:51:52 | elapsed time: 2.140 s


check gauge invariance for 2+1:

In [None]:
print(A1(0,1,0,0))

0.0994265907571522
2025-11-23 19:01:44 | elapsed time: 0.002 s


In [None]:
[kappa*dx,dk,kMax,nx,nk,xDim,gDim]

[0.13123359580052493, 0.1246819584393449, 7.9173043608984015, 128, 128, 2, 4]

2025-11-23 18:59:18 | elapsed time: 0.002 s


In [None]:
doAllB(p1,p2,p3,pp1,pp2,pp3)

Array(0.00219116, dtype=float64)

2025-11-23 18:59:23 | elapsed time: 3.846 s


In [None]:
0.00219116/0.00219057-1

0.00026933629146741467

2025-11-23 18:59:46 | elapsed time: 0.002 s


In [21]:
[kappa*dx,dk,kMax,nx,nk,xDim,gDim]

[0.13123359580052493, 0.1246819584393449, 7.9173043608984015, 128, 128, 3, 4]

2025-11-23 18:35:57 | elapsed time: 0.002 s


In [22]:
doAll(p1,p2,p3,pp1,pp2,pp3)

2025-11-23 18:36:15 | elapsed time: 16.864 s


In [23]:
[N1vers1,N1vers2,N1vers1/N1vers2-1]

[Array(0.01747039, dtype=float64),
 Array(0.01747064, dtype=float64),
 Array(-1.47275488e-05, dtype=float64)]

2025-11-23 18:36:21 | elapsed time: 0.088 s


In [24]:
doAllB(p1,p2,p3,pp1,pp2,pp3)

Array(0.01747039, dtype=float64)

2025-11-23 18:36:50 | elapsed time: 20.918 s


# 1+1

In [None]:
pmin=-1.2; pmax=1.2; dp=(pmax-pmin)/100.
pList=np.arange(pmin,pmax+dp,dp)

2025-11-23 18:23:37 | elapsed time: 0.011 s


parallel-sequential: parallelization for solving the Dirac equation, but running over momentum grid points sequentially

In [None]:
%%time
tab = np.array([doAllB(p, p2, p3, -p, pp2, pp3) for p in tqdm(pList)])

100%|██████████| 101/101 [00:27<00:00,  3.69it/s]

CPU times: user 19.8 s, sys: 7.57 s, total: 27.3 s
Wall time: 27.4 s
2025-11-23 18:25:12 | elapsed time: 26.102 s





parallel-parallel: run over momentum grid points in parallel - much, much faster

In [None]:
%%time
batched_doAll = jax.jit(jax.vmap(lambda p: doAllB(p,p2,p3,-p,pp2,pp3)))
tabB = batched_doAll(pList)

CPU times: user 1.52 s, sys: 220 ms, total: 1.74 s
Wall time: 1.85 s
2025-11-23 18:49:02 | elapsed time: 1.852 s


In [None]:
np.max(np.abs(tabB/tab-1))

Array(4.47997771e-07, dtype=float64)

2025-11-23 18:49:15 | elapsed time: 0.340 s


In [None]:
pre1D=2/((2*np.pi)**4)

2025-11-23 18:49:21 | elapsed time: 0.000 s


In [None]:
inter1=make_interp_spline(pList,pre1D*tab, k=3)
inter2=make_interp_spline(pList,pre1D*tabB, k=3)

pDense=np.linspace(pList.min(), pList.max(), 300)
dense1=inter1(pDense)
dense2=inter2(pDense)

plt.plot(pDense,dense1,label='N1vers1')
plt.plot(pDense,dense2,linestyle="--",label='N1vers2')
plt.legend()
plt.show()

2D $(p_1,p_1')$ grid

In [None]:
%%time

def inner(p1):
    return jax.vmap(lambda pp1: doAllB(p1, p2, p3, pp1, pp2, pp3))(pList)

batched_doAll2D = jax.jit(jax.vmap(inner))

tab2D = batched_doAll2D(pList)

CPU times: user 1.52 s, sys: 261 ms, total: 1.78 s
Wall time: 1.68 s
2025-11-23 18:49:39 | elapsed time: 1.679 s


It took less than 2 seconds to compute this 100*100 grid!

In [None]:
inter2D=RegularGridInterpolator([pList,pList],pre1D*tab2D,method='cubic',bounds_error=False, fill_value=None)

2025-11-23 18:50:14 | elapsed time: 0.066 s


In [None]:
check=np.array([inter2D((p,-p)) for p in pList])

2025-11-23 18:50:15 | elapsed time: 0.215 s


In [None]:
inter1=make_interp_spline(pList,pre1D*tab, k=3)
inter2=make_interp_spline(pList,check, k=3)

pDense=np.linspace(pList.min(), pList.max(), 300)
dense1=inter1(pDense)
dense2=inter2(pDense)

plt.plot(pDense,dense1,label='1')
plt.plot(pDense,dense2,linestyle="--",label='2')
plt.legend()
plt.show()

In [None]:
npDense=300
pDense=np.linspace(pList.min(), pList.max(),npDense)
p1d, pp1d = np.meshgrid(pDense,pDense, indexing='ij')
p1pp1d = np.array([p1d.ravel(), pp1d.ravel()]).T
N2Ddense=inter2D(p1pp1d).reshape(npDense,npDense)

2025-11-23 18:50:21 | elapsed time: 0.209 s


In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(N2Ddense.T, aspect='auto', origin='lower',
           extent=[pList[0], pList[-1], pList[0], pList[-1]],
           cmap='viridis')
plt.xlabel('p1')
plt.ylabel('pp1')
plt.colorbar()
plt.show()

# 2+1

In [None]:
pmin=-1.2; pmax=1.2; dp=(pmax-pmin)/100.
pList=np.arange(pmin,pmax+dp,dp)

2025-11-23 18:52:03 | elapsed time: 0.012 s


parallel-sequential: parallelization for solving the Dirac equation, but running over momentum grid points sequentially

In [None]:
%%time
tab = np.array([doAllB(p, p2, p3, -p, pp2, pp3) for p in tqdm(pList)])

100%|██████████| 101/101 [01:06<00:00,  1.53it/s]

CPU times: user 52.7 s, sys: 13.6 s, total: 1min 6s
Wall time: 1min 6s
2025-11-23 18:53:17 | elapsed time: 64.272 s





parallel-parallel: run over momentum grid points in parallel - much faster

In [None]:
%%time
batched_doAll = jax.jit(jax.vmap(lambda p: doAllB(p,p2,p3,-p,pp2,pp3)))
tabB = batched_doAll(pList)

CPU times: user 17.8 s, sys: 299 ms, total: 18.1 s
Wall time: 19.1 s
2025-11-23 18:53:59 | elapsed time: 17.905 s


In [None]:
np.max(np.abs(tabB/tab-1))

Array(5.24525281e-08, dtype=float64)

2025-11-23 18:54:11 | elapsed time: 0.146 s


In [23]:
pre2D=2/((2*np.pi)**5)

2025-11-23 20:06:00 | elapsed time: 0.000 s


In [None]:
inter1=make_interp_spline(pList,pre2D*tab, k=3)
inter2=make_interp_spline(pList,pre2D*tabB, k=3)

pDense=np.linspace(pList.min(), pList.max(), 300)
dense1=inter1(pDense)
dense2=inter2(pDense)

plt.plot(pDense,dense1,label='N1vers1')
plt.plot(pDense,dense2,linestyle="--",label='N1vers2')
plt.legend()
plt.show()

2D $(p_1,p_1')$ grid

In [None]:
%%time

def inner(p1):
    return jax.vmap(lambda pp1: doAllB(p1, p2, p3, pp1, pp2, pp3))(pList)

batched_doAll2D = jax.jit(jax.vmap(inner))

tab2D = batched_doAll2D(pList)

CPU times: user 17.9 s, sys: 296 ms, total: 18.2 s
Wall time: 17.8 s
2025-11-23 18:54:52 | elapsed time: 17.850 s


It took less than 20 seconds to compute this 100*100 grid!

In [None]:
inter2D=RegularGridInterpolator([pList,pList],pre2D*tab2D,method='cubic',bounds_error=False, fill_value=None)

2025-11-23 18:55:11 | elapsed time: 0.062 s


In [None]:
check=np.array([inter2D((p,-p)) for p in pList])

2025-11-23 18:55:13 | elapsed time: 0.206 s


In [None]:
inter1=make_interp_spline(pList,pre2D*tab, k=3)
inter2=make_interp_spline(pList,check, k=3)

pDense=np.linspace(pList.min(), pList.max(), 300)
dense1=inter1(pDense)
dense2=inter2(pDense)

plt.plot(pDense,dense1,label='1')
plt.plot(pDense,dense2,linestyle="--",label='2')
plt.legend()
plt.show()

In [None]:
npDense=300
pDense=np.linspace(pList.min(), pList.max(),npDense)
p1d, pp1d = np.meshgrid(pDense,pDense, indexing='ij')
p1pp1d = np.array([p1d.ravel(), pp1d.ravel()]).T
N2Ddense=inter2D(p1pp1d).reshape(npDense,npDense)

plt.figure(figsize=(8, 6))
plt.imshow(N2Ddense.T, aspect='auto', origin='lower',
           extent=[pList[0], pList[-1], pList[0], pList[-1]],
           cmap='viridis')
plt.xlabel('p1')
plt.ylabel('pp1')
plt.colorbar()
plt.show()

In [20]:
pmin=-1; pmax=1; dp=(pmax-pmin)/50.
pList=np.arange(pmin,pmax+dp,dp)

2025-11-23 20:04:18 | elapsed time: 0.014 s


In [22]:
batched_doAll = jax.jit(jax.vmap(lambda p: doAllB(p1,p,p3,pp1,-p,pp3)))
tabB = batched_doAll(pList)

2025-11-23 20:05:51 | elapsed time: 20.679 s


In [None]:
inter1=make_interp_spline(pList,pre2D*tabB, k=3)

pDense=np.linspace(pList.min(), pList.max(), 200)
dense1=inter1(pDense)

plt.plot(pDense,dense1)
plt.grid(True)

GPU on laptop:

In [None]:
import cupy as cp

cp.cuda.runtime.getDeviceProperties(0)['name'].decode()

'NVIDIA GeForce RTX 5070 Laptop GPU'

2025-11-23 19:32:44 | elapsed time: 0.002 s


# 3+1

Not enough memory on laptop's GPU. Switch to Colab:

In [33]:
import cupy as cp

cp.cuda.runtime.getDeviceProperties(0)['name'].decode()

'NVIDIA A100-SXM4-80GB'

2025-11-23 18:46:31 | elapsed time: 0.890 s


18:46:31 UTC means 19:46:31 CET

In [25]:
pmin=0; pmax=1.2; dp=(pmax-pmin)/20.
pList=np.arange(pmin,pmax+dp,dp)

2025-11-23 18:37:20 | elapsed time: 0.019 s


In [26]:
%%time
batched_doAll = jax.jit(jax.vmap(lambda p: doAllB(p,p2,p3,-p,pp2,pp3)))
tabB = batched_doAll(pList)

CPU times: user 3min 34s, sys: 3.61 s, total: 3min 37s
Wall time: 3min 38s
2025-11-23 18:41:05 | elapsed time: 218.606 s


In [27]:
pre4D=2/((2*np.pi)**6)

2025-11-23 18:43:23 | elapsed time: 0.000 s


In [None]:
inter1 = make_interp_spline(pList, pre4D*tabB, k=3)
pDense = np.linspace(pList.min(), pList.max(), 100)
dense1 = inter1(pDense)

plt.plot(pDense, dense1)
plt.scatter(pList, pre4D*tabB, s=40, marker='o')

plt.xlabel('p')
plt.grid(True)
plt.show()

# Checking spin definitions, z basis

In [None]:
p1=-.35; p2=.11; p3=.2; p0=np.sqrt(1+p1*p1+p2*p2+p3*p3)
pp1=-p1; pp2=-p2; pp3=-p3; pp0=np.sqrt(1+pp1*pp1+pp2*pp2+pp3*pp3)

In [None]:
ut1=np.array([uab(p1,p2,p3,a,1) for a in range(gDim)])
ut2=np.array([uab(p1,p2,p3,a,2) for a in range(gDim)])
vt1=np.array([vab(pp1,pp2,pp3,a,1) for a in range(gDim)])
vt2=np.array([vab(pp1,pp2,pp3,a,2) for a in range(gDim)])

ut=[ut1,ut2,vt1,vt2]

In [None]:
for i in range(4):
  for j in range(4):
    if i==j:
      print(chop(np.abs(np.conjugate(ut[i])@ut[j]-1)))
    else:
      print(chop(np.abs(np.conjugate(ut[i])@ut[j])))

In [None]:
chop(np.abs((p0*g0+p1*g1+p2*g2+p3*g3-id)@ut1))

Array([0., 0., 0., 0.], dtype=float64)

In [None]:
chop(np.abs((pp0*g0+pp1*g1+pp2*g2+pp3*g3+id)@vt2))

Array([0., 0., 0., 0.], dtype=float64)

In [None]:
for i in range(4):
  for j in range(4):
    if i==j:
      print(chop((g[i]@g[j]+g[j]@g[i]-2*id*eta[i,j])).ravel())
    else:
      print(chop((g[i]@g[j]+g[j]@g[i])).ravel())

In [None]:
for i in range(4):
  print(g[i].conj().T-g0@g[i]@g0)

# Padding: Grid, ODE, projection and integral definitions

In [None]:
x0=-xb; x1=xb; dx=(x1-x0)/(nx-1)
xGrid=np.linspace(x0,x1,nx)
dk=2*np.pi/(nk*dx); kMax=dk*(nk-1)/2
kGrid=2*np.pi*np.fft.fftfreq(nk,dx)
nPad=(nk-nx)//2
nH=round((nx-1)/2)

psiOut=np.zeros(2*gDim*(nx**xDim),dtype=realType)

if xDim==1:
  X=xGrid; Y=0; Z=0;
  K1=kGrid; K2=0; K3=0;
elif xDim==2:
  X,Y=np.meshgrid(xGrid,xGrid,indexing='ij')
  Z=0
  K1,K2=np.meshgrid(kGrid,kGrid,indexing='ij')
  K3=0
elif xDim==3:
  X,Y,Z=np.meshgrid(xGrid,xGrid,xGrid,indexing='ij')
  K1,K2,K3=np.meshgrid(kGrid,kGrid,kGrid,indexing='ij')

Ushape=[gDim]+xDim*[nx]

The spatial derivatives, $\partial_i$, are calculated by Fourier transforming, multiplying by $k_i$, and then Fourier transforming back to $x^i$.

dPsi gives the right-hand side of the Dirac equation for the scattered wave, $\partial_t\psi_\text{scat.}=dPsi$. There is a minus sign because we integrate backwards in time.

In [None]:
yShape=[2,gDim]+xDim*[nx]
unPad=tuple(slice(nPad,-nPad) for _ in range(xDim))

def makeDer(P1,P2,P3,spin,sign):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)

  if sign==1:
    uOut=[uab(P1,P2,P3,a,spin) for a in range(gDim)]
  elif sign==-1:
    uOut=[vab(P1,P2,P3,a,spin) for a in range(gDim)]

  if xDim==1:
    B=1j*beta-1j*sign*(P2*alpha2+P3*alpha3)
  elif xDim==2:
    B=1j*beta-1j*sign*(P3*alpha3)
  elif xDim==3:
    B=1j*beta

  @jax.jit
  def dPsi(t,y,args):
    y=y.reshape(yShape)
    dy = [np.zeros(xDim*[nx],dtype=compType) for _ in range(gDim)]
    A0ar=chop(A0(t,X,Y,Z))
    A1ar=chop(A1(t,X,Y,Z))
    A2ar=chop(A2(t,X,Y,Z))
    A3ar=chop(A3(t,X,Y,Z))

    for b in range(gDim):
      phi=y[0,b]+1j*y[1,b]
      phi2=phi+uOut[b]*np.exp(-1j*sign*(P0*t+P1*X+P2*Y+P3*Z))
      dy[b]=dy[b]+1j*A0ar*phi2
      der=np.fft.ifftn(np.pad(phi,nPad))
      for a in range(gDim):
        dy[a]=dy[a]+B[a,b]*phi+np.fft.fftn(-1j*(alpha1[a,b]*K1+alpha2[a,b]*K2+alpha3[a,b]*K3)*der)[unPad]
        dy[a]=dy[a]+1j*(alpha1[a,b]*A1ar+alpha2[a,b]*A2ar+alpha3[a,b]*A3ar)*phi2

    for a in range(gDim):
      dy[a]=-chop(dy[a])

    dy=np.concatenate(dy)
    return np.concatenate([np.real(dy),np.imag(dy)]).ravel()

  return dPsi

$u_r^\dagger({\bf q})$ and $v_r^\dagger(-{\bf q})$ for projecting with $(U_{\rm in}[r,{\bf q}]|$ and $(V_{\rm in}[r,-{\bf q}]|$. $r=1,2$ is a spin index, which is irrelevant for the 2D spinors.

In [None]:
uFFT1=[uab(K1,-K2,K3,a,1)*np.exp(1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]
vFFT1=[vab(-K1,K2,-K3,a,1)*np.exp(-1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]

uFFT2=[uab(K1,-K2,K3,a,2)*np.exp(1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]
vFFT2=[vab(-K1,K2,-K3,a,2)*np.exp(-1j*np.sqrt(1+K1*K1+K2*K2+K3*K3)*tin) for a in range(gDim)]

Project $\psi_\text{scat.}=\Delta U$ and $\psi_\text{scat.}=\Delta V$ onto in states: $(U_{\rm in}|\Delta U)$, $(V_{\rm in}|\Delta U)$, $(U_{\rm in}|\Delta V)$ and $(V_{\rm in}|\Delta V)$. We can compute these using FFT. First, $f({\bf q}):=(U_{\rm in}[{\bf q}]|\Delta U)=U_{\rm in}^\dagger(r,{\bf q},t_{\rm in},{\bf x}\to 0)\Delta U({\bf q})$. Then, for the next step in the calculation of $N_1$, we express $f({\bf q})=:\int d^3{\bf x}\;e^{i{\bf q}\cdot{\bf x}}f({\bf x})$. U1U etc. below corresponds to $f({\bf x})$.

In [None]:
def projectU():
  global U1U, V1U, U2U, V2U

  f=[np.pad(Usol[a],nPad) for a in range(gDim)]
  f=[np.fft.ifftn(f[a]) for a in range(gDim)]

  U1U=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1U=np.fft.fftn(U1U)
  U1U=U1U[unPad]

  U2U=sum(uFFT2[a]*f[a] for a in range(gDim))
  U2U=np.fft.fftn(U2U)
  U2U=U2U[unPad]

  V1U=sum(vFFT1[a]*f[a] for a in range(gDim))
  V1U=np.fft.fftn(V1U)
  V1U=V1U[unPad]

  V2U=sum(vFFT2[a]*f[a] for a in range(gDim))
  V2U=np.fft.fftn(V2U)
  V2U=V2U[unPad]

def projectV():
  global U1V, V1V, U2V, V2V

  f=[np.pad(Vsol[a],nPad) for a in range(gDim)]
  f=[np.fft.ifftn(f[a]) for a in range(gDim)]

  U1V=sum(uFFT1[a]*f[a] for a in range(gDim))
  U1V=np.fft.fftn(U1V)
  U1V=U1V[unPad]

  U2V=sum(uFFT2[a]*f[a] for a in range(gDim))
  U2V=np.fft.fftn(U2V)
  U2V=U2V[unPad]

  V1V=sum(vFFT1[a]*f[a] for a in range(gDim))
  V1V=np.fft.fftn(V1V)
  V1V=V1V[unPad]

  V2V=sum(vFFT2[a]*f[a] for a in range(gDim))
  V2V=np.fft.fftn(V2V)
  V2V=V2V[unPad]

Solve the Dirac equation, project onto in states, and then perform the integrals in

$$
N_{1\text{vers1}}=\big|{}_m(U_\infty|\Delta V)_n+{}_m(\Delta U|U_\infty)(U_\infty|\Delta V)_n\big|_{t=t_{\rm in}}^2
$$
$$
N_{1\text{vers2}}=\big|{}_m(\Delta U|V_\infty)_n+{}_m(\Delta U|V_\infty)(V_\infty|\Delta V)_n\big|_{t=t_{\rm in}}^2
$$

We compute the integrals over ${\bf q}$ using
$$
\int\frac{d^3{\bf q}}{(2\pi)^3}f^\dagger({\bf q})g({\bf q})=\int d^3{\bf x}f^\dagger({\bf x})g({\bf x})
$$
where $f$ and $g$ are the functions calculated with projectU() and projectV().
We should find $N_{1\text{vers1}}=N_{1\text{vers2}}$ to within the numerical precision.

In [None]:
def finalInts():

  global N1vers1, N1vers2, terms1, terms2

  projectU()
  projectV()

  UU1U1V=np.conjugate(U1U)*U1V
  UU2U2V=np.conjugate(U2U)*U2V
  UV1V1V=np.conjugate(V1U)*V1V
  UV2V2V=np.conjugate(V2U)*V2V
  U1Vdelta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U1V
  U2Vdelta=np.exp(1j*(p1*X+p2*Y+p3*Z))*U2V
  V1Udelta=np.exp(1j*(pp1*X+pp2*Y+pp3*Z))*np.conjugate(V1U)
  V2Udelta=np.exp(1j*(pp1*X+pp2*Y+pp3*Z))*np.conjugate(V2U)

  for i in range(xDim):
    UU1U1V=simpson(UU1U1V,xGrid)
    UU2U2V=simpson(UU2U2V,xGrid)
    U1Vdelta=simpson(U1Vdelta,xGrid)
    U2Vdelta=simpson(U2Vdelta,xGrid)

    UV1V1V=simpson(UV1V1V,xGrid)
    UV2V2V=simpson(UV2V2V,xGrid)
    V1Udelta=simpson(V1Udelta,xGrid)
    V2Udelta=simpson(V2Udelta,xGrid)

  terms1=[U1Vdelta,U2Vdelta,UU1U1V,UU2U2V]
  terms2=[V1Udelta,V2Udelta,UV1V1V,UV2V2V]

  if uSpin==1:
    delta1=U1Vdelta
  elif uSpin==2:
    delta1=U2Vdelta

  if vSpin==1:
    delta2=V1Udelta
  elif vSpin==2:
    delta2=V2Udelta

  if gDim==4:
    secondSpin=1
  else:
    secondSpin=0

  N1vers1=np.abs(delta1+UU1U1V+secondSpin*UU2U2V)**2
  N1vers2=np.abs(delta2+UV1V1V+secondSpin*UV2V2V)**2

In [None]:
solver = Dopri5()
controller = PIDController(rtol=1e-5, atol=1e-10)

def doAll():
  global Usol, Vsol, N1vers1, N1vers2, terms1, terms2

  UdPsi=makeDer(p1,p2,p3,uSpin,1)
  term=ODETerm(UdPsi)
  solutionU = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Usolr,Usoli=np.split(solutionU.ys[0],2)
  Usol=Usolr+1j*Usoli
  Usol=Usol.reshape(Ushape)

  VdPsi=makeDer(pp1,pp2,pp3,vSpin,-1)
  term=ODETerm(VdPsi)
  solutionV = diffeqsolve(term,solver,t0=tout,t1=tin,dt0=-1e-2,y0=psiOut,
    stepsize_controller=controller)

  Vsolr,Vsoli=np.split(solutionV.ys[0],2)
  Vsol=Vsolr+1j*Vsoli
  Vsol=Vsol.reshape(Ushape)

  finalInts()