# Import the libraries

In [None]:
import torch.nn as nn
import torch
import pandas as pd
import numpy as np
import torch.optim as optim
from IPython.display import clear_output
clear = lambda: clear_output(wait=True)
import time
import datetime
import matplotlib.pyplot as plt
from sklearn import preprocessing
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import torchvision.transforms
from io import BytesIO
writer = SummaryWriter('runs/bfwg-gp')
# !tensorboard --logdir=runs

In [None]:
from PIL import 

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

# Read the data

In [None]:
full_df = pd.read_csv('beamfile.txt', sep=':')
print('Total number of points in the dataset:', full_df.size)
full_df.head(3)

In [None]:
df = full_df[full_df['particleFlag']==2].drop(columns = ['stepNumber', 'particleFlag', 'Z']) # only core
# df = full_df.drop(columns = ['stepNumber', 'particleFlag', 'Z']) # all points
print('Number of points from beam core (i.e. flag=2):', df.size)
df.head(3)

In [None]:
real_data_raw = df.to_numpy()
scaler = preprocessing.StandardScaler().fit(real_data_raw)
real_data = scaler.transform(real_data_raw)
dataset = torch.utils.data.TensorDataset(torch.FloatTensor(real_data))

# Define generator and discriminator

In [54]:
n_dim = 5 # number of variables in data

In [55]:
class SequentialLinearModule(nn.Module):


    def __init__(self, in_dim, out_dim, layers, activation, out_activation, bias):
        
        '''
        if layers is list: list containing the size of the linear layers between input and output
        if layers is int: number of layers with exponentially decreasing size
        '''

        super().__init__()
        assert type(layers) is list or type(layers) is int, '"layers" argument should be list or str, more in doctstring'

        if type(layers) is list:
            layer_sizes = layers
            layers.insert(0, in_dim)
            num_layers = len(layer_sizes)

        if type(layers) is int:
            layer_sizes = []
            num_layers = layers
            # exponential decrease
            for i in range(num_layers):
                layer_sizes.append(int(in_dim/(2**i)))

        # stack the layers
        layer_sizes.append(out_dim)
        layers_list = []
        for i in range(num_layers):
            if i!=0:
                layers_list.append(activation())
            layers_list.append(nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=bias))
        layers_list.append(out_activation())
        self.f = nn.Sequential(*layers_list)

    def forward(self, x):
        return self.f(x)

In [56]:
class Generator(SequentialLinearModule):
    def __init__(self, noise_dim, n_dim, layers):
        super().__init__(noise_dim, n_dim, layers, nn.ReLU, nn.Identity, True)

In [57]:
class Critic(SequentialLinearModule):
    def __init__(self, n_dim, layers):
        super().__init__(n_dim, 1, layers, nn.LeakyReLU, nn.LeakyReLU, True)

# Parameters

In [58]:
lr = 1e-4
noise_dim = 128
batch_size = 512
batch_progress_period = int( len(dataset)/batch_size * 0.1 )
num_epochs = 50
lr_critic = lr
lr_gen = lr/4

critic_iterations = 5
lambda_gp = 10

# Initialize main objects

In [59]:
critic = Critic(n_dim, [32 for _ in range(2)]).to(device)
gen = Generator(noise_dim, n_dim, [32 for _ in range(3)]).to(device)
critic_loss_list = []
gen_loss_list = []

In [60]:
# print(critic)
# print('\n'*3)
# print(gen)

In [61]:
opt_critic = optim.Adam(critic.parameters(), lr=lr_critic, betas=(0.0, 0.9))
opt_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.0, 0.9))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop

In [62]:
def gradient_penalty(critic, real, fake):
    batch_size, n_dim = real.shape
    epsilon = torch.rand((batch_size, 1)).repeat(1, n_dim).to(device)
    interpolated_images = real*epsilon + fake*(1-epsilon)

    mixed_scores = critic(interpolated_images)
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0] # (batch_size, n_dim=5)

    gradient_norm = gradient.norm(2, dim=1) # (batch_size)
    
    return torch.mean((gradient_norm - 1) ** 2)

