# Models utilities module

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp models.__init__

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
import torch

In [None]:
#| export
from functools import reduce
from torch.nn.parallel import DistributedDataParallel
from mawm.models.jepa import JEPA
from mawm.models.vision import MeNet6
from mawm.models.comm import MSGEnc, CommModule
from mawm.models.misc import JepaProjector
import logging

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

def init_models(
cfg,
device,
distributed= True
):
    model = {}
    total_params = 0
    
    jepa = JEPA(cfg.model, 
                input_dim=(cfg.model.channels, cfg.model.img_size, cfg.model.img_size),
                action_dim= cfg.model.predictor.action_dim)
    jepa.to(device)

    msg_enc = MSGEnc(num_primitives= 5, latent_dim = 32)
    msg_enc.to(device) 

    obs_enc = MeNet6(cfg.model.backbone, (3, cfg.model.img_size, cfg.model.img_size)) 
    obs_enc.to(device)
    
    proj = JepaProjector(z_channels=32, c_input_dim=msg_enc.latent_dim)
    proj.to(device)

    comm_module = CommModule(input_channel= 32, num_primitives= 5)
    comm_module.to(device)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    logger.info("JEPA Parameters: %d", count_parameters(jepa))
    logger.info("CommModule Parameters: %d", count_parameters(comm_module))
    logger.info("MSgEncoder Parameters: %d", count_parameters(msg_enc))
    logger.info("Projector Parameters: %d", count_parameters(proj))
    
    logger.info("-"*50)
    total_params += count_parameters(jepa) + count_parameters(comm_module) + count_parameters(msg_enc) + count_parameters(proj) + count_parameters(obs_enc)

    model["rec"] = {}
    model["send"] = {}
    if distributed:
        model["rec"]['jepa'] = DistributedDataParallel(jepa, device_ids = [device], find_unused_parameters=True)
        
        model["send"]["obs_enc"] = DistributedDataParallel(obs_enc, device_ids = [device], find_unused_parameters=True)
        model["send"]["msg_enc"] = DistributedDataParallel(msg_enc, device_ids = [device], find_unused_parameters= True)
        model["send"]["proj"] = DistributedDataParallel(proj, device_ids = [device], find_unused_parameters= True)
        model["send"]["comm_module"] = DistributedDataParallel(comm_module, device_ids = [device], find_unused_parameters=True)

    else:
        model["rec"]["jepa"] = jepa
        
        model["send"]["obs_enc"] = obs_enc
        model["send"]["msg_enc"] = msg_enc
        model["send"]["proj"] = proj
        model["send"]["comm_module"] = comm_module

    logger.info("Total Parameters: %d", total_params)
    

    return model

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

cfg.model.backbone
models = init_models(cfg, "cpu", distributed= False)

INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Total Parameters: 2462245


In [None]:
#| hide
models['send'].keys(), models['rec'].keys()

(dict_keys(['obs_enc', 'msg_enc', 'proj', 'comm_module']), dict_keys(['jepa']))

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()