In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import yaml
import tqdm
import torch
import pickle
import logging

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import *
from torch import nn

from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple
from multiprocessing import cpu_count
from shutil import copyfile

# custom
from holodecml.vae.checkpointer import *
from holodecml.vae.data_loader import *
from holodecml.vae.optimizers import *
from holodecml.vae.transforms import *
from holodecml.vae.spectral import *

from holodecml.vae.trainers import *
from holodecml.vae.models import *
from holodecml.vae.visual import *
from holodecml.vae.losses import *

In [2]:
with open("config.yml") as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

data_save = config["log"]
model_type = config["type"]

In [3]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

In [4]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
    
logging.info(f'Preparing to use device {device}')

INFO:root:Preparing to use device cpu


In [5]:
transform = LoadTransformations(config["transforms"], device = device)

INFO:holodecml.vae.transforms:Loaded Rescale transformation with output size 600
INFO:holodecml.vae.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.vae.transforms:Loaded ToTensor transformation, putting tensors on device cpu


In [6]:
train_gen = LoadReader(
    reader_type = model_type, 
    split = "test", 
    transform = transform,
    scaler = None,
    config = config["data"]
)

valid_gen = LoadReader(
    reader_type = model_type, 
    split = "test", 
    transform = transform, 
    scaler = train_gen.get_transform(),
    config = config["data"],
)

INFO:holodecml.vae.data_loader:Loading reader type encoder-vae
INFO:holodecml.vae.data_loader:Loaded test hologram data containing 1000 images
INFO:holodecml.vae.data_loader:Loading reader type encoder-vae
INFO:holodecml.vae.data_loader:Loaded test hologram data containing 1000 images


In [7]:
logging.info(f"Loading training data iterator using {config['iterator']['num_workers']} workers")

dataloader = DataLoader(
    train_gen,
    **config["iterator"]
)

valid_dataloader = DataLoader(
    valid_gen,
    **config["iterator"]
)

INFO:root:Loading training data iterator using 0 workers


In [8]:
model = LoadModel(model_type, config["model"], device)

INFO:holodecml.vae.models:Loading model type encoder-vae
INFO:holodecml.vae.models:Loaded a self-attentive encoder-decoder VAE model
INFO:root:Loaded VAE weights test/pretrained.pt and froze these parameters
INFO:holodecml.vae.models:The model contains 551800 trainable parameters


In [9]:
retrain = config["trainer"]["start_epoch"] > 0

optimizer_config = config["optimizer"]
learning_rate = optimizer_config["lr"] #if not retrain else model_dict["lr"]
optimizer_type = optimizer_config["type"]

optimizer = LoadOptimizer(optimizer_type, model.parameters(), learning_rate)

#if retrain:
#   optimizer.load_state_dict(model_dict["optimizer_state_dict"])
#   logging.info(f"Loaded optimizer weights from {model_dict}")

INFO:holodecml.vae.optimizers:Loaded the lookahead-diffgrad optimizer with learning rate 0.0006426221935682487 and L2 penalty 0.0


In [10]:
trainer = LoadTrainer(
    model_type, 
    model = model,
    optimizer = optimizer,
    train_gen = train_gen,
    valid_gen = train_gen, 
    dataloader = dataloader,
    valid_dataloader = valid_dataloader,
    device = device,
    config = config["trainer"]
)

INFO:holodecml.vae.trainers:Loading trainer type encoder-vae
INFO:holodecml.vae.trainers:Clipping gradients to range [-1.0, 1.0]


In [11]:
# Initialize LR annealing scheduler 
if "ReduceLROnPlateau" in config["callbacks"]:
    schedule_config = config["callbacks"]["ReduceLROnPlateau"]
    scheduler = ReduceLROnPlateau(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ReduceLROnPlateau learning rate annealer with patience {schedule_config['patience']}"
    )
elif "ExponentialLR" in config["callbacks"]:
    schedule_config = config["callbacks"]["ExponentialLR"]
    scheduler = ExponentialLR(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ExponentialLR learning rate annealer with reduce factor {schedule_config['gamma']}"
    )

# Early stopping
checkpoint_config = config["callbacks"]["EarlyStopping"]
early_stopping = EarlyStopping(**checkpoint_config)

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(**config["callbacks"]["MetricsLogger"])

INFO:root:Loaded ExponentialLR learning rate annealer with reduce factor 0.95
INFO:holodecml.vae.checkpointer:Loaded EarlyStopping checkpointer with patience 100
INFO:holodecml.vae.checkpointer:Loaded a metrics logger test/training_log.csv to track the training results


In [12]:
trainer.train(scheduler, early_stopping, metrics_logger)

INFO:holodecml.vae.trainers:Training the model for up to 500 epochs starting at epoch 0
loss: 0.127:  11%|█         | 7/63 [00:30<04:06,  4.40s/it]


KeyboardInterrupt: 