# World Model trainer

> This module implements LeJepa training procedure with three predictors and two input modalities.

In [None]:
#| default_exp trainers.findgoal_trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

## WorldModel Trainer

In [None]:
#| export
from mawm.models.utils import save_checkpoint
from mawm.loggers.base import AverageMeter
from mawm.losses.sigreg import SIGRegFunctional
from mawm.losses.idm import IDMLoss    
from mawm.models.utils import flatten_conv_output
from einops import rearrange

class WMTrainer:
    def __init__(self, cfg, model, train_loader, sampler,
                 optimizer=None, device=None,earlystopping=None, 
                 scheduler=None, writer= None, verbose= None, logger= None):
        
        self.cfg = cfg
        self.device = device

        self.train_loader = train_loader
        self.sampler = sampler

        self.jepa = model['rec']['jepa']
        self.obs_enc = model["send"]["obs_enc"]
        self.msg_enc = model["send"]["msg_enc"]
        self.proj = model["send"]["proj"]
        self.comm_module = model["send"]["comm_module"]
        
        self.optimizer = optimizer
        self.earlystopping = earlystopping
        self.scheduler = scheduler

        self.writer = writer
        self.verbose = verbose
        self.logger = logger

        self.sigreg = SIGRegFunctional().to(self.device)
        self.cross_entropy = nn.CrossEntropyLoss()
        self.idm = IDMLoss(cfg.loss.idm, (32, 15, 15), device= self.device)

        self.lambda_ = self.cfg.loss.lambda_
        self.W_H_PRED = self.cfg.loss.W_H_PRED
        self.W_SIM_T = self.cfg.loss.W_SIM_T

        self.schedule_start_epoch = 5  # Start mixing at epoch 5
        self.schedule_end_epoch = 20    # Fully use predictions by epoch 20
    
        self.agents = [f"agent_{i}" for i in range(len(self.cfg.env.agents))]

        self.dmpc_dir = os.path.join(self.cfg.log_dir, self.cfg.log_subdir, self.cfg.now)
        if not os.path.exists(self.dmpc_dir):
            os.makedirs(self.dmpc_dir , exist_ok=True)

    

In [None]:
#| export
@patch
def criterion(self: WMTrainer, global_step, z0, z, actions, msg_target, msg_hat, proj_h, proj_z, mask_t):

    flat_encodings = flatten_conv_output(z0) # [T, B, c`, h`, w`] => [T, B, D]
    sigreg_img = self.sigreg(flat_encodings, global_step= global_step, mask= mask_t)

    transition_mask = mask_t[1:] * mask_t[:-1]
    diff = (z0[1:] - z[1:]).pow(2).mean(dim=(2, 3, 4)) # (T-1, B)
    sim_loss = (diff * transition_mask).sum() / transition_mask.sum().clamp_min(1)

    if self.cfg.loss.vicreg.sim_coeff_t:
        diff_t = ( z0[1:] -  z0[:-1]).pow(2).mean(dim=(2, 3, 4))# (T-1, B)
        sim_loss_t = (diff_t * transition_mask).sum() / transition_mask.sum().clamp_min(1)
    else:
        sim_loss_t = torch.zeros([1], device=self.device)
    
    idm_loss = self.idm(embeddings= z0, predictions= z, actions= actions)
    
    # SENDER LOSSES
    msg_pred_loss = self.cross_entropy(msg_hat.flatten(0,1), msg_target.flatten(0,1)) #msg_hat: [B, T, 5, 7, 7], targe: [B, T, 7, 7] with long() dtype.

    sigreg_msg = self.sigreg(proj_h, global_step= global_step, mask= mask_t)
    sigreg_obs = self.sigreg(proj_z, global_step= global_step, mask= mask_t)

    inv_loss_sender = (proj_z - proj_h).square().mean(dim= -1)  # [T, B, d= 128] => [T, B]
    inv_loss_sender = (inv_loss_sender * mask_t).sum() / mask_t.sum().clamp_min(1) 

    return {
        'sigreg_img': sigreg_img,
        'sigreg_msg': sigreg_msg,
        'sigreg_obs': sigreg_obs,
        'sim_loss_dynamics': sim_loss,
        'sim_loss_t': sim_loss_t,
        'inv_loss_sender': inv_loss_sender,
        'msg_pred_loss': msg_pred_loss,
        'idm_loss': idm_loss
    }
    

