# Import the libraries

In [199]:
import torch.nn as nn
import torch
import pandas as pd
import numpy as np
import torch.optim as optim
from IPython.display import clear_output
clear = lambda: clear_output(wait=True)
import time
import datetime
import matplotlib.pyplot as plt
from sklearn import preprocessing
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import torchvision.transforms
import utils

import importlib
importlib.reload(utils)
# !tensorboard --logdir=runs

<module 'utils' from 'C:\\Users\\Arsen\\dev\\bachelor\\utils.py'>

In [175]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

Device: cuda


# Read the data

In [176]:
full_df = pd.read_csv('beamfile.txt', sep=':')
print('Total number of points in the dataset:', full_df.size)
full_df.head(3)

Total number of points in the dataset: 4205032


Unnamed: 0,stepNumber,particleFlag,X,Y,Z,dXdZ,dYdZ,P
0,127141,2,0.982458,-0.269195,1130.724,-0.909843,-0.61807,99.899742
1,127141,2,8.08756,6.26015,1130.724,0.228221,-1.226108,105.682022
2,127141,2,-0.482854,6.407821,1130.724,0.396534,-0.9081,106.127846


In [177]:
flag = 'core' # core/halo/all

if flag == 'core':
    df = full_df[full_df['particleFlag']==2].drop(columns = ['stepNumber', 'particleFlag', 'Z'])
elif flag == 'halo':
    df = full_df[full_df['particleFlag']==3].drop(columns = ['stepNumber', 'particleFlag', 'Z'])
elif flag == 'all':
    df = full_df.drop(columns = ['stepNumber', 'particleFlag', 'Z'])
else:
    raise Exception('Choose the flag')
    
print('Resulting number of points:', df.size)
df.head(3)

Resulting number of points: 1984365


Unnamed: 0,X,Y,dXdZ,dYdZ,P
0,0.982458,-0.269195,-0.909843,-0.61807,99.899742
1,8.08756,6.26015,0.228221,-1.226108,105.682022
2,-0.482854,6.407821,0.396534,-0.9081,106.127846


In [178]:
real_data_raw = df.to_numpy()
scaler = preprocessing.StandardScaler().fit(real_data_raw)
real_data = scaler.transform(real_data_raw)
dataset = torch.utils.data.TensorDataset(torch.FloatTensor(real_data))

# Define architecture

In [179]:
n_dim = 5 # number of variables in data

In [180]:
class SequentialLinearModule(nn.Module):


    def __init__(self, in_dim, out_dim, layer_sizes, activation, out_activation, bias=True):

        super().__init__()
        
        # layers is the list that contains the sizes of linear layers between input and output
        # [a, b, c]
        layer_sizes.insert(0, in_dim)
        layer_sizes.append(out_dim)
        num_layers = len(layer_sizes)-1
        
        layers = [] # list of the modules that we put into sequential
        for i in range(num_layers):
            if i!=0:
                layers.append(activation()) # activation after the last linear
                layers.append(nn.BatchNorm1d(layer_sizes[i])) # batchnorm
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=bias)) # linear
            
            
        layers.append(out_activation()) # final activation
        self.f = nn.Sequential(*layers) # combine everything

    def forward(self, x):
        return self.f(x)

In [181]:
class Generator(SequentialLinearModule):
    def __init__(self, noise_dim, n_dim, layer_sizes):
        super().__init__(noise_dim, n_dim, layer_sizes, nn.ReLU, nn.Identity, True)
        self.noise_dim = noise_dim

In [182]:
class Critic(SequentialLinearModule):
    def __init__(self, n_dim, layer_sizes):
        super().__init__(n_dim, 1, layer_sizes, nn.LeakyReLU, nn.LeakyReLU, True)

# Parameters

In [183]:
lr = 1e-4
noise_dim = 128
batch_size = 512
batch_progress_period = int( len(dataset)/batch_size * 0.1 )
num_epochs = 50
lr_critic = lr
lr_gen = lr/4
critic_iterations = 5
lambda_gp = 10

# Initialize main objects