In [63]:
def train():
    running_critic_loss = 0
    running_gen_loss = 0
    
    last_time = None
    eta = None

    for epoch in range(num_epochs):

        # batch is stored in "real" variable (as opposed to "fake" generated data)
        for batch_idx, (real, ) in enumerate(dataloader):

            real = real.to(device)

            for _ in range(critic_iterations):

                # generate fake data
                noise = torch.randn(real.shape[0], noise_dim).to(device) # noise for generator
                fake = gen(noise) # fake data

                # critic training
                real_preds = critic(real)
                fake_preds = critic(fake)
                gp = gradient_penalty(critic, real, fake)
                loss_critic = torch.mean(fake_preds) - torch.mean(real_preds) + lambda_gp*gp
                critic.zero_grad()
                loss_critic.backward(retain_graph=True) # without retain_graph all graph variables will be freed to optimize memory consumption
                opt_critic.step()
                running_critic_loss += loss_critic.detach().cpu().numpy()

            # generator training
            fake_preds = critic(fake)
            loss_gen = -torch.mean(fake_preds)
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()
            running_gen_loss += loss_gen.detach().cpu().numpy()
            
            if (batch_progress_period!=0 and batch_idx % batch_progress_period == 0) or batch_idx==0:

                text = ''
                # epoch num and %
                text += f'Epoch {epoch+1}/{num_epochs}\n'
                text += f'{100*epoch/num_epochs:.2f}%\n\n'
                # batch num and %
                text += f'Batch {batch_idx+1}/{len(dataloader)}\n'
                text += str(100*batch_idx/len(dataloader))[:5] + '%'
                text += f'{100*batch_idx/len(dataloader):.2f}%\n\n'
        
                # eta
                if last_time and batch_progress_period!=0:
                    eta = (((len(dataloader)-batch_idx) + (num_epochs-epoch-1)*len(dataloader))/batch_progress_period*(time.time()-last_time))
                last_time = time.time()
                if type(eta) is float:
                    eta_string = str(datetime.timedelta(seconds=eta))
                    text += f"ETA {':'.join([i.split('.')[0] for i in eta_string.split(':')])}\n"
                else:
                    text += 'ETA N/A\n'
                
                writer.add_text('progress', text, epoch*len(dataloader) + batch_idx)
                
                img_buf = draw_hist()
                img = Image.open(img_buf)
                img = torchvision.transforms.ToTensor()(img)
                writer.add_image('histograms', img, epoch*len(dataloader) + batch_idx)
                
                # memory usage
                writer.add_scalar('memory', torch.cuda.max_memory_allocated()/1024**3, epoch*len(dataloader) + batch_idx)
                torch.cuda.reset_peak_memory_stats()
                
                # loss plots
                running_critic_loss /= batch_progress_period*critic_iterations
                running_gen_loss /= batch_progress_period
    
                writer.add_scalar('loss/critic', running_critic_loss, epoch*len(dataloader) + batch_idx)
                writer.add_scalar('loss/generator', running_gen_loss, epoch*len(dataloader) + batch_idx)
                
                running_critic_loss = 0
                running_gen_loss = 0

# Result

In [64]:
def draw_hist():
    
    noise = torch.randn(real_data.shape[0], noise_dim).to(device)
    fake = gen(noise).detach().cpu().numpy()
    real = real_data
    bins = 100
        
    plotnum = 1
    fig, ax = plt.subplots(figsize=(14, 70)) # size in inches
    for i in range(5):
        for j in range(i+1, 5):
            # (i,j) are the columns that we are going to draw a histogram of

            # real data
            plt.subplot(10, 2, plotnum)
            plt.title(f'real, {df.columns[i]}/{df.columns[j]}')
            plotnum += 1
            _, xedges, yedges, _ = plt.hist2d(real[:,i], real[:,j], bins=bins,)# range = [[x_min,x_max], [y_min,y_max]])
            
            x_min = xedges[0]
            x_max = xedges[-1]
            y_min = yedges[0]
            y_max = yedges[-1]

            # fake data
            plt.subplot(10, 2, plotnum)
            plt.title(f'fake, {df.columns[i]}/{df.columns[j]}')
            plotnum += 1
            plt.hist2d(fake[:,i], fake[:,j], bins=bins, range = [[x_min,x_max], [y_min,y_max]])
            
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return buf

In [None]:
train()

  fig, ax = plt.subplots(figsize=(14, 70)) # size in inches
