In [None]:
import os
import numpy as np
import random
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import time
import matplotlib.pyplot as plt

import nflows
from nflows.flows.base import Flow
from nflows.transforms.standard import AffineTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.base import CompositeTransform

## Define the model class

In [None]:
class PC_MAF(nn.Module):
    def __init__(self, 
                 dim_condition,
                 dim_input,
                 num_coupling_layers=1,
                 hidden_size=128,
                 device='cpu',
                 weight_particles=False,
                 num_blocks_mat = 2,     
                 activation = 'relu',
                 random_mask = False):
        
        '''
        Masked autoregressive flows model from https://papers.nips.cc/paper/2017/hash/6c1da886822c67822bcf3679d04369fa-Abstract.html
        Args:
            dim_condition(integer): dimensionality of condition
            dim_input(integer): dimensionality of input
            num_coupling_blocks(integer): number of coupling blocks in the model
            hidden_size(integer): number of hidden units per hidden layer in subnetworks
            device: "cpu" or "cuda"

        '''

        super().__init__()
        self.device = device
        self.num_coupling_layers = num_coupling_layers
        self.hidden_size = hidden_size

        self.dim_input = dim_input
        self.dim_condition = dim_condition
        
        self.num_blocks_mat = num_blocks_mat
        self.random_mask = random_mask
        
        
        # Activation functions
        activation_functions = {
            'relu': F.relu,
            'sigmoid': torch.sigmoid,
            'tanh': torch.tanh,
            'elu': F.elu,
            'silu': F.silu,
            'leaky_relu': F.leaky_relu,
            'gelu': F.gelu
        }
        self.activation = activation_functions.get(activation)
        
        if self.activation is None:
            raise ValueError("Unsupported activation function")
            
        self.model = self.init_model().to(self.device)
        self.weight_particles = weight_particles
    
    def init_model(self):
        base_dist = nflows.distributions.normal.StandardNormal(shape=[self.dim_input])
        
        transforms = []
        for _ in range(self.num_coupling_layers):
            transforms.append(ReversePermutation(features=self.dim_input))
            transforms.append(MaskedAffineAutoregressiveTransform(features=self.dim_input, 
                                                                  hidden_features=self.hidden_size, 
                                                                  context_features=self.dim_condition,
                                                                  use_residual_blocks=True,  
                                                                  num_blocks = self.num_blocks_mat,
                                                                  activation = self.activation,
                                                                  random_mask = self.random_mask))
        transform = CompositeTransform(transforms)

        return Flow(transform, base_dist).to(self.device)

    def forward(self, x, p):
        loss = self.model(x, c=p)
        return loss

## Functions for normalization, sampling, plotting, and saving checkpoints

In [None]:
def normalize_columns(original_array):
    xyz_columns = original_array[:, :3]
    x_min, x_max = xyz_columns[:, 0].min(), xyz_columns[:, 0].max()
    y_min, y_max = xyz_columns[:, 1].min(), xyz_columns[:, 1].max()
    z_min, z_max = xyz_columns[:, 2].min(), xyz_columns[:, 2].max()

    xyz_columns[:, 0] = (xyz_columns[:, 0] - x_min) / (x_max - x_min)
    xyz_columns[:, 1] = (xyz_columns[:, 1] - y_min) / (y_max - y_min)
    xyz_columns[:, 2] = (xyz_columns[:, 2] - z_min) / (z_max - z_min)

    normalized_array = np.concatenate((xyz_columns, original_array[:, 3:]), axis=1)
    return normalized_array

def denormalize_columns(normalized_array, gt):
    
    x_min, x_max = gt[:, 0].min(), gt[:, 0].max()
    y_min, y_max = gt[:, 1].min(), gt[:, 1].max()
    z_min, z_max = gt[:, 2].min(), gt[:, 2].max()

    xyz_columns = normalized_array[:, :3]

    xyz_columns[:, 0] = xyz_columns[:, 0] * (x_max - x_min) + x_min
    xyz_columns[:, 1] = xyz_columns[:, 1] * (y_max - y_min) + y_min
    xyz_columns[:, 2] = xyz_columns[:, 2] * (z_max - z_min) + z_min

    denormalized_array = np.concatenate((xyz_columns, normalized_array[:, 3:]), axis=1)
    return denormalized_array

def sample_pointcloud(model, num_samples, cond):
    model.model.eval()
    with torch.no_grad():
        pc_pr = (model.model.sample(num_samples, cond))
        
    return pc_pr

