# World Model trainer

> This module implements LeJepa training procedure with three predictors and two input modalities.

In [None]:
#| default_exp trainers.wm_trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import pandas as pd

## WorldModel Trainer

In [None]:
#| export
from mawm.trainers.trainer import Trainer
from mawm.models.utils import save_checkpoint
from mawm.logger.base import AverageMeter
from mawm.losses.sigreg import SIGReg, SIGRegDistributed
from mawm.losses.vicreg import VICReg
from mawm.models.utils import flatten_conv_output

class WMTrainer(Trainer):
    def __init__(self, cfg, model, train_loader, sampler,
                 optimizer=None, device=None,earlystopping=None, 
                 scheduler=None, writer= None, verbose= None, logger= None):
        
        self.cfg = cfg
        self.device = device

        self.train_loader = train_loader
        self.sampler = sampler

        self.model = model['jepa']
        self.msg_encoder = model['msg_encoder']
        self.msg_predictor = model['msg_predictor']
        self.obs_predictor = model['obs_predictor']

        self.optimizer = optimizer
        self.earlystopping = earlystopping
        self.scheduler = scheduler

        self.writer = writer
        # self.loss_meter = AverageMeter()
        self.verbose = verbose
        self.logger = logger

        self.sigreg = SIGReg().to(self.device)
        self.disSigReg = SIGRegDistributed().to(self.device)

        # self.vicreg = VICReg(self.cfg).to(self.device)
        self.lambda_ = self.cfg.loss.lambda_
        self.W_H_PRED = self.cfg.loss.W_H_PRED
        self.W_SIM_T = self.cfg.loss.W_SIM_T

        self.agents = [f"agent_{i}" for i in range(len(self.cfg.env.agents))]

        self.dmpc_dir = os.path.join(self.cfg.log_dir, 'dmpc_marlrid', self.cfg.now)
        if not os.path.exists(self.dmpc_dir):
            os.makedirs(self.dmpc_dir , exist_ok=True)

    

In [None]:
#| export
@patch
def criterion(self: WMTrainer, global_step, Z0, Z, h, h_hat, mask_t, mask):

    flat_encodings = flatten_conv_output(Z0) # [T, B, c`, h`, w`] => [T, B, d]
    sigreg_img = self.disSigReg(flat_encodings[:1], global_step= global_step)
    sigreg_msg = self.disSigReg(h, global_step= global_step)
    transition_mask = mask_t[1:] * mask_t[:-1]

    diff = (Z0[1:] - Z[1:]).pow(2).mean(dim=(2, 3, 4)) # (T-1, B)
    sim_loss = (diff * transition_mask).sum() / transition_mask.sum().clamp_min(1)

    if self.cfg.loss.vicreg.sim_coeff_t:
        diff_t = ( Z0[1:] -  Z0[:-1]).pow(2).mean(dim=(2, 3, 4))# (T-1, B)
        sim_loss_t = (diff_t * transition_mask).sum() / transition_mask.sum().clamp_min(1)
    else:
        sim_loss_t = torch.zeros([1], device=self.device)

    # z_sender_flat = flatten_conv_output(z_sender)  # [B, T, c`, h`, w`] => [B, T,d]
    # z_sender_hat = flatten_conv_output(z_sender_hat)  # [B, T, d]

    # z_pred_loss = (z_sender_flat - z_sender_hat).square().mean(dim= -1)  # [B, T, d] => [B, T]
    # z_pred_loss = (z_pred_loss * mask).sum() / mask.sum().clamp_min(1) 

    h_pred_loss = (h - h_hat).square().mean(dim= -1)  # [B, T, dim=32] => [B, T]
    h_pred_loss = (h_pred_loss * mask).sum() / mask.sum().clamp_min(1) 

    return {
        'sigreg_img': sigreg_img,
        'sigreg_msg': sigreg_msg,
        'sim_loss': sim_loss,
        'sim_loss_t': sim_loss_t,
        # 'z_pred_loss': z_pred_loss,
        'h_pred_loss': h_pred_loss
    }
            

