In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import yaml
import tqdm
import torch
import pickle
import logging

from sklearn.metrics import accuracy_score as Accuracy

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import *
from torch import nn

from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple
from multiprocessing import cpu_count, Manager, Process, Pool
from shutil import copyfile

# custom
from holodecml.vae.checkpointer import *
from holodecml.vae.data_loader import *
from holodecml.vae.optimizers import *
from holodecml.vae.transforms import *
from holodecml.vae.spectral import *

from holodecml.vae.trainers import *
from holodecml.vae.models import *
from holodecml.vae.visual import *
from holodecml.vae.losses import *

from aimlutils.gpu import gpu_report

In [20]:
sorted(gpu_report().items(), key = lambda x: x[1], reverse = True)

[(0, 24860)]

In [3]:
cpu_count()

72

In [4]:
with open("config.yml") as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

data_save = config["log"]
model_type = config["type"]

In [5]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

In [6]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
    
logging.info(f'Preparing to use device {device}')

INFO:root:Preparing to use device cuda:0


In [7]:
transform = LoadTransformations(config["transforms"], device = device)

INFO:holodecml.vae.transforms:Loaded Rescale transformation with output size 600
INFO:holodecml.vae.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.vae.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0


In [8]:
train_gen = LoadReader(
    reader_type = model_type, 
    split = "train", 
    transform = transform,
    scaler = None,
    config = config["data"]
)

valid_gen = LoadReader(
    reader_type = model_type, 
    split = "test", 
    transform = transform, 
    scaler = train_gen.get_transform(),
    config = config["data"],
)

INFO:holodecml.vae.data_loader:Loading reader-type encoder-vae
INFO:holodecml.vae.data_loader:Loaded train hologram data containing 5000 images
INFO:holodecml.vae.data_loader:Loading reader-type encoder-vae
INFO:holodecml.vae.data_loader:Loaded test hologram data containing 1000 images


In [9]:
# logging.info(f"Loading training data iterator using {config['iterator']['num_workers']} workers")

# dataloader = DataLoader(
#     train_gen,
#     **config["iterator"]
# )

# valid_dataloader = DataLoader(
#     valid_gen,
#     **config["iterator"]
# ) 

In [10]:
class AsynchronousDataLoader:
    
    def __init__(self, DataReader, Queue, workers = 1, epochs = 100, batch_size = 32, max_queue_size = 32, shuffle = True):
        self.DataReader = DataReader
        self.q = Queue
        self.workers = workers
        self.batch_size = batch_size
        self.max_queue_size = max_queue_size
        self.total_items = list(range(self.DataReader.__len__()))
        self.epochs = epochs
        self.shuffle = shuffle
        
        if self.shuffle:
            random.shuffle(self.total_items)
            
    def local(self, batch):
        return [self.DataReader.__getitem__(item) for item in batch]
        
    def __call__(self):  
        for epoch in range(self.epochs):
            x_batch = []
            y_batch = defaultdict(list)
            w_batch = []
            with Pool(self.workers) as p:
                batches = len(self.total_items) / 100
                chunked = np.array_split(self.total_items, batches)
                for chunk in p.imap(self.local, chunked):
                    for (x, y, w) in chunk:
                        x_batch.append(x)
                        for key, value in y.items():
                            y_batch[key].append(torch.from_numpy(value))
                        w_batch.append(torch.from_numpy(w))
                        if len(x_batch) == self.batch_size:
                            x_batch = torch.stack(x_batch, 0)
                            w_batch = torch.stack(w_batch, 0)
                            for key in y_batch:
                                y_batch[key] = torch.stack(y_batch[key], 0)
                            while (self.q.qsize() >= self.max_queue_size):
                                time.sleep(0.1)
                            self.q.put((x_batch, y_batch, w_batch))
                            x_batch = []
                            y_batch = defaultdict(list)
                            w_batch = []
                if len(x_batch) > 0:
                    x_batch = torch.stack(x_batch, 0)
                    w_batch = torch.stack(w_batch, 0)
                    for key in y_batch:
                        y_batch[key] = torch.stack(y_batch[key], 0)
                    self.q.put((x_batch, y_batch, w_batch))
                
                # Send message to Reciever so it knows to die.
                self.q.put('stop')
            
            if self.shuffle:
                random.shuffle(self.total_items)
            
class AsynchronousDataReciever:
    
    def __init__(self, Queue, batch_size):
        self.q = Queue
        self.wait = True
        self.batch_size = batch_size
        
    def __iter__(self):
        return self
    
    def __next__(self): 
        while self.q.empty():
            time.sleep(0.1)
        result = self.q.get()
        if result == 'stop':
            raise StopIteration
        return result

In [11]:
train_queue = Manager().Queue()
valid_queue = Manager().Queue()

In [12]:
train_dataload = AsynchronousDataLoader(train_gen, train_queue, 
                                        workers = 16, batch_size = 32, max_queue_size = 128)
p1 = Process(target=train_dataload)
p1.start()
dataloader = AsynchronousDataReciever(train_queue, batch_size = 32)

