In [1]:
import numpy as np
import pickle
import os, sys
import pathlib
import jax
import pandas as pd
import optax
import jax.numpy as jnp
import sklearn
import matplotlib.pyplot as plt

from jax.example_libraries import stax
from jax import grad, jit, vmap
import time

dir_dir = '/Users/mxd6118/Desktop/DiffSim/'
#src_dir = os.path.dirname(dir_dir) + "/src/"
sys.path.insert(0,'/Users/mxd6118/Desktop/DiffSim')
from Plots import *

from src.simulator.NEW_Simulator_normal import simulate_waveforms, init_params
from jax import random

In [2]:
def get_data(data, name_data):
        
        with open(f'{name_data}.pickle','wb') as f:
            pickle.dump(data, f)
            f.close()

class Producer():

    def __init__(self):
        
        self.key = random.PRNGKey(int(time.time()))
        number_of_events = int(input('How many events to be produced?'))
        params_path = dir_dir + "/bin/output/8530/krypton/test_noise_constant_uniform/trained_params.pickle"
        params = self.load_state(params_path)
        
        
        self.dataloader  = self.build_dataloader(number_of_events)
        
        
        batch_real = next(self.dataloader.iterate())
        batch_real['Label'] = np.ones(len( batch_real['energy_deposits']))
        
        energy_depo = batch_real['energy_deposits']#self.build_random_batch(number_of_events)
        
        start = time.time()
        produced_pmt,produced_sipm = self.arrays(energy_depo,params)
        time_taken = time.time() - start
        
        batch_fake = {"energy_deposits": energy_depo,
                 "PMT_FAKE": np.array(produced_pmt),
                 "SIPM_FAKE": np.array(produced_sipm),
                 "Label_tempo": np.zeros(len(energy_depo))}
        
        self.data_set = batch_real | batch_fake
        
        #get_data(data_set,f'batch_produced_{number_of_events}')
        
        print(f'All done, time taken {time_taken} sec')

    def build_dataloader(self,number_of_events):

        from src.utils.dataloaders.krypton_DATES_CUSTOM_DROPOUT import krypton
        # Load the sipm database:
        sipm_db = pd.read_pickle("/Users/mxd6118/Desktop/DiffSim/database/new_sipm.pkl")
    
        dl = krypton(
            batch_size  = number_of_events,
            db          = sipm_db,
            path        = "/Users/mxd6118/Desktop/DiffSim/kdst",
            run         = 8530,
            shuffle = True,
            drop = 0,
            
            z_slice = 0,
            )
            
        return dl

    
    def arrays(self,monitor_data,params):
        
        self.key, subkey = jax.random.split(self.key)
        # First, run the monitor data through the simulator:
        simulated_pmts, simulated_sipms = simulate_waveforms(monitor_data, params, subkey)
        
        return simulated_pmts, simulated_sipms
        
    def load_state(self,file):
        with open(file,"rb") as f:
            params = pickle.load(f)
        return params
        

    def build_random_batch(self, number_of_events):
    
        batch =[]
        for i in range(0,number_of_events):
            one = np.hstack((np.random.uniform(low = -150, high = 150),
                             np.random.uniform(low = -150, high = 150),
                             np.random.uniform(low = 20,   high = 500),0.0415575))

            two = np.vstack((one,np.zeros(4)))
    
            batch.append(two)

        return np.array(batch)

In [3]:
from jax.example_libraries import optimizers as jax_opt

class GAN:
    
    def __init__(self):
        
        prod = Producer()
        
        data = prod.data_set
        
        self.batch = self.Chanteclair(data)
        
        #print(self.batch['Labels'].shape, flush = True)
        
        self.key = random.PRNGKey(int(time.time()))
        
        self.key, self.subkey = random.split(self.key)
        
        parameters, self.dis_apply = self.init_params(self.subkey)
        
        #print(self.out_size, flush = True)
        
        self.trainer = self.build_trainer(self.batch, self.dis_apply, parameters)
        
       
    def Chanteclair(self,data):
        train_batch_filtered = {}
        train_batch_filtered['S2Si'] = []
        train_batch_filtered['SIPM_FAKE'] =[]


        for n in range(0,len(data['energy_deposits'])):
            for z in np.unique(np.where(data['S2Si'][n] != 0)[2]):

                train_batch_filtered['S2Si'].append(data['S2Si'][n,:,:,z])
                train_batch_filtered['SIPM_FAKE'].append(data['SIPM_FAKE'][n,:,:,z])


        l = len(train_batch_filtered['S2Si'])

        train_batch_filtered['train'] = np.vstack((train_batch_filtered['S2Si'],
                                                   train_batch_filtered['SIPM_FAKE']))
        
        labels =[]

        for c in range(0,2*l):
            if c < l:
                labels.append(np.array((1,0)))
            else:
                labels.append(np.array((0,1)))

        train_batch_filtered['Labels'] = np.array(labels)


        train, labels = sklearn.utils.shuffle(train_batch_filtered['train'],
                                              train_batch_filtered['Labels'])


        batch = {'Train': train, 'Labels' :labels}
        
        return batch 
        
    
    def init_params(self, subkey):
        
        dis_init, dis_apply = stax.serial(
            stax.Flatten,
            stax.Dense(128),stax.Sigmoid,
            stax.Dense(16), stax.Sigmoid,
            stax.Dense(2),stax.Softmax
        )
        
        dis_out_size, dis_network_params = dis_init(subkey,(1,47,47))
        
        parameters = {
        'D_parameters': dis_network_params
        }
        
        return parameters, dis_apply
    
    def build_trainer(self, batch, fn, params):

        # Shouldn't reach this portion unless training.
        trainer = GAN_trainer(batch, fn, params)
        
        return trainer

   
    def train(self):
        c = 0
        self.key = jax.random.PRNGKey(int(time.time()))
        self.key, subkey = jax.random.split(self.key)

        while c <= 100:

            metrics = {}
            start = time.time()

            metrics["io_time"] = time.time() - start

            train_metrics, opt_state, acc = self.trainer.train_iteration(self.batch, c)
            
            #print('train metrics',train_metrics.keys(),flush = True)

            # print(model_parameters.keys())
            # print(model_parameters['diffusion'])
            
            metrics.update(train_metrics)

            metrics['time'] = time.time() - start
            metrics['accuracy'] = acc
        

            if c % 1 == 0:
                print(f"step = {c}, loss = {metrics['loss/loss']:.3f}, acc = {metrics['accuracy']:.3f}, time = {metrics['time']:.3f}",flush = True)

            c += 1
            
        get_data(self.trainer.get_params(opt_state),'D_params')
        