In [None]:
# #| export
# @patch
# def criterion(
#     self: WMTrainer,
#     global_step,
#     z0,            # encoded latents [T, B, C, H, W] (stop-grad)
#     z_hat,         # predicted latents [T, B, C, H, W]
#     msg_target,
#     msg_hat,
#     proj_h,
#     proj_z,
#     mask_t,
#     actions=None,          # [T-1, B, A]  (for action-sep)
#     anchor_target=None     # optional (dx, dy) or moved flag
# ):
#     """
#     Combined JEPA loss with:
#     - delta dynamics (B)
#     - action separation (A)
#     - weak control anchor (C)
#     """

#     losses = {}

#     # ---------------------------------------------------------
#     # 1. SIGReg on image / message / obs (UNCHANGED)
#     # ---------------------------------------------------------
#     flat_encodings = flatten_conv_output(z0)  # [T, B, d]
#     losses['sigreg_img'] = self.disSigReg(flat_encodings[:1], global_step)
#     losses['sigreg_msg'] = self.disSigReg(proj_h, global_step)
#     losses['sigreg_obs'] = self.disSigReg(proj_z, global_step)

#     # ---------------------------------------------------------
#     # 2. Transition mask
#     # ---------------------------------------------------------
#     transition_mask = mask_t[1:] * mask_t[:-1]   # [T-1, B]

#     # ---------------------------------------------------------
#     # 3. DELTA DYNAMICS LOSS (Option B)
#     # ---------------------------------------------------------
#     delta_hat = z_hat[1:] - z_hat[:-1]
#     delta_true = z0[1:] - z0[:-1]

#     diff_delta = (delta_hat - delta_true).pow(2).mean(dim=(2, 3, 4))
#     losses['sim_loss_dynamics'] = (
#         (diff_delta * transition_mask).sum()
#         / transition_mask.sum().clamp_min(1)
#     )

#     # ---------------------------------------------------------
#     # 4. GATED TIME-SMOOTHNESS (important!)
#     # ---------------------------------------------------------
#     if self.cfg.loss.vicreg.sim_coeff_t:
#         # penalize only unexplained change
#         resid = (z0[1:] - z0[:-1]) - delta_hat.detach()
#         diff_t = resid.pow(2).mean(dim=(2, 3, 4))

#         losses['sim_loss_t'] = (
#             (diff_t * transition_mask).sum()
#             / transition_mask.sum().clamp_min(1)
#         )
#     else:
#         losses['sim_loss_t'] = torch.zeros(1, device=self.device)

#     # ---------------------------------------------------------
#     # 5. MESSAGE PREDICTION (UNCHANGED)
#     # ---------------------------------------------------------
#     losses['msg_pred_loss'] = self.cross_entropy(
#         msg_hat.flatten(0, 1),
#         msg_target.flatten(0, 1)
#     )

#     # ---------------------------------------------------------
#     # 6. SENDER / RECEIVER INVARIANCE (UNCHANGED)
#     # ---------------------------------------------------------
#     inv_loss_sender = (proj_z - proj_h).square().mean(dim=-1)
#     losses['inv_loss_sender'] = (
#         (inv_loss_sender * transition_mask).sum()
#         / transition_mask.sum().clamp_min(1)
#     )

#     # ---------------------------------------------------------
#     # # 7. ACTION SEPARATION LOSS (Option A)
#     # # ---------------------------------------------------------
#     # if actions is not None:
#     #     # sample a second action (shuffle within batch)
#     #     perm = torch.randperm(actions.shape[1])
#     #     actions_alt = actions[:, perm]

#     #     delta_a1 = self.predict_delta(z_hat[:-1], actions)
#     #     delta_a2 = self.predict_delta(z_hat[:-1], actions_alt)

#     #     act_sep = (delta_a1 - delta_a2).pow(2).mean(dim=(2, 3, 4))
#     #     losses['action_separation'] = -(
#     #         (act_sep * transition_mask).sum()
#     #         / transition_mask.sum().clamp_min(1)
#     #     )
#     # else:
#     #     losses['action_separation'] = torch.zeros(1, device=self.device)

#     # ---------------------------------------------------------
#     # 8. WEAK CONTROL ANCHOR (Option C)
#     # ---------------------------------------------------------
#     # if anchor_target is not None:
#     #     latents= flatten_conv_output(z0)  # [T, B, d]
#     #     latents = rearrange(latents, 't b d -> (t b) d')  # [(T*B), d]
#     #     anchor_pred = self.anchor_head(latents[:-1])
#     #     anchor_loss = self.anchor_loss(anchor_pred, anchor_target)

