# Baseline

## Load datasets

In [1]:
import torch
from torch import nn
from torch.utils import data
import torchvision as vis
import sys

# torch.manual_seed(117850791)
is_windows = sys.platform == "win32"
has_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if has_cuda else "cpu")
device

device(type='cuda', index=0)

In [2]:
from PIL import Image
import os

class FlatImageData(vis.datasets.VisionDataset):
  def __init__(self, root, transform, validation_reserved_images=31136):
    self.root = root
    self.images = os.listdir(root)
    self.images.sort(key=lambda x: int(x[6:-5]))# sort by frame no.
    self.transform = transform
    self.training_mode = True
    self.reserved_images = validation_reserved_images
        
  def __len__(self):
    if self.training_mode:
      return len(self.images) - self.reserved_images
    else:
      return self.reserved_images
    
  def pil_loader(self, path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

  def __getitem__(self, index):
    if self.training_mode:
      index += self.reserved_images
    
    image_name = self.images[index]
    image_path = f"{self.root}/{image_name}"
    image = self.pil_loader(image_path)
    if self.transform is not None:
         image = self.transform(image)

    return image
    
class ImageWindow(data.Dataset):
  def __init__(self, dataset, wide_window_size):
    self.dataset = dataset
    self.wide_window_size = wide_window_size
    self.window_offset = 0
  
  def shuffle(self):
    self.window_offset = torch.randint(0, self.wide_window_size, (1,)).item()
  
  def __len__(self):
    return (len(self.dataset) - self.window_offset) // self.wide_window_size
  
  def __getitem__(self, index):
    image_index_start = self.window_offset + index * self.wide_window_size
    images = [self.dataset[image_idx] for image_idx in
              range(image_index_start, image_index_start + self.wide_window_size, 1)]
    return torch.stack(images, dim=0)

In [3]:
dataset = FlatImageData(root="/home/ubuntu/data/knnw-256p",
                             transform=vis.transforms.Compose([
                               vis.transforms.RandomHorizontalFlip(),
                               vis.transforms.RandomApply(nn.ModuleList([
                                 vis.transforms.RandomAffine(degrees=15),
                                 vis.transforms.CenterCrop((200, 200))
                               ]), p=0.5),
                               vis.transforms.ToTensor(),
                               nn.AdaptiveAvgPool2d((64, 64))
                             ])
                            )
# dataset.training_mode=False
dataset

Dataset FlatImageData
    Number of datapoints: 160745
    Root location: /home/ubuntu/data/knnw-256p

In [4]:
window = ImageWindow(dataset, 3)

## Train Model

In [5]:
import os
model_store = "model_checkpoints"

class StoredModel:
  def __init__(self, model, optimizer, scheduler, criterion):
    self.model = model
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.criterion = criterion

In [6]:
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

class BetaVAE(nn.Module):

    num_iter = 0 # Global static variable to keep track of iterations

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 beta: int = 4,
                 gamma:float = 1000.,
                 max_capacity: int = 25,
                 Capacity_max_iter: int = 1e5,
                 loss_type:str = 'B',
                 **kwargs) -> None:
        super(BetaVAE, self).__init__()

        self.latent_dim = latent_dim
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim
            
        modules.append(nn.Flatten())

        self.encoder = nn.Sequential(*modules)
        
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1))

    def encode(self, inputs: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x Window x C x H x W]
        :return: (Tensor) List of latent codes
        """
        
        batch_size, window_size, C, H, W = inputs.shape
        
        result = self.encoder(inputs.view(-1, C, H, W))
        
        # use average
        conmbined_features = result.view(batch_size, window_size, -1).mean(dim=1)
        
#         # sliding window
#         for batch_i in range(batch_size):
#           for window_i in range(wide_window_size - self.window_size + 1):
#             concated_features.append(result[batch_i, window_i:window_i+self.window_size, :].flatten())
            
#         concated_features = torch.stack(concated_features)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(conmbined_features)
        log_var = self.fc_var(conmbined_features)

        return (inputs, mu, log_var)

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough to compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, inputs: Tensor, **kwargs) -> Tensor:
        inputs, mu, log_var = self.encode(inputs)
        z = self.reparameterize(mu, log_var)
        
        self.current_inputs = inputs
        self.current_mu = mu
        self.current_log_var = log_var
        self.current_recon = self.decode(z)
        
        return self.current_recon

    def loss(self, *args, **kwargs) -> dict:
        self.num_iter += 1
        recons = self.current_recon
        input = self.current_inputs
        mu = self.current_mu
        log_var = self.current_log_var
        kld_weight = kwargs['kld_weight']  # Account for the minibatch samples from the dataset
        
        batch_size = recons.shape[0]
        window_size = input.shape[1]
        recons_loss =F.binary_cross_entropy_with_logits(recons.unsqueeze(dim=1).expand_as(input),
                                                        input, reduction='sum') / batch_size / window_size
  
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
            loss = recons_loss + self.beta * kld_weight * kld_loss
        elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
            self.C_max = self.C_max.to(input.device)
            C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
            loss = recons_loss + self.gamma * kld_weight * (kld_loss - C).abs()
        else:
            raise ValueError('Undefined loss type.')

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [7]:
def load_model(model_id, specific_epoch = None):
  global optimizer, scheduler
  epoch_start = -1
  for checkpoint in os.listdir(f"{model_store}/{model_id}"):
    if not checkpoint.startswith("epoch"):
      continue
    epoch = int(checkpoint.split("_")[1])
    if specific_epoch is None:
      # find the latest
      if epoch > epoch_start:
        epoch_start = epoch
        last_checkpoint = checkpoint
    else:
      if epoch == specific_epoch:
        epoch_start = epoch
        last_checkpoint = checkpoint
        break

  if epoch_start == -1:
    print(f"No checkpoints available for {model_id}!")
    return -1, None
  else:
    epoch_start += 1
    print(f"resuming from last checkpoint {last_checkpoint}")
    data = torch.load(f"{model_store}/{model_id}/{last_checkpoint}")
    
    model = data.model
    optimizer = data.optimizer
    scheduler = data.scheduler
    criterion = data.criterion
    
    model.to(device)
    return epoch_start, model, criterion

### Resume from checkpoint or a new model?

#### train a new model

In [8]:
from torchsummary import summary_string
model_id = "many_to_one_mean_B_loss_64x64_10_latent"

model = BetaVAE(3, 128, loss_type="B")

epoch_start = 0
model.to(device)
print(model)

model_spec = summary_string(model, (3, 3, 64, 64))[0]
print(model_spec)

BetaVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4): Seque

In [9]:
os.mkdir(f"{model_store}/{model_id}")
# save model summary to a txt file
with open(f"{model_store}/{model_id}/model_spec.txt", "w") as file:
  file.write(str(model) + "\n")
  file.write(model_spec)

#### load a trained model from checkpoint 

In [26]:
model_id = "many_to_one_mean_H_loss_128x128_128_latent_beta_2.5"
epoch_start, model, criterion = load_model(model_id)
print(model)

resuming from last checkpoint epoch_23_tr-loss_25402.493092
BetaVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

### Start training

In [None]:
# clear GPU cache
if has_cuda:
  torch.cuda.empty_cache()

In [10]:
train_dataloader_args = dict(batch_size=64,
                             num_workers=0 if is_windows else 4) if has_cuda else dict(batch_size=64)
train_dataloader_args["shuffle"] = True

train_dataloader = data.DataLoader(window, **train_dataloader_args)

In [11]:
from torch import optim
from itertools import chain

num_epochs = 40

if epoch_start == 0:
  # define only at the start of the training
  
  regularization = 2e-5
#   learning_rate = 1e-1
#   optimizer = optim.SGD(chain(model.parameters(), criterion.parameters()),
#                          lr = learning_rate, momentum=0.9, weight_decay=regularization, nesterov=True)
  learning_rate = 1e-3
  optimizer = optim.Adam(model.parameters(),
                         lr = learning_rate, weight_decay=regularization)
#   scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 20, gamma = 0.5)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2,
                                                   threshold=0.001)
  

scaler = torch.cuda.amp.GradScaler() # mix-precision training

with open(f"{model_store}/{model_id}/training_params.txt", "w") as file:
  file.write(f"num_epochs = {num_epochs}\n")
  file.write(f"optimizer = {optimizer}\n")
  file.write(f"scheduler = {type(scheduler).__name__}({scheduler.state_dict()})\n")

In [15]:
for param_group in optimizer.param_groups:
  param_group['lr'] = 1e-3
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2,
                                                threshold=0.001)

In [12]:
from tqdm import tqdm
import sys
import json

print(f"Model: {model_id}. Training for {num_epochs} epochs", file=sys.stderr)

for epoch in range(epoch_start, num_epochs):
  print(f"Epoch {epoch}", file=sys.stderr)
  
  # set model in training mode
  model.train()
  training_loss = 0.0
  reconstruction_loss = 0.0
  kld_loss = 0.0

  for x in tqdm(train_dataloader, desc="Train"):
    optimizer.zero_grad() # clear calculated gradients

    x = x.to(device)
    
    with torch.cuda.amp.autocast():
      output = model(x)
      all_loss = model.loss(kld_weight=1.0)
      loss = all_loss["loss"]
    
    # backpropo loss and accumuate loss stat
    scaler.scale(loss).backward()    
    
    training_loss += loss.detach().item() # otherwise this would be a tensor
    reconstruction_loss += all_loss['Reconstruction_Loss'].detach().item()
    kld_loss += all_loss['KLD'].detach().item()

    scaler.step(optimizer)
    scaler.update()
    
  # let scheduler know it's the next epoch
  training_loss /= len(train_dataloader)
  reconstruction_loss /= len(train_dataloader)
  kld_loss /= len(train_dataloader)
  
  scheduler.step(training_loss)
  
  log_str = json.dumps({
    "Epoch": epoch,
    "training loss": round(training_loss, 6),
    "reconstruction loss": round(reconstruction_loss, 6),
    "KLD loss": round(kld_loss, 6),
    "Learning rate": scheduler._last_lr
  })

  with open(f"{model_store}/{model_id}/training_logs.txt", "a") as log_file:
    log_file.write(log_str + "\n")
  print(log_str, file=sys.stderr)
  
  torch.save(StoredModel(model, optimizer, scheduler, None),
             f"{model_store}/{model_id}/epoch_{epoch:02d}" +\
             f"_tr-loss_{training_loss:.6f}")

Model: many_to_one_mean_B_loss_64x64_10_latent. Training for 40 epochs
Epoch 0
Train: 100%|██████████| 838/838 [02:14<00:00,  6.25it/s]
{"Epoch": 0, "training loss": 8747.507996, "reconstruction loss": 8210.095212, "KLD loss": 0.641461, "Learning rate": [0.001]}
Epoch 1
Train: 100%|██████████| 838/838 [02:14<00:00,  6.23it/s]
{"Epoch": 1, "training loss": 8087.458632, "reconstruction loss": 7995.213994, "KLD loss": 0.39451, "Learning rate": [0.001]}
Epoch 2
Train: 100%|██████████| 838/838 [02:14<00:00,  6.23it/s]
{"Epoch": 2, "training loss": 7923.292402, "reconstruction loss": 7856.261709, "KLD loss": 0.567289, "Learning rate": [0.001]}
Epoch 3
Train: 100%|██████████| 838/838 [02:14<00:00,  6.24it/s]
{"Epoch": 3, "training loss": 7791.024082, "reconstruction loss": 7731.145271, "KLD loss": 0.762325, "Learning rate": [0.001]}
Epoch 4
Train: 100%|██████████| 838/838 [02:14<00:00,  6.24it/s]
{"Epoch": 4, "training loss": 7578.101603, "reconstruction loss": 7532.429114, "KLD loss": 0.9652

In [None]:
from tqdm import tqdm

validataion_dataloader_args = dict(batch_size=128,
                             num_workers=0 if is_windows else 4) if has_cuda else dict(batch_size=64)
validataion_dataloader_args["shuffle"] = False

validataion_dataloader = data.DataLoader(dataset, **validataion_dataloader_args)

# set model in training mode
model.eval()

latent_mu = list()
latent_log_var = list()

for i, x in enumerate(tqdm(validataion_dataloader, desc="Validate")):
  x = x.to(device)

  _, mus, log_vars = model.encode(x)
  latent_mu.append(mus.detach().cpu())
  latent_log_var.append(log_vars.detach().cpu())

In [None]:
torch.save((torch.vstack(latent_mu), torch.vstack(latent_log_var)), f"latent_vectors/{model_id}")

In [None]:
L2_divergence_raw = list()

image_1 = dataset[0].to(device)

for i in tqdm(range(len(dataset) - 1), desc="L2"):
  image_2 = dataset[i + 1].to(device)
  
  diff = (image_1 - image_2).flatten()
  
  L2_divergence_raw.append(torch.linalg.norm(diff, 2).cpu().item())
  
  image_1 = image_2

In [None]:
torch.save(torch.tensor(L2_divergence_raw), f"temp_store/{model_id}/l2_divergence_raw")

In [None]:
normalize = lambda X, mn, mx: [(x - mn)/(mx - mn) for x in X]