In [8]:
%load_ext autoreload
%autoreload 2

In [1]:
import pysr
from pysr import PySRRegressor

import sys
import os
import argparse
import glob

from handlers import trainer, evaluation, annealing, sr_trainer
from handlers.args import setup_argparse

from ml_utils import losses
from ml_utils import surrogates
from ml_utils.optimizers import optim

from preprocessing.dataloaders import train_load, test_load
from preprocessing.datasets import SimpleIterDataset

from postprocessing.io_writer import _write_outputs_to_root

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import mplhep as mh
plt.style.use(mh.style.CMS)

from importlib.util import spec_from_file_location, module_from_spec

from weaver.utils.logger import _logger, warn_n_times, _configLogger
import copy
from pprint import pformat
import time

from main import import_module, assemble_loaders, initialize_models

Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [2]:
workdir = os.getenv('WORKDIR')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Args:
    def __init__(self, **kwargs):
        # defaults
        self.data_train = []
        self.data_test = []
        self.data_val = []
        self.num_workers = 0
        self.num_epochs = 0
        self.data_config = ''
        self.file_fraction = 1
        self.data_fraction = 1
        self.batch_size = 0
        self.local_rank = None
        self.model_prefix = None
        self.lr_finder = None
        self.optimizer_option = []
        self.optimizer = 'ranger'
        self.start_lr = 1e-3
        self.final_lr = 1e-6
        self.lr_scheduler = 'flat+decay'
        self.kl_weight = 0.1
        self.class_weight = 1.0
        self.kl_anneal = False
        self.alpha = 0
        self.beta = 0
        self.gamma = 0
        self.bit_size = None
        
        for key, value in kwargs.items():
            setattr(self, key, value)

In [4]:
yaml_config = f'{workdir}/data_config/JetClass/JetClass_TTBar.yaml'
encoder_path = f'{workdir}/wrappers/vae.py'

jc_paths = {
    'train': f'{workdir}/datasets/JetClass/Pythia/train_100M',
    'val': f'{workdir}/datasets/JetClass/Pythia/val_5M',
    'test': f'{workdir}/datasets/JetClass/Pythia/test_20M'
}

num_classes = 2
signal = '/TTBar_*.root'
background = '/ZJetsToNuNu_*.root'

datasets = {}

for name, path in jc_paths.items():

    if isinstance(signal, str):
        signal_files = glob.glob(path+signal)

    if isinstance(background, str):
        background_files = glob.glob(path+background)

    datasets[name] = signal_files + background_files

complexity = {
    'particle_attn': 2,
    'class_attn': 1
}

In [7]:
args = Args(
    data_train = datasets['train'],
    data_val = datasets['val'],
    data_test = datasets['test'],
    data_config = yaml_config,
    num_epochs = 3,
    batch_size = 128,
    model_prefix = 'throaway_folder/VAE',
    file_fraction = 1,
    data_fraction = 0.01,
    model_network = encoder_path,
    kl_anneal = True,
    alpha=1,
    beta=4,
    gamma=1,
    bit_size=1
)

In [8]:
vae_dict = initialize_models(args, True, encoder_path)
loader_dict = assemble_loaders(args)

In [9]:
from main import vae

vae(args, loader_dict, vae_dict)

0it [00:00, ?it/s]

=== Restarting DataIter train, seed=None ===


1562it [07:17,  3.57it/s, lr=1.00e-03, Loss=2.04584, MIL=-0.00269, TCL=0.00423, DWKL=1.00000, Recon Loss=1.03160, Beta=4.00, AvgMIL=0.01382, AvgTCL=0.00446, AvgDWKL=1.00970, AvgReconLoss=1.43530]
0it [00:00, ?it/s]

=== Restarting DataIter val, seed=None ===


78it [00:10,  7.71it/s, Loss=2.24797, MIL=0.01358, TCL=0.08951, DWKL=1.00000, ReconLoss=1.22587, Beta=0.10, AvgMIL=0.01664, AvgTCL=0.08949, AvgDWKL=1.00000, AvgReconLoss=1.11992, Avg Loss=2.13943] 


[0.00048754012, 3.6425285e-05, 0.00017259546, 0.0042684427, 0.000472121, 0.0040507587, 4.41674e-05, 0.00032635452, 0.00019425122, 0.00022657446, 0.0008704075, 0.0026241133, 0.0005513293, 5.1851282e-05, 0.00055494154, 0.00056339824]


0it [00:00, ?it/s]

=== Restarting DataIter train, seed=None ===


35it [00:21,  1.65it/s, lr=1.00e-03, Loss=2.10526, MIL=0.02515, TCL=0.00434, DWKL=1.00000, Recon Loss=1.07924, Beta=0.20, AvgMIL=0.01374, AvgTCL=0.00446, AvgDWKL=1.00969, AvgReconLoss=1.42865] 


KeyboardInterrupt: 