In [None]:
#| export
from mawm.models.utils import flatten_conv_output
from einops import rearrange
@patch
def train_epoch(self: WMTrainer, epoch):
    self.model.train()
    
    total_running_loss = 0.0
    total_valid_steps = 0

    self.logger.info(f"Device used: {self.device}")
    self.sampler.set_epoch(epoch)
    for batch_idx, data in enumerate(self.train_loader):
        global_step = epoch * len(self.train_loader) + batch_idx
        self.optimizer.zero_grad()
        batch_loss = 0

        for agent_id in self.agents:
            obs, _, _, act, _, dones = data[agent_id].values()
            mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
            mask = rearrange(mask, 'b t d-> b (t d)')  # [T, B]
            mask_t = rearrange(mask, 'b t -> t b')  # [T, B]
            
            agent_loss = 0

            if mask.sum() == 0:
                continue 

            obs = obs.to(self.device)
            # pos = pos.to(self.device)
            act = act.to(self.device)

            self.logger.info(f"device used for main agent data: {mask_t.device} {obs.device}, {act.device}")

            for other_agent in self.agents:
                if other_agent != agent_id:
                    obs_sender, _, msg, _, _,_ = data[other_agent].values()

                    obs_sender = obs_sender.to(self.device)
                    # pos_sender = pos_sender.to(self.device)
                    msg = msg.to(self.device)
                    self.logger.info(f"device used for other agent data: {obs_sender.device}, {msg.device}")
            
            h = self.msg_encoder(msg) # [B, T, C, H, W] => [B, T, dim=32]

            Z0, Z = self.model(x= obs, pos= None, actions= act, msgs= h, T= act.size(1)-1)#[B, T, c, h, w] =>  [T, B, c, h, w]
            # vicreg_loss = self.vicreg(Z0, Z, mask= mask_t)
            # self.logger.info("Vicreg losses: %s" % str({k: v.item() for k, v in vicreg_loss.items()}))
            
            if hasattr(self.model, 'module'):
                z_sender = self.model.module.backbone(obs_sender, position = None)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
            else:
                z_sender = self.model.backbone(obs_sender, position = None)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
                
            # z_sender_hat = self.obs_predictor(h) # [B, T, d=32] => [B, T, C, H, W]
            h_hat = self.msg_predictor(z_sender) # [B, T, d] => [B, T, dim=32]
            
            # losses = self.criterion(global_step, Z0, Z, h, h_hat, mask_t, mask, z_sender, z_sender_hat)
            losses = self.criterion(global_step, Z0, Z, h, h_hat, mask_t, mask)


            self.writer.write({
                f'{agent_id}/train/sigreg_img': losses['sigreg_img'].item(),
                f'{agent_id}/train/sigreg_msg': losses['sigreg_msg'].item(),
                f'{agent_id}/train/sim_loss': losses['sim_loss'].item(),
                f'{agent_id}/train/sim_loss_t': losses['sim_loss_t'].item(),
                # f'{agent_id}/train/z_pred_loss': losses['z_pred_loss'].item(),
                f'{agent_id}/train/h_pred_loss': losses['h_pred_loss'].item(),
            })
            
            self.logger.info("Losses: %s" % str({k: v.item() for k, v in losses.items()}))
            
            # jepa_1_loss = (1 - self.lambda_) * losses['sim_loss'] + self.lambda_ * losses['sigreg_img']
            # jepa_2_loss = (1 - self.lambda_) * losses['z_pred_loss'] + self.lambda_ * losses['sigreg_msg']
            # jepa_3_loss = self.W_H_PRED * losses['h_pred_loss']

            jepa_1_loss = (1 - self.lambda_) * losses['sim_loss'] + self.lambda_ * losses['sigreg_img']
            jepa_2_loss = (1 - self.lambda_) * losses['h_pred_loss'] + self.lambda_ * losses['sigreg_msg']
            # jepa_3_loss = self.W_H_PRED * losses['h_pred_loss']

            self.logger.info(f"JEPA Losses: jepa_1_loss: {jepa_1_loss.item():.4f}, jepa_2_loss: {jepa_2_loss.item():.4f}, sim_loss_t: {losses['sim_loss_t'].item():.4f}")

            agent_loss = jepa_1_loss + jepa_2_loss + self.W_SIM_T * losses['sim_loss_t']
            self.logger.info(f"Agent: {agent_id}, agent_loss: {agent_loss.item():.4f}")
            
            # agent_loss = self.lambda_ * z_pred_loss + (1 - self.lambda_) * h_pred_loss + vicreg_loss['total_loss']
            # self.logger.info(f"Agent: {agent_id}, z_pred_loss: {z_alignment.item():.4f}, h_pred_loss: {h_alignment.item():.4f}, vicreg_loss: {vicreg_loss['total_loss'].item():.4f}, agent_loss: {agent_loss.item():.4f}")
            batch_loss += agent_loss

            num_valid = mask.sum().item()
            total_running_loss += agent_loss.item() * num_valid
            total_valid_steps += num_valid
            
        loss = batch_loss / len(self.agents)
        loss.backward()
        self.optimizer.step()
       

        if batch_idx % 20 == 0:
            self.logger.info(f'Train Epoch: {epoch} [{batch_idx * len(obs)}/{len(self.train_loader.dataset)} '
                  f'({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    # divide by len(self.agents) to match the 'loss' scale used in training
    final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents)
    self.logger.info(f'====> Epoch: {epoch} Average loss: {final_epoch_loss:.4f}')

    return final_epoch_loss
       

In [None]:
# #| export
# from einops import rearrange
# @patch
# def eval_epoch(self: WMTrainer):
#     self.model.eval()

#     total_running_loss = 0.0
#     total_valid_steps = 0

