# Gauge transformation in the GPU including the z dependence

### Set parameters and enviroment variables

In [4]:
import numpy as np

# hbar * c [GeV * fm]
hbarc = 0.197326 

# Simulation box 
L = 2      
# N = 128   
N = 16 
tau_sim = 1     
DTS = 4     

# Glasma fields
su_group = 'su3'
Qs = 2        
ns = 50    
factor = 0.8        
g2mu = Qs / factor     
g = np.pi * np.sqrt(1 / np.log(Qs / 0.2))          		
mu = g2mu / g**2          	
ir = 0.1 * g**2 * mu         
uv = 10.0       

# TODO: Run more events
nevents = 1

In [5]:
import os

os.environ["MY_NUMBA_TARGET"] = "cuda"
os.environ["PRECISION"] = "double"
os.environ['GAUGE_GROUP'] = su_group

# Import relevant modules
import sys
sys.path.append('..')

# Glasma modules
import curraun.core as core
import curraun.mv as mv
import curraun.initial as initial
initial.DEBUG = False

import curraun.su as su
from curraun.numba_target import use_cuda
if use_cuda:
    from numba import cuda

import curraun.su as su
import curraun.lc_gauge as lc

### We define the simulation routine

In [None]:
from tqdm import tqdm

# Simulation rutine
def simulate():
    output = {}
    
    # Derived parameters
    a = L/N
    E0 = N / L * hbarc
    DT = 1.0 / DTS
    maxt = int(tau_sim / a * DTS)
    
    #TODO: only to compare with CPU, remove after
    mv.set_seed(24)
    
    # We create the object simulation
    s = core.Simulation(N, DT, g)

    # We initilize the Glasma fields
    va = mv.wilson(s, mu=mu / E0, m=ir / E0, uv=uv / E0, num_sheets=ns)
    vb = mv.wilson(s, mu=mu / E0, m=ir / E0, uv=uv / E0, num_sheets=ns)
    initial.init(s, va, vb)
    
    # We create the necessary objects for the gauge transformations
    nplus = maxt//DTS
    lc = lc.LCGaugeTransf(s, nplus)
    
    # We create the object where we will store the transformed field
    uplus_lc = su.GROUP_TYPE(np.zeros((maxt//DTS-1, N*N, su.GROUP_ELEMENTS)))
    
    
    with tqdm(total=maxt) as pbar:
        for t in range(maxt):            
            # Evolve Glasma fields
            core.evolve_leapfrog(s)

            if t%DTS == 0:
                
                xplus = t//DTS
                lc.evolve_lc(xplus, a)

                if xplus!= 0:
                    uplus_lc[xplus-1] = lc.up_lc.copy()

    if use_cuda:
        cuda.current_context().deallocations.clear()
        
    
    # We write the transformed fields in a dictionary
    output["nplus"] = nplus
    output["uplus_lc"] = uplus_lc