In [13]:
valid_dataload = AsynchronousDataLoader(valid_gen, valid_queue, 
                                        workers = 16, batch_size = 32, max_queue_size = 128)
p2 = Process(target=valid_dataload)
p2.start()
valid_dataloader = AsynchronousDataReciever(valid_queue, batch_size = 32)

In [14]:
model = LoadModel(model_type, config["model"], device)

INFO:holodecml.vae.models:Loading model-type encoder-vae
INFO:holodecml.vae.models:Loaded a self-attentive encoder-decoder VAE model
INFO:root:Loaded VAE weights pretrained/pretrained.pt and froze these parameters
INFO:holodecml.vae.models:The model contains 6890500 trainable parameters


In [15]:
retrain = config["trainer"]["start_epoch"] > 0

optimizer_config = config["optimizer"]
learning_rate = optimizer_config["lr"] #if not retrain else model_dict["lr"]
optimizer_type = optimizer_config["type"]

optimizer = LoadOptimizer(optimizer_type, model.parameters(), learning_rate)

#if retrain:
#   optimizer.load_state_dict(model_dict["optimizer_state_dict"])
#   logging.info(f"Loaded optimizer weights from {model_dict}")

INFO:holodecml.vae.optimizers:Loaded the lookahead-diffgrad optimizer with learning rate 0.000631 and L2 penalty 0.0


In [16]:
trainer = LoadTrainer(
    model_type, 
    model = model,
    optimizer = optimizer,
    train_gen = train_gen,
    valid_gen = valid_gen,
    dataloader = dataloader,
    valid_dataloader = valid_dataloader,
    device = device,
    config = config["trainer"]
)

INFO:holodecml.vae.trainers:Loading trainer-type encoder-vae
INFO:holodecml.vae.trainers:Clipping gradients to range [-1.0, 1.0]


In [17]:
# Initialize LR annealing scheduler 
if "ReduceLROnPlateau" in config["callbacks"]:
    schedule_config = config["callbacks"]["ReduceLROnPlateau"]
    scheduler = ReduceLROnPlateau(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ReduceLROnPlateau learning rate annealer with patience {schedule_config['patience']}"
    )
elif "ExponentialLR" in config["callbacks"]:
    schedule_config = config["callbacks"]["ExponentialLR"]
    scheduler = ExponentialLR(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ExponentialLR learning rate annealer with reduce factor {schedule_config['gamma']}"
    )

# Early stopping
checkpoint_config = config["callbacks"]["EarlyStopping"]
early_stopping = EarlyStopping(**checkpoint_config)

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(**config["callbacks"]["MetricsLogger"])

INFO:root:Loaded ExponentialLR learning rate annealer with reduce factor 0.95
INFO:holodecml.vae.checkpointer:Loaded EarlyStopping checkpointer with patience 5
INFO:holodecml.vae.checkpointer:Loaded a metrics logger test/training_log.csv to track the training results


In [18]:
trainer.train(scheduler, early_stopping, metrics_logger, metric = "val_acc")

INFO:holodecml.vae.trainers:Training the model for up to 100 epochs starting at epoch 0
Epoch: 0 loss: 0.191 mse: 0.194 bce: 0.163 acc: 0.846: 100%|██████████| 157/157 [01:10<00:00,  2.21it/s]
Epoch: 0 val_loss: 0.547 val_mse: 0.273 val_bce: 0.274 val_acc: 0.858: 100%|██████████| 32/32 [00:04<00:00,  7.59it/s]
INFO:holodecml.vae.checkpointer:Validation loss decreased on epoch 0 (inf --> -0.857607).  Saving model.
Epoch: 1 loss: 0.174 mse: 0.179 bce: 0.130 acc: 0.862: 100%|██████████| 157/157 [01:14<00:00,  2.11it/s]
Epoch: 1 val_loss: 0.561 val_mse: 0.274 val_bce: 0.287 val_acc: 0.854: 100%|██████████| 32/32 [00:04<00:00,  7.10it/s]
INFO:holodecml.vae.checkpointer:EarlyStopping counter: 1 out of 5
Epoch: 2 loss: 0.168 mse: 0.172 bce: 0.124 acc: 0.868: 100%|██████████| 157/157 [01:13<00:00,  2.13it/s]
Epoch: 2 val_loss: 0.545 val_mse: 0.272 val_bce: 0.273 val_acc: 0.864: 100%|██████████| 32/32 [00:03<00:00,  8.27it/s]
INFO:holodecml.vae.checkpointer:Validation loss decreased on epoch 2 

{'epoch': 8,
 'train_loss': 0.14090711884437854,
 'train_mse': 0.14639340872597542,
 'train_bce': 0.0860442709486196,
 'valid_loss': 0.5567045044153929,
 'valid_mse': 0.2807715078815818,
 'valid_bce': 0.27593300212174654,
 'lr': 0.00041861829214339834,
 'train_acc': 0.9149840550058207,
 'valid_acc': 0.8661230225116014}