In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import yaml
import tqdm
import torch
import pickle
import logging

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import *
from torch import nn

from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple
from multiprocessing import cpu_count
from shutil import copyfile

# custom
from holodecml.vae.checkpointer import *
from holodecml.vae.data_loader import *
from holodecml.vae.optimizers import *
from holodecml.vae.transforms import *
from holodecml.vae.spectral import *
from holodecml.vae.trainers import *
from holodecml.vae.models import *
from holodecml.vae.visual import *
from holodecml.vae.losses import *

In [2]:
with open("/glade/work/schreck/repos/holodec-ml/scripts/schreck/vae/results/50_100/attention/config.yml") as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

In [3]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

In [4]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
    
logging.info(f'Preparing to use device {device}')

INFO:root:Preparing to use device cuda:0


In [5]:
tforms = []
transform_config = config["transforms"]

if "RandomVerticalFlip" in transform_config:
    tforms.append(RandVerticalFlip(0.5))
if "RandomHorizontalFlip" in transform_config:
    tforms.append(RandHorizontalFlip(0.5))
if "Rescale" in transform_config:
    rescale = transform_config["Rescale"]
    tforms.append(Rescale(rescale))
if "Normalize" in transform_config:
    mode = transform_config["Normalize"]
    tforms.append(Normalize(mode))
if "ToTensor" in transform_config:
    tforms.append(ToTensor(device))
if "RandomCrop" in transform_config:
    tforms.append(RandomCrop())
if "Standardize" in transform_config:
    tforms.append(Standardize())

transform = transforms.Compose(tforms)

INFO:holodecml.vae.transforms:Loaded RandomVerticalFlip transformation with probability 0.5
INFO:holodecml.vae.transforms:Loaded RandomHorizontalFlip transformation with probability 0.5
INFO:holodecml.vae.transforms:Loaded Rescale transformation with output size 600
INFO:holodecml.vae.transforms:Loaded Normalize transformation that normalizes data in the range 0 to 1
INFO:holodecml.vae.transforms:Loaded ToTensor transformation, putting tensors on device cuda:0


In [6]:
train_gen = HologramDataset(
    split = "train", 
    transform = transform,
    **config["data"]
)

train_scalers = train_gen.get_transform()

valid_gen = HologramDataset(
    split = "test",
    transform = transform,
    **config["data"]
)

INFO:holodecml.vae.data_loader:Loaded train hologram data containing 5000 images
INFO:holodecml.vae.data_loader:Loaded test hologram data containing 1000 images


In [7]:
logging.info(f"Loading training data iterator using {config['iterator']['num_workers']} workers")

dataloader = DataLoader(
    train_gen,
    **config["iterator"]
)

valid_dataloader = DataLoader(
    valid_gen,
    **config["iterator"]
)

INFO:root:Loading training data iterator using 8 workers


