In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import yaml
import tqdm
import torch
import pickle
import logging

from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple
from multiprocessing import cpu_count

# custom
from holodecml.vae.checkpointer import *
from holodecml.vae.data_loader import *
from holodecml.vae.optimizers import *
from holodecml.vae.transforms import *
from holodecml.vae.models import *
from holodecml.vae.visual import *
from holodecml.vae.losses import *

from torch import nn
from abc import abstractmethod

In [2]:
criterion = nn.BCELoss(reduction='sum')

def loss_fn(recon_x, x, mu, logvar):
    BCE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

class Trainer:
    
    def __init__(self, 
                 model, 
                 optimizer,
                 train_gen, 
                 valid_gen, 
                 dataloader, 
                 valid_dataloader,
                 batch_size,
                 path_save,
                 device,
                 test_image = None):
        
        self.model = model
        self.optimizer = optimizer
        self.train_gen = train_gen
        self.valid_gen = valid_gen
        self.dataloader = dataloader
        self.valid_dataloader = valid_dataloader
        self.batch_size = batch_size
        self.path_save = path_save
        self.device = device
        self.test_image = test_image
        
        
    def train_one_epoch(self, epoch):

        self.model.train()
        batches_per_epoch = int(np.ceil(self.train_gen.__len__() / self.batch_size))
        batch_group_generator = tqdm.tqdm(
            enumerate(self.dataloader),
            total=batches_per_epoch, 
            leave=True
        )

        epoch_losses = {"loss": [], "bce": [], "kld": []}
        for idx, images in batch_group_generator:

            images = images.to(self.device)
            recon_images, mu, logvar = self.model(images)
            loss, bce, kld = loss_fn(recon_images, images, mu, logvar)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            batch_loss = loss.item() / batch_size
            bce_loss = bce.item() / batch_size
            kld_loss = kld.item() / batch_size

            epoch_losses["loss"].append(batch_loss)
            epoch_losses["bce"].append(bce_loss)
            epoch_losses["kld"].append(kld_loss)

            loss = np.mean(epoch_losses["loss"])
            bce = np.mean(epoch_losses["bce"])
            kld = np.mean(epoch_losses["kld"])

            to_print = "loss: {:.3f} bce: {:.3f} kld: {:.3f}".format(loss, bce, kld)
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

        return loss, bce, kld


    def test(self, epoch):

        self.model.eval()
        batches_per_epoch = int(np.ceil(self.valid_gen.__len__() / self.batch_size))

        with torch.no_grad():

            batch_group_generator = tqdm.tqdm(
                enumerate(self.valid_dataloader),
                total=batches_per_epoch, 
                leave=True
            )

            epoch_losses = {"loss": [], "bce": [], "kld": []}
            for idx, images in batch_group_generator:

                images = images.to(self.device)
                recon_images, mu, logvar = self.model(images)
                loss, bce, kld = loss_fn(recon_images, images, mu, logvar)

                batch_loss = loss.item() / batch_size
                bce_loss = bce.item() / batch_size
                kld_loss = kld.item() / batch_size

                epoch_losses["loss"].append(batch_loss)
                epoch_losses["bce"].append(bce_loss)
                epoch_losses["kld"].append(kld_loss)

                loss = np.mean(epoch_losses["loss"])
                bce = np.mean(epoch_losses["bce"])
                kld = np.mean(epoch_losses["kld"])

                to_print = "val_loss: {:.3f} val_bce: {:.3f} val_kld: {:.3f}".format(loss, bce, kld)
                batch_group_generator.set_description(to_print)
                batch_group_generator.update()

            if os.path.isfile(self.test_image):
                with open(self.test_image, "rb") as fid:
                    pic = pickle.load(fid)
                self.compare(epoch, pic)

        return loss, bce, kld
    
    
    def compare(self, epoch, x):
        x = x.to(self.device)
        recon_x, _, _ = self.model(x)
        compare_x = torch.cat([x, recon_x])
        save_image(compare_x.data.cpu(), f'{self.path_save}/image_epoch_{epoch}.png')