#     #     losses['anchor_loss'] = (
#     #         (anchor_loss * transition_mask).sum()
#     #         / transition_mask.sum().clamp_min(1)
#     #     )
#     # else:
#     #     losses['anchor_loss'] = torch.zeros(1, device=self.device)

#     return losses


In [None]:
# #| export
# @patch
# def setup_ddp_hooks(self, sender, rec):
#     # Map the symbols to the actual DDP-wrapped modules
#     modules_to_track = {
#         "phi_sender": self.get_module(self.model[sender].backbone),
#         "p_theta": self.get_module(self.comm_module),
#         "q_theta": self.get_module(self.model[sender].msg_enc),
#         "f_theta_rec": self.get_module(self.model[rec].jepa.dynamics) 
#     }

#     for name, m in modules_to_track.items():
#         m.register_full_backward_hook(self._hook_fn(name))

# def get_module(self, m):
#     """Helper to unwrap DDP module if necessary."""
#     return m.module if hasattr(m, 'module') else m

In [None]:
#| export
@patch
def get_sampling_prob(self: WMTrainer, epoch):
    if epoch < self.schedule_start_epoch:
        return 0.0  # Always use ground truth
    elif epoch >= self.schedule_end_epoch:
        return 1.0  # Always use predictions
    else:
        # Linear interpolation
        progress = (epoch - self.schedule_start_epoch) / (self.schedule_end_epoch - self.schedule_start_epoch)
        return progress

In [None]:
#| export
@patch
def sender_jepa(self: WMTrainer, data, sampling_prob):

    obs_sender, pos_sender, msg, msg_target, _,_, _ = data
    obs_sender = obs_sender.to(self.device)
    pos_sender = pos_sender.to(self.device)
    msg = msg.to(self.device)
    msg_target = msg_target.to(self.device)

        
    self.logger.info(f"device used for other agent data: {obs_sender.device}, {msg.device}")

    z = self.obs_enc(obs_sender, position = pos_sender)  #[B, T, c, h, w] => [B, T, c`, h`, w`]
    h_target = self.msg_enc(msg)  # [B, T, C, H, W] => [B, T, dim=32]
    proj_z, proj_h = self.proj(z, h_target) # True JEPA alignment

    msg_hat = self.comm_module(z)  # [B, T, c`, h`, w`] => [B, T, C=5, H=7, W=7]
    if torch.rand(1).item() < sampling_prob:
        sample = F.one_hot(msg_hat.argmax(dim=2), num_classes=5)  # [B, T, 7, 7, 5]
        sample = rearrange(sample, 'b t h w c -> b t c h w')# [B, T, 5, 7, 7]
        probs = F.softmax(msg_hat, dim=2)  # [B, T, 5, 7, 7]
        msg_used = sample + probs - probs.detach() # [B, T, C, H, W] `one-hot with straight-through`
        h_for_receiver = self.msg_enc(msg_used.to(probs.dtype)) # [B, T, C, H, W] => [B, T, dim=32]

    else:
        msg_used = msg  # [B, T, C, H, W]
        h_for_receiver = h_target  # Use target encoding when not sampling
    
    return msg_hat, msg_target, h_for_receiver, proj_z, proj_h

In [None]:
#| export
@patch
def rec_jepa(self: WMTrainer, data, h):
    obs, pos, _, _, act, _, dones = data
    mask = (~dones.bool()).float().to(self.device) # [B, T, d=1]
    mask = rearrange(mask, 'b t d-> b (t d)', d=1)
    mask_t = rearrange(mask, 'b t -> t b')

    if mask.sum() == 0:
        return  
 
    obs = obs.to(self.device)
    pos = pos.to(self.device)
    act = act.to(self.device)

    self.logger.info(f"device used for main agent data: {mask_t.device} {obs.device}, {act.device}")
    
    z0, z = self.jepa(x= obs, #[B, T, c, h, w] =>  [T, B, c, h, w]
                      pos= pos,
                      actions= act,
                      msgs= h,
                      T= act.size(1)-1)
    
    return z0, z, act, mask_t, mask, len(obs)

In [None]:
# #| export
# from mawm.models.utils import flatten_conv_output
# from einops import rearrange
# @patch
# def train_epoch(self: WMTrainer, epoch):
#     self.model.train()
#     self.msg_enc.train()
#     self.comm_module.train()
#     self.proj.train()
    
#     total_running_loss = 0.0
#     total_valid_steps = 0

#     self.logger.info(f"Device used: {self.device}")
#     self.sampler.set_epoch(epoch) if epoch > 0 else None

