# Train Variational Autoencoder

In [1]:
from argparse import Namespace

import h5py

from PIL import Image

import torchvision

import torch
import torch.utils
import torch.utils.data

# from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
# from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import ExponentialLR

from torchvision import transforms

import wandb

import pandas as pd
import numpy as np

In [2]:
from accelerate import Accelerator
from accelerate.utils import GradientAccumulationPlugin
from accelerate.utils import set_seed

In [3]:
from accelerate.utils import write_basic_config

write_basic_config()

Configuration already exists at /Users/pavankantharaju/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


False

### Create Config

In [22]:
# DEVICE = torch.device(
#     'cuda' if torch.cuda.is_available() \
#         else 'mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE = 'cpu'

HIDDEN_DIMS = 64

CONFIG = Namespace(
    project_name="3dshapes",
    run_name='3dshapes-run-1',
    model_name=f'3dshapes-{HIDDEN_DIMS}-model-v1',
    hidden_dims=HIDDEN_DIMS,
    horizontal_flip_prob=0.5,
    gaussian_blur_kernel_size=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=15,
    learning_rate=4e-4,
    seed=1,
    beta_schedule='squaredcos_cap_v2',
    lr_exp_schedule_gamma=0.85,
    lr_warmup_steps=500,
    train_limit=-1,
    save_model=True,
    mixed_precision=None,
    grad_accumulation_steps=4
    )
CONFIG.device = DEVICE

### Create Dataset

In [62]:
class ShapeDataset(torch.utils.data.Dataset):

    def __init__(self, h5_data: h5py.File) -> None:
        super().__init__()

        self.h5_data = h5_data
        self.image_shape = self.h5_data['images'][0].shape
        self.num_labels = self.h5_data['labels'][0].shape[0]
        self.normalize_transform = transforms.Normalize([0.5], [0.5]) # Map to (-1, 1)

    def __len__(self):
        return len(self.h5_data['images'])

    def __getitem__(self, index):
        img_array = self.h5_data['images'][index]
        labels = self.h5_data['labels'][index]

        img_tensor = torch.tensor(img_array, dtype=torch.float32)

        img_tensor = img_tensor.transpose(0, 2)
        img_tensor = img_tensor/255
        # print(f"Image tensor before: {img_tensor}")
        img_tensor = self.normalize_transform(img_tensor)
        # print(f"Image tensor after: {img_tensor}")
        # raise
        labels_tensor = torch.tensor(labels)

        # print(f"Image tensor shape: {img_tensor.shape}")
        # print(f"Labels tensor shape: {labels_tensor.shape}")

        output = {
            'image': img_tensor,
            'labels': labels_tensor
            }

        return output
    
def create_dataset():
    """
    Create dataset
    """

    data = h5py.File('3dshapes.h5', 'r')
    shape_dataset = ShapeDataset(data)
    return shape_dataset

def prepare_dataloader(config: Namespace):
    """
    Prepare dataloader
    """

    shape_dataset = create_dataset()

    generator = torch.Generator().manual_seed(config.seed)
    train_dataset, val_dataset = torch.utils.data.random_split(shape_dataset, [0.8, 0.2], generator)

    train_gen = torch.Generator().manual_seed(config.seed)
    val_gen = torch.Generator().manual_seed(config.seed)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.per_device_train_batch_size,
        shuffle=True, generator=train_gen)

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=config.per_device_eval_batch_size,
        shuffle=True, generator=val_gen)

    return train_dataloader, val_dataloader, shape_dataset

### Create Model

In [63]:
import torch.nn.functional as F

class ShapeModelEncoder(torch.nn.Module):

    def __init__(self, in_channels: int, dims: int) -> None:
        super().__init__()

        self.conv_1 = torch.nn.Conv2d(
            in_channels, dims, kernel_size=3, padding='same')
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2)

        self.conv_2 = torch.nn.Conv2d(
            dims, 2*dims, kernel_size=3, padding='same')
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2)

        self.conv_3 = torch.nn.Conv2d(
            2*dims, 2*dims, kernel_size=3, padding='same')
        self.max_pool_3 = torch.nn.MaxPool2d(kernel_size=2)

    def forward(self, x: torch.Tensor):
        """
        Forward pass
        """

        x_ = self.conv_1(x)
        # print(f"Output of conv 1: {x_.shape}")
        x_ = self.max_pool_1(x_)
        # print(f"Output of conv & max pool 1: {x_.shape}")

        x_ = self.conv_2(x_)
        # print(f"Output of conv 2: {x_.shape}")
        x_ = self.max_pool_2(x_)
        # print(f"Output of conv & max pool 2: {x_.shape}")

        x_ = self.conv_3(x_)
        # print(f"Output of conv 3: {x_.shape}")
        x_ = self.max_pool_3(x_)
        # print(f"Output of conv & max pool 3: {x_.shape}")

        return x_

