In [8]:
import numpy as np
from numpy import expand_dims as ed
from poenta.jitted import C_mu_Sigma2,dC_dmu_dSigma2
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:

def R_matrix2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi, old_state):
    batch, cutoff, _ = old_state.shape
    dtype = old_state.dtype
    
    C, mu ,Sigma = C_mu_Sigma2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi)

    sqrt = np.sqrt(np.arange(cutoff, dtype = dtype))
    sqrtT = sqrt.reshape(-1, 1)

    R = np.zeros((batch, cutoff, cutoff, cutoff+1, cutoff+1), dtype = dtype)
    G_00pq = np.zeros((cutoff, cutoff+1), dtype = dtype)
    
    
    #G_mn00
    G_00pq[0,0] = C
    for q in range(1, cutoff):
        G_00pq[0,q] = (mu[3]*G_00pq[0,q-1] - Sigma[3,3]*sqrt[q-1]*G_00pq[0,q-2])/sqrt[q]

    
    for p in range(1,cutoff):
        for q in range(0,cutoff):
            G_00pq[p,q] = (mu[2]*G_00pq[p-1,q] - Sigma[2,2]*sqrt[p-1]*G_00pq[p-2,q] - Sigma[2,3]*sqrt[q]*G_00pq[p-1,q-1])/sqrt[p]
              
#     for p in range(1,cutoff):
#         G_00pq[p,:-1] = (mu[2]*G_00pq[p-1,:-1] - Sigma[2,2]*sqrt[p-1]*G_00pq[p-2,:-1] - Sigma[2,3]*G_00pq[p-1,1:])/sqrt[p]


    # R_00^jk = a_dagger^j \G_00pq> b^k  * |old_state>
    G_00pq2 = G_00pq[:,:-1]
    for j in range(cutoff):
        G_00pq3 = G_00pq2
        for k in range(cutoff):
            R[0,0,j,k] = np.sum(G_00pq3*old_state[:,j:,k:]) #!!here only works for batch = 1, if not the np.sum needs to be changed axis=-1 for several times = batch.
            G_00pq3 = G_00pq3[:,:-1]*sqrt[k+1:]
        G_00pq2 = sqrtT[j+1:]*G_00pq2[:-1,:]

    #R_0n^jk
    
#     for n in range(1,cutoff):
#         for k in range(0,cutoff):
#             for j in range(0,cutoff):
#                 R[:,0,n,j,k] = mu[1]/sqrt[n]*R[:,0,n-1,j,k] - Sigma[1,1]/sqrt[n]*sqrt[n-1]*R[:,0,n-2,j,k] - Sigma[1,2]/sqrt[n]*R[:,0,n-1,j+1,k] - Sigma[1,3]/sqrt[n]*R[:,0,n-1,j,k+1]
   
    for n in range(1,cutoff):
        R[:,0,n,:-1,:-1] = mu[1]/sqrt[n]*R[:,0,n-1,:-1,:-1] - Sigma[1,1]/sqrt[n]*sqrt[n-1]*R[:,0,n-2,:-1,:-1] - Sigma[1,2]/sqrt[n]*R[:,0,n-1,1:,:-1] - Sigma[1,3]/sqrt[n]*R[:,0,n-1,:-1,1:]


    #R_mn^jk
    for m in range(1,cutoff):
        for n in range(0,cutoff):
            for j in range(0,cutoff-m):
                for k in range(0,cutoff-m-j):
                    R[:,m,n,j,k] = mu[0]/sqrt[m]*R[:,m-1,n,j,k] - Sigma[0,0]/sqrt[m]*sqrt[m-1]*R[:,m-2,n,j,k] - Sigma[0,1]*sqrt[n]/sqrt[m]*R[:,m-1,n-1,j,k] - Sigma[0,2]/sqrt[m]*R[:,m-1,n,j+1,k] - Sigma[0,3]/sqrt[m]*R[:,m-1,n,j,k+1]   
    return R

In [16]:
old_state = np.ones((1,2,2),dtype=np.complex128)

In [17]:
gamma1 = 1+1j
gamma2= 1+1j
phi1 =0.1
phi2 =0.2
theta1 = 0.4
varphi1 = 0.4
zeta1= 1+1j
zeta2= 1+1j
theta=0.6
varphi=0.2

In [18]:
R1 = R_matrix2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi, old_state)

