In [1]:
%%capture output
%pip install diffrax

In [2]:
import jax.numpy as np
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, PIDController
import matplotlib.pyplot as plt
from scipy.integrate import simpson
from scipy.interpolate import make_interp_spline
from tqdm import tqdm

# Definitions

In [3]:
g0=np.array([[0,0,1,0],[0,0,0,1],[1,0,0,0],[0,1,0,0]])
g1=np.array([[0,0,0,1],[0,0,1,0],[0,-1,0,0],[-1,0,0,0]])
g2=np.array([[0,0,0,-1j],[0,0,1j,0],[0,1j,0,0],[-1j,0,0,0]])
g3=np.array([[0,0,1,0],[0,0,0,-1],[-1,0,0,0],[0,1,0,0]])
id=np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])

g=[g0,g1,g2,g3]
eta=np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,-1]])

a1=g0@g1
a2=g0@g2
a3=g0@g3

In [4]:
R1=np.array([0,1,0,0])
R2=np.array([0,0,1,0])

def u4D(P1,P2,P3,spin):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  if spin==1:
    R=R1
  elif spin==2:
    R=R2
  return (P0*g0+P1*g1+P2*g2+P3*g3+id)@R/np.sqrt(2*P0*(P0+P3))

def v4D(P1,P2,P3,spin):
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)
  if spin==1:
    R=R1
  elif spin==2:
    R=R2
  return (-P0*g0-P1*g1-P2*g2-P3*g3+id)@R/np.sqrt(2*P0*(P0+P3))

In [5]:
p1=-.35; p2=0; p3=0; p0=np.sqrt(1+p1*p1+p2*p2+p3*p3)
pp1=-p1; pp2=0; pp3=0; pp0=np.sqrt(1+pp1*pp1+pp2*pp2+pp3*pp3)

In [6]:
gamma=1
E0=1/3
omega=gamma*E0
kappa=gamma*E0

zb=22
nz=201
nk=401

tin=-10
tout=10

chop=1e-7

def A0(t,x,y,z):
  return (E0/kappa)*np.exp(-(omega*t)**2-kappa*kappa*(x*x+y*y+z*z))

def At(t):
  return A0(t,0,0,0)/A0(0,0,0,0)

In [7]:
z0=-zb
z1=zb
dz=(z1-z0)/(nz-1)
zGrid=np.linspace(z0,z1,nz)
dk=2*np.pi/(nk*dz)
kMax=dk*(nk-1)/2
kGrid=dk*np.linspace(-(nk-1)/2,(nk-1)/2,nk)
kFFT=2*np.pi*np.fft.fftfreq(nk,dz)
kFFTT=kFFT.reshape(-1,1)
nPad=(nk-nz)//2

psiOut=np.zeros(2*4*nz*nz*nz,dtype=np.float32)

X,Y,Z=np.meshgrid(zGrid,zGrid,zGrid,indexing='ij')
K1,K2,K3=np.meshgrid(kFFT,kFFT,kFFT,indexing='ij')

AzGrid=A0(0,X,Y,Z)
AzGrid=np.where(np.abs(AzGrid)<chop,0.+0.j,AzGrid)

xA=np.zeros((nz,nz,nz),dtype=np.complex64)
inho=xA
kA1=np.zeros((nk,nk,nk),dtype=np.complex64)
kA2=kA1

xDer=xA
yDer=xA
zDer=xA

In [8]:
def fDer(phi):
  global kA1, kA2, xDer, yDer, zDer

  kA1=np.pad(phi,nPad)
  kA1=np.fft.ifftn(kA1)

  kA2=-1j*K1*kA1
  kA2=np.fft.fftn(kA2)
  xDer=kA2[nPad:-nPad,nPad:-nPad,nPad:-nPad]
  xDer=np.where(np.abs(xDer)<chop,0.+0.j,xDer)

  kA2=-1j*K2*kA1
  kA2=np.fft.fftn(kA2)
  yDer=kA2[nPad:-nPad,nPad:-nPad,nPad:-nPad]
  yDer=np.where(np.abs(yDer)<chop,0.+0.j,yDer)

  kA2=-1j*K3*kA1
  kA2=np.fft.fftn(kA2)
  zDer=kA2[nPad:-nPad,nPad:-nPad,nPad:-nPad]
  zDer=np.where(np.abs(zDer)<chop,0.+0.j,zDer)