class ShapeModelDecoder(torch.nn.Module):

    def __init__(self, num_factors: int, dims: int, in_channels: int) -> None:
        super().__init__()

        # NOTE: I still don't understand how transpose convolution works
        # self.upsample_3 = torch.nn.Upsample(scale_factor=2)
        self.deconv_1 = torch.nn.ConvTranspose2d(num_factors, 2*dims, 2, stride=2)

        # self.upsample_2 = torch.nn.Upsample(scale_factor=2)
        self.deconv_2 = torch.nn.ConvTranspose2d(2*dims, dims, 2, stride=2)

        # self.upsample_1 = torch.nn.Upsample(scale_factor=2)
        self.deconv_3 = torch.nn.ConvTranspose2d(dims, in_channels, 2, stride=2)

    def forward(self, x: torch.Tensor):
        """
        Forward pass
        """

        # x_ = self.upsample_3(x_)
        # print(f"Output of upsample 3: {x_.shape}")
        x_ = self.deconv_1(x)
        # print(f"Output of deconv 1: {x_.shape}")
  
        # x_ = self.upsample_2(x_)
        # print(f"Output of upsample 2: {x_.shape}")
        x_ = self.deconv_2(x_)
        # print(f"Output of  deconv 2: {x_.shape}")

        # x_ = self.upsample_1(x_)
        # print(f"Output of upsample 1: {x_.shape}")
        x_ = self.deconv_3(x_)
        # print(f"Output of deconv 3: {x_.shape}")

        return x_

class ShapeModel(torch.nn.Module):

    def __init__(self, in_channels: int, dims: int, num_factors: int):
        super().__init__()

        self.encoder = ShapeModelEncoder(in_channels, dims)

        self.mean = torch.nn.Linear(2*dims*8*8, num_factors)
        self.log_var = torch.nn.Linear(2*dims*8*8, num_factors)

        self.rng_state = torch.Generator()

        self.upsampler = torch.nn.Upsample(scale_factor=8)

        self.decoder = ShapeModelDecoder(num_factors, dims, in_channels)

    def reparameterize(self, mean: torch.Tensor, log_var: torch.Tensor):
        """
        """

        std_val = torch.exp(log_var/2)

        # Sample noise from gaussian
        cov_matrix = torch.eye(mean.shape[-1])
        noise = torch.normal(0.0, cov_matrix)

        return mean + torch.matmul(std_val, noise)

    def forward(self, x: torch.Tensor):
        """
        Forward pass
        """

        # print(f"Input shape: {x.shape}")

        encoder_output = F.relu(self.encoder(x))
        # print(f"Encoder output shape: {encoder_output.shape}")

        flattened_output = torch.flatten(encoder_output, start_dim=1)
        # print(f"Flattened output: {flattened_output.shape}")

        mean = F.relu(self.mean(flattened_output))
        log_var = F.relu(self.log_var(flattened_output))

        latent = self.reparameterize(mean, log_var)
        # print(f"Latent shape: {latent.shape}")

        decoder_input = latent.unsqueeze(-1).unsqueeze(-1)
        # print(f"Decoder input: {decoder_input.shape}")

        decoder_input = self.upsampler(decoder_input)
        # print(f"Decoder input: {decoder_input.shape}")

        decoder_output = F.relu(self.decoder(decoder_input))
        # print(f"Decoder output shape: {decoder_output.shape}")

        return decoder_output, mean, log_var

    # def sample(self):
    #     """
    #     Sample an image
    #     """

    #     decoder_output = self.decoder(encoder_output)
    #     return decoder_output

