# Dynamics trainer

> This module implements LeJepa training procedure.

In [None]:
#| default_exp trainers.trainer_dynamics

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import pandas as pd

## Dynamics Trainer

In [None]:
#| export
from mawm.trainers.trainer import Trainer
from mawm.models.utils import save_checkpoint
from mawm.losses.sigreg import SIGReg
from mawm.losses.vicreg import VICReg

class DynamicsTrainer(Trainer):
    def __init__(self, cfg, model, train_loader, val_loader=None,
                 optimizer=None, device=None,earlystopping=None, 
                 scheduler=None, writer= None):
        
        self.cfg = cfg
        self.device = device if device else torch.device('cpu')

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.model = model['jepa']
        self.msg_encoder = model['msg_encoder']
        # self.projector = model['projector']
        self.msg_predictor = model['msg_predictor']
        self.obs_predictor = model['obs_predictor']

        self.optimizer = optimizer
        self.earlystopping = earlystopping
        self.scheduler = scheduler

        self.writer = writer

        self.sigreg = SIGReg().to(self.device)
        self.vicreg = VICReg(self.cfg, repr_dim=self.model.backbone.repr_dim).to(self.device) # TODO: ensure .cude comment for projector is returned back
        self.lambda_ = self.cfg.loss.lambda_

        
        self.agents = [f"agent_{i}" for i in range(len(self.cfg.env.agents))]

        self.dmpc_dir = os.path.join(self.cfg.log_dir, 'dmpc_marlrid')
        if not os.path.exists(self.dmpc_dir):
            os.makedirs(self.dmpc_dir)

    

In [None]:
# #| export
# from mawm.models.utils import flatten_conv_output
# from einops import rearrange
# @patch
# def train_epoch(self: DynamicsTrainer, epoch):
#     self.model.train()
    
#     total_running_loss = 0.0
#     total_valid_steps = 0

#     for batch_idx, data in enumerate(self.train_loader):
#         self.optimizer.zero_grad()
#         batch_loss = 0
        
#         for agent_id in self.agents:
#             obs, pos, _, act, _, dones = data[agent_id].values()
#             mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
#             mask = rearrange(mask, 'b t d-> t (b d)')  # [T, B]
#             mask_t = mask[:-1].clone() #rearrange(mask[:-1], 't b -> t b') # [T-1, B]
            
#             agent_loss = 0
#             actual_len = 0

#             if mask.sum() == 0: # CHECK: mask is determined per the reciever agent
#                 continue  # entire batch is terminals
#             num_valid = mask[:, :-1].sum().item()

#             obs = obs.to(self.device)
#             pos = pos.to(self.device)
#             act = act.to(self.device)

#             for other_agent in self.agents:
#                 if other_agent != agent_id:
#                     obs_sender, pos_sender, msg, _, _,_ = data[other_agent].values()

#                     obs_sender = obs_sender.to(self.device)
#                     pos_sender = pos_sender.to(self.device)
#                     msg = msg.to(self.device)
            
            
#             h = self.msg_encoder(msg) # [B, T, C, H, W] => [B, T, dim=32]

#             Z0, Z = self.model(x= obs, pos= pos, actions= act, msgs= h, T= act.size(1)-1)#[B, T, c, h, w] =>  [T, B, c, h, w]
#             vicreg_loss = self.vicreg(Z0, Z, mask= mask)
            
#             z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
#             z_sender_hat = self.obs_predictor(h)
            
#             z_sender = flatten_conv_output(z_sender)  # [B, T, c`, h`, w`] => [B, T,d]
#             z_sender_hat = flatten_conv_output(z_sender_hat)  # [B, T, d]

#             z_pred_loss = (z_sender - z_sender_hat).square().mean(dim= -1)  # [B, T, d] => [B, T]
#             z_pred_loss = (z_pred_loss * mask[:-1]).sum() / mask[:-1].sum().clamp_min(1) 

#             h_hat = self.msg_predictor(z_sender) # [B, T, d] => [B, T, dim=32]
#             h_pred_loss = (h - h_hat).square().mean(dim= -1)  # [B, T, dim=32] => [B, T]
#             h_pred_loss = (h_pred_loss * mask[:-1]).sum() / mask[:-1].sum().clamp_min(1) 



#             # proj_z, proj_c = self.projector(z_sender, h)# [B, T, *] => [T-1, B, dim=128]
            
#             # inv_loss = (proj_z - proj_c).square().mean(dim= -1)  # [T-1, B, dim=128] => [T-1, B]
#             # inv_loss = (inv_loss * mask[:-1]).sum() / mask[:-1].sum().clamp_min(1) 
            
#             valid_idx = mask_t.bool()
#             # sigreg_img = self.sigreg(proj_z[valid_idx])
#             # sigreg_msg = self.sigreg(proj_c[valid_idx])

