In [1]:
from pathlib import Path
from importlib import reload
import matplotlib.pyplot as plt
from collections import defaultdict

import pandas as pd
import numpy as np
from tqdm.auto import trange, tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import sys
sys.path.insert(0, '/app')
from torch.utils.tensorboard import SummaryWriter
from lib.utils import plot_hist, score_func

In [2]:
from models.model import Model
import models.gans as gan
import models.vae as vae

In [3]:
seed = 52
logdir = Path('/_data/richgan/runs')
tag = 'cramer'
# Вспомогательная функция, чтобы генерить одинаковые модели для каждой частицы
def get_model():
    latent_dim = 16
    condition_dim = 3
    target_dim = 5
    d_hidden_dims = [64, 64, 128, 128]
    g_hidden_dims = [64, 64, 128, 128]
    output_dim = 128
    

    device = torch.device('cuda:0')
    generator = gan.MLPGenerator(latent_dim, condition_dim, g_hidden_dims, target_dim,).to(device)
    discriminator = gan.MLPDiscriminator(target_dim, condition_dim, d_hidden_dims, output_dim).to(device)

    generator_opt = optim.Adam(generator.parameters(),  lr=1e-4, betas=(0, 0.9))
    discriminator_opt = optim.Adam(discriminator.parameters(),  lr=1e-4, betas=(0, 0.9))

    model = gan.CramerGAN(
        generator,
        discriminator,
        generator_opt,
        discriminator_opt,
        lambda_gp=10,
    )
    
    return Model(
        model,
        QuantileTransformer(output_distribution='normal', random_state=seed),
        QuantileTransformer(output_distribution='normal', random_state=seed),
        simulate_error_codes=True,
    )


In [4]:
data_dir = Path('/_data/data_csv/')

dataframes = {}
for file in tqdm(data_dir.iterdir()):
    name = file.stem.split('_')[1]
    train, test = train_test_split(pd.read_csv(file.as_posix(), dtype=np.float32), random_state=seed)
    dataframes[name]  =  {'train': train, 'test': test}
    
condition_cols = ['TrackP', 'TrackEta', 'NumLongTracks']
target_cols = ['RichDLLbt', 'RichDLLk', 'RichDLLmu', 'RichDLLp', 'RichDLLe']

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [None]:
models = dict()
figs = defaultdict(dict)
scores = dict()

for particle in tqdm(['proton', 'muon', 'kaon', 'pion']):
    
    c_train = dataframes[particle]['train'][condition_cols]
    x_train = dataframes[particle]['train'][target_cols]
    c_test = dataframes[particle]['test'][condition_cols]
    x_test = dataframes[particle]['test'][target_cols]
    
    model = get_model()
    model.fit(
        c_train,
        x_train,
        start_epoch=0,
        num_epochs=50,
        n_critic=1,
        batch_size=512,
        writer=SummaryWriter(log_dir=Path(logdir, tag, particle)),
        num_workers=6,
    )
    models[particle] = model
    predicted = model.predict(c_test)
    reference = np.c_[x_test.values, c_test.values]
    generated = np.c_[predicted.values, c_test.values]
    score = score_func(generated, reference, n_slices=1000)
    scores[particle] = score
    print(particle, ': ', score)
    for col in target_cols:
        print(col)
        fig = plot_hist(x_test[col].values, predicted[col].values)
        figs[particle][col] = fig
        display(fig)
    print('='*100)
        

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, description='train', max=1444, style=ProgressStyle(description_width='init…

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
figs_filtered = defaultdict(dict)
scores_filtered = dict()
for particle in tqdm(['proton', 'muon', 'kaon', 'pion']):
    
    c_test = dataframes[particle]['test'][condition_cols]
    x_test = dataframes[particle]['test'][target_cols]
    mask1 = (x_test == -999).values.all(axis=1)
    mask2 = (x_test == 0).values.all(axis=1)
    mask = (mask1 | mask2)
    c_test = c_test[~mask]
    x_test = x_test[~mask]
    model = models[particle]
    model.simulate_error_codes = False
    predicted = model.predict(c_test)
    reference = np.c_[x_test.values, c_test.values]
    generated = np.c_[predicted.values, c_test.values]
    score = score_func(generated, reference, n_slices=1000)
    scores_filtered[particle] = score
    print(particle, ': ', score)
    for col in target_cols:
        print(col)
        fig = plot_hist(x_test[col].values, predicted[col].values)
        figs_filtered[particle][col] = fig
        display(fig)
    print('='*100)
        

In [None]:
scores

In [None]:
scores_filtered

In [None]:
for particle in figs:
    for col in figs[particle]:
        p = Path(f'/_data/richgan/pics/{tag}', particle)
        p.mkdir(parents=True, exist_ok=True)
        figs[particle][col].savefig(Path(p, col).with_suffix('.png').as_posix(), format='png')

In [None]:
for particle in figs_filtered:
    for col in figs_filtered[particle]:
        p = Path(f'/_data/richgan/pics/{tag}/', particle)
        p.mkdir(parents=True, exist_ok=True)
        figs_filtered[particle][col].savefig(Path(p, f'{col}-filtered').with_suffix('.png').as_posix(), format='png')