# VicReg Jepa trainer

> This module implements VicReg Jepa training procedure.

In [None]:
#| default_exp trainers.trainer_vic_reg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import pandas as pd

## VicReg Jepa Trainer

In [None]:
#| export
from mawm.trainers.trainer import Trainer
from mawm.core import *
from mawm.models.utils import save_checkpoint
from mawm.losses.sigreg import SIGReg
from mawm.models.program.creator import create_specs_from_image, batchify_programs

class VicRegJepaTrainer(Trainer):
    def __init__(self, cfg, v_encoder, p_encoder, train_loader, val_loader=None, 
                 criterion=None, optimizer=None,
                 device=None, scheduler=None, writer= None):
        
        self.cfg = cfg
        self.v_encoder = v_encoder
        self.p_encoder = p_encoder
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler
        self.writer = writer
        self.sigreg = SIGReg().to(self.device)
        self.lambda_ = self.cfg.lambda_

        self.prog_lejepa_dir = os.path.join(self.cfg.log_dir, 'prog_lejepa_marlrid')
        if not os.path.exists(self.prog_lejepa_dir):
            os.mkdir(self.prog_lejepa_dir)

    

In [None]:
#| export
@patch
def train_epoch(self: VicRegJepaTrainer, epoch):
    self.v_encoder.train()
    self.p_encoder.train()

    train_loss = 0
    actual_len = 0

    def denormalize(tensor):
        return tensor * 0.5 + 0.5
    
    
    while True:
        try:
            self.train_loader.dataset.load_next_buffer()
        except:
            break
    
        for batch_idx, data in enumerate(self.train_loader):

            obs, dones, agent_id = data
            mask = ~dones.bool()     # keep only where done is False

            if mask.sum() == 0:
                continue  # entire batch is terminals

            obs = obs[mask]          # filter observations

            programs = [create_specs_from_image(denormalize(img).permute(1, 2, 0).numpy()) for img in obs]
            batch_prim_ids, batch_param_tensor = batchify_programs(programs)

            batch_prim_ids = batch_prim_ids.to(self.device)
            batch_param_tensor = batch_param_tensor.to(self.device)
            obs = obs.to(self.device)

            self.optimizer.zero_grad()

            img_proj = self.v_encoder(obs)
            prog_proj = self.p_encoder(batch_prim_ids, batch_param_tensor)

            sigreg_loss = self.sigreg(img_proj) + self.sigreg(prog_proj)
            inv_loss = (img_proj.mean(0) - prog_proj).square().mean()

            loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss
            train_loss += loss.item()
            
            loss.backward()
            self.optimizer.step()

            actual_len += len(obs)
            
            if batch_idx % 20 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(obs), len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader),
                    loss.item() / len(obs)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / actual_len))

    return train_loss / actual_len
       

In [None]:
#| export
@patch
def eval_epoch(self: VicRegJepaTrainer):
    self.v_encoder.eval()
    self.p_encoder.eval()

    def denormalize(tensor):
        return tensor * 0.5 + 0.5

    test_loss = 0
    actual_len = 0
    while True:
        try:
            self.val_loader.dataset.load_next_buffer()
        except:
            break

        with torch.no_grad():
            for data in self.val_loader:
                obs, dones, agent_id = data
                mask = ~dones.bool()     # keep only where done is False

                if mask.sum() == 0:
                    continue  # entire batch is terminals

                obs = obs[mask]          # filter observations
                programs = [create_specs_from_image(denormalize(img).permute(1, 2, 0).numpy()) for img in obs]
                batch_prim_ids, batch_param_tensor = batchify_programs(programs)

                batch_prim_ids = batch_prim_ids.to(self.device)
                batch_param_tensor = batch_param_tensor.to(self.device)
                obs = obs.to(self.device)

                self.optimizer.zero_grad()

                img_proj = self.v_encoder(obs)
                prog_proj = self.p_encoder(batch_prim_ids, batch_param_tensor)

                sigreg_loss = self.sigreg(img_proj) + self.sigreg(prog_proj)
                inv_loss = (img_proj.mean(0) - prog_proj).square().mean()
                
                loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss
                test_loss += loss.item()
                actual_len += obs.size(0)
            
    test_loss /= actual_len
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss


In [None]:
#| export
import wandb

@patch
def fit(self: VicRegJepaTrainer):
    cur_best = None
    lst_dfs = []
    
    for epoch in range(1, self.cfg.epochs + 1):
        train_loss = self.train_epoch(epoch)
        test_loss = self.eval_epoch()

        best_filename = os.path.join(self.prog_lejepa_dir, 'best.pth')
        filename = os.path.join(self.prog_lejepa_dir, 'checkpoint.pth')

        is_best = not cur_best or test_loss < cur_best
        if is_best:
            cur_best = test_loss

        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'test_loss': test_loss,
            'optimizer': self.optimizer.state_dict(),
        }
        save_checkpoint(state= state, is_best= is_best, filename= filename, best_filename= best_filename)

        to_log = {
            "train_loss": train_loss, 
            "test_loss": test_loss,
        }

        self.writer.write(to_log)
        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss, "test_loss":test_loss}], index= "epoch")
        lst_dfs.append(df)

        self.train_loader.dataset.reset_buffer()
        self.val_loader.dataset.reset_buffer()

    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    self.writer.write({'Train-Val Loss Table': wandb.Table(dataframe= df_reset)})

    self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()