In [4]:
def binary_cross_entropy(y_true, y_pred):
    epsilon = 1e-8 # Small value to avoid division by zero
    y_pred = jnp.clip(y_pred, epsilon, 1.0 - epsilon)  # Clip values to prevent NaNs
    loss = -(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred))
    return jnp.mean(loss)
    

class GAN_trainer():
    
    def __init__(self,batch,fn,parameters):
        
        self.key = jax.random.PRNGKey(int(time.time()))
        self.key, subkey = jax.random.split(self.key)
        self.dis_apply = fn
        @jit
        def forward_pass(batch, parameters, key):
            

            fake_labels = self.dis_apply(parameters['D_parameters'], batch['Train'])
        
            loss = binary_cross_entropy(batch['Labels'],fake_labels)
                          
            return loss
        
        self.gradient_fn = jit(jax.value_and_grad(forward_pass, argnums=1))
        
        opt_init, opt_update, get_params = jax_opt.adamax(2e-3)
            
            
        self.opt_state = opt_init(parameters)

        self.opt_update = opt_update
        self.get_params = get_params
    
    def parameters(self):
        
        p = self.get_params(self.opt_state)
        
        #print(p.keys())
    
        parameters =  p['D_parameters']
        
        return parameters
        
    
    def train_iteration(self, batch, c):
        
        metrics = {}

        self.parameters = self.get_params(self.opt_state)
        
        #print(parameters)

        self.key, subkey = jax.random.split(self.key)

        loss, gradients = self.gradient_fn(batch, self.parameters, subkey)
        
        #print(gradients['Dis_parameters'])
        
        self.opt_state = self.opt_update(c, gradients, self.opt_state)

        metrics['loss/loss'] = loss

        metrics.update(self.parameters)
        
        accuracy = self.acc(self.parameters,self.dis_apply, batch)

        return metrics, self.opt_state, accuracy
    
    def acc(self, parameters,fn,batch):
        
        fake_labels = fn(parameters['D_parameters'], batch['Train'])
        
        accuracy = len(np.unique(np.where(np.argmax(fake_labels, axis = 1) - np.argmax(batch['Labels'],axis=1) == 0))) /len(batch['Labels'])
        
        return accuracy 

In [6]:
 GAN().train()

How many events to be produced?20
All done, time taken 2.5022847652435303 sec
step = 0, loss = 0.713, acc = 0.500, time = 0.222
step = 1, loss = 0.671, acc = 0.604, time = 0.008
step = 2, loss = 0.647, acc = 0.766, time = 0.007
step = 3, loss = 0.628, acc = 0.746, time = 0.008
step = 4, loss = 0.609, acc = 0.768, time = 0.008
step = 5, loss = 0.591, acc = 0.802, time = 0.009
step = 6, loss = 0.573, acc = 0.812, time = 0.009
step = 7, loss = 0.556, acc = 0.826, time = 0.009
step = 8, loss = 0.540, acc = 0.841, time = 0.009
step = 9, loss = 0.525, acc = 0.845, time = 0.008
step = 10, loss = 0.510, acc = 0.853, time = 0.009
step = 11, loss = 0.496, acc = 0.857, time = 0.008
step = 12, loss = 0.481, acc = 0.865, time = 0.010
step = 13, loss = 0.466, acc = 0.867, time = 0.010
step = 14, loss = 0.452, acc = 0.867, time = 0.008
step = 15, loss = 0.439, acc = 0.867, time = 0.009
step = 16, loss = 0.426, acc = 0.867, time = 0.009
step = 17, loss = 0.413, acc = 0.870, time = 0.009
step = 18, los