def random_sample(data, sample_size):
    
    # Check if the sample size is greater than the number of points in the data
    if sample_size > data.shape[0]:
        raise ValueError("Sample size exceeds the number of points in the data")

    random_indices = np.random.choice(data.shape[0], sample_size, replace=False)
    sampled_data = data[random_indices]

    return sampled_data

def save_checkpoint(model, optimizer, path, last_loss, min_valid_loss, epoch, wandb_run_id):
        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'last_loss': last_loss.item(),
            'epoch': epoch,
            'min_valid_loss': min_valid_loss,
            'wandb_run_id': wandb_run_id,
        }

        torch.save(state, path + '/model_' + str(epoch))
        
        
def create_position_density_plots(x, y, z,
                                  x_pr, y_pr, z_pr,
                                  bins=100, t=1000, path=''):
    
    # Specify the number of bins for each axis
    bins_x = np.linspace(min(x), max(x), bins)
    bins_y = np.linspace(min(y), max(y), bins)
    bins_z = np.linspace(min(z), max(z), bins)
    
    # Create subplots for each plane
    plt.figure(figsize=(15, 10))
    
    # XY Plane Ground Truth
    plt.subplot(231)
    plt.hist2d(x, y, bins=[bins_x, bins_y], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('XY Plane Ground Truth at t = {}'.format(t))
    
    # XZ Plane Ground Truth
    plt.subplot(232)
    plt.hist2d(x, z, bins=[bins_x, bins_z], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('X')
    plt.ylabel('Z')
    plt.title('XZ Plane Ground Truth at t = {}'.format(t))
    
    # YZ Plane Ground Truth
    plt.subplot(233)
    plt.hist2d(y, z, bins=[bins_y, bins_z], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('Y')
    plt.ylabel('Z')
    plt.title('YZ Plane Ground Truth at t = {}'.format(t))
    
    # XY Plane Prediction
    plt.subplot(234)
    plt.hist2d(x_pr, y_pr, bins=[bins_x, bins_y], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('XY Plane Prediction at t = {}'.format(t))
    
    # XZ Plane Prediction
    plt.subplot(235)
    plt.hist2d(x_pr, z_pr, bins=[bins_x, bins_z], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('X')
    plt.ylabel('Z')
    plt.title('XZ Plane Prediction at t = {}'.format(t))
    
    # YZ Plane Prediction
    plt.subplot(236)
    plt.hist2d(y_pr, z_pr, bins=[bins_y, bins_z], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('Y')
    plt.ylabel('Z')
    plt.title('YZ Plane Prediction at t = {}'.format(t))
    
    plt.tight_layout()
    plt.show()
    
    # Save the plots as image files
    if path:
        plt.savefig(path + '/density_plots_{}.png'.format(t))

def create_momentum_density_plots(px, py, pz,
                                  px_pr, py_pr, pz_pr,
                                  bins=100, t=1000, path=''):
    
    # Specify the number of bins for each axis
    bins_px = np.linspace(min(px), max(px), bins)
    bins_py = np.linspace(min(py), max(py), bins)
    bins_pz = np.linspace(min(pz), max(pz), bins)
    
    # Create subplots for each plane
    plt.figure(figsize=(15, 10)) 
    
    # px-py Plane Ground Truth
    plt.subplot(231)
    plt.hist2d(px, py, bins=[bins_px, bins_py], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('px')
    plt.ylabel('py')
    plt.title('px-py Plane Ground Truth at t = {}'.format(t))
    
    # px-pz Plane Ground Truth
    plt.subplot(232)
    plt.hist2d(px, pz, bins=[bins_px, bins_pz], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('px')
    plt.ylabel('pz')
    plt.title('px-pz Plane Ground Truth at t = {}'.format(t))
    
    # py-pz Plane Ground Truth
    plt.subplot(233)
    plt.hist2d(py, pz, bins=[bins_py, bins_pz], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('py')
    plt.ylabel('pz')
    plt.title('py-pz Plane Ground Truth at t = {}'.format(t))
    
    # px-py Plane Prediction
    plt.subplot(234)
    plt.hist2d(px_pr, py_pr, bins=[bins_px, bins_py], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('px_pr')
    plt.ylabel('py_pr')
    plt.title('px-py Plane Prediction at t = {}'.format(t))
    
    # px-pz Plane Prediction
    plt.subplot(235)
    plt.hist2d(px_pr, pz_pr, bins=[bins_px, bins_pz], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('px_pr')
    plt.ylabel('pz_pr')
    plt.title('px-pz Plane Prediction at t = {}'.format(t))
    
    # py-pz Plane Prediction
    plt.subplot(236)
    plt.hist2d(py_pr, pz_pr, bins=[bins_py, bins_pz], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('py_pr')
    plt.ylabel('pz_pr')
    plt.title('py-pz Plane Prediction at t = {}'.format(t))
    
    plt.tight_layout()
    plt.show()
    
    # Save the plots as image files
    if path:
        plt.savefig(path + '/momentum_density_plots_{}.png'.format(t))



def create_force_density_plots(fx, fy, fz,
                               fx_pr, fy_pr, fz_pr,
                               bins=100, t=1000, path=''):
    
    # Specify the number of bins for each axis
    bins_fx = np.linspace(min(fx), max(fx), bins)
    bins_fy = np.linspace(min(fy), max(fy), bins)
    bins_fz = np.linspace(min(fz), max(fz), bins)
    
    # Create subplots for each plane
    plt.figure(figsize=(15, 10))  # Adjust the figure size
    
    # fx-fy Plane Ground Truth
    plt.subplot(231)
    plt.hist2d(fx, fy, bins=[bins_fx, bins_fy], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('fx')
    plt.ylabel('fy')
    plt.title('fx-fy Plane Ground Truth at t = {}'.format(t))
    
    # fx-fz Plane Ground Truth
    plt.subplot(232)
    plt.hist2d(fx, fz, bins=[bins_fx, bins_fz], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('fx')
    plt.ylabel('fz')
    plt.title('fx-fz Plane Ground Truth at t = {}'.format(t))
    
    # fy-fz Plane Ground Truth
    plt.subplot(233)
    plt.hist2d(fy, fz, bins=[bins_fy, bins_fz], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('fy')
    plt.ylabel('fz')
    plt.title('fy-fz Plane Ground Truth at t = {}'.format(t))
    
    # fx-fy Plane Prediction
    plt.subplot(234)
    plt.hist2d(fx_pr, fy_pr, bins=[bins_fx, bins_fy], cmap='Blues')
    plt.colorbar(label='Density')
    plt.xlabel('fx_pr')
    plt.ylabel('fy_pr')
    plt.title('fx-fy Plane Prediction at t = {}'.format(t))
    
    # fx-fz Plane Prediction
    plt.subplot(235)
    plt.hist2d(fx_pr, fz_pr, bins=[bins_fx, bins_fz], cmap='Greens')
    plt.colorbar(label='Density')
    plt.xlabel('fx_pr')
    plt.ylabel('fz_pr')
    plt.title('fx-fz Plane Prediction at t = {}'.format(t))
    
    # fy-fz Plane Prediction
    plt.subplot(236)
    plt.hist2d(fy_pr, fz_pr, bins=[bins_fy, bins_fz], cmap='Reds')
    plt.colorbar(label='Density')
    plt.xlabel('fy_pr')
    plt.ylabel('fz_pr')
    plt.title('fy-fz Plane Prediction at t = {}'.format(t))
    
    plt.tight_layout()
    plt.show()
    
    # Save the plots as image files
    if path:
        plt.savefig(path + '/force_density_plots_{}.png'.format(t))
    

def inference(gpu_index,t_index):

    p_gt = np.load(hyperparameter_defaults["pathpattern1"].format(t_index),allow_pickle = True)

    p_gt = [random_sample(element, sample_size=10000) for element in p_gt]
    p_gt = np.array(p_gt, dtype = np.float32)

    p_rad = torch.from_numpy(np.load(hyperparameter_defaults["pathpattern2"].format(t_index)).astype(np.cfloat))

    p_rad_x = p_rad[gpu_index,0,:]
    p_rad_y = p_rad[gpu_index,1,:]
    p_rad_z = p_rad[gpu_index,2,:]

    p_rad = p_rad[:, 1:, :]
    p_rad = p_rad.view(p_rad.shape[0], -1)
    p_rad = p_rad.unsqueeze(1)

    p_rad = p_rad[gpu_index,:,:]
    p_gt = p_gt[gpu_index,:,:]

    # Compute the phase (angle) of the complex number
    phase = torch.angle(p_rad)

    # Compute the amplitude (magnitude) of the complex number
    amplitude = torch.abs(p_rad)
    p_rad = torch.cat((amplitude, phase), dim=1).to(torch.float32)

    num_samples = 1
    cond = p_rad.cuda()

    pc_pr =  sample_pointcloud(model, num_samples, cond)

    pc_pr = pc_pr.squeeze().cpu().numpy()

    pc_pr = pc_pr.reshape(10000,9)

    pc_pr = denormalize_columns(pc_pr, p_gt)

    x = p_gt[:, 0]  # X coordinates
    y = p_gt[:, 1]  # Y coordinates
    z = p_gt[:, 2]  # Z coordinates

    px = p_gt[:, 3]  # Px component of momentum
    py = p_gt[:, 4]  # Py component of momentum
    pz = p_gt[:, 5]  # Pz component of momentum

    fx = p_gt[:, 6]  # Fx component of force
    fy = p_gt[:, 7]  # Fy component of force
    fz = p_gt[:, 8]  # Fz component of force


    x_pr = pc_pr[:, 0]  # X coordinates
    y_pr = pc_pr[:, 1]  # Y coordinates
    z_pr = pc_pr[:, 2]  # Z coordinates

    px_pr = pc_pr[:, 3]  # Px component of momentum
    py_pr = pc_pr[:, 4]  # Py component of momentum
    pz_pr = pc_pr[:, 5]  # Pz component of momentum

    fx_pr = pc_pr[:, 6]  # Fx component of force
    fy_pr = pc_pr[:, 7]  # Fy component of force
    fz_pr = pc_pr[:, 8]  # Fz component of force


    create_position_density_plots(x, y, z, x_pr, y_pr, z_pr, bins=100, t=t_index)

    create_momentum_density_plots(px, py, pz, px_pr, py_pr, pz_pr, bins=100, t=t_index)

    create_force_density_plots(fx, fy, fz, fx_pr, fy_pr, fz_pr, bins=100, t=t_index)
      

## Define the loader

In [None]:
class Loader:
    def __init__(self, pathpattern1="/bigdata/hplsim/aipp/Jeyhun/khi/particle_box/40_80_80_160_0_2/{}.npy", pathpattern2="/bigdata/hplsim/aipp/Jeyhun/khi/part_rad/radiation_ex/{}.npy", t0=0, t1=100, timebatchsize=20, particlebatchsize=10240):
        self.pathpattern1 = pathpattern1
        self.pathpattern2 = pathpattern2
                
        self.t0 = t0
        self.t1 = t1
        
        self.timebatchsize = timebatchsize
        self.particlebatchsize = particlebatchsize

        num_files = t1 - t0
        missing_files = [i for i in range(t0, t1) if not os.path.exists(pathpattern1.format(i))]
        num_missing = len(missing_files)
        all_files_exist = num_missing == 0

        if all_files_exist:
            print("All {} files from {} to {} exist in the directory.".format(num_files, t0, t1))
        else:
            print("{} files are missing out of {} in the directory.".format(num_missing, num_files))

    def __len__(self):
        return self.t1 - self.t0
        
    def __getitem__(self, idx):
        
        class Epoch:
            def __init__(self, loader, t0, t1, timebatchsize=20, particlebatchsize=10240):
                self.perm = torch.randperm(len(loader))
                self.loader = loader
                self.t0 = t0
                self.t1 = t1
                self.timebatchsize = timebatchsize
                self.particlebatchsize = particlebatchsize

            def __len__(self):
                return len(self.loader) // self.timebatchsize
        
            def __getitem__(self, timebatch):
                i = self.timebatchsize*timebatch
                bi = self.perm[i:i+self.timebatchsize]
                radiation = []
                particles = []
                for time in bi:
                    index = time + self.t0
                    
                    p = np.load(self.loader.pathpattern1.format(index), allow_pickle = True)
                    
                    p = [normalize_columns(element) for element in p]
                    p = np.array(p, dtype=object)
                    
                    p = [random_sample(element, sample_size=10000) for element in p]
                    p = torch.from_numpy(np.array(p, dtype = np.float32))
                    
                    p = p.view(p.shape[0],-1)
                    
                    r = torch.from_numpy(np.load(self.loader.pathpattern2.format(index)).astype(np.cfloat) )
                    r = r[:, 1:, :]
                    r = r.view(r.shape[0], -1)
                    
                    # Compute the phase (angle) of the complex number
                    phase = torch.angle(r)
                    
                    # Compute the absolute value of the complex number
                    absolute = torch.abs(r)
                    r = torch.cat((absolute, phase), dim=1).to(torch.float32)

                    particles.append(p)
                    radiation.append(r)
                
                particles = torch.cat(particles)
                radiation = torch.cat(radiation)
                
                class Timebatch:
                    def __init__(self, particles, radiation, batchsize):
                        self.batchsize = batchsize
                        self.particles = particles
                        self.radiation = radiation
                        
                        self.perm = torch.randperm(self.radiation.shape[0])
                        
                    def __len__(self):
                        return self.radiation.shape[0] // self.batchsize

                    def __getitem__(self, batch):
                        i = self.batchsize*batch
                        bi = self.perm[i:i+self.batchsize]
                    
                        return self.particles[bi], self.radiation[bi]
                
                return Timebatch(particles, radiation, self.particlebatchsize)
                    
        return Epoch(self, self.t0, self.t1, self.timebatchsize, self.particlebatchsize)

## Set the hyperparameters of the model and initialise it

In [None]:
hyperparameter_defaults = dict(
t0 = 1990,
t1 = 2001,
dim_input = 90000,
timebatchsize = 4,
particlebatchsize = 32,
dim_condition = 2048,
num_coupling_layers = 3,
hidden_size = 64,
lr = 0.001,
num_epochs = 10,
num_blocks_mat = 2,
activation = 'gelu',
pathpattern1 = "/bigdata/hplsim/aipp/Jeyhun/khi/part_rad/particle_002/{}.npy",
pathpattern2 = "/bigdata/hplsim/aipp/Jeyhun/khi/part_rad/radiation_ex_002/{}.npy"
)

enable_wandb = False
start_epoch = 0
min_valid_loss = np.inf

if enable_wandb:
    print('New session...')
    # Pass your defaults to wandb.init
    wandb.init(entity="jeyhun", config=hyperparameter_defaults, project="khi_public")

    # Access all hyperparameter values through wandb.config
    config = wandb.config

l = Loader(pathpattern1 = hyperparameter_defaults["pathpattern1"],
           pathpattern2 = hyperparameter_defaults["pathpattern2"],
           t0 = hyperparameter_defaults["t0"],
           t1 = hyperparameter_defaults["t1"],
           timebatchsize = hyperparameter_defaults["timebatchsize"],
           particlebatchsize = hyperparameter_defaults["particlebatchsize"])

model = (PC_MAF(dim_condition = hyperparameter_defaults["dim_condition"],
                           dim_input = hyperparameter_defaults["dim_input"],
                           num_coupling_layers = hyperparameter_defaults["num_coupling_layers"],
                           hidden_size = hyperparameter_defaults["hidden_size"],
                           device = 'cuda',
                           num_blocks_mat = hyperparameter_defaults["num_blocks_mat"],
                           activation = hyperparameter_defaults["activation"]
                         ))

# Calculate the total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

optimizer = optim.Adam(model.parameters(), lr=hyperparameter_defaults["lr"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)

if enable_wandb:
    directory ='/bigdata/hplsim/aipp/Jeyhun/khi/checkpoints/'+str(wandb.run.id)

    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Directory '{directory}' created.")
    else:
        print(f"Directory '{directory}' already exists.")

## Training loop

In [None]:
epoch = l[0]

start_time = time.time()
for i_epoch in range(start_epoch, hyperparameter_defaults["num_epochs"]):   
    loss_overall = []
    for tb in range(len(epoch)):
        loss_avg = []
        timebatch = epoch[tb]

        start_timebatch = time.time()
        for b in range(len(timebatch)):
            optimizer.zero_grad()
            phase_space, radiation = timebatch[b]

            loss = - model.model.log_prob(inputs=phase_space.to(model.device),context=radiation.to(model.device))

            loss = loss.mean()
            loss_avg.append(loss.item())
            loss.backward()
            optimizer.step()

        end_timebatch = time.time()
        elapsed_timebatch = end_timebatch - start_timebatch

        loss_timebatch_avg = sum(loss_avg)/len(loss_avg)
        loss_overall.append(loss_timebatch_avg)
        print('i_epoch:{}, tb: {}, last timebatch loss: {}, avg_loss: {}, time: {}'.format(i_epoch,tb,loss.item(), loss_timebatch_avg, elapsed_timebatch))

    loss_overall_avg = sum(loss_overall)/len(loss_overall)  

    if min_valid_loss > loss_overall_avg:     
        print(f'Training Loss Decreased({min_valid_loss:.6f}--->{loss_overall_avg:.6f}) \t Saving The Model')
        min_valid_loss = loss_overall_avg
        # Saving State Dict
        # torch.save(model.state_dict(), directory + '/best_model_', _use_new_zipfile_serialization=False)

    if (i_epoch + 1) % 10 == 0 and enable_wandb:
        save_checkpoint(model, optimizer, directory, loss, min_valid_loss, i_epoch, wandb.run.id)

    scheduler.step()

    if enable_wandb:
        # Log the loss and accuracy values at the end of each epoch
        wandb.log({
            "Epoch": i_epoch,
            "loss_timebatch_avg_loss": loss_timebatch_avg,
            "loss_overall_avg": loss_overall_avg,
            "min_valid_loss": min_valid_loss,
        })

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Total elapsed time: {elapsed_time:.6f} seconds")


## Inference

In [None]:
# choose the time step and gpu box you want to visualise
t_index = 2000
gpu_index = 19

inference(gpu_index,t_index)