In [27]:
from numba import njit
@njit()
def dPsi2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi, state_in, G00, R):
    """
    Computes the gradient of the new state with respect to
    gamma, gamma*, phi, z, z* but not with respect to the old state

    Arguments:
        gamma1 (complex): displacement parameter1
        gamma2 (complex): displacement parameter2
        phi1 (float): phase rotation parameter1
        phi2 (float): phase rotation parameter2
        
        theta1(float): transmissivity angle of the beamsplitter1
        varphi1(float): reflection phase of the beamsplitter1
        
        zeta1 (complex): squeezing parameter1
        zeta2 (complex): squeezing parameter2
        
        theta(float): transmissivity angle of the beamsplitter
        varphi(float): reflection phase of the beamsplitter
        
        state_in: (complex array[bath,D,D]): old state
        G00 (complex array[D,D]): G[0,0,:,:] of the G matrix
        R (complex array[bath, D,D,D,D]): complete R matrix R[:,:,:,:] (!not really complete....)

    Returns:
        (complex array[batch, D, D, 14]): gradient of the new state with respect to
                                    gamma1, gamma1*, gamma2, gamma2*, phi1, phi2, theta1, varphi1, zeta1, zeta1*, zeta2, zeta2*, theta, varphi
    """
    batch, cutoff, _ = state_in.shape
    dtype = state_in.dtype
    
    C, mu, Sigma = C_mu_Sigma2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi)
    dC, dmu, dSigma = dC_dmu_dSigma2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi)
    
    sqrt = np.sqrt(np.arange(cutoff, dtype=dtype))
    sqrtT = sqrt.reshape(-1, 1)
    
    dR = np.zeros((batch, cutoff, cutoff, cutoff+1 , cutoff+1, 14),dtype = dtype)
    dG00 = np.zeros((cutoff, cutoff, 14),dtype = dtype)
    
    dG00[0,0] = dC
    for q in range(1, cutoff):
        dG00[0,q] = (dmu[3]*G00[0,q-1]+mu[3]*dG00[0,q-1] - dSigma[3,3]*sqrt[q-1]*G00[0,q-2]- Sigma[3,3]*sqrt[q-1]*dG00[0,q-2])/sqrt[q]


    for p in range(1,cutoff):
        for q in range(0,cutoff):
            dG00[p,q] = (dmu[2]*G00[p-1,q]+ mu[2]*dG00[p-1,q] - dSigma[2,2]*sqrt[p-1]*G00[p-2,q]- Sigma[2,2]*sqrt[p-1]*dG00[p-2,q] - dSigma[2,3]*sqrt[q]*G00[p-1,q-1]- Sigma[2,3]*sqrt[q]*dG00[p-1,q-1])/sqrt[p]
                    
    dG002 = dG00
    for j in range(cutoff):
        dG003 = dG002
        for k in range(cutoff):
        #dG003[D,D,14]*state_in[batch,D,D] - > we want [batch,14]
            test = ed(dG003,0)*ed(state_in[:,j:,k:],-1)
            dR[:,0,0,j,k] = test.sum(axis=1).sum(axis=1)
            dG003 = dG003[:,:-1]*ed(sqrt[k+1:],1)
        dG002 = ed(sqrtT[j+1:],1)*dG002[:-1,:]

    for n in range(1,cutoff):
        for k in range(0,cutoff):
            for j in range(0,cutoff):
            #dR[batch,D,D,D,D,14] R[batch,D,D,D,D] dmu[4,14] dSigma[4,4,14]
            # dR[:,0,n,j,k] = [batch,14]
#                 dR[:,0,n,j,k] = (dmu[1]*R[:,0,n-1,j,k] + ed(mu[1],0)*dR[:,0,n-1,j,k] - dSigma[1,1]*sqrt[n-1]*R[:,0,n-2,j,k] - Sigma[1,1]*sqrt[n-1]*dR[:,0,n-2,j,k] - dSigma[1,2]*R[:,0,n-1,j+1,k] - Sigma[1,2]*dR[:,0,n-1,j+1,k] - dSigma[1,3]*R[:,0,n-1,j,k+1] - Sigma[1,3]*dR[:,0,n-1,j,k+1]
#                                 )/sqrt[n]
                dR[:,0,n,j,k] = (ed(dmu[1],0)*R[:,0,n-1,j,k]+ mu[1]*dR[:,0,n-1,j,k]- ed(dSigma[1,1],0)*sqrt[n-1]*R[:,0,n-2,j,k]- Sigma[1,1]*sqrt[n-1]*dR[:,0,n-2,j,k] - ed(dSigma[1,2],0)*R[:,0,n-1,j+1,k] - Sigma[1,2]*dR[:,0,n-1,j+1,k] - ed(dSigma[1,3],0)*R[:,0,n-1,j,k+1] - Sigma[1,3]*dR[:,0,n-1,j,k+1])/sqrt[n]


    for m in range(1,cutoff):
        for n in range(0,cutoff):
            for j in range(0,cutoff-m):
                for k in range(0,cutoff-m-j):