#     with torch.no_grad():
#         for batch_idx, data in enumerate(self.train_loader):
#             batch_loss = 0
            
#             for agent_id in self.agents:
#                 obs, pos, _, act, _, dones = data[agent_id].values()
#                 mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
#                 mask = rearrange(mask, 'b t d-> b (t d)')  # [T, B]
#                 mask_t = rearrange(mask, 'b t -> t b')  # [T, B]
                    
#                 agent_loss = 0
                
#                 if mask.sum() == 0: # CHECK: mask is determined per the reciever agent
#                     continue  # entire batch is terminals

#                 obs = obs.to(self.device)
#                 pos = pos.to(self.device)
#                 act = act.to(self.device)

#                 for other_agent in self.agents:
#                     if other_agent != agent_id:
#                         obs_sender, pos_sender, msg, _, _,_ = data[other_agent].values()

#                         obs_sender = obs_sender.to(self.device)
#                         pos_sender = pos_sender.to(self.device)
#                         msg = msg.to(self.device)
                
#                 h = self.msg_encoder(msg) # [B, T, C, H, W] => [B, T, dim=32]

#                 Z0, Z = self.model(x= obs, pos= pos, actions= act, msgs= h, T= act.size(1)-1)#[B, T, c, h, w] =>  [T, B, c, h, w]
#                 vicreg_loss = self.vicreg(Z0, Z, mask= mask_t)
                
#                 z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
#                 z_sender_hat = self.obs_predictor(h)
                
#                 z_sender = self.model.backbone(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
#                 z_sender_hat = self.obs_predictor(h)

#                 z_sender_flat = flatten_conv_output(z_sender)  # [B, T, c`, h`, w`] => [B, T,d]
#                 z_sender_hat = flatten_conv_output(z_sender_hat)  # [B, T, d]

#                 z_pred_loss = (z_sender_flat - z_sender_hat).square().mean(dim= -1)  # [B, T, d] => [B, T]
#                 z_pred_loss = (z_pred_loss * mask).sum() / mask.sum().clamp_min(1) 

#                 h_hat = self.msg_predictor(z_sender[:, :, :-2]) # [B, T, d] => [B, T, dim=32]
#                 h_pred_loss = (h - h_hat).square().mean(dim= -1)  # [B, T, dim=32] => [B, T]
#                 h_pred_loss = (h_pred_loss * mask).sum() / mask.sum().clamp_min(1) 
                
#                 agent_loss = self.lambda_ * z_pred_loss + (1 - self.lambda_) * h_pred_loss + vicreg_loss['total_loss']
#                 batch_loss += agent_loss

#                 num_valid = mask.sum().item()
#                 total_running_loss += agent_loss.item() * num_valid
#                 total_valid_steps += num_valid

#     final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents)
#     print(f'====>  Test set loss: {final_epoch_loss:.4f}')
#     return final_epoch_loss
       

In [None]:
#| export
import wandb
CHECKPOINT_FREQ = 1
@patch
def fit(self: WMTrainer):

    
    latest_file = "latest.pt"
    folder = self.dmpc_dir
    latest_path = os.path.join(folder, latest_file)
    
    loss_meter = AverageMeter()
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        self.logger.info("Epoch %d" % (epoch))
        lr = self.scheduler.adjust_learning_rate(epoch)
        train_loss = self.train_epoch(epoch)
        loss_meter.update(train_loss)
        
        def save_checkpoint(epoch, path):
            if not self.verbose:
                return
            
            def get_state(m):
                return m.module.state_dict() if hasattr(m, 'module') else m.state_dict()
            
            save_dict = {
                'epoch': epoch,
                'jepa': get_state(self.model),
                'msg_encoder': get_state(self.msg_encoder),
                "msg_predictor": get_state(self.msg_predictor),
                "obs_predictor": get_state(self.obs_predictor),
                'train_loss': train_loss,
                'optimizer': self.optimizer.state_dict(),
                "lr": lr,
            }
            try:
                torch.save(save_dict, path)
                self.logger.info(f"Successfully saved checkpoint to {path}")
            except Exception as e:
                self.logger.info(f"Encountered exception when saving checkpoint: {e}")
        
        self.logger.info("avg. loss %.3f" % loss_meter.avg)

        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss}], index= "epoch")
        lst_dfs.append(df)

        if epoch % CHECKPOINT_FREQ == 0 or epoch == (self.cfg.epochs - 1):
            save_checkpoint(epoch + 1, latest_path)
            if self.cfg.save_every_freq > 0 and epoch % self.cfg.save_every_freq == 0:
                save_every_file = f"e{epoch}.pt"
                save_every_path = os.path.join(folder, save_every_file)
                save_checkpoint(epoch + 1, save_every_path)

        to_log = {
            "train_loss": train_loss, 
        }

        if self.verbose:
            self.writer.write(to_log)

    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    if self.verbose:
        self.writer.write({'Train Loss Table': wandb.Table(dataframe= df_reset)})

    self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()