# Models utilities module

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp models.__init__

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
import torch

In [None]:
#| export
from functools import reduce
from torch.nn.parallel import DistributedDataParallel
from mawm.models.jepa import JEPA
from mawm.models.comm import MSGEnc, CommModule
from mawm.models.misc import JepaProjector
import logging

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

def init_models(
    cfg,
    device,
    distributed= True
):
    model = {}
    total_params = {}
    for agent in cfg.env.agents:
        
        jepa = JEPA(cfg.model, 
                    input_dim=(cfg.model.channels, cfg.model.img_size, cfg.model.img_size),
                    action_dim= cfg.model.predictor.action_dim)
        jepa.to(device)
        
        comm_module = CommModule(input_channel= 32, num_primitives= 5)
        comm_module.to(device)

        msg_enc = MSGEnc(num_primitives= 5, latent_dim = 32)
        msg_enc.to(device) 
        
        proj = JepaProjector(z_channels=32, c_input_dim=msg_enc.latent_dim)
        proj.to(device)

        def count_parameters(model):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        logger.info("Agent: %s", agent)
        logger.info("JEPA Parameters: %d", count_parameters(jepa))
        logger.info("CommModule Parameters: %d", count_parameters(comm_module))
        logger.info("MSgEncoder Parameters: %d", count_parameters(msg_enc))
        logger.info("Projector Parameters: %d", count_parameters(proj))
        
        logger.info("-"*50)
        total_params[agent] = count_parameters(jepa) + count_parameters(comm_module) + count_parameters(msg_enc) + count_parameters(proj)

        model[agent] = {}
        if distributed:
            model[agent]["jepa"] = DistributedDataParallel(jepa, device_ids = [device], find_unused_parameters=True)
            model[agent]["comm_module"] = DistributedDataParallel(comm_module, device_ids = [device], find_unused_parameters=True)
            model[agent]["msg_enc"] = DistributedDataParallel(msg_enc, device_ids = [device], find_unused_parameters= True)
            model[agent]["proj"] = DistributedDataParallel(proj, device_ids = [device], find_unused_parameters= True)

        else:
            model[agent]["jepa"] = jepa
            model[agent]["comm_module"] = comm_module
            model[agent]["msg_enc"] = msg_enc
            model[agent]["proj"] = proj

    logger.info("Total Parameters: %d", sum(total_params.values()))
    logger.info("Parameters per Agent: %s", sum(total_params.values())/len(cfg.env.agents))


        
    
    return model

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

cfg.model.backbone
models = init_models(cfg, "cpu", distributed= False)

INFO:root:Agent: agent_0
INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Agent: agent_1
INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Total Parameters: 4857418
INFO:root:Parameters per Agent: 2428709.0


In [None]:
#| hide
models.keys(), models['agent_0'].keys()

(dict_keys(['agent_0', 'agent_1']),
 dict_keys(['jepa', 'comm_module', 'msg_enc', 'proj']))

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()