# Models utilities module

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp dist.model_utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
import torch

In [None]:
#| export
from mawm.models.jepa import JEPA
from mawm.models.vision import SemanticEncoder
from mawm.models.misc import MsgPred, ObsPred
import logging

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

def init_models_dist(
        cfg,
        device,
):
    
    
    jepa = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim= cfg.data.action_dim)
    msg_enc = SemanticEncoder(num_primitives= 5, latent_dim = 32)
    msg_pred = MsgPred(h_dim=32)
    obs_pred = ObsPred(h_dim=32)

    jepa.to(device)
    msg_enc.to(device)
    msg_pred.to(device)
    obs_pred.to(device)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    logger.info("JEPA Parameters: %d", count_parameters(jepa))
    logger.info("SemanticEncoder Parameters: %d", count_parameters(msg_enc))
    logger.info("MsgPred Parameters: %d", count_parameters(msg_pred))
    logger.info("ObsPred Parameters: %d", count_parameters(obs_pred))
    
    return jepa, msg_enc, msg_pred, obs_pred

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()