In [16]:
def makeDer(P1,P2,P3,spin,sign):
  global xA, kA1, kA2, xDer, yDer, zDer
  P0=np.sqrt(1+P1*P1+P2*P2+P3*P3)

  if sign==1:
    uOut=u4D(P1,P2,P3,spin)
  elif sign==-1:
    uOut=v4D(P1,P2,P3,spin)

  inho=A0(0,X,Y,Z)*np.exp(-1j*sign*(P1*X+P2*Y+P3*Z))
  inho=np.where(np.abs(xA)<chop,0.+0.j,inho)

  def dPsi(mt,y,args):
    global xA, kA1, kA2, xDer, yDer, zDer
    t=-mt

    dy=[np.zeros((nz,nz,nz),dtype=np.complex64) for i in range(4)]

    for a in range(4):
      offRe=a*(nz**3)
      offIm=(4+a)*(nz**3)
      xA=(y[offRe:offRe+nz**3]+1j*y[offIm:offIm+nz**3]).reshape((nz,nz,nz))
      fDer(xA)
      for b in range(4):
        dy[b]=dy[b]+1j*(a1[b,a]*xDer+a2[b,a]*yDer+a3[b,a]*zDer-g0[b,a]*xA)
        if a==b:
          dy[b]=dy[b]-1j*At(t)*(AzGrid*xA+uOut[a]*np.exp(-1j*sign*P0*t)*inho)
      dy[b]=np.where(np.abs(dy[b])<chop,0.+0.j,dy[b])

    return np.array([[np.real(dy[a]) for a in range(4)],[np.imag(dy[a]) for a in range(4)]]).ravel()

  return dPsi

# Computation

In [17]:
UdPsi=makeDer(p1,p2,p3,1,1)

In [18]:
test=UdPsi(0,psiOut,0)

In [19]:
test.shape

(64964808,)

In [105]:
2*4*nz*nz*nz

64964808

In [21]:
solver = Dopri5()
controller = PIDController(rtol=1e-5, atol=1e-10)

UdPsi=makeDer(p1,p2,p3,1,1)
term = ODETerm(UdPsi)
solutionU = diffeqsolve(term,solver,t0=tin,t1=tout,dt0=1e-3,y0=psiOut,
  stepsize_controller=controller,saveat=SaveAt(t1=True)
)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9044783600 bytes.

In [20]:
p1

-0.35

# Checking definitions

In [None]:
for i in range(4):
  for j in range(4):
    print(np.abs(g[i]@g[j]+g[j]@g[i]-2*id*eta[i,j]).ravel())

In [None]:
for vec1 in [u4D(p1,p2,p3,1),u4D(p1,p2,p3,2),v4D(pp1,pp2,pp3,1),v4D(pp1,pp2,pp3,2)]:
  for vec2 in [u4D(p1,p2,p3,1),u4D(p1,p2,p3,2),v4D(pp1,pp2,pp3,1),v4D(pp1,pp2,pp3,2)]:
    print(np.abs(np.dot(vec1,vec2)))

In [None]:
np.max(np.abs((p0*g0+p1*g1+p2*g2+p3*g3-id)@u4D(p1,p2,p3,2)))

Array(5.9604645e-08, dtype=float32)

In [None]:
np.max(np.abs((pp0*g0+pp1*g1+pp2*g2+pp3*g3+id)@v4D(pp1,pp2,pp3,1)))

Array(5.9604645e-08, dtype=float32)