#             # sigreg_loss = 0.5 * (sigreg_img + sigreg_msg )
#             # lejepa_loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss 

#             # agent_loss = lejepa_loss + vicreg_loss['total_loss']
#             agent_loss = self.lambda_ * z_pred_loss + (1 - self.lambda_) * h_pred_loss + vicreg_loss['total_loss']
#             batch_loss += agent_loss

#             num_valid = mask.sum().item()
#             total_running_loss += agent_loss.item() * num_valid
#             total_valid_steps += num_valid

            
#         loss = batch_loss / len(self.agents)
#         loss.backward()
#         self.optimizer.step()
       

#         if batch_idx % 20 == 0:
#             print(f'Train Epoch: {epoch} [{batch_idx * len(obs)}/{len(self.train_loader.dataset)} '
#                   f'({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

#     # divide by len(self.agents) to match the 'loss' scale used in training
#     final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents)
#     print(f'====> Epoch: {epoch} Average loss: {final_epoch_loss:.4f}')

#     return final_epoch_loss
       

In [None]:
#| export
from mawm.models.utils import flatten_conv_output
from einops import rearrange
@patch
def train_epoch(self: DynamicsTrainer, epoch):
    self.model.train()
    
    total_running_loss = 0.0
    total_valid_steps = 0

    for batch_idx, data in enumerate(self.train_loader):
        self.optimizer.zero_grad()
        batch_loss = 0

        for agent_id in self.agents:
            obs, pos, _, act, _, dones = data[agent_id].values()
            mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
            mask = rearrange(mask, 'b t d-> b (t d)')  # [T, B]
            mask_t = rearrange(mask, 'b t -> t b')  # [T, B]
            # mask_t = mask[:-1].clone() #rearrange(mask[:-1], 't b -> t b') # [T-1, B]
            
            agent_loss = 0

            if mask.sum() == 0: # CHECK: mask is determined per the reciever agent
                continue  # entire batch is terminals

            obs = obs.to(self.device)
            pos = pos.to(self.device)
            act = act.to(self.device)

            for other_agent in self.agents:
                if other_agent != agent_id:
                    obs_sender, pos_sender, msg, _, _,_ = data[other_agent].values()

                    obs_sender = obs_sender.to(self.device)
                    pos_sender = pos_sender.to(self.device)
                    msg = msg.to(self.device)
            
            h = self.msg_encoder(msg) # [B, T, C, H, W] => [B, T, dim=32]

            Z0, Z = self.model(x= obs, pos= pos, actions= act, msgs= h, T= act.size(1)-1)#[B, T, c, h, w] =>  [T, B, c, h, w]
            vicreg_loss = self.vicreg(Z0, Z, mask= mask_t)
            print(vicreg_loss)
            
            z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
            z_sender_hat = self.obs_predictor(h) # [B, T, d=32] => [B, T, C, H, W]

            z_sender_flat = flatten_conv_output(z_sender)  # [B, T, c`, h`, w`] => [B, T,d]
            z_sender_hat = flatten_conv_output(z_sender_hat)  # [B, T, d]

            z_pred_loss = (z_sender_flat - z_sender_hat).square().mean(dim= -1)  # [B, T, d] => [B, T]
            z_pred_loss = (z_pred_loss * mask).sum() / mask.sum().clamp_min(1) 

            h_hat = self.msg_predictor(z_sender[:, :, :-2]) # [B, T, d] => [B, T, dim=32]
            h_pred_loss = (h - h_hat).square().mean(dim= -1)  # [B, T, dim=32] => [B, T]
            h_pred_loss = (h_pred_loss * mask).sum() / mask.sum().clamp_min(1) 
            z_alignment = z_pred_loss * 250.0 
            h_alignment = h_pred_loss * 2000.0

            # 3. Combine
            agent_loss = (
                vicreg_loss['total_loss'] + 
                self.lambda_ * z_alignment + 
                (1 - self.lambda_) * h_alignment
            )
                        
            # agent_loss = self.lambda_ * z_pred_loss + (1 - self.lambda_) * h_pred_loss + vicreg_loss['total_loss']
            print(f"Agent: {agent_id}, z_pred_loss: {z_alignment.item():.4f}, h_pred_loss: {h_alignment.item():.4f}, vicreg_loss: {vicreg_loss['total_loss'].item():.4f}, agent_loss: {agent_loss.item():.4f}")
            batch_loss += agent_loss

            num_valid = mask.sum().item()
            total_running_loss += agent_loss.item() * num_valid
            total_valid_steps += num_valid

            
        loss = batch_loss / len(self.agents)
        loss.backward()
        self.optimizer.step()
       

        if batch_idx % 20 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(obs)}/{len(self.train_loader.dataset)} '
                  f'({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    # divide by len(self.agents) to match the 'loss' scale used in training
    final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents)
    print(f'====> Epoch: {epoch} Average loss: {final_epoch_loss:.4f}')

    return final_epoch_loss
       

