In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import yaml
import torch

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from typing import List, Dict
from multiprocessing import cpu_count

# custom
from checkpointer import EarlyStopping, MetricsLogger
from data_loader import HologramDataset
from optimizers import LookaheadDiffGrad
from transforms import *
from models import count_parameters, CNN_VAE
from visual import generate_video
from losses import loss_fn

In [2]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

In [3]:
if "ipykernel_launcher" in sys.argv[0]:
    from tqdm import tqdm as tqdm_base
    def tqdm(*args, **kwargs):
        if hasattr(tqdm_base, '_instances'):
            for instance in list(tqdm_base._instances):
                tqdm_base._decr_instances(instance)
        return tqdm_base(*args, **kwargs)
else:
    import tqdm 
    tqdm = tqdm.tqdm

In [4]:
config_file = "config.yml"

In [5]:
with open(config_file) as config_file:
    config = yaml.load(config_file, Loader=yaml.FullLoader)

In [6]:
try:
    os.makedirs(config["path_save"])
except:
    pass

In [7]:
path_data = config["path_data"]
path_save = config["path_save"]
num_particles = config["num_particles"]
split = 'train'
subset = False
output_cols = ["x", "y", "z", "d", "hid"]

batch_size = config["conv2d_network"]["batch_size"]

input_shape = (600, 400, 1)

n_particles = config["num_particles"]
output_channels = len(output_cols) - 1

In [8]:
transform = transforms.Compose([
    Rescale(384),
    #RandomCrop(224),
    #Standardize(),
    Normalize(),
    ToTensor(device)
])

In [9]:
train_gen = HologramDataset(
    path_data, num_particles, "train", subset, 
    output_cols, maxnum_particles = 3, transform = transform
)

In [10]:
dataloader = DataLoader(
    train_gen,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 24
)

In [11]:
train_scalers = train_gen.get_transform()

valid_gen = HologramDataset(
    path_data, num_particles, "test", subset, 
    output_cols, scaler = train_scalers, maxnum_particles = 3,
    transform = transform
)

In [12]:
valid_dataloader = DataLoader(
    valid_gen,
    batch_size = batch_size,
    shuffle = False,
    num_workers = 24
)

### VAE model

In [13]:
epochs = 100
image_channels = 1
h_dim = 50176
z_dim = 1000

start_epoch = 0

In [14]:
vae = CNN_VAE(image_channels = image_channels, h_dim = h_dim, z_dim = z_dim).to(device)

In [15]:
print(count_parameters(vae))

104995849


In [16]:
optimizer = LookaheadDiffGrad(vae.parameters(), lr=1e-3)

In [17]:
# Initialize LR annealing scheduler 
scheduler = ReduceLROnPlateau(optimizer,
                              mode='min',
                              patience=5,
                              factor=0.5,
                              min_lr=1e-14,
                              verbose=True)

# Early stopping
model_save_path = os.path.join(f"{path_save}", "checkpoint.pt")
early_stopping = EarlyStopping(path=model_save_path, 
                               patience=10, 
                               verbose=True)