In [184]:
critic = Critic(n_dim, [32 for _ in range(2)]).to(device)
gen = Generator(noise_dim, n_dim, [32 for _ in range(3)]).to(device)

In [185]:
print(critic)
print('\n'*3)
print(gen)

Critic(
  (f): Sequential(
    (0): Linear(in_features=5, out_features=32, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=32, out_features=32, bias=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=32, out_features=1, bias=True)
    (7): LeakyReLU(negative_slope=0.01)
  )
)




Generator(
  (f): Sequential(
    (0): Linear(in_features=128, out_features=32, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=32, out_features=32, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=32, out_features=32, bias=True)
    (7): ReLU()
    (8): BatchNorm1d(32, eps=1e-05, momentum=0.1, 

In [186]:
opt_critic = optim.Adam(critic.parameters(), lr=lr_critic, betas=(0.0, 0.9))
opt_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.0, 0.9))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop

In [196]:
def train():
    writer = SummaryWriter(f'runs/bfwg-gp_{datetime.datetime.now().strftime("%d.%m.%Y_%H.%M.%S")}')
    
    running_critic_loss = 0
    running_gen_loss = 0
    
    last_time = None
    eta = None

    for epoch in range(num_epochs):

        # batch is stored in "real" variable (as opposed to "fake" generated data)
        for batch_idx, (real, ) in enumerate(dataloader):

            real = real.to(device)

            for _ in range(critic_iterations):

                # generate fake data
                noise = torch.randn(real.shape[0], noise_dim).to(device) # noise for generator
                fake = gen(noise) # fake data

                # critic training
                real_preds = critic(real)
                fake_preds = critic(fake)
                gp = utils.gradient_penalty(critic, real, fake, device)
                loss_critic = torch.mean(fake_preds) - torch.mean(real_preds) + lambda_gp*gp
                critic.zero_grad()
                # without retain_graph all graph variables will be freed to optimize for memory
                loss_critic.backward(retain_graph=True)
                opt_critic.step()
                running_critic_loss += loss_critic.detach().cpu().numpy()

            # generator training
            fake_preds = critic(fake)
            loss_gen = -torch.mean(fake_preds)
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()
            running_gen_loss += loss_gen.detach().cpu().numpy()
            
            if (batch_progress_period!=0 and batch_idx % batch_progress_period == 0) or batch_idx==0:
                
                current_step = epoch*len(dataloader) + batch_idx
                
                text = ''
                
                # epoch num and %
                text += f'Epoch {epoch+1}/{num_epochs}\n'
                text += f'{100*epoch/num_epochs:.2f}%\n\n'
                # batch num and %
                text += f'Batch {batch_idx+1}/{len(dataloader)}\n'
                text += f'{100*batch_idx/len(dataloader):.2f}%\n\n'
                # eta
                if last_time and batch_progress_period!=0:
                    epochs_rem = num_epochs-epoch-1
                    batches_rem = epochs_rem*len(dataloader) + (len(dataloader) - batch_idx)
                    eta = int(batches_rem * (time.time() - last_time)/batch_progress_period)
                    text += f'ETA {datetime.timedelta(seconds=int(eta))}\n\n'                   
                else:
                    text += 'ETA N/A\n\n'
                last_time = time.time()
                writer.add_text('progress', text, current_step)
                
                
                # histograms
                img_buf = utils.draw_hist(real_data, gen, device, df)
                img = Image.open(img_buf)
                img = torchvision.transforms.ToTensor()(img)
                writer.add_image('histograms', img, current_step)
                
                
                # memory usage
                writer.add_scalar('memory', torch.cuda.max_memory_allocated()/1024**3, current_step)
                torch.cuda.reset_peak_memory_stats()
                
                
                # loss plots
                running_critic_loss /= batch_progress_period*critic_iterations
                running_gen_loss /= batch_progress_period
                writer.add_scalar('loss/critic', running_critic_loss, current_step)
                writer.add_scalar('loss/generator', running_gen_loss, current_step)
                running_critic_loss = 0
                running_gen_loss = 0