# World Model trainer

> This module implements LeJepa training procedure with three predictors and two input modalities.

In [None]:
#| default_exp trainers.findgoal_trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

## WorldModel Trainer

In [None]:
#| export
from mawm.models.utils import save_checkpoint
from mawm.loggers.base import AverageMeter
from mawm.losses.sigreg import SIGReg
from mawm.losses.idm import IDMLoss    
from mawm.models.utils import flatten_conv_output
from einops import rearrange
from torch.nn.parallel import DistributedDataParallel

class WMTrainer:
    def __init__(self, cfg, model, train_loader, sampler,
                 optimizer=None, device=None,earlystopping=None, 
                 scheduler=None, writer= None, verbose= None, logger= None):
        
        self.cfg = cfg
        self.device = device

        self.train_loader = train_loader
        self.sampler = sampler

        self.jepa = model['rec']['jepa']
        self.obs_enc = model["send"]["obs_enc"]
        self.msg_enc = model["send"]["msg_enc"]
        self.proj = model["send"]["proj"]
        self.comm_module = model["send"]["comm_module"]
        
        self.optimizer = optimizer
        self.earlystopping = earlystopping
        self.scheduler = scheduler

        self.writer = writer
        self.verbose = verbose
        self.logger = logger

        self.sigreg = SIGReg().to(self.device)

        self.idm = IDMLoss(cfg.loss.idm, (32, 15, 15), device= self.device)
        if self.cfg.distributed:
            self.idm.action_predictor = DistributedDataParallel(self.idm.action_predictor, device_ids = [self.device], find_unused_parameters=True)
        
        else:
            self.idm.action_predictor = self.idm.action_predictor.to(self.device)
        
        new_opt_group = {'params': self.idm.action_predictor.parameters(), 'lr': 0.001, 'weight_decay': 1e-4}
        self.optimizer.add_param_group(new_opt_group)

        # self.lambda_ = self.cfg.loss.lambda_
        self.schedule_start_epoch = 5  # Start mixing at epoch 5
        self.schedule_end_epoch = 20    # Fully use predictions by epoch 20
    
        self.agents = [f"agent_{i}" for i in range(len(self.cfg.env.agents))]

        self.dmpc_dir = os.path.join(self.cfg.log_dir, self.cfg.log_subdir, self.cfg.now)
        if not os.path.exists(self.dmpc_dir):
            os.makedirs(self.dmpc_dir , exist_ok=True)

    

In [None]:
#| hide
import torch
import torch.nn.functional as F
logits = torch.randn(8, 17, 5, 7, 7)
targets = torch.randint(0, 5, (8, 17, 7, 7))
print(logits.flatten(0,1).shape)
F.cross_entropy(logits.flatten(0,1), targets.flatten(0,1))

torch.Size([136, 5, 7, 7])


tensor(1.9662)

In [None]:
#| export
@patch
def criterion(self: WMTrainer, global_step, z0, z, actions, msg_target, msg_hat, proj_h, proj_z, mask_t):

    # RECEIVER LOSSES
    flat_encodings = flatten_conv_output(z0) # [T, B, c`, h`, w`] => [T, B, D]
    sigreg_img =self.sigreg(flat_encodings, global_step= global_step, across_dim= (0, 1), distributed= self.cfg.distributed)
    sigreg_time = self.sigreg(flat_encodings, global_step= global_step, across_dim= 0, distributed= self.cfg.distributed)

    transition_mask = mask_t[1:] * mask_t[:-1]
    diff = (z0[1:] - z[1:]).pow(2).mean(dim=(2, 3, 4)) # (T-1, B)
    sim_loss = (diff * transition_mask).sum() / transition_mask.sum().clamp(min=1)
    
    idm_loss = self.idm(embeddings= z0, predictions= z, actions= actions)
    
    # SENDER LOSSES
    sigreg_msg = self.sigreg(proj_h, global_step= global_step, across_dim= (0, 1), distributed= self.cfg.distributed)
    sigreg_obs = self.sigreg(proj_z, global_step= global_step, across_dim= (0, 1), distributed= self.cfg.distributed)

    inv_loss_sender = (proj_z - proj_h).square().mean()

    msg_pred_loss = self.cross_entropy(msg_hat.flatten(0,1), msg_target.flatten(0,1)) #msg_hat: [B*T, 5, 7, 7], targe: [B*T, 7, 7] with long() dtype.

    return {
        'sigreg_img': sigreg_img,
        'sigreg_msg': sigreg_msg,
        'sigreg_obs': sigreg_obs,
        'sigreg_time': sigreg_time,
        'sim_loss_dynamics': sim_loss,
        # 'sim_loss_t': sim_loss_t,
        'inv_loss_sender': inv_loss_sender,
        'msg_pred_loss': msg_pred_loss,
        'idm_loss': idm_loss
    }
    

