# VAE trainer

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp trainers.vae_trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import pandas as pd

In [None]:
#| export
from MAWM.trainers.trainer import Trainer
from MAWM.models.utils import save_checkpoint

class VAETrainer(Trainer):
    def __init__(self, cfg, model, train_loader, val_loader=None, 
                 criterion=None, optimizer=None, device=None,
                 earlystopping=None, scheduler=None, writer= None):
        
        super().__init__(cfg, model, train_loader, val_loader, criterion, optimizer, device)
        self.earlystopping = earlystopping
        self.scheduler = scheduler
        self.writer = writer

        self.vae_dir = os.path.join(self.cfg.log_dir, 'vae_marlrid')
        if not os.path.exists(self.vae_dir):
            os.mkdir(self.vae_dir)
            os.mkdir(os.path.join(self.vae_dir, 'samples'))

    

In [None]:
#| export
@patch
def reload(self: VAETrainer):

    reload_file = os.path.join(self.vae_dir, 'best.pth')
    if not self.cfg.noreload and os.path.exists(reload_file):
        state = torch.load(reload_file)
        print("Reloading model at epoch {}"
            ", with test error {}".format(
                state['epoch'],
                state['precision']))
        
        self.model.load_state_dict(state['state_dict'])
        self.optimizer.load_state_dict(state['optimizer'])
        self.scheduler.load_state_dict(state['scheduler'])
        self.earlystopping.load_state_dict(state['earlystopping'])



In [None]:
#| export
@patch
def train_epoch(self: VAETrainer, epoch):
    self.model.train()
    self.train_loader.dataset.load_next_buffer()
    train_loss = 0
    
    for batch_idx, data in enumerate(self.train_loader):
        # import pdb; pdb.set_trace()
        obs, dones, agent_id = data
        mask = ~dones.bool()     # keep only where done is False

        if mask.sum() == 0:
            continue  # entire batch is terminals

        obs = obs[mask]          # filter observations
        obs = obs.to(self.device)
        self.optimizer.zero_grad()
        recon_batch, mu, logvar = self.model(obs)
        loss = self.criterion(recon_batch, obs, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        self.optimizer.step()
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(obs), len(self.train_loader.dataset),
                100. * batch_idx / len(self.train_loader),
                loss.item() / len(obs)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(self.train_loader.dataset)))

    return train_loss / len(self.train_loader.dataset)
       

In [None]:
#| export
@patch
def eval_epoch(self: VAETrainer):
    self.model.eval()
    self.val_loader.dataset.load_next_buffer()
    test_loss = 0
    with torch.no_grad():
        for data in self.val_loader:
            obs, dones, agent_id = data
            mask = ~dones.bool()     # keep only where done is False

            if mask.sum() == 0:
                continue  # entire batch is terminals

            obs = obs[mask]          # filter observations
            obs = obs.to(self.device)
            recon_batch, mu, logvar = self.model(obs)
            test_loss += self.criterion(recon_batch, obs, mu, logvar).item()

    test_loss /= len(self.val_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss


In [None]:
#| export
import wandb


@patch
def fit(self: VAETrainer):
    cur_best = None
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        train_loss = self.train_epoch(epoch)
        test_loss = self.eval_epoch()
        self.scheduler.step(test_loss)

        # checkpointing
        best_filename = os.path.join(self.vae_dir, 'best.pth')
        filename = os.path.join(self.vae_dir, 'checkpoint.pth')

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss

        save_checkpoint({
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'precision': test_loss,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            # 'earlystopping': self.earlystopping.state_dict()
        }, is_best, filename, best_filename)

        if self.earlystopping.early_stop(test_loss):             
            break
       

        to_log = {
            "train_loss": train_loss, 
            "test_loss": test_loss,
        }

        self.writer.write(to_log)
        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss, "test_loss":test_loss}], index= "epoch")
        lst_dfs.append(df)

    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    self.writer.write({'Train-Val Loss Table': wandb.Table(dataframe= df_reset)})

    self.writer.finish()
    return df_reset

In [None]:
#| hide
import pandas as pd
df = pd.DataFrame.from_records([{"epoch": 2 ,"train_loss": 3, "test_loss":3}], index= "epoch")
df2 = pd.DataFrame.from_records([{"epoch": 3 ,"train_loss": 4, "test_loss":4}], index= "epoch")
pd.concat([df, df2])

Unnamed: 0_level_0,train_loss,test_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1
2,3,3
3,4,4


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()