#     sampling_prob = self.get_sampling_prob(epoch)
#     for batch_idx, data in enumerate(self.train_loader):
        
#         global_step = epoch * len(self.train_loader) + batch_idx
#         self.optimizer.zero_grad()
#         batch_loss = 0

#         for rec in self.agents:
#             #### SENDER JEPA
#             for sender in self.agents:
#                 if sender == rec: continue
#                 msg_hat, msg_target, h, proj_z, proj_h = self.sender_jepa(data[sender].values(), sender, sampling_prob)
            
#             ### RECEIVER JEPA
#             z0, z, act, mask_t, mask, len_obs = self.rec_jepa(data[rec].values(), rec, h)
            
#             losses = self.criterion(global_step, z0, z, act, msg_target, msg_hat, proj_h, proj_z, mask_t)
            
#             if self.verbose:
#                 to_log = {
#                     f'{rec}/sigreg_img': losses['sigreg_img'].item(),
#                     f'{rec}/sim_loss': losses['sim_loss_dynamics'].item(),
#                     f'{rec}/sim_loss_t': losses['sim_loss_t'].item(),
#                     f'{rec}/idm_loss': losses['idm_loss'].item(),

#                     f'{sender}/sigreg_obs': losses['sigreg_obs'].item(),

#                     'shared/msg_pred_loss': losses['msg_pred_loss'].item(),
#                     'shared/sigreg_msg': losses['sigreg_msg'].item(),
#                     'shared/inv_loss_sender': losses['inv_loss_sender'].item(),
#                 }

#                 if self.verbose:
#                     self.writer.write(to_log)
            
#             self.logger.info("Losses: %s" % str({k: v.item() for k, v in losses.items()}))
            
#             s_jepa = self.lambda_ * (losses['sigreg_obs'] + losses['sigreg_msg']) + (1 - self.lambda_) * losses['inv_loss_sender']
#             r_jepa =  self.lambda_ * losses['sigreg_img'] + (1 - self.lambda_) * losses['sim_loss_dynamics']
#             task_loss = (self.W_H_PRED * losses['msg_pred_loss'] + 
#                          self.W_SIM_T * losses['sim_loss_t'] + 
#                          self.cfg.loss.idm.coeff * losses['idm_loss'])

#             self.logger.info(f"JEPA Losses: sender_jepa_loss: {s_jepa.item():.4f}, rec_jepa_loss: {r_jepa.item():.4f}, task_loss: {task_loss.item():.4f}")

#             pair_loss = s_jepa + r_jepa + task_loss
#             num_pairs = len(self.agents) * (len(self.agents) - 1)
#             scaled_loss = pair_loss / num_pairs #len(self.agents)
            
#             scaled_loss.backward()
#             self.logger.info(f"Agent: {rec}, agent_loss: {pair_loss.item():.4f}")
            
#             batch_loss += scaled_loss

#             num_valid = mask.sum().item()
#             total_running_loss += pair_loss.item() * num_valid
#             total_valid_steps += num_valid
            
#         loss = batch_loss
#         self.optimizer.step()

#         if batch_idx % 20 == 0:
#             self.logger.info(f'Train Epoch: {epoch} [{batch_idx * len_obs}/{len(self.train_loader.dataset)} '
#                   f'({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

#     final_epoch_loss = (total_running_loss / total_valid_steps) / len(self.agents) if total_valid_steps > 0 else 0.0
#     self.logger.info(f'====> Epoch: {epoch} Average loss: {final_epoch_loss:.4f}')

#     return final_epoch_loss
       