In [18]:
def train_one_epoch(epoch):
    
    vae.train()
    batches_per_epoch = int(np.ceil(train_gen.__len__() / batch_size))
    batch_group_generator = tqdm(enumerate(dataloader), 
                                 total=batches_per_epoch, 
                                 leave=True)
    
    epoch_losses = {"loss": [], "bce": [], "kld": []}
    for idx, images in batch_group_generator:
        
        images = images.to(device)
        recon_images, mu, logvar = vae(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_loss = loss.item() / batch_size
        bce_loss = bce.item() / batch_size
        kld_loss = kld.item() / batch_size
        
        epoch_losses["loss"].append(batch_loss)
        epoch_losses["bce"].append(bce_loss)
        epoch_losses["kld"].append(kld_loss)
        
        loss = np.mean(epoch_losses["loss"])
        bce = np.mean(epoch_losses["bce"])
        kld = np.mean(epoch_losses["kld"])
        
        to_print = "loss: {:.3f} bce: {:.3f} kld: {:.3f}".format(loss, bce, kld)
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
        
    return loss, bce, kld

In [19]:
def compare(epoch, x):
    x = x.to(device)
    recon_x, _, _ = vae(x)
    compare_x = torch.cat([x, recon_x])
    save_image(compare_x.data.cpu(), f'{path_save}/image_epoch_{epoch}.png')

In [20]:
def test(epoch):
    
    vae.eval()
    batches_per_epoch = int(np.ceil(valid_gen.__len__() / batch_size))
    
    with torch.no_grad():
        
        batch_group_generator = tqdm(enumerate(valid_dataloader), 
                                     total=batches_per_epoch, 
                                     leave=True)
        
        epoch_losses = {"loss": [], "bce": [], "kld": []}
        for idx, images in batch_group_generator:
            
            images = images.to(device)
            recon_images, mu, logvar = vae(images)
            loss, bce, kld = loss_fn(recon_images, images, mu, logvar)

            batch_loss = loss.item() / batch_size
            bce_loss = bce.item() / batch_size
            kld_loss = kld.item() / batch_size

            epoch_losses["loss"].append(batch_loss)
            epoch_losses["bce"].append(bce_loss)
            epoch_losses["kld"].append(kld_loss)

            loss = np.mean(epoch_losses["loss"])
            bce = np.mean(epoch_losses["bce"])
            kld = np.mean(epoch_losses["kld"])

            to_print = "val_loss: {:.3f} val_bce: {:.3f} val_kld: {:.3f}".format(loss, bce, kld)
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()
            
    with open(f"image.pkl", "rb") as fid:
        pic = pickle.load(fid)
    compare(epoch, pic)
                
    return loss, bce, kld

In [21]:
metrics_logger = MetricsLogger(path_save)

In [22]:
for epoch in range(start_epoch, epochs):

    train_loss, train_bce, train_kld = train_one_epoch(epoch)
    test_loss, test_bce, test_kld = test(epoch)
    
    scheduler.step(test_loss)
    early_stopping(epoch, test_loss, vae, optimizer)
    
    # Write results to logger / separate file
    result = {
        "epoch": epoch,
        "train_loss": train_loss,
        "train_bce": train_bce,
        "train_kld": train_kld,
        "valid_loss": test_loss,
        "valid_bce": test_bce,
        "valid_kld": test_kld,
        "lr": early_stopping.print_learning_rate(optimizer)
    }
    metrics_logger.update(result)
    
    if early_stopping.early_stop:
        print("Early stopping")
        break

loss: 68719.289 bce: 68719.251 kld: 0.038:   6%|▌         | 22/391 [00:17<05:00,  1.23it/s] 


KeyboardInterrupt: 

In [30]:
generate_video("./", "generated_hologram.avi") 

['image_epoch_0.png', 'image_epoch_1.png', 'image_epoch_2.png', 'image_epoch_3.png', 'image_epoch_4.png', 'image_epoch_5.png', 'image_epoch_6.png', 'image_epoch_7.png', 'image_epoch_8.png', 'image_epoch_9.png', 'image_epoch_10.png', 'image_epoch_11.png', 'image_epoch_12.png', 'image_epoch_13.png', 'image_epoch_14.png', 'image_epoch_15.png', 'image_epoch_16.png', 'image_epoch_17.png', 'image_epoch_18.png', 'image_epoch_19.png', 'image_epoch_20.png', 'image_epoch_21.png', 'image_epoch_22.png', 'image_epoch_23.png', 'image_epoch_24.png', 'image_epoch_25.png', 'image_epoch_26.png', 'image_epoch_27.png', 'image_epoch_28.png', 'image_epoch_29.png', 'image_epoch_30.png', 'image_epoch_31.png', 'image_epoch_32.png', 'image_epoch_33.png', 'image_epoch_34.png', 'image_epoch_35.png', 'image_epoch_36.png', 'image_epoch_37.png', 'image_epoch_38.png', 'image_epoch_39.png', 'image_epoch_40.png', 'image_epoch_41.png', 'image_epoch_42.png', 'image_epoch_43.png', 'image_epoch_44.png', 'image_epoch_45.png