In [8]:
class Self_Attention(nn.Module):
    
    """ Self attention Layer
        Based on https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
    """
    
    def __init__(self, in_dim):
        
        super(Self_Attention, self).__init__()
        
        self.chanel_in = in_dim
        
        self.query_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
        
        self.query_conv = SpectralNorm(self.query_conv)
        self.key_conv = SpectralNorm(self.key_conv)
        self.value_conv = SpectralNorm(self.value_conv)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1) #
        
    def forward(self, x):
        
        """
            inputs :
                x : input feature maps(B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        
        B, C, width, height = x.size()
        proj_query = self.query_conv(x).view(B, -1, width*height).permute(0,2,1) # B X CX(N)
        proj_key = self.key_conv(x).view(B, -1, width*height) # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(B, -1, width*height) # B X C X N

        out = torch.bmm(proj_value, attention.permute(0,2,1))
        out = out.view(B, C, width, height)
        out = self.gamma * out + x
        
        return out, attention

In [9]:
class ATTENTION_VAE(nn.Module):
    
    def __init__(self,
                 image_channels=1,
                 hidden_dims=[8, 16, 32, 64, 128, 256],
                 z_dim=10):
        
        super(ATTENTION_VAE, self).__init__()
        
        self.image_channels = image_channels
        self.hidden_dims = hidden_dims
        self.z_dim = z_dim
        
        self.encoder_block1 = self.encoder_block(self.image_channels, self.hidden_dims[0], 4, 2, 1)
        self.encoder_atten1 = Self_Attention(self.hidden_dims[0])
        self.encoder_block2 = self.encoder_block(self.hidden_dims[0], self.hidden_dims[1], 4, 2, 1)
        self.encoder_atten2 = Self_Attention(self.hidden_dims[1])
        self.encoder_block3 = self.encoder_block(self.hidden_dims[1], self.hidden_dims[2], 4, 2, 1)
        self.encoder_atten3 = Self_Attention(self.hidden_dims[2])
        self.encoder_block4 = self.encoder_block(self.hidden_dims[2], self.hidden_dims[3], (3,2), (3,2), 0)
        self.encoder_atten4 = Self_Attention(self.hidden_dims[3])
        self.encoder_block5 = self.encoder_block(self.hidden_dims[3], self.hidden_dims[4], 5, 5, 0)
        self.encoder_atten5 = Self_Attention(self.hidden_dims[4])
        self.encoder_block6 = self.encoder_block(self.hidden_dims[4], self.hidden_dims[5], 5, 5, 0)
                
        self.fc1 = nn.Linear(self.hidden_dims[-1], self.z_dim)
        self.fc2 = nn.Linear(self.hidden_dims[-1], self.z_dim)
        self.fc3 = nn.Linear(self.z_dim, self.hidden_dims[-1])
        
        self.decoder_block1 = self.decoder_block(self.hidden_dims[5], self.hidden_dims[4], 5, 5, 0)
        self.decoder_atten1 = Self_Attention(self.hidden_dims[4])
        self.decoder_block2 = self.decoder_block(self.hidden_dims[4], self.hidden_dims[3], 5, 5, 0)
        self.decoder_atten2 = Self_Attention(self.hidden_dims[3])
        self.decoder_block3 = self.decoder_block(self.hidden_dims[3], self.hidden_dims[2], (3,2), (3,2), 0)
        self.decoder_atten3 = Self_Attention(self.hidden_dims[2])
        self.decoder_block4 = self.decoder_block(self.hidden_dims[2], self.hidden_dims[1], 4, 2, 1)
        self.decoder_atten4 = Self_Attention(self.hidden_dims[1])
        self.decoder_block5 = self.decoder_block(self.hidden_dims[1], self.hidden_dims[0], 4, 2, 1)
        self.decoder_atten5 = Self_Attention(self.hidden_dims[0])
        self.decoder_block6 = self.decoder_block(self.hidden_dims[0], self.image_channels, 4, 2, 1)
        
        logger.info("Loaded a self-attentive encoder-decoder VAE model")    
    
    def encoder_block(self, dim1, dim2, kernel_size, stride, padding):
        return nn.Sequential(
            SpectralNorm(
                nn.Conv2d(dim1, dim2, kernel_size=kernel_size, stride=stride, padding=padding)
            ),
            nn.BatchNorm2d(dim2),
            nn.LeakyReLU()
        )
    def decoder_block(self, dim1, dim2, kernel_size, stride, padding):
        return nn.Sequential(
            SpectralNorm(
                nn.ConvTranspose2d(dim1, dim2, kernel_size=kernel_size, stride=stride, padding=padding)
            ),
            nn.BatchNorm2d(dim2),
            nn.LeakyReLU()
        )

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)                               
        esp = torch.randn(*mu.size()).to(std.device)
        z = mu + std * esp
        return z

    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def encode(self, x):
        h = self.encoder_block1(x)
        #h, att_map1 = self.encoder_atten1(h)
        h = self.encoder_block2(h)
        #h, att_map2 = self.encoder_atten2(h)
        h = self.encoder_block3(h)
        h, att_map3 = self.encoder_atten3(h)
        h = self.encoder_block4(h)
        h, att_map4 = self.encoder_atten4(h)
        h = self.encoder_block5(h)
        h, att_map5 = self.encoder_atten5(h)
        h = self.encoder_block6(h)
        h = h.view(h.size(0), -1) # flatten
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar#, att_map

    def decode(self, z):
        z = self.fc3(z)
        z = z.view(z.size(0), self.hidden_dims[-1], 1, 1) #flatten/reshape
        z = self.decoder_block1(z)
        z, att_map1 = self.decoder_atten1(z)
        z = self.decoder_block2(z)
        z, att_map2 = self.decoder_atten2(z)
        z = self.decoder_block3(z)
        z, att_map3 = self.decoder_atten3(z)
        z = self.decoder_block4(z)
        #z, att_map4 = self.decoder_atten4(z)
        z = self.decoder_block5(z)
        #z, att_map5 = self.decoder_atten5(z)
        z = self.decoder_block6(z)
        return z#, att_map

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar#, encoder_att_map, decoder_att_map

In [10]:
#vae = ConvVAE().to(device)
vae = ATTENTION_VAE(**config["model"]).to(device)
    
# Print the total number of model parameters
logging.info(
    f"The model contains {count_parameters(vae)} parameters"
)

INFO:holodecml.vae.losses:Loaded a self-attentive encoder-decoder VAE model
INFO:root:The model contains 41987627 parameters


In [11]:
# model_dict = torch.load("attention/best.pt", map_location=lambda storage, loc: storage)
# vae.load_state_dict(model_dict["model_state_dict"])

In [12]:
retrain = config["trainer"]["start_epoch"] > 0

optimizer_config = config["optimizer"]
learning_rate = optimizer_config["lr"] if not retrain else model_dict["lr"]
optimizer_type = optimizer_config["type"]

if optimizer_type == "lookahead-diffgrad":
    optimizer = LookaheadDiffGrad(vae.parameters(), lr=learning_rate)
elif optimizer_type == "diffgrad":
    optimizer = DiffGrad(vae.parameters(), lr=learning_rate)
elif optimizer_type == "lookahead-radam":
    optimizer = LookaheadRAdam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "radam":
    optimizer = RAdam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "adam":
    optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
elif optimizer_type == "sgd":
    optimizer = torch.optim.SGD(vae.parameters(), lr=learning_rate)
else:
    logging.warning(
        f"Optimzer type {optimizer_type} is unknown. Exiting with error."
    )
    sys.exit(1)

logging.info(
    f"Loaded the {optimizer_type} optimizer with learning rate {learning_rate}"
)

if retrain:
    optimizer.load_state_dict(model_dict["optimizer_state_dict"])

#logging.info(f"Loaded optimizer weights from {model_dict}")

INFO:root:Loaded the lookahead-diffgrad optimizer with learning rate 0.001


In [13]:
trainer = BaseTrainer(
    model = vae,
    optimizer = optimizer,
    train_gen = train_gen,
    valid_gen = valid_gen, 
    dataloader = dataloader, 
    valid_dataloader = valid_dataloader,
    device = device,
    **config["trainer"]
)

INFO:holodecml.vae.losses:Loaded Symmetric MSE loss ...
INFO:holodecml.vae.losses:... with alpha = 1.0, gamma = 1.0, and kld_weight = 0.0064
INFO:holodecml.vae.losses:Loaded Symmetric MSE loss ...
INFO:holodecml.vae.losses:... with alpha = 1.0, gamma = 1.0, and kld_weight = 0.032
INFO:holodecml.vae.trainers:Clipping gradients to range [-1.0, 1.0]


In [14]:
# Initialize LR annealing scheduler 
if "ReduceLROnPlateau" in config["callbacks"]:
    schedule_config = config["callbacks"]["ReduceLROnPlateau"]
    scheduler = ReduceLROnPlateau(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ReduceLROnPlateau learning rate annealer with patience {schedule_config['patience']}"
    )
elif "ExponentialLR" in config["callbacks"]:
    schedule_config = config["callbacks"]["ExponentialLR"]
    scheduler = ExponentialLR(trainer.optimizer, **schedule_config)
    logging.info(
        f"Loaded ExponentialLR learning rate annealer with reduce factor {schedule_config['gamma']}"
    )

# Early stopping
checkpoint_config = config["callbacks"]["EarlyStopping"]
early_stopping = EarlyStopping(**checkpoint_config)

# Write metrics to csv each epoch
metrics_logger = MetricsLogger(**config["callbacks"]["MetricsLogger"])

INFO:root:Loaded ExponentialLR learning rate annealer with reduce factor 0.95
INFO:holodecml.vae.checkpointer:Loaded EarlyStopping checkpointer with patience 10000000
INFO:holodecml.vae.checkpointer:Loaded a metrics logger /glade/work/schreck/repos/holodec-ml/scripts/schreck/vae/results/50_100/attention/training_log.csv to track the training results


In [15]:
trainer.train(scheduler, early_stopping, metrics_logger)

INFO:holodecml.vae.trainers:Training the model for up to 100 epochs starting at epoch 0
loss: 525624.485 bce: 477338.731 kld: 7544649.418: 100%|██████████| 157/157 [01:12<00:00,  2.17it/s]
val_loss: 302711.242 val_bce: 293364.262 val_kld: 292093.090: 100%|██████████| 32/32 [00:06<00:00,  4.91it/s]
INFO:holodecml.vae.checkpointer:Validation loss decreased on epoch 0 (inf --> 302711.242310).  Saving model.
loss: 343985.422 bce: 339558.788 kld: 691661.538: 100%|██████████| 157/157 [01:12<00:00,  2.17it/s]
val_loss: 343817.781 val_bce: 321395.951 val_kld: 700682.086: 100%|██████████| 32/32 [00:06<00:00,  4.67it/s] 
INFO:holodecml.vae.checkpointer:EarlyStopping counter: 1 out of 10000000
loss: 275243.692 bce: 272796.998 kld: 382296.083: 100%|██████████| 157/157 [01:11<00:00,  2.18it/s]
val_loss: 341882.345 val_bce: 335105.186 val_kld: 211786.183: 100%|██████████| 32/32 [00:06<00:00,  4.82it/s]
INFO:holodecml.vae.checkpointer:EarlyStopping counter: 2 out of 10000000
loss: 235419.699 bce: 234

KeyboardInterrupt: 

In [None]:
#generate_video(f"{path_save}", "generated_hologram.avi") 