In [None]:
#| export
@patch
def get_sampling_prob(self: WMTrainer, epoch):
    if epoch < self.schedule_start_epoch:
        return 0.0  # Always use ground truth
    elif epoch >= self.schedule_end_epoch:
        return 1.0  # Always use predictions
    else:
        # Linear interpolation
        progress = (epoch - self.schedule_start_epoch) / (self.schedule_end_epoch - self.schedule_start_epoch)
        return progress

In [None]:
#| export
@patch
def sender_jepa(self: WMTrainer, data, sampling_prob, step):

    obs_sender, pos_sender, msg, msg_target, _,_, _ = data
    obs_sender = obs_sender.to(self.device)
    pos_sender = pos_sender.to(self.device)
    msg = msg.to(self.device)
    msg_target = msg_target.to(self.device)

    g = torch.Generator().manual_seed(step) 
    decision_rand = torch.rand(1, generator=g).item()
    
    self.logger.info(f"device used for other agent data: {obs_sender.device}, {msg.device}")

    z = self.obs_enc(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
    h_target = self.msg_enc(msg)  # [B, T, C, H, W] => [B, T, dim=32]
    proj_z, proj_h = self.proj(z, h_target) # True JEPA alignment

    msg_hat = self.comm_module(z.detach())  # [B, T, c`, h`, w`] => [B, T, C=5, H=7, W=7]
    if decision_rand < sampling_prob:
        sample = F.one_hot(msg_hat.argmax(dim=2), num_classes=5)  # [B, T, 7, 7, 5]
        sample = rearrange(sample, 'b t h w c -> b t c h w')# [B, T, 5, 7, 7]
        probs = F.softmax(msg_hat, dim=2)  # [B, T, 5, 7, 7]
        msg_used = sample + probs - probs.detach() # [B, T, C, H, W] `one-hot with straight-through`
        h_for_receiver = self.msg_enc(msg_used.to(probs.dtype)) # [B, T, C, H, W] => [B, T, dim=32]

    else:
        msg_used = msg  # [B, T, C, H, W]
    h_for_receiver = h_target  # Use target encoding when not sampling

    return msg_hat, msg_target, h_for_receiver, proj_z, proj_h

In [None]:
#| export
@patch
def rec_jepa(self: WMTrainer, data, h):
    obs, pos, _, _, act, _, dones = data
    mask = (~dones.bool()).float().to(self.device).clone() # [B, T, d=1]
    mask = rearrange(mask, 'b t d-> b (t d)', d=1)
    mask_t = rearrange(mask, 'b t -> t b')

    if mask.sum() == 0:
        return  
 
    obs = obs.to(self.device)
    pos = pos.to(self.device)
    act = act.to(self.device)

    self.logger.info(f"device used for main agent data: {mask_t.device} {obs.device}, {act.device}")
    
    z0, z = self.jepa(x= obs, #[B, T, c, h, w] =>  [T, B, c, h, w]
                      pos= pos,
                      actions= act,
                      msgs= h,
                      T= act.size(1)-1)
    
    return z0, z, act, mask_t, mask, len(obs)

In [None]:
#| export
@patch  
def train_epoch(self: WMTrainer, epoch):
    self.logger.info(f"Inside train_epoch: Starting epoch {epoch}")
    total_running_loss = 0.0
    total_valid_steps = 0
    num_pairs = len(self.agents) * (len(self.agents) - 1)

    if self.sampler:
        self.logger.info(f"Setting epoch {epoch} for sampler")
        self.sampler.set_epoch(epoch) if epoch > 0 else None
        
    sampling_prob = self.get_sampling_prob(epoch)

    self.logger.info(f"Sampling probability for epoch {epoch}: {sampling_prob:.4f}")
    for batch_idx, data in enumerate(self.train_loader):
        self.logger.info(f"Starting batch {batch_idx} of epoch {epoch}")
        global_step = epoch * len(self.train_loader) + batch_idx
        lr = self.scheduler.adjust_learning_rate(global_step)

        if self.verbose  and epoch == 1 and batch_idx == 0:
            self.logger.info(f"\n=== LR DIAGNOSTIC ===")
            self.logger.info(f"Epoch: {epoch}")
            self.logger.info(f"Batch idx: {batch_idx}")
            self.logger.info(f"Global step passed to scheduler: {global_step}")
            self.logger.info(f"Total batches per epoch: {len(self.train_loader)}")
            self.logger.info(f"Total epochs: {self.cfg.epochs}")
            self.logger.info(f"Batch size: {self.scheduler.batch_size}")
            self.logger.info(f"Base LR: {self.scheduler.base_lr}")
        
        self.optimizer.zero_grad()
        
        batch_log_accumulator = {}
        if self.verbose:
            if batch_idx == 0:
                obs, pos, msg, msg_target, act, _, dones = data[self.agents[0]].values()
                self.logger.info(f"Input data stats:")
                self.logger.info(f"  obs: min={obs.min():.3f}, max={obs.max():.3f}, mean={obs.mean():.3f}, std={obs.std():.3f}")
                self.logger.info(f"  msg: min={msg.min():.3f}, max={msg.max():.3f}, mean={msg.mean():.3f}")
            
        for rec in self.agents:
            for sender in self.agents:
                if sender == rec: continue
                
                msg_hat, msg_target, h, proj_z, proj_h = self.sender_jepa(
                    data[sender].values(), sampling_prob, epoch + batch_idx
                )
                
                z0, z, act, mask_t, mask, len_obs = self.rec_jepa(
                    data[rec].values(), h
                )

                if self.verbose:
                    if torch.isnan(z0).any():
                        self.logger.error(f"NaN detected in z0!")
                    if torch.isnan(z).any():
                        self.logger.error(f"NaN detected in z!")
                    
                losses = self.criterion(global_step, z0, z, act, msg_target, msg_hat, proj_h, proj_z, mask_t)
                
                s_jepa = self.cfg.loss.sigreg.msg * (losses['sigreg_obs'] + losses['sigreg_msg']) + self.cfg.loss.inv_loss.coeff * losses['inv_loss_sender']
                
                r_jepa = self.cfg.loss.sigreg.img * losses['sigreg_img'] + losses['sim_loss_dynamics']

                collapse_loss = (self.cfg.loss.msg_pred.coeff * losses['msg_pred_loss'] + 
                                self.cfg.loss.sigreg.time * losses['sigreg_time'] +
                                self.cfg.loss.idm.coeff * losses['idm_loss'])

                pair_loss = s_jepa + r_jepa + collapse_loss

                if self.verbose:
                    if batch_idx % 5 == 0:
                        self.logger.info(f"\nBatch {batch_idx}, Pair {sender}->{rec}:")
                        self.logger.info(f"  Individual losses:")
                        for k, v in losses.items():
                            self.logger.info(f"    {k}: {v.item():.4f}")
                        self.logger.info(f"  Combined losses:")
                        self.logger.info(f"    s_jepa: {s_jepa.item():.4f}")
                        self.logger.info(f"    r_jepa: {r_jepa.item():.4f}")
                        self.logger.info(f"    collapse_loss: {collapse_loss.item():.4f}")
                        self.logger.info(f"    TOTAL pair_loss: {pair_loss.item():.4f}")

                scaled_loss = pair_loss / num_pairs
                scaled_loss.backward()
                
                num_valid = mask.sum().item()
                total_running_loss += pair_loss.item() * num_valid
                total_valid_steps += num_valid

                if self.verbose:
                    for k, v in losses.items():
                        batch_log_accumulator[f'pair_{sender}_to_{rec}/{k}'] = v.item()
    
        self.optimizer.step()

        if self.verbose:
            self.writer.write(batch_log_accumulator)
            
            processed_samples = batch_idx * len_obs 
            self.logger.info(f'Train Epoch: {epoch} [{processed_samples}/{len(self.train_loader.dataset)} '
                             f'({100. * batch_idx / len(self.train_loader):.0f}%)]\t')

    final_epoch_loss = (total_running_loss / total_valid_steps) if total_valid_steps > 0 else 0.0
    return final_epoch_loss

In [None]:
#| export
import wandb
CHECKPOINT_FREQ = 1
@patch
def fit(self: WMTrainer):
    self.jepa.train()
    self.obs_enc.train()
    self.msg_enc.train()
    self.comm_module.train()
    self.proj.train()

    latest_file = "latest.pt"
    folder = self.dmpc_dir
    latest_path = os.path.join(folder, latest_file)
    
    loss_meter = AverageMeter()
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        self.logger.info("Epoch %d" % (epoch))        
        train_loss = self.train_epoch(epoch)
        loss_meter.update(train_loss)
        
        def save_checkpoint(epoch, path):
            if not self.verbose:
                return
            
            def get_state(m):
                return m.module.state_dict() if hasattr(m, 'module') else m.state_dict()
            
            save_dict = {
                'epoch': epoch,
                'jepa': get_state(self.jepa),
                'obs_enc': get_state(self.obs_enc),
                'msg_enc': get_state(self.msg_enc),
                'comm_module': get_state(self.comm_module),
                'proj': get_state(self.proj),
                'train_loss': train_loss,
                'optimizer': self.optimizer.state_dict(),
            }
            
            try:
                torch.save(save_dict, path)
                self.logger.info(f"Successfully saved checkpoint to {path}")
            except Exception as e:
                self.logger.info(f"Encountered exception when saving checkpoint: {e}")
        
        self.logger.info("avg. loss %.3f" % loss_meter.avg)

        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss}], index= "epoch")
        lst_dfs.append(df)

        if epoch % CHECKPOINT_FREQ == 0 or epoch == (self.cfg.epochs - 1):
            save_checkpoint(epoch + 1, latest_path)
            if self.cfg.save_every_freq > 0 and epoch % self.cfg.save_every_freq == 0:
                save_every_file = f"e{epoch}.pt"
                save_every_path = os.path.join(folder, save_every_file)
                save_checkpoint(epoch + 1, save_every_path)

        to_log = {
            "train_loss": train_loss, 
        }

        if self.verbose:
            self.writer.write(to_log)

    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    if self.verbose:
        self.writer.write({'Train Loss Table': wandb.Table(dataframe= df_reset)})
        self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()