In [None]:
#| export
from mawm.models.utils import flatten_conv_output
from einops import rearrange
@patch
def train_epoch(self: WMTrainer, epoch):
    
    total_running_loss = 0.0
    total_valid_steps = 0
    num_pairs = len(self.agents) * (len(self.agents) - 1) # Equals 2 for two agents

    self.sampler.set_epoch(epoch) if epoch > 0 else None
    sampling_prob = self.get_sampling_prob(epoch)

    for batch_idx, data in enumerate(self.train_loader):
        global_step = epoch * len(self.train_loader) + batch_idx
        self.optimizer.zero_grad()
        
        batch_log_accumulator = {}

        for rec in self.agents:
            for sender in self.agents:
                if sender == rec: continue
                
                msg_hat, msg_target, h, proj_z, proj_h = self.sender_jepa(
                    data[sender].values(), sampling_prob
                )
                
                z0, z, act, mask_t, mask, len_obs = self.rec_jepa(
                    data[rec].values(), h
                )
                
                losses = self.criterion(global_step, z0, z, act, msg_target, msg_hat, proj_h, proj_z, mask_t)
                self.logger.info("Losses: %s" % str({k: v.item() for k, v in losses.items()}))
                
                s_jepa = self.lambda_ * (losses['sigreg_obs'] + losses['sigreg_msg']) + (1 - self.lambda_) * losses['inv_loss_sender']
                r_jepa = self.lambda_ * losses['sigreg_img'] + (1 - self.lambda_) * losses['sim_loss_dynamics']
                task_loss = (self.W_H_PRED * losses['msg_pred_loss'] + 
                             self.W_SIM_T * losses['sim_loss_t'] + 
                             self.cfg.loss.idm.coeff * losses['idm_loss'])

                pair_loss = s_jepa + r_jepa + task_loss

                self.logger.info(f"JEPA Losses: sender_jepa_loss: {s_jepa.item():.4f}, rec_jepa_loss: {r_jepa.item():.4f}, task_loss: {task_loss.item():.4f}")
                
                scaled_loss = pair_loss / num_pairs
                scaled_loss.backward()

                num_valid = mask.sum().item()
                total_running_loss += pair_loss.item() * num_valid
                total_valid_steps += num_valid

                if self.verbose:
                    for k, v in losses.items():
                        batch_log_accumulator[f'{rec}_as_rec/{k}'] = v.item()

        self.optimizer.step()

        # if batch_idx % 2 == 0:
        if self.verbose:
            self.writer.write(batch_log_accumulator)
            
            processed_samples = batch_idx * len_obs 
            self.logger.info(f'Train Epoch: {epoch} [{processed_samples}/{len(self.train_loader.dataset)} '
                             f'({100. * batch_idx / len(self.train_loader):.0f}%)]\t'
                             f'Pair Avg Loss: {pair_loss.item():.64f}')

    final_epoch_loss = (total_running_loss / total_valid_steps) if total_valid_steps > 0 else 0.0
    return final_epoch_loss

In [None]:
#| export
import wandb
CHECKPOINT_FREQ = 1
@patch
def fit(self: WMTrainer):
    self.jepa.train()
    self.obs_enc.train()
    self.msg_enc.train()
    self.comm_module.train()
    self.proj.train()

    latest_file = "latest.pt"
    folder = self.dmpc_dir
    latest_path = os.path.join(folder, latest_file)
    
    loss_meter = AverageMeter()
    lst_dfs = []

    for epoch in range(1, self.cfg.epochs + 1):
        self.logger.info("Epoch %d" % (epoch))
        lr = self.scheduler.adjust_learning_rate(epoch)
        train_loss = self.train_epoch(epoch)
        loss_meter.update(train_loss)
        
        def save_checkpoint(epoch, path):
            if not self.verbose:
                return
            
            def get_state(m):
                return m.module.state_dict() if hasattr(m, 'module') else m.state_dict()
            
            save_dict = {
                'epoch': epoch,
                'jepa': get_state(self.jepa),
                'obs_enc': get_state(self.obs_enc),
                'msg_enc': get_state(self.msg_enc),
                'comm_module': get_state(self.comm_module),
                'proj': get_state(self.proj),
                'train_loss': train_loss,
                'optimizer': self.optimizer.state_dict(),
                "lr": lr,
            }
            
            try:
                torch.save(save_dict, path)
                self.logger.info(f"Successfully saved checkpoint to {path}")
            except Exception as e:
                self.logger.info(f"Encountered exception when saving checkpoint: {e}")
        
        self.logger.info("avg. loss %.3f" % loss_meter.avg)

        df = pd.DataFrame.from_records([{"epoch": epoch ,"train_loss": train_loss}], index= "epoch")
        lst_dfs.append(df)

        if epoch % CHECKPOINT_FREQ == 0 or epoch == (self.cfg.epochs - 1):
            save_checkpoint(epoch + 1, latest_path)
            if self.cfg.save_every_freq > 0 and epoch % self.cfg.save_every_freq == 0:
                save_every_file = f"e{epoch}.pt"
                save_every_path = os.path.join(folder, save_every_file)
                save_checkpoint(epoch + 1, save_every_path)

        to_log = {
            "train_loss": train_loss, 
        }

        if self.verbose:
            self.writer.write(to_log)

    df_res = pd.concat(lst_dfs)
    df_reset = df_res.reset_index()
    if self.verbose:
        self.writer.write({'Train Loss Table': wandb.Table(dataframe= df_reset)})
        self.writer.finish()
    return df_reset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()