In [1]:
%load_ext autoreload
%autoreload 2

# Main VAE

> Main file to train VAE model

In [2]:
from fastcore import *
from fastcore.utils import *
import torch

In [3]:

import argparse
from os.path import join, exists
from os import mkdir

import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from omegaconf import OmegaConf
from dotenv import load_dotenv


In [4]:

from mawm.core import get_cls

from mawm.data.utils import transform_train, transform_test
from mawm.data.loaders import RolloutObservationDataset

from mawm.optimizer.utils import ReduceLROnPlateau, EarlyStopping
from mawm.trainers.vae_trainer import VAETrainer
from mawm.writers.wandb_writer import WandbWriter

In [5]:
cfg = OmegaConf.load(join("../cfgs", "vae", "cfg.yaml"))

In [6]:
cfg

{'project_name': 'vae_meltingpot', 'epochs': 1000, 'loss_fn': 'CrossEntropyLoss', 'writer': 'WandbWriter', 'noreload': False, 'root_dir': '/scratch/project_2009050/', 'save_dir': 'models', 'res_dir': 'results', 'log_dir': '../logs', 'state_dir': 'aggregated_model_', 'data': {'data_dir': '/scratch/project_2009050/datasets/meltingpot_data', 'batch_size': 32, 'name': 'meltingpot'}, 'model': {'name': 'VAE', 'channels': 3, 'img_size': 40, 'latent_size': 512, 'grad_norm_clip': 1.0}, 'optimizer': {'name': 'Adam', 'lr': 0.001}}

In [7]:
cfg.epochs = 2
cfg.data.data_dir = "../meltingpot_data/"
cfg.data.batch_size = 1
cfg.model.latent_size = 32

In [None]:

parser = argparse.ArgumentParser(description='VAE Training')
parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
parser.add_argument('--timestamp', type=str, help='Time stamp', required=True)
parser.add_argument('--env_file', type=str, help='Path to the .env file', required=False)


parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
                    help='number of epochs to train (default: 1000)')
parser.add_argument('--log_dir', type=str, help='Directory where results are logged')
parser.add_argument('--noreload', action='store_true',
                    help='Best model is not reloaded if specified')
parser.add_argument('--nosamples', action='store_true',
                    help='Does not save samples during training if specified')


args = parser.parse_args()



In [None]:

if args.env_file:
    load_dotenv(args.env_file)
    key = os.getenv("WANDB_API_KEY", None)
    hf_secret = os.getenv("HF_SECRET_CODE", None)

    if key:
        os.environ["WANDB_API_KEY"] = key
    if hf_secret:
        os.environ["HF_SECRET_CODE"] = hf_secret     

try:
    cfg = OmegaConf.load(args.config)
except:
    print("Invalid config file path")

In [9]:

load_dotenv()
key = os.getenv("WANDB_API_KEY", None)
hf_secret = os.getenv("HF_SECRET_CODE", None)

if key:
    os.environ["WANDB_API_KEY"] = key
if hf_secret:
    os.environ["HF_SECRET_CODE"] = hf_secret     



In [None]:

cfg.now = args.timestamp 

cfg.optimizer.lr = float(args.lr) if args.lr else cfg.optimizer.lr
cfg.data.batch_size = int(args.batch_size) if args.batch_size else cfg.data.batch_size
cfg.optimizer.name = args.optimizer if args.optimizer else cfg.optimizer.name


In [10]:

cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")


In [11]:
cfg.data.data_dir

'../meltingpot_data/'

In [12]:

dataset_train = RolloutObservationDataset(cfg.data.data_dir,
                                          transform_train, 
                                          train=True)

dataset_test = RolloutObservationDataset(cfg.data.data_dir,
                                         transform_test,
                                         train=False)

train_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=cfg.data.batch_size, shuffle=True, num_workers=2)

val_loader = torch.utils.data.DataLoader(
    dataset_test, batch_size=cfg.data.batch_size, shuffle=True, num_workers=2)

model_cls = get_cls(f"MAWM.models.{cfg.model.name.lower()}", cfg.model.name)
model = model_cls(cfg.model.channels, cfg.model.latent_size).to(device)


Loading file buffer ...: 100%|██████████| 200/200 
Loading file buffer ...: 100%|██████████| 200/200 


In [None]:
#| hide
from mawm.core import get_cls
cls = get_cls(f"MAWM.models.{"VAE".lower()}", "VAE")
model = cls(3, 32)
model

VAE(
  (encoder): Encoder(
    (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
    (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (conv4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (fc_mu): Linear(in_features=1024, out_features=128, bias=True)
    (fc_logsigma): Linear(in_features=1024, out_features=128, bias=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=128, out_features=1024, bias=True)
    (deconv1): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2))
    (deconv2): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
    (deconv3): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2))
    (deconv4): ConvTranspose2d(32, 3, kernel_size=(6, 6), stride=(2, 2))
  )
)

In [11]:
cfg.model.latent_size

32

In [13]:
# export
optimizer_cls = get_cls("torch.optim", cfg.optimizer.name)
optimizer = optimizer_cls(model.parameters(), lr=cfg.optimizer.lr)

scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)


In [14]:

def criterion(recon_x, x, mu, logsigma):
    """ VAE loss function """
    BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
    return BCE + KLD

In [15]:
cfg.log_dir = "../logs"

In [17]:
cfg.model

{'name': 'VAE', 'channels': 3, 'img_size': 40, 'latent_size': 32, 'grad_norm_clip': 1.0}

In [16]:
import time
now = time.strftime("%Y%m%d-%H%M%S")
cfg.now = now
writer = WandbWriter(cfg)
trainer = VAETrainer(cfg, model, train_loader, val_loader, criterion, 
                     optimizer, device, dataset_train, dataset_test,
                     earlystopping, scheduler, writer)

df_res = trainer.fit()


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ahmed/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Loading file buffer ...: 100%|██████████| 200/200 




KeyboardInterrupt: 