# Dynamics trainer

> This module implements LeJepa training procedure.

In [None]:
#| default_exp trainers.trainer_dynamics

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import pandas as pd

## Dynamics Trainer

In [None]:
#| export
from mawm.trainers.trainer import Trainer
from mawm.models.utils import save_checkpoint
from mawm.losses.sigreg import SIGReg
from mawm.losses.vicreg import VICReg

class DynamicsTrainer(Trainer):
    def __init__(self, cfg, model, train_loader, val_loader=None,
                 optimizer=None, device=None,earlystopping=None, 
                 scheduler=None, writer= None):
        
        self.cfg = cfg

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.model = model['jepa']
        self.msg_encoder = model['msg_enc']
        self.projector = model['projector']

        self.optimizer = optimizer
        self.earlystopping = earlystopping
        self.scheduler = scheduler

        self.writer = writer

        self.sigreg = SIGReg().to(self.device)
        self.vicreg = VICReg().to(self.device) # TODO: ensure .cude comment for projector is returned back
        self.lambda_ = self.cfg.loss.lambda_

        self.device = device if device else torch.device('cpu')
        self.agents = [f"agent_{i}" for i in range(len(self.cfg.env.agents))]

        self.dmpc_dir = os.path.join(self.cfg.log_dir, 'dmpc_marlrid')
        if not os.path.exists(self.dmpc_dir):
            os.mkdir(self.dmpc_dir)

    

In [None]:
#| export
from mawm.models.utils import flatten_conv_output
@patch
def train_epoch(self: DynamicsTrainer, epoch):
    self.model.train()
    
    train_loss = 0
    actual_len = 0

    for batch_idx, data in enumerate(self.train_loader):
        self.optimizer.zero_grad()

        for agent_id in range(self.agents):
            obs, pos, _, act, _, dones = data[agent_id].values()
            batch_samples = obs.size(0)
            mask = ~dones.bool()     # keep only where done is False

            if mask.sum() == 0: # CHECK: mask is determined per the reciever agent
                continue  # entire batch is terminals

            obs = obs[mask]          # filter observations
            pos = pos[mask]
            msg = msg[mask]
            act = act[mask]

            obs = obs.to(self.device)
            pos = pos.to(self.device)
            msg = msg.to(self.device)
            act = act.to(self.device)

            for other_agent in range(self.agents):
                if other_agent != agent_id:
                    obs_sender, pos_sender, msg, _, _, dones = data[other_agent].values()
                    
                    msg = msg[mask]
                    obs_sender = obs_sender[mask]
                    pos_sender = pos_sender[mask]

                    obs_sender = obs_sender.to(self.device)
                    pos_sender = pos_sender.to(self.device)
                    msg = msg.to(self.device)

            
            C = self.msg_encoder(msg[:-1]) # [B, T-1, C, H, W] => [T-1, B, dim=32]

            Z0, Z = self.model(x= obs, pos= pos, actions= act[:-1], msgs= C, T= act.size(0)-1)  # Z0, Z: [T-1, B, c, h, w]
            vicreg_loss = self.vicreg(Z0, Z)
            
            z_sender = self.model.backbone(obs_sender[:-1], position= pos_sender[:-1])  # [T-1, B, c, h, w]
            proj_z, proj_c = self.projector(z_sender, C)

            inv_loss = (proj_z - proj_c).square().mean()

            sigreg_img = self.sigreg(proj_z)
            sigreg_msg = self.sigreg(proj_c)

            sigreg_loss = (sigreg_img + sigreg_msg ) / 2.0
            lejepa_loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss 

            loss = (lejepa_loss + vicreg_loss['total_loss']) / len(self.agents)
            loss.backward()

            train_loss += loss.item() * batch_samples            
        
        actual_len += batch_samples
        self.optimizer.step()
               
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(obs), len(self.train_loader.dataset),
                100. * batch_idx / len(self.train_loader),
                loss.item() / len(obs)))
        

    print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / actual_len))

    return train_loss / actual_len
       

In [None]:
#| export
from einops import rearrange
@patch
def eval_epoch(self: DynamicsTrainer):
    self.model.eval()

    val_loss = 0
    actual_len = 0

    with torch.no_grad():

        for batch_idx, data in enumerate(self.val_loader):
            for agent_id in range(self.agents):

                obs, pos, msg, act, _, dones = data[agent_id].values()
                mask = ~dones.bool()     # keep only where done is False

                if mask.sum() == 0:
                    continue  # entire batch is terminals

                obs = obs[mask]          # filter observations
                pos = pos[mask]
                msg = msg[mask]
                act = act[mask]

                obs = obs.to(self.device)
                pos = pos.to(self.device)
                msg = msg.to(self.device)
                act = act.to(self.device)

                for other_agent in range(self.agents):
                    if other_agent != agent_id:
                        obs_sender, pos_sender, msg, _, _, dones = data[other_agent].values()
                        
                        msg = msg[mask]
                        obs_sender = obs_sender[mask]
                        pos_sender = pos_sender[mask]

                        obs_sender = obs_sender.to(self.device)
                        pos_sender = pos_sender.to(self.device)
                        msg = msg.to(self.device)

                C = self.msg_encoder(msg[:-1]) # [B, T-1, dim=32]
                Z0, Z = self.model(x= obs, pos= pos, actions= act[:-1], msgs= C, T= self.cfg.data.seq_len - 1)
                vicreg_loss = self.vicreg(Z0, Z)
                
                z_sender = self.model.backbone(obs_sender[:-1], position= pos_sender)
                z_sender = rearrange(z_sender, 't b c h w -> (b t) (c h w)')

                proj_z = self.z_projector(z_sender)
                proj_c = self.msg_projector(C)

                inv_loss = (proj_z - proj_c).square().mean()

                sigreg_img = self.sigreg(proj_z)
                sigreg_msg = self.sigreg(proj_c)

                sigreg_loss = (sigreg_img + sigreg_msg ) / 2.0
                lejepa_loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss 

                loss = lejepa_loss + vicreg_loss['total_loss']
                train_loss += loss.item() * obs.size(0)
                val_loss += loss.item() * obs.size(0)
                actual_len += obs.size(0)

    val_loss /= actual_len
    print('====> Test set loss: {:.4f}'.format(val_loss))
    return val_loss

In [None]:
#| export
import wandb

@patch
def fit(self: DynamicsTrainer):
    self.model.to(self.device)
    self.msg_encoder.to(self.device)
    self.z_projector.to(self.device)
    self.msg_projector.to(self.device)
    
    cur_best = None
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        lr = self.scheduler.adjust_learning_rate(epoch)
        train_loss = self.train_epoch(epoch)
        val_loss = self.eval_epoch()

        # checkpointing
        best_filename = os.path.join(self.lejepa_dir, 'best.pth')
        filename = os.path.join(self.lejepa_dir, 'checkpoint.pth')

        is_best = not cur_best or val_loss < cur_best
        if is_best:
            cur_best = val_loss

        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'val_loss': val_loss,
            'optimizer': self.optimizer.state_dict(),
        }
        save_checkpoint(state= state, is_best= is_best, filename= filename, best_filename= best_filename)

        to_log = {
            "train_loss": train_loss, 
            "val_loss": val_loss,
        }

        self.writer.write(to_log)
        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss, "val_loss":val_loss}], index= "epoch")
        lst_dfs.append(df)


    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    self.writer.write({'Train-Val Loss Table': wandb.Table(dataframe= df_reset)})

    self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()