In [11]:
from pathlib import Path
from importlib import reload

import pandas as pd
import numpy as np
from tqdm.auto import trange, tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer, MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import sys
sys.path.insert(0, '/app')
import lib.collections as lc
from torchvision.transforms import Compose, ToTensor
from torch.utils.tensorboard import SummaryWriter

In [15]:
torch.split(torch.ones(10), (5,5))

(tensor([1., 1., 1., 1., 1.]), tensor([1., 1., 1., 1., 1.]))

In [3]:
from models.gans import GAN, CGAN, WGAN, MLPDiscriminator, MLPGenerator

In [4]:
seed = 42
batch_size = 256
epoch_num = 30
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
data_dir = Path('/_data/data_csv/')

dataframes = {}
for file in tqdm(data_dir.iterdir()):
    name = file.stem.split('_')[1]
    train, test = train_test_split(pd.read_csv(file.as_posix(), dtype=np.float32), random_state=seed)
    dataframes[name]  =  {'train': train, 'test': test}
    
condition_cols = ['TrackP', 'TrackEta', 'NumLongTracks']
target_cols = ['RichDLLbt', 'RichDLLk', 'RichDLLmu', 'RichDLLp', 'RichDLLe']


transformers = dict()

for particle in tqdm(dataframes):
    tr = QuantileTransformer(output_distribution='normal',random_state=seed)
    dataframes[particle]['train'][:] = tr.fit_transform(
        dataframes[particle]['train']
    ).astype(np.float32)

    dataframes[particle]['test'][:] = tr.transform(dataframes[particle]['test']).astype(np.float32)
    transformers[particle] = tr
    

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6), HTML(value='')))




In [6]:
latent_dim = 50
condition_dim = 3
d_hidden_dims = [32, 64, 128, 128]
g_hidden_dims = [32, 64, 128, 128]
target_dim = 5

device = torch.device('cuda:0')
generator = MLPGenerator(latent_dim, condition_dim, g_hidden_dims, target_dim,).to(device)
discriminator = MLPDiscriminator(target_dim, condition_dim, d_hidden_dims).to(device)

generator_opt = optim.Adam(generator.parameters(), lr=1e-4,betas=(0, 0.9))
discriminator_opt = optim.Adam(discriminator.parameters(), lr=1e-4,betas=(0, 0.9))


model = WGAN(
    generator,
    discriminator,
    generator_opt,
    discriminator_opt,
    lambda_gp=10
)

In [7]:
datasets = {
    particle: { 
        phase: lc.Dataset(
            dataframes[particle][phase][condition_cols].values,
            dataframes[particle][phase][target_cols].values,
        )
        for phase in dataframes[particle]
    }
    for particle in dataframes 
}


dataloaders = {
    particle: { 
        phase: data.DataLoader(
            datasets[particle][phase],
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
        for phase in dataframes[particle]
    }
    for particle in dataframes 
}

In [10]:
model.train(
    dataloaders['kaon'], 
    writer=SummaryWriter('/_data/richgan/runs/wgan_gp'),
    start_epoch=4,
    num_epochs=30,
    log_grad_norms=True,
    plot_dists=True,
)

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=2930, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='test', max=2930, style=ProgressStyle(description_width='initi…

Epoch  4 score:  0.05043620174480701
Epoch  5 score:  0.046592186368745514
Epoch  6 score:  0.04004816019264079
Epoch  7 score:  0.06681626726506906
Epoch  8 score:  0.10805643222572892
Epoch  9 score:  0.17062868251473007
Epoch  10 score:  0.04591218364873462
Epoch  11 score:  0.04106016424065695
Epoch  12 score:  0.04225216900867601
Epoch  13 score:  0.04235216940867759
Epoch  14 score:  0.04517218068872275
Epoch  15 score:  0.040912163648654604
Epoch  16 score:  0.03532814131256523
Epoch  17 score:  0.04719618878475512
Epoch  18 score:  0.031032124128496513
Epoch  19 score:  0.025812103248413043
Epoch  20 score:  0.021432085728342876
Epoch  21 score:  0.037376149504598
Epoch  22 score:  0.023360093440373764
Epoch  23 score:  0.031860127440509756
Epoch  24 score:  0.03731214924859699
Epoch  25 score:  0.02448809795239182
Epoch  26 score:  0.0480401921607686
Epoch  27 score:  0.050288201152804635
Epoch  28 score:  0.0387321549286197
Epoch  29 score:  0.030040120160480743
Epoch  30 sco