In [None]:
#| export
from einops import rearrange
@patch
def eval_epoch(self: DynamicsTrainer):
    self.model.eval()

    total_running_loss = 0.0
    total_valid_steps = 0

    with torch.no_grad():
        for batch_idx, data in enumerate(self.train_loader):
            batch_loss = 0
            
            for agent_id in self.agents:
                obs, pos, _, act, _, dones = data[agent_id].values()
                mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
                mask = rearrange(mask, 'b t d-> b (t d)')  # [T, B]
                mask_t = rearrange(mask, 'b t -> t b')  # [T, B]
                    
                agent_loss = 0
                
                if mask.sum() == 0: # CHECK: mask is determined per the reciever agent
                    continue  # entire batch is terminals

                obs = obs.to(self.device)
                pos = pos.to(self.device)
                act = act.to(self.device)

                for other_agent in self.agents:
                    if other_agent != agent_id:
                        obs_sender, pos_sender, msg, _, _,_ = data[other_agent].values()

                        obs_sender = obs_sender.to(self.device)
                        pos_sender = pos_sender.to(self.device)
                        msg = msg.to(self.device)
                
                h = self.msg_encoder(msg) # [B, T, C, H, W] => [B, T, dim=32]

                Z0, Z = self.model(x= obs, pos= pos, actions= act, msgs= h, T= act.size(1)-1)#[B, T, c, h, w] =>  [T, B, c, h, w]
                vicreg_loss = self.vicreg(Z0, Z, mask= mask_t)
                
                z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
                z_sender_hat = self.obs_predictor(h)
                
                z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
                z_sender_hat = self.obs_predictor(h)

                z_sender_flat = flatten_conv_output(z_sender)  # [B, T, c`, h`, w`] => [B, T,d]
                z_sender_hat = flatten_conv_output(z_sender_hat)  # [B, T, d]

                z_pred_loss = (z_sender_flat - z_sender_hat).square().mean(dim= -1)  # [B, T, d] => [B, T]
                z_pred_loss = (z_pred_loss * mask).sum() / mask.sum().clamp_min(1) 

                h_hat = self.msg_predictor(z_sender[:, :, :-2]) # [B, T, d] => [B, T, dim=32]
                h_pred_loss = (h - h_hat).square().mean(dim= -1)  # [B, T, dim=32] => [B, T]
                h_pred_loss = (h_pred_loss * mask).sum() / mask.sum().clamp_min(1) 
                
                agent_loss = self.lambda_ * z_pred_loss + (1 - self.lambda_) * h_pred_loss + vicreg_loss['total_loss']
                batch_loss += agent_loss

                num_valid = mask.sum().item()
                total_running_loss += agent_loss.item() * num_valid
                total_valid_steps += num_valid

    final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents)
    print(f'====>  Test set loss: {final_epoch_loss:.4f}')
    return final_epoch_loss
       

In [None]:
#| export
import wandb

@patch
def fit(self: DynamicsTrainer):
    self.model.to(self.device)
    self.msg_encoder.to(self.device)
    self.msg_predictor.to(self.device)
    self.obs_predictor.to(self.device)
    # self.projector.to(self.device)
    
    cur_best = None
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        lr = self.scheduler.adjust_learning_rate(epoch)
        train_loss = self.train_epoch(epoch)
        val_loss = self.eval_epoch()

        # checkpointing
        best_filename = os.path.join(self.dmpc_dir, 'best.pth')
        filename = os.path.join(self.dmpc_dir, 'checkpoint.pth')

        is_best = not cur_best or val_loss < cur_best
        if is_best:
            cur_best = val_loss

        state = {
            'epoch': epoch,
            'jepa': self.model.state_dict(),
            'msg_encoder': self.msg_encoder.state_dict(),
            # 'projector': self.projector.state_dict(),
            "msg_predictor": self.msg_predictor.state_dict(),
            "obs_predictor": self.obs_predictor.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'optimizer': self.optimizer.state_dict(),
            "lr": lr,
        }
        save_checkpoint(state= state, is_best= is_best, filename= filename, best_filename= best_filename)

        to_log = {
            "train_loss": train_loss, 
            "val_loss": val_loss,
        }

        self.writer.write(to_log)
        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss, "val_loss":val_loss}], index= "epoch")
        lst_dfs.append(df)


    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    self.writer.write({'Train-Val Loss Table': wandb.Table(dataframe= df_reset)})

    self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()