In [3]:
class Trainer:
    
    def __init__(self, 
                 model, 
                 optimizer,
                 train_gen, 
                 valid_gen, 
                 dataloader, 
                 valid_dataloader,
                 batch_size,
                 path_save,
                 device,
                 test_image = None):
        
        self.model = model
        self.optimizer = optimizer
        self.train_gen = train_gen
        self.valid_gen = valid_gen
        self.dataloader = dataloader
        self.valid_dataloader = valid_dataloader
        self.batch_size = batch_size
        self.path_save = path_save
        self.device = device
        self.test_image = test_image
        
        
    def train_one_epoch(self, epoch):

        self.model.train()
        batches_per_epoch = int(np.ceil(self.train_gen.__len__() / self.batch_size))
        batch_group_generator = tqdm.tqdm(
            enumerate(self.dataloader),
            total=batches_per_epoch, 
            leave=True
        )

        epoch_losses = {"loss": [], "bce": [], "kld": []}
        for idx, images in batch_group_generator:

            images = images.to(self.device)
            recon_images, mu, logvar = self.model(images)
            
            loss, bce, kld = loss_fn(recon_images, images, mu, logvar)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            batch_loss = loss.item() / batch_size
            bce_loss = bce.item() / batch_size
            kld_loss = kld.item() / batch_size

            epoch_losses["loss"].append(batch_loss)
            epoch_losses["bce"].append(bce_loss)
            epoch_losses["kld"].append(kld_loss)

            loss = np.mean(epoch_losses["loss"])
            bce = np.mean(epoch_losses["bce"])
            kld = np.mean(epoch_losses["kld"])

            to_print = "loss: {:.3f} bce: {:.3f} kld: {:.3f}".format(loss, bce, kld)
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

        return loss, bce, kld


    def test(self, epoch):

        self.model.eval()
        batches_per_epoch = int(np.ceil(self.valid_gen.__len__() / self.batch_size))

        with torch.no_grad():

            batch_group_generator = tqdm.tqdm(
                enumerate(self.valid_dataloader),
                total=batches_per_epoch, 
                leave=True
            )

            epoch_losses = {"loss": [], "bce": [], "kld": []}
            for idx, images in batch_group_generator:

                images = images.to(self.device)
                recon_images, mu, logvar = self.model(images)
                loss, bce, kld = loss_fn(recon_images, images, mu, logvar)

                batch_loss = loss.item() / batch_size
                bce_loss = bce.item() / batch_size
                kld_loss = kld.item() / batch_size

                epoch_losses["loss"].append(batch_loss)
                epoch_losses["bce"].append(bce_loss)
                epoch_losses["kld"].append(kld_loss)

                loss = np.mean(epoch_losses["loss"])
                bce = np.mean(epoch_losses["bce"])
                kld = np.mean(epoch_losses["kld"])

                to_print = "val_loss: {:.3f} val_bce: {:.3f} val_kld: {:.3f}".format(loss, bce, kld)
                batch_group_generator.set_description(to_print)
                batch_group_generator.update()

            if os.path.isfile(self.test_image):
                with open(self.test_image, "rb") as fid:
                    pic = pickle.load(fid)
                self.compare(epoch, pic)

        return loss, bce, kld
    
    
    def compare(self, epoch, x):
        x = x.to(self.device)
        recon_x, _, _ = self.model(x)
        compare_x = torch.cat([x, recon_x])
        save_image(compare_x.data.cpu(), f'{self.path_save}/image_epoch_{epoch}.png')

In [4]:
with open("config.yml") as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

In [5]:
try:
    os.makedirs(config["path_save"])
except:
    pass

In [6]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

# Save the log file
logger_name = os.path.join(config["path_save"], "log.txt")
fh = logging.FileHandler(logger_name,
                         mode="w",
                         encoding='utf-8')
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
root.addHandler(fh)

In [7]:
logging.info(f'Reading parameters from config.yml')
    
