In [1]:
import warnings

import copy
import yaml
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import *

import nltk
import itertools
import pickle
import random
import joblib
import numpy as np

from holodecml.torch.utils import *
from holodecml.torch.losses import *
from holodecml.torch.visual import *
from holodecml.torch.models import *
from holodecml.torch.trainers import *
from holodecml.torch.transforms import *
from holodecml.torch.optimizers import *
from holodecml.torch.data_loader import *
from holodecml.torch.beam_search import *

from aimlutils.hyper_opt.base_objective import *
from aimlutils.torch.checkpoint import *
#from aimlutils.torch.losses import *
from aimlutils.utils.tqdm import *

from typing import List, Callable, Tuple, Dict, Union

In [2]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

In [3]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

In [4]:
with open("model.yml") as config_file:
    conf = yaml.load(config_file, Loader=yaml.FullLoader)

### Load data readers

In [6]:
# Load the image/(x,y,z,d) transformations
train_transform = LoadTransformations(conf["train_transforms"], device = device)
valid_transform = LoadTransformations(conf["validation_transforms"], device = device)

INFO:holodecml.torch.transforms:Loaded RandomVerticalFlip transformation with probability 0.5
INFO:holodecml.torch.transforms:Loaded RandomHorizontalFlip transformation with probability 0.5
INFO:holodecml.torch.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.torch.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0
INFO:holodecml.torch.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.torch.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0


In [7]:
# Load the readers
scaler_path = os.path.join(conf["trainer"]["path_save"], "scalers.save")

In [8]:
train_gen = LoadReader( 
    transform = train_transform, 
    scaler = joblib.load(scaler_path) if os.path.isfile(scaler_path) else True,
    config = conf["train_data"]
)

if not os.path.isfile(scaler_path):
    joblib.dump(train_gen.scaler, scaler_path)

INFO:holodecml.torch.data_loader:Loading reader-type multi
INFO:holodecml.torch.data_loader:Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}
INFO:holodecml.torch.data_loader:Loaded ['/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_multiparticle_training.nc', '/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_12-25particle_gamma_600x400_training.nc'] hologram data containing 130000 images


Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}


In [9]:
valid_gen = LoadReader(
    transform = valid_transform, 
    scaler = train_gen.scaler,
    config = conf["validation_data"]
)

INFO:holodecml.torch.data_loader:Loading reader-type multi
INFO:holodecml.torch.data_loader:Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}
INFO:holodecml.torch.data_loader:Loaded ['/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_multiparticle_validation.nc', '/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_12-25particle_gamma_600x400_validation.nc'] hologram data containing 20000 images


Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}


### Load Torch's iterator class

In [10]:
# Load data iterators from pytorch
train_dataloader = DataLoader(
    train_gen,
    **conf["train_iterator"]
)

valid_dataloader = DataLoader(
    valid_gen,
    **conf["valid_iterator"]
)

### Load trainer

In [11]:
trainer = LoadTrainer(
    train_gen, 
    valid_gen, 
    train_dataloader,
    valid_dataloader,
    device, 
    conf
)

INFO:holodecml.torch.trainers:Loading trainer-type decoder-vae
INFO:holodecml.torch.models:Loading model-type att-vae with settings
INFO:holodecml.torch.models:weights: False
INFO:holodecml.torch.models:image_channels: 1
INFO:holodecml.torch.models:out_image_channels: 2
INFO:holodecml.torch.models:hidden_dims: [50, 73, 96, 304, 906, 1820]
INFO:holodecml.torch.models:z_dim: 1146
INFO:holodecml.torch.models.cnn:Loaded a self-attentive encoder-decoder VAE model
INFO:holodecml.torch.models.cnn:The model contains 264609625 trainable parameters
INFO:holodecml.torch.models.cnn:Setting tunable parameter weights according to Xavier's uniform initialization
INFO:holodecml.torch.trainers:Updating the output size of the RNN decoder to 124
INFO:holodecml.torch.models:Loading model-type gru-decoder with settings
INFO:holodecml.torch.models:hidden_size: 1146
INFO:holodecml.torch.models:output_size: 124
INFO:holodecml.torch.models:n_layers: 3
INFO:holodecml.torch.models:dropout: 0.20486082287158408
IN

### Load metrics and callbacks

In [12]:
# Initialize LR annealing scheduler 
if "ReduceLROnPlateau" in conf["callbacks"]:
    if "decoder" in conf["callbacks"]["ReduceLROnPlateau"]:
        schedule_config1 = conf["callbacks"]["ReduceLROnPlateau"]["decoder"]
        scheduler_rnn = ReduceLROnPlateau(trainer.rnn_optimizer, **schedule_config1)
    if "regressor" in conf["callbacks"]["ReduceLROnPlateau"]:
        schedule_config2 = conf["callbacks"]["ReduceLROnPlateau"]["regressor"]
        scheduler_linear = ReduceLROnPlateau(trainer.particle_optimizer, **schedule_config2)

if "ExponentialLR" in conf["callbacks"]:
    if "decoder" in conf["callbacks"]["ExponentialLR"]:
        schedule_config1 = conf["callbacks"]["ExponentialLR"]["decoder"]
        scheduler_rnn = ExponentialLR(trainer.rnn_optimizer, **schedule_config1)
    if "regressor" in conf["callbacks"]["ExponentialLR"]:
        schedule_config2 = conf["callbacks"]["ExponentialLR"]["regressor"]
        scheduler_linear = ExponentialLR(trainer.particle_optimizer, **schedule_config2)

# Early stopping
early_stopping_rnn = EarlyStopping(**conf["callbacks"]["EarlyStopping"]["decoder"]) 
early_stopping_linear = EarlyStopping(**conf["callbacks"]["EarlyStopping"]["regressor"])

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(**conf["callbacks"]["MetricsLogger"])

INFO:aimlutils.torch.checkpoint.checkpointer:Loaded EarlyStopping checkpointer with patience 5
INFO:aimlutils.torch.checkpoint.checkpointer:Loaded EarlyStopping checkpointer with patience 3
INFO:aimlutils.torch.checkpoint.checkpointer:Loaded a metrics logger /glade/work/schreck/repos/holodec-ml/scripts/schreck/decoder/results/multi_particle/training_log.csv to track the training results


### Train the model

In [None]:
results = trainer.train(scheduler_rnn, scheduler_linear, early_stopping_rnn, early_stopping_linear, metrics_logger)

Epoch 0 train_bce: 0.577 train_mse: 0.877 train_acc: 0.003 train_stop_acc: 0.675 train_seq_acc: 0.014:   3%|▎         | 73/2724 [00:41<23:34,  1.87it/s] 