def create_model(in_dimensions: int, dims: int, num_factors: int):
    """
    Create model
    """

    model = ShapeModel(in_dimensions, dims, num_factors)
    return model

In [65]:
def compute_loss(input: torch.Tensor, output: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor):
    """
    Compute VAE Loss
    """

    # Sum over each subset & average over each batch
    mse_loss_fn = torch.nn.MSELoss(reduction='mean')

    mse_loss = mse_loss_fn(output, input)
    kl_loss = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    loss = mse_loss + 0.002 * kl_loss

    return loss, mse_loss, kl_loss

@torch.no_grad()
def eval_loop(epoch: int, model, dataloader,
              wandb_run, accelerator: Accelerator):
    """
    Evaluation loop
    """

    dataframe = []

    avg_total_loss = 0
    avg_mse_loss = 0
    avg_kl_loss = 0

    for _, batch in enumerate(dataloader):

        print(batch['image'])
        raise
        pred_image, mean, log_var = model(batch['image'])
        # labels = batch['labels']
        loss, mse_loss, kl_loss = compute_loss(batch['image'], pred_image, mean, log_var)

        avg_total_loss += loss.item()
        avg_mse_loss += mse_loss.item()
        avg_kl_loss += kl_loss.item()

        # TODO: Add FID Score

        images = []
        pred_images = []

        for j in range(batch['image'].shape[0]):
            images.append(batch['image'][j, :])
            pred_images.append(pred_image[j, :])
    
        batch_dataframe = pd.DataFrame()
        batch_dataframe['epoch'] = [epoch for _ in range(len(images))]
        batch_dataframe['image'] = \
            [wandb.Image(image) for image in images]
        batch_dataframe['pred_image'] = \
            [wandb.Image(image) for image in pred_images]

        dataframe.append(batch_dataframe)

    dataframe = pd.concat(dataframe, axis=0, ignore_index=True)

    avg_total_loss = avg_total_loss/len(dataloader)
    avg_mse_loss = avg_mse_loss/len(dataloader)
    avg_kl_loss = avg_kl_loss/len(dataloader)

    metrics_str = f"Val total loss: {avg_total_loss} - MSE loss: {avg_mse_loss} - KL loss: {avg_kl_loss}"
    if accelerator:
        accelerator.print(metrics_str)
    else:
        print(metrics_str)

    table = wandb.Table(data=dataframe)
    # wandb_run.log({'accuracy': acc}, commit=False)
    if wandb_run:
        wandb_run.log({'val-total-loss': avg_total_loss}, commit=False)
        wandb_run.log({'val-mse-loss': avg_mse_loss}, commit=False)
        wandb_run.log({'val-kl-loss': avg_kl_loss}, commit=False)

        wandb_run.log({'eval-table': table})

def training_loop(config: Namespace, debug_mode=False):
    """
    Training loop
    """

    wandb_run = None
    if not debug_mode:
        wandb_run = wandb.init(project=config.project_name, entity=None,
                            job_type='training',
                            name=config.run_name,
                            config=config)

    accelerator = None
    if not debug_mode:
        set_seed(config.seed)

        grad_accumulation_plugin = GradientAccumulationPlugin(
            num_steps=config.grad_accumulation_steps,
            adjust_scheduler=True,
            sync_with_dataloader=True)

        accelerator = Accelerator(
            mixed_precision=config.mixed_precision,
            gradient_accumulation_plugin=grad_accumulation_plugin,
            cpu=(config.device == 'cpu'))

    train_dataloader, val_dataloader, dataset = prepare_dataloader(config)
    model = create_model(dataset.image_shape[-1], config.hidden_dims, dataset.num_labels)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

#     scheduler = CosineAnnealingLR(
#         optimizer,
#         T_max=config.num_train_epochs)
    scheduler = ExponentialLR(
        optimizer,
        config.lr_exp_schedule_gamma)

