In [1]:
import warnings

import copy
import yaml
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import *

import nltk
import itertools
import pickle
import random
import joblib
import numpy as np

from holodecml.torch.utils import *
from holodecml.torch.losses import *
from holodecml.torch.visual import *
from holodecml.torch.models import *
from holodecml.torch.trainers import *
from holodecml.torch.transforms import *
from holodecml.torch.optimizers import *
from holodecml.torch.data_loader import *
from holodecml.torch.beam_search import *

from aimlutils.hyper_opt.base_objective import *
from aimlutils.torch.checkpoint import *
#from aimlutils.torch.losses import *
from aimlutils.utils.tqdm import *

from typing import List, Callable, Tuple, Dict, Union

In [2]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

In [3]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

In [4]:
with open("results/test/model.yml") as config_file:
    conf = yaml.load(config_file, Loader=yaml.FullLoader)

### Load data readers

In [None]:
# Load the image/(x,y,z,d) transformations
train_transform = LoadTransformations(conf["train_transforms"], device = device)
valid_transform = LoadTransformations(conf["validation_transforms"], device = device)

In [5]:
# Load the readers
scaler_path = os.path.join(conf["trainer"]["path_save"], "scalers.save")

In [6]:
train_gen = LoadReader( 
    transform = train_transform, 
    scaler = joblib.load(scaler_path) if os.path.isfile(scaler_path) else True,
    config = conf["train_data"]
)

if not os.path.isfile(scaler_path):
    joblib.dump(train_gen.scaler, scaler_path)

INFO:holodecml.torch.transforms:Loaded RandomVerticalFlip transformation with probability 0.5
INFO:holodecml.torch.transforms:Loaded RandomHorizontalFlip transformation with probability 0.5
INFO:holodecml.torch.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.torch.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0
INFO:holodecml.torch.data_loader:Loading reader-type multi
INFO:holodecml.torch.data_loader:Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}
INFO:holodecml.torch.data_loader:Loaded ['/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_3particle_training.nc'] hologram data containing 50000 images


Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}


In [7]:
valid_gen = LoadReader(
    transform = valid_transform, 
    scaler = train_gen.scaler,
    config = conf["validation_data"]
)

INFO:holodecml.torch.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.torch.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0
INFO:holodecml.torch.data_loader:Loading reader-type multi
INFO:holodecml.torch.data_loader:Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}
INFO:holodecml.torch.data_loader:Loaded ['/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_3particle_validation.nc'] hologram data containing 10000 images


Loaded data scaler transformation {'x': StandardScaler(copy=True, with_mean=True, with_std=True), 'y': StandardScaler(copy=True, with_mean=True, with_std=True), 'z': StandardScaler(copy=True, with_mean=True, with_std=True), 'd': StandardScaler(copy=True, with_mean=True, with_std=True)}


### Load Torch's iterator class

In [8]:
# Load data iterators from pytorch
train_dataloader = DataLoader(
    train_gen,
    **conf["train_iterator"]
)

valid_dataloader = DataLoader(
    valid_gen,
    **conf["valid_iterator"]
)

### Load trainer

In [11]:
trainer = LoadTrainer(
    train_gen, 
    valid_gen, 
    train_dataloader,
    valid_dataloader,
    device, 
    conf
)

INFO:holodecml.torch.trainers:Loading trainer-type decoder-vae
INFO:holodecml.torch.models:Loading model-type att-vae with settings
INFO:holodecml.torch.models:weights: /glade/work/schreck/repos/holodec-ml/scripts/schreck/vae/results/double_channel_1221/best.pt
INFO:holodecml.torch.models:image_channels: 1
INFO:holodecml.torch.models:out_image_channels: 2
INFO:holodecml.torch.models:hidden_dims: [50, 73, 96, 304, 906, 1820]
INFO:holodecml.torch.models:z_dim: 1146
INFO:holodecml.torch.models.cnn:Loaded a self-attentive encoder-decoder VAE model
INFO:holodecml.torch.models.cnn:The model contains 264609625 trainable parameters
INFO:holodecml.torch.models.cnn:Loading weights from /glade/work/schreck/repos/holodec-ml/scripts/schreck/vae/results/double_channel_1221/best.pt
INFO:holodecml.torch.trainers:Updating the output size of the RNN decoder to 124
INFO:holodecml.torch.models:Loading model-type gru-decoder with settings
INFO:holodecml.torch.models:hidden_size: 1146
INFO:holodecml.torch.m

### Load metrics and callbacks

In [12]:
# Initialize LR annealing scheduler 
if "ReduceLROnPlateau" in conf["callbacks"]:
    if "decoder" in conf["callbacks"]["ReduceLROnPlateau"]:
        schedule_config1 = conf["callbacks"]["ReduceLROnPlateau"]["decoder"]
        scheduler_rnn = ReduceLROnPlateau(trainer.rnn_optimizer, **schedule_config1)
    if "regressor" in conf["callbacks"]["ReduceLROnPlateau"]:
        schedule_config2 = conf["callbacks"]["ReduceLROnPlateau"]["regressor"]
        scheduler_linear = ReduceLROnPlateau(trainer.particle_optimizer, **schedule_config2)

