In [1]:
PARTICLE_TYPE = 'pion'

In [2]:
MODEL_NAME = "FastFastRICH_Cramer_{}_5layers".format(PARTICLE_TYPE)

In [3]:
from comet_ml import Experiment

In [4]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import os
from IPython.display import clear_output
import scipy
from tqdm import tqdm_notebook, tqdm
from torch.autograd import Variable, grad
%matplotlib inline

In [5]:
import rich_utils.torch_utils_rich_mrartemev as utils_rich

## Data

In [6]:
data_train, data_val, scaler = utils_rich.get_merged_typed_dataset(PARTICLE_TYPE, dtype=np.float32, log=True)

Reading and concatenating datasets:
	.././RichGAN/data_calibsample/pion2_+_down_2016_.csv
	.././RichGAN/data_calibsample/pion2_-_up_2016_.csv
	.././RichGAN/data_calibsample/pion_+_up_2016_.csv
	.././RichGAN/data_calibsample/pion_+_down_2016_.csv
	.././RichGAN/data_calibsample/pion2_+_up_2016_.csv
	.././RichGAN/data_calibsample/pion_-_up_2016_.csv
	.././RichGAN/data_calibsample/pion_-_down_2016_.csv
	.././RichGAN/data_calibsample/pion2_-_down_2016_.csv
splitting to train/val/test
fitting the scaler
scaler train sample size: 2000000
scaler n_quantiles: 100000, time = 1.7589631080627441
scaling train set
scaling test set
converting dtype to <class 'numpy.float32'>


In [7]:
BATCH_SIZE = int(1e3)
LATENT_DIMENSIONS = 64

In [8]:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch import Tensor

def make_data_loader(data_train):
    all_data = TensorDataset(Tensor(data_train.values[:,:-1]), Tensor(data_train.values[:,-1]), 
                             Tensor((data_train.values[:, utils_rich.y_count:])[:,:-1]), 
                             Tensor((data_train.values[:, utils_rich.y_count:])[:,-1]),
                             Tensor((data_train.values[:, utils_rich.y_count:])[:,:-1]), 
                             Tensor((data_train.values[:, utils_rich.y_count:])[:,-1]))
    return DataLoader(all_data, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
all_dataloader = make_data_loader(data_train)

## Model

In [10]:
CRAMER_DIM = 256
NUM_LAYERS = 5

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear((LATENT_DIMENSIONS + data_train.shape[1] - 1 - utils_rich.y_count), 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, utils_rich.y_count)
        )

    def forward(self, input):
        noise = torch.empty(input.shape[0], LATENT_DIMENSIONS, device=input.device).normal_(mean=0,std=3.0)
        return self.main(torch.cat((noise, input), dim=1))

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Linear((data_train.shape[1] - 1), 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, CRAMER_DIM)
        )

    def forward(self, input):
        output = self.main(input)
        return output

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [11]:
device = torch.device("cpu")
netG = Generator().to(device)
netC = Critic().to(device)
netC.apply(init_weights)
netG.apply(init_weights)
print('Ok')

Ok




In [12]:
optC = torch.optim.RMSprop(netC.parameters(), lr=1e-3)
lr_C = torch.optim.lr_scheduler.ExponentialLR(optimizer=optC, gamma=0.98)
optG = torch.optim.RMSprop(netG.parameters(), lr=1e-3)
lr_G = torch.optim.lr_scheduler.ExponentialLR(optimizer=optG, gamma=0.98)

## Train

In [13]:
LOGDIR = "./log"
CRITIC_ITERATIONS_CONST = 15
CRITIC_ITERATIONS_VAR = 0
TOTAL_ITERATIONS = int(1e5)
VALIDATION_INTERVAL = 1000
critic_policy = lambda i: (
    CRITIC_ITERATIONS_CONST + (CRITIC_ITERATIONS_VAR * (TOTAL_ITERATIONS - i)) // TOTAL_ITERATIONS
)

In [14]:
def cramer_critic(x, y):
    discriminated_x = netC(x)
    return torch.norm(discriminated_x - netC(y), dim=1) - torch.norm(discriminated_x, dim=1)

In [15]:
lambda_pt = lambda i: 20 / np.pi * 2 * torch.atan(torch.tensor(i, dtype=torch.float32, device=device)/1e4)    

In [16]:
with torch.no_grad():
    N_VAL = int(3e5)
    validation_np = data_val.sample(N_VAL).values
    val = torch.tensor(validation_np, device=device)