#                     dR[:,m,n,j,k] = (dmu[0]*R[m-1,n,j,k] + mu[0]*dR[:,m-1,n,j,k] - dSigma[0,0]*sqrt[m-1]*R[m-2,n,j,k] - Sigma[0,0]*sqrt[m-1]*dR[:,m-2,n,j,k] - dSigma[0,1]*sqrt[n]*R[m-1,n-1,j,k] - Sigma[0,1]*sqrt[n]*dR[:,m-1,n-1,j,k] - dSigma[0,2]*R[m-1,n,j+1,k] - Sigma[0,2]*dR[:,m-1,n,j+1,k] - dSigma[0,3]*R[m-1,n,j,k+1] - Sigma[0,3]*dR[:,m-1,n,j,k+1])/sqrt[m]
                    dR[:,m,n,j,k] = (ed(dmu[0],0)*R[:,m-1,n,j,k] + mu[0]*dR[:,m-1,n,j,k]  - ed(dSigma[0,0],0)*sqrt[m-1]*R[:,m-2,n,j,k]  - Sigma[0,0]*sqrt[m-1]*dR[:,m-2,n,j,k]- ed(dSigma[0,1],0)*sqrt[n]*R[:,m-1,n-1,j,k] - Sigma[0,1]*sqrt[n]*dR[:,m-1,n-1,j,k] - ed(dSigma[0,2],0)*R[:,m-1,n,j+1,k] - Sigma[0,2]*dR[:,m-1,n,j+1,k]  - ed(dSigma[0,3],0)*R[:,m-1,n,j,k+1]  - Sigma[0,3]*dR[:,m-1,n,j,k+1])/sqrt[m]
    return np.transpose(dR[:,:,:,0,0,:],(3,0,1,2))

In [28]:
gamma1 = 1+1j
gamma2= 1+1j
phi1 =0.1
phi2 =0.2
theta1 = 0.4
varphi1 = 0.4
zeta1= 1+1j
zeta2= 1+1j
theta=0.6
varphi=0.2
state_in = np.ones((1,2,2),dtype=np.complex128)
state_in[0,0,0] = 5
state_in[0,0,1] = 6
state_in[0,1,0] = 7
state_in[0,1,1] = 8
G00 = np.array([[1j,2,],[3,4]],dtype=np.complex128)
R=R1

In [29]:
dpsi2 = dPsi2(gamma1, gamma2, phi1, phi2, theta1, varphi1, zeta1, zeta2, theta, varphi, state_in, G00, R)

In [30]:
dpsi2

array([[[[-1.93329724e-02-2.34987092e-02j,
          -4.73216389e+00-2.49363226e+00j,
          -1.93329724e-02-2.34987092e-02j,
          -3.44841597e+00-5.06139770e+00j,
          -2.28136331e+00-7.22996189e+00j,
           1.57015005e+00-8.51397941e+00j,
           7.85857585e-01-4.16802658e+00j,
           1.39401867e-01-5.81273501e+00j,
           4.10993078e+00-4.97853250e+00j,
           4.48872835e+00+1.85123967e+00j,
           5.20785969e-01-5.78364169e-01j,
           1.04070991e-01+2.82042761e+00j,
          -3.54222102e-01-1.01204305e+00j,
          -2.53503348e+00+2.08247592e+00j],
         [-1.57263891e-02-8.39983065e-02j,
          -4.62104594e+00-1.33906310e+01j,
          -1.32088724e-02-7.77028725e-02j,
           3.90829095e+00-1.78436972e+01j,
           8.66986041e+00-1.80799489e+01j,
           2.16406289e+01-1.40072913e+01j,
           1.09553393e+01-3.98314221e+00j,
           1.22104591e+01-1.17154408e+01j,
           1.75872027e+01-1.86332021e+00j,
          

In [31]:
dpsi2.shape

(1, 2, 2, 14)

In [32]:
np.transpose(dpsi2,(3,0,1,2))

array([[[[-1.93329724e-02-2.34987092e-02j,
          -1.57263891e-02-8.39983065e-02j],
         [-1.32028427e-02-5.53354356e-02j,
           3.15446738e-02-1.69276179e-01j]]],


       [[[-4.73216389e+00-2.49363226e+00j,
          -4.62104594e+00-1.33906310e+01j],
         [-6.39288819e+00-1.05079180e+01j,
           5.19052562e+00-3.18448972e+01j]]],


       [[[-1.93329724e-02-2.34987092e-02j,
          -1.32088724e-02-7.77028725e-02j],
         [-1.57203594e-02-6.16308696e-02j,
           3.37302721e-02-1.73934895e-01j]]],


       [[[-3.44841597e+00-5.06139770e+00j,
           3.90829095e+00-1.78436972e+01j],
         [-1.96245341e-01-1.49610552e+01j,
           3.08649311e+01-3.17750861e+01j]]],


       [[[-2.28136331e+00-7.22996189e+00j,
           8.66986041e+00-1.80799489e+01j],
         [ 4.03070280e+00-1.69481442e+01j,
           3.68701026e+01-2.68998279e+01j]]],


       [[[ 1.57015005e+00-8.51397941e+00j,
           2.16406289e+01-1.40072913e+01j],
         [ 1.46874586e+