if "ExponentialLR" in conf["callbacks"]:
    if "decoder" in conf["callbacks"]["ExponentialLR"]:
        schedule_config1 = conf["callbacks"]["ExponentialLR"]["decoder"]
        scheduler_rnn = ExponentialLR(trainer.rnn_optimizer, **schedule_config1)
    if "regressor" in conf["callbacks"]["ExponentialLR"]:
        schedule_config2 = conf["callbacks"]["ExponentialLR"]["regressor"]
        scheduler_linear = ExponentialLR(trainer.particle_optimizer, **schedule_config2)

# Early stopping
early_stopping_rnn = EarlyStopping(**conf["callbacks"]["EarlyStopping"]["decoder"]) 
early_stopping_linear = EarlyStopping(**conf["callbacks"]["EarlyStopping"]["regressor"])

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(**conf["callbacks"]["MetricsLogger"])

INFO:aimlutils.torch.checkpoint.checkpointer:Loaded EarlyStopping checkpointer with patience 5
INFO:aimlutils.torch.checkpoint.checkpointer:Loaded EarlyStopping checkpointer with patience 3
INFO:aimlutils.torch.checkpoint.checkpointer:Loaded a metrics logger /glade/work/schreck/repos/holodec-ml/scripts/schreck/decoder/results/test/training_log.csv to track the training results


### Train the model

In [13]:
results = trainer.train(scheduler_rnn, scheduler_linear, early_stopping_rnn, early_stopping_linear, metrics_logger)

Epoch 0 train_bce: 0.071 train_mse: 0.458 train_acc: 0.439 train_stop_acc: 0.996 train_seq_acc: 0.537: 100%|██████████| 2084/2084 [08:33<00:00,  4.05it/s]
Epoch 0 val_bce: 0.052 val_mae: 0.407 val_acc: 0.486 val_stop_acc: 1.000 val_seq_acc: 0.692: 100%|██████████| 313/313 [01:51<00:00,  2.80it/s]
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 0 (inf --> 0.307941).  Saving model.
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 0 (inf --> 0.406501).  Saving model.
Epoch 1 train_bce: 0.050 train_mse: 0.292 train_acc: 0.587 train_stop_acc: 0.999 train_seq_acc: 0.718: 100%|██████████| 2084/2084 [08:33<00:00,  4.06it/s]
Epoch 1 val_bce: 0.044 val_mae: 0.304 val_acc: 0.516 val_stop_acc: 1.000 val_seq_acc: 0.707: 100%|██████████| 313/313 [01:50<00:00,  2.84it/s]
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 1 (0.307941 --> 0.292998).  Saving model.
INFO:aimlutils.torch.checkpoint.checkpointer:Val

Epoch    14: reducing learning rate of group 0 to 2.4009e-04.


INFO:aimlutils.torch.checkpoint.checkpointer:EarlyStopping counter: 2 out of 3
Epoch 14 train_bce: 0.016 train_mse: 0.159 train_acc: 0.869 train_stop_acc: 1.000 train_seq_acc: 0.929: 100%|██████████| 2084/2084 [08:31<00:00,  4.08it/s]
Epoch 14 val_bce: 0.030 val_mae: 0.184 val_acc: 0.633 val_stop_acc: 1.000 val_seq_acc: 0.793: 100%|██████████| 313/313 [01:49<00:00,  2.87it/s]
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 14 (0.210796 --> 0.206836).  Saving model.
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 14 (0.203470 --> 0.183844).  Saving model.
Epoch 15 train_bce: 0.015 train_mse: 0.152 train_acc: 0.878 train_stop_acc: 1.000 train_seq_acc: 0.933: 100%|██████████| 2084/2084 [08:33<00:00,  4.06it/s]
Epoch 15 val_bce: 0.030 val_mae: 0.185 val_acc: 0.634 val_stop_acc: 1.000 val_seq_acc: 0.793: 100%|██████████| 313/313 [01:48<00:00,  2.89it/s]
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on e

Epoch    17: reducing learning rate of group 0 to 4.8017e-05.


Epoch 17 train_bce: 0.013 train_mse: 0.146 train_acc: 0.898 train_stop_acc: 1.000 train_seq_acc: 0.943: 100%|██████████| 2084/2084 [08:31<00:00,  4.07it/s]
Epoch 17 val_bce: 0.030 val_mae: 0.183 val_acc: 0.636 val_stop_acc: 1.000 val_seq_acc: 0.793: 100%|██████████| 313/313 [01:48<00:00,  2.89it/s]
INFO:aimlutils.torch.checkpoint.checkpointer:EarlyStopping counter: 2 out of 5
INFO:aimlutils.torch.checkpoint.checkpointer:Validation loss decreased on epoch 17 (0.183844 --> 0.183050).  Saving model.


Epoch    18: reducing learning rate of group 0 to 1.1578e-04.


Epoch 18 train_bce: 0.010 train_mse: 0.146 train_acc: 0.918 train_stop_acc: 1.000 train_seq_acc: 0.950:   9%|▉         | 186/2084 [00:46<07:58,  3.96it/s]


KeyboardInterrupt: 