In [17]:
experiment = Experiment(api_key="k62CZajG08ctPlNYUYRv9YVdO",
                        project_name="general", workspace="nzinci")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/nzinci/general/26b1d03df60040f2af6d8090077853aa



In [None]:
for i in tqdm(range(TOTAL_ITERATIONS), position=0, leave=True):
    for j in range(critic_policy(i)):
        train_full, w_full, train_x_1, w_x_1, train_x_2, w_x_2 = next(iter(all_dataloader))
        train_full = train_full.to(device)
        w_full = w_full.to(device)
        train_x_1 = train_x_1.to(device)
        w_x_1 = w_x_1.to(device)
        train_x_2 = train_x_2.to(device)
        w_x_2 = w_x_2.to(device)

        optC.zero_grad()
        output = netC(train_full)
        gen_y_1 = netG(train_x_1)
        gen_y_2 = netG(train_x_2)
        gen_full_2 = torch.cat((gen_y_2, train_x_2), dim=1)
        gen_full_1 = torch.cat((gen_y_1, train_x_1), dim=1)
        generator_loss = torch.mean(cramer_critic(train_full, gen_full_2) * w_full * w_x_2 -
                    cramer_critic(gen_full_1, gen_full_2) * w_x_1  * w_x_2)

        alpha = torch.empty(train_full.shape[0], 1, device=device).normal_(0.0,1.0)
        interpolates = alpha * train_full + (1.0 - alpha) * gen_full_1
        disc_interpolates = cramer_critic(interpolates, gen_full_2)
        gradients = grad(outputs=disc_interpolates, inputs=interpolates, 
                         grad_outputs=torch.ones_like(disc_interpolates))[0]
        slopes = torch.norm(torch.reshape(gradients, (list(gradients[0].shape)[0], -1)), dim=1)
        gradient_penalty = torch.mean(torch.pow(torch.max(torch.abs(slopes) - 1, 
                                                          torch.zeros(8, device=device)), 2))
        critic_loss = lambda_pt(i) * gradient_penalty - generator_loss
        critic_loss.backward(retain_graph=True)
        optC.step()
    
    train_full, w_full, train_x_1, w_x_1, train_x_2, w_x_2 = next(iter(all_dataloader))
    train_full = train_full.to(device)
    w_full = w_full.to(device)
    train_x_1 = train_x_1.to(device)
    w_x_1 = w_x_1.to(device)
    train_x_2 = train_x_2.to(device)
    w_x_2 = w_x_2.to(device)

    optG.zero_grad()
    generator_loss = torch.mean(cramer_critic(train_full, gen_full_2) * w_full * w_x_2 -
                            cramer_critic(gen_full_1, gen_full_2) * w_x_1  * w_x_2)
    generator_loss.backward()
    optG.step()
    experiment.log_metrics({'Generator loss': generator_loss.item(),
                            'Critic loss': critic_loss.item()},
                            step = i)
    lr_C.step()
    lr_G.step()
    torch.save({'netC_state_dict': netC.state_dict(),
                'netG_state_dict': netG.state_dict(),
                'optC_state_dict': optC.state_dict(),
                'optG_state_dict': optG.state_dict(),
                'lr_C_state_dict': lr_C.state_dict(),
                'lr_G_state_dict': lr_G.state_dict()
               }, LOGDIR)

    clear_output(False)
    with torch.no_grad():
        y_t = netG(val[:, utils_rich.y_count:-1])
        fig, axes = plt.subplots(2, 2, figsize=(15, 15))
        for INDEX, ax in zip((0, 1, 3, 4), axes.flatten()):
            _, bins, _ = ax.hist(val[:, INDEX].cpu(), bins=100, label="data", normed=True,
                                 weights=val[:,-1].cpu())
            ax.hist(y_t[:, INDEX].cpu(), bins=bins, label="generated", alpha=0.5, normed=True,
                    weights=val[:,-1].cpu())
            ax.legend()
            ax.set_title(utils_rich.dll_columns[INDEX])
        experiment.log_figure()
        plt.show()
experiment.end()

In [None]:
with torch.no_grad():

    y_t = netG(val[:, utils_rich.y_count:-1])
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    for INDEX, ax in zip((0, 1, 3, 4), axes.flatten()):
        _, bins, _ = ax.hist(val[:, INDEX].cpu(), bins=100, label="data", normed=True,
                             weights=val[:,-1].cpu())
        ax.hist(y_t[:, INDEX].cpu(), bins=bins, label="generated", alpha=0.5, normed=True,
                weights=val[:,-1].cpu())
        ax.legend()
        ax.set_title(utils_rich.dll_columns[INDEX])
    plt.show()