path_data = config["path_data"]
path_save = config["path_save"]
num_particles = config["num_particles"]
maxnum_particles = config["maxnum_particles"]
output_cols = config["output_cols"]
subset = config["subset"]
test_image = config["test_image"]

batch_size = config["batch_size"]
workers = min(config["workers"], cpu_count())
epochs = config["epochs"]
retrain = False if "retrain" not in config else config["retrain"]

model_save_path = os.path.join(f"{path_save}", "checkpoint.pt")

start_epoch = 0
if retrain:
    saved_model_optimizer = torch.load(model_save_path)
    start_epoch = saved_model_optimizer["epoch"] + 1

INFO:root:Reading parameters from config.yml


In [8]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
    
logging.info(f'Preparing to use device {device}')

INFO:root:Preparing to use device cuda:0


In [9]:
tforms = []
transform_config = config["transforms"]

if "Rescale" in transform_config:
    rescale = transform_config["Rescale"]
    tforms.append(Rescale(rescale))
if "Normalize" in transform_config:
    tforms.append(Normalize())
if "ToTensor" in transform_config:
    tforms.append(ToTensor(device))
if "RandomCrop" in transform_config:
    tforms.append(RandomCrop())
if "Standardize" in transform_config:
    tforms.append(Standardize())

transform = transforms.Compose(tforms)

INFO:holodecml.vae.transforms:Loaded Rescale transformation with output size 384
INFO:holodecml.vae.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.vae.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0


In [10]:
train_gen = HologramDataset(
    path_data, num_particles, "train", subset, 
    output_cols, maxnum_particles = maxnum_particles, 
    transform = transform
)

train_scalers = train_gen.get_transform()

valid_gen = HologramDataset(
    path_data, num_particles, "test", subset, 
    output_cols, scaler = train_scalers, 
    maxnum_particles = maxnum_particles,
    transform = transform
)

INFO:holodecml.vae.data_loader:Loaded train hologram data containing 50000 images
INFO:holodecml.vae.data_loader:Loaded test hologram data containing 10000 images


In [11]:
logging.info(f"Loading training data iterator using {workers} workers")
    
dataloader = DataLoader(
    train_gen,
    batch_size = batch_size,
    shuffle = True,
    num_workers = workers
)

valid_dataloader = DataLoader(
    valid_gen,
    batch_size = batch_size,
    shuffle = False,
    num_workers = workers
)

INFO:root:Loading training data iterator using 24 workers


In [12]:
kernel_size = 4
stride = 1
padding = 0
init_kernel = 16 # initial number of filters

# Based off of https://debuggercafe.com/face-image-generation-using-convolutional-variational-autoencoder-and-pytorch/