#     scheduler = CosineAnnealingWarmRestarts(
#         optimizer,
#         T_0=config.lr_warmup_steps)
        # last_epoch=config.num_train_epochs*len(train_dataloader))

    if accelerator:
        model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, val_dataloader, scheduler)

    num_steps = 0
    for epoch in range(config.num_train_epochs):
        model.train()

        epoch_str = f"---------------- Epoch {epoch} ----------------"
        if accelerator:
            accelerator.print(epoch_str)
        else:
            print(epoch_str)

        epoch_total_loss = 0
        epoch_mse_loss = 0
        epoch_kl_loss = 0

        num_iters = 0

        for _, batch in enumerate(train_dataloader):
            # with accelerator.accumulate(model):
    
            optimizer.zero_grad()
            pred_image, mean, log_var = model(batch['image'])
            # labels = batch['labels']
            # print(mean, log_var)

            loss, mse_loss, kl_loss = compute_loss(batch['image'], pred_image, mean, log_var)

            # accelerator.print(f"Loss: {loss.item()}")
            if accelerator:
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            epoch_total_loss += loss.item()
            epoch_mse_loss += mse_loss.item()
            epoch_kl_loss += kl_loss.item()

            if wandb_run:
                wandb_run.log({'loss': loss.item()}, commit=False, step=num_steps)
                wandb_run.log({'mse-loss': mse_loss.item()}, commit=False, step=num_steps)
                wandb_run.log({'kl-loss': kl_loss.item()}, commit=False, step=num_steps)

                wandb_run.log({'lr': scheduler.get_lr()[0]}, commit=False, step=num_steps)
            else:
                print(f"Batch: {loss.item()}, {mse_loss.item()}, {kl_loss.item()}, {scheduler.get_lr()[0]}")

            num_steps += 1
            num_iters += 1

            # Update the model parameters with the optimizer
            optimizer.step()
        scheduler.step()

        # Validate model
        eval_model_str = "Evaluating model"
        if accelerator:
            accelerator.print(eval_model_str)
        else:
            print(eval_model_str)

        eval_loop(epoch, model, val_dataloader, wandb_run, accelerator)

        if wandb_run:
            wandb_run.log({'epoch-total-loss': epoch_total_loss/num_iters})
            wandb_run.log({'epoch-mse-loss': epoch_mse_loss/num_iters})
            wandb_run.log({'epoch-kl-loss': epoch_kl_loss/num_iters})
        else:
            print(f"Epoch: {epoch_total_loss}, {epoch_mse_loss}, {epoch_kl_loss}")

    if config.save_model:
        # Save model to W&Bs
        model_art = wandb.Artifact(config.model_name, type='model')
        torch.save(model.state_dict(), 'model.pt')

        model_art.add_file('model.pt')
        wandb_run.log_artifact(model_art)
    wandb_run.finish()

### Train model

In [66]:
training_loop(CONFIG, debug_mode=True)

---------------- Epoch 0 ----------------
Batch: 0.6280412077903748, 0.6276530623435974, 0.194059818983078, 0.0004
Batch: 0.6822159290313721, 0.6822144985198975, 0.000708162784576416, 0.0004
Batch: 0.6380550861358643, 0.6380550861358643, -0.0, 0.0004
Batch: 0.6321941614151001, 0.6321941614151001, -0.0, 0.0004
Batch: 0.6048686504364014, 0.6048686504364014, -0.0, 0.0004
Batch: 0.6682799458503723, 0.6682799458503723, -0.0, 0.0004
Batch: 0.6395112872123718, 0.6395108103752136, 0.0002484321594238281, 0.0004
Batch: 0.6079604029655457, 0.6079604029655457, -0.0, 0.0004
Batch: 0.6508994698524475, 0.6508994698524475, -0.0, 0.0004
Batch: 0.6501573920249939, 0.6501573920249939, -0.0, 0.0004
Batch: 0.6152146458625793, 0.6152146458625793, -0.0, 0.0004


KeyboardInterrupt: 

In [None]:
# from accelerate import notebook_launcher

# notebook_launcher(training_loop, (CONFIG, ), num_processes=1)

In [None]:
# # pil_to_tensor = torchvision.transforms.PILToTensor()
# # tensor_to_pil = torchvision.transforms.ToPILImage()
# img_tensor = pil_to_tensor(img).type(torch.float32)
# img_tensor = img_tensor.unsqueeze(0)
# output_img = None
# with torch.no_grad():
#     output = model(img_tensor)
#     output_img = output.squeeze(0).transpose(0, 2)
# tensor_to_pil(output_img.type(torch.uint8).numpy())