<a href="https://colab.research.google.com/github/ZahraFayyaz/3dshape-vqvae-pyTorch/blob/main/3dshape_vqvae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import modules

In [2]:
import sys, os, yaml

# connect to google drive
from google.colab import drive
drive.mount('/content/drive')

# Enter the foldername where all modules are stored together
FOLDERNAME = 'INI - Generative Episodic Memory/'
assert FOLDERNAME is not None, "[!] Enter the foldername."

# Change the working/current directory
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
%cd /content/drive/My\ Drive/$FOLDERNAME/

Mounted at /content/drive
/content/drive/My Drive/INI - Generative Episodic Memory


In [3]:
import torch
import torch.nn as nn
from data_loader import load_data
from vqvae import get_model
# import argparse
import random
# import shutil
import cv2
import torchvision
import numpy as np
# from tqdm import tqdm
# from vqvae import get_model
from torch.optim import Adam
from torchvision.utils import make_grid
from torch.optim.lr_scheduler import ReduceLROnPlateau

## Modify PyTorch Configuration

In [4]:
# Enable CuDNN benchmark mode for faster runtime optimizations
torch.backends.cudnn.benchmark = True

# Ensure deterministic behavior in CuDNN operations
torch.backends.cudnn.deterministic = True

# Optionally disable CuDNN's tensor core usage if encountering issues
torch.backends.cudnn.enabled = True

# Load Data

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_for_one_epoch(epoch_idx, model, data_loader, optimizer, crtierion, config):
    r"""
    Method to run the training for one epoch.
    :param epoch_idx: iteration number of current epoch
    :param model: VQVAE model
    :param data_loader: Data loder for mnist
    :param optimizer: optimzier to be used taken from config
    :param crtierion: For computing the loss
    :param config: configuration for the current run
    :return:
    """
    recon_losses = []
    codebook_losses = []
    commitment_losses = []
    losses = []
    # We ignore the label for VQVAE
    count = 0
    for im, _ in data_loader:
        im = im.float().to(device)
        optimizer.zero_grad()
        model_output = model(im)
        output = model_output['generated_image']
        quantize_losses = model_output['quantized_losses']
        # z_q = model_output['quantized_output']
        # indices = model_output['quantized_indices']

        # if config['train_params']['save_training_image']:
        #     cv2.imwrite('input.jpeg', (255 * (im.detach() + 1) / 2).cpu().permute(0, 1, 2, 3).numpy().astype(np.uint8)) #(255 * (im.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0]
        #     cv2.imwrite('output.jpeg', (255 * (output.detach() + 1) / 2).cpu().permute(0, 1, 2, 3).numpy().astype(np.uint8)) #(255 * (output.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0]

        recon_loss = crtierion(output, im)
        loss = (config['train_params']['reconstruction_loss_weight']*recon_loss +
                config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'] +
                config['train_params']['commitment_loss_weight']*quantize_losses['commitment_loss'])
        recon_losses.append(recon_loss.item())
        codebook_losses.append(config['train_params']['codebook_loss_weight']*quantize_losses['codebook_loss'].item())
        commitment_losses.append(quantize_losses['commitment_loss'].item())
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    print('Finished epoch: {} | Recon Loss : {:.4f} | Codebook Loss : {:.4f} | Commitment Loss : {:.4f}'.
          format(epoch_idx + 1,
                 np.mean(recon_losses),
                 np.mean(codebook_losses),
                 np.mean(commitment_losses)))
    return np.mean(losses)


def train(config_path, sample=None):
    ######## Read the config file #######
    with open(config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
    #         print(exc)
    # print(config)
    #######################################

    ######## Set the desired seed value #######
    seed = config['train_params']['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if device == 'cuda':
        torch.cuda.manual_seed_all(seed)
        # print(args.seed)
    #######################################

    # Create the model and dataset
    model = get_model(config).to(device)
    data_loader = load_data(config['train_params']['path'],sample_data=sample)
    num_epochs = config['train_params']['epochs']
    optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1)
    criterion = {
        'l1': torch.nn.L1Loss(),
        'l2': torch.nn.MSELoss()
    }.get(config['train_params']['crit'])

    # Create output directories
    if not os.path.exists(config['train_params']['task_name']):
        os.mkdir(config['train_params']['task_name'])
    if not os.path.exists(os.path.join(config['train_params']['task_name'],
                                       config['train_params']['output_train_dir'])):
        os.mkdir(os.path.join(config['train_params']['task_name'],
                              config['train_params']['output_train_dir']))

    # Load checkpoint if found
    if os.path.exists(os.path.join(config['train_params']['task_name'],
                                                        config['train_params']['ckpt_name'])):
        print('Loading checkpoint')
        model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                      config['train_params']['ckpt_name']), map_location=device))
    best_loss = np.inf

    for epoch_idx in range(num_epochs):
        mean_loss = train_for_one_epoch(epoch_idx, model, data_loader, optimizer, criterion, config)
        scheduler.step(mean_loss)
        # Simply update checkpoint if found better version
        if mean_loss < best_loss:
            print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
            torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
                                                        config['train_params']['ckpt_name']))
            best_loss = mean_loss
        else:
            print('No Loss Improvement')

    return model


trained_model = train(config_path='hyperparameters.yaml')
# kf = KFold(n_splits=5, shuffle=True, random_state=seed)
# for fold, (train_idx, val_idx) in enumerate(kf.split(data)):
#     print(f"Fold {fold + 1}")

#     # Create data samplers
#     train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
#     val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)

#     # Create data loaders
#     train_loader = load_data(config['train_params']['path'],sample_data=train_sampler)
#     val_loader = load_data(config['train_params']['path'],sample_data=val_sampler)

{'model_params': {'in_channels': 3, 'convbn_blocks': 4, 'conv_kernel_size': [3, 3, 3, 2], 'conv_kernel_strides': [2, 2, 1, 1], 'convbn_channels': [3, 6, 12, 24, 72], 'conv_activation_fn': 'leaky', 'transpose_bn_blocks': 4, 'transposebn_channels': [72, 24, 12, 6, 3], 'transpose_kernel_size': [3, 3, 3, 2], 'transpose_kernel_strides': [2, 2, 1, 1], 'transpose_activation_fn': 'leaky', 'latent_dim': 72, 'codebook_size': 10}, 'train_params': {'task_name': 'vqvae_latent_72_codebook_10_nnLayers_4', 'batch_size': 72, 'epochs': 10, 'lr': 0.005, 'crit': 'l2', 'reconstruction_loss_weight': 1, 'codebook_loss_weight': 1, 'commitment_loss_weight': 0.2, 'ckpt_name': 'best_vqvae_latent_72_codebook_10.pth', 'seed': 42, 'save_training_image': True, 'path': '/content/drive/MyDrive/Data/3dshapes.h5', 'output_train_dir': 'output'}}
{'model_params': {'in_channels': 3, 'convbn_blocks': 4, 'conv_kernel_size': [3, 3, 3, 2], 'conv_kernel_strides': [2, 2, 1, 1], 'convbn_channels': [3, 6, 12, 24, 72], 'conv_activa



In [2]:
trained_model

NameError: name 'trained_model' is not defined