class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Conv2d(
            in_channels=1, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc2 = nn.Conv2d(
            in_channels=init_kernel, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc3 = nn.Conv2d(
            in_channels=init_kernel*2, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc4 = nn.Conv2d(
            in_channels=init_kernel*4, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc5 = nn.Conv2d(
            in_channels=init_kernel*8, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        
        # decoder 
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec5 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=1, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = self.enc5(x)
        # get `mu` and `log_var`
        mu = x
        log_var = x
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        reconstruction = torch.sigmoid(self.dec5(x))
        return reconstruction, mu, log_var

In [13]:
train_gen.__getitem__(0).shape

torch.Size([1, 384, 256])

In [14]:
vae = ConvVAE().to(device)
    
# Print the total number of model parameters
logging.info(
    f"The model contains {count_parameters(vae)} parameters"
)

if retrain:
    vae = vae.load_state_dict(
        saved_model_optimizer["model_state_dict"], map_location=device
    )
    logging.info(f"Loaded model weights from {model_save_path}")

INFO:root:The model contains 410609 parameters


In [15]:
optimizer_config = config["optimizer"]
learning_rate = optimizer_config["lr"] if not retrain else saved_model_optimizer["lr"]
optimizer_type = optimizer_config["type"]

if optimizer_type == "lookahead-diffgrad":
    optimizer = LookaheadDiffGrad(vae.parameters(), lr=learning_rate)
elif optimizer_type == "diffgrad":
    optimizer = DiffGrad(vae.parameters(), lr=learning_rate)
elif optimizer_type == "lookahead-radam":
    optimizer = LookaheadRAdam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "radam":
    optimizer = RAdam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "adam":
    optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "sgd":
    optimizer = torch.optim.SGD(vae.parameters(), lr=learning_rate)
else:
    logging.warning(
        f"Optimzer type {optimizer_type} is unknown. Exiting with error."
    )
    sys.exit(1)

logging.info(
    f"Loaded the {optimizer_type} optimizer with learning rate {learning_rate}"
)

if retrain:
    optimizer = optimizer.load_state_dict(
        saved_model_optimizer["optimizer_state_dict"], map_location=device
    )
    logging.info(f"Loaded optimizer weights from {model_save_path}")

INFO:root:Loaded the lookahead-diffgrad optimizer with learning rate 0.001


In [16]:
# Initialize LR annealing scheduler 
schedule_config = config["callbacks"]["ReduceLROnPlateau"]

logging.info(
    f"Loaded ReduceLROnPlateau learning rate annealer with patience {schedule_config['patience']}"
)

scheduler = ReduceLROnPlateau(optimizer,
                              mode=schedule_config["mode"],
                              patience=schedule_config["patience"],
                              factor=schedule_config["factor"],
                              min_lr=schedule_config["min_lr"],
                              verbose=schedule_config["verbose"])

# Early stopping
checkpoint_config = config["callbacks"]["EarlyStopping"]
early_stopping = EarlyStopping(path=model_save_path, 
                               patience=checkpoint_config["patience"], 
                               verbose=checkpoint_config["verbose"])

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(path_save, reload = retrain)

INFO:root:Loaded ReduceLROnPlateau learning rate annealer with patience 5
INFO:holodecml.vae.checkpointer:Loaded EarlyStopping checkpointer with patience 10
INFO:holodecml.vae.checkpointer:Loaded a metrics logger test/training_log.csv to track the training results


In [17]:
logging.info("Loading trainer object")
    
trainer = Trainer(
    vae,
    optimizer,
    train_gen,
    valid_gen, 
    dataloader, 
    valid_dataloader,
    batch_size,
    path_save,
    device,
    test_image
)

INFO:root:Loading trainer object


In [None]:
logging.info(
    f"Training the model for up to {epochs} epochs starting at epoch {start_epoch}"
)

for epoch in range(start_epoch, epochs):

    train_loss, train_bce, train_kld = trainer.train_one_epoch(epoch)
    test_loss, test_bce, test_kld = trainer.test(epoch)

    scheduler.step(test_loss)
    early_stopping(epoch, test_loss, trainer.model, trainer.optimizer)

    # Write results to the callback logger 
    result = {
        "epoch": epoch,
        "train_loss": train_loss,
        "train_bce": train_bce,
        "train_kld": train_kld,
        "valid_loss": test_loss,
        "valid_bce": test_bce,
        "valid_kld": test_kld,
        "lr": early_stopping.print_learning_rate(trainer.optimizer)
    }
    metrics_logger.update(result)

    if early_stopping.early_stop:
        print("Early stopping")
        break

INFO:root:Training the model for up to 100 epochs starting at epoch 0
loss: 67428.523 bce: 66929.622 kld: 498.902: 100%|██████████| 1563/1563 [16:55<00:00,  1.54it/s]
val_loss: 67173.021 val_bce: 66552.117 val_kld: 620.904: 100%|██████████| 313/313 [01:08<00:00,  4.54it/s]
INFO:holodecml.vae.checkpointer:Validation loss decreased on epoch 0 (inf --> 67173.021491).  Saving model.
loss: 67182.075 bce: 66568.344 kld: 613.732:  63%|██████▎   | 978/1563 [10:37<06:20,  1.54it/s]

In [None]:
generate_video(f"{path_save}", "generated_hologram.avi") 