# Utilities for Optimization

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp optimizers.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from functools import partial
from torch.optim import Optimizer

In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
def get_opt(cfg, model):
    optimizer_cls = get_cls("torch.optim", cfg.optimizer.name)
    optimizer = optimizer_cls(model.parameters(), lr=cfg.optimizer.lr)
    return optimizer

In [None]:
#| export
# Source - https://stackoverflow.com/a
# Posted by isle_of_gods, modified by community. See post 'Timeline' for change history
# Retrieved 2025-11-15, License - CC BY-SA 4.0

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


In [None]:
#| export
import torch
def init_opt(
    cfg,
    models,
):
    # all_params = []
    # for k in models.keys():
    #     for model in models[k].values():
    #         all_params += list(model.parameters())
    # optimizer = torch.optim.AdamW(all_params, lr= cfg.optimizer.lr, betas=betas, eps=eps)

    base_lr = cfg.optimizer.lr
    jepa_params = list(models['rec']['jepa'].parameters())
    encoder_params = list(models["send"]["obs_enc"].parameters()) + list(models['send']['msg_enc'].parameters())
    comm_params = list(models['send']['comm_module'].parameters())
    proj_params = list(models['send']['proj'].parameters())
    
    param_groups = [
        {'params': jepa_params, 'lr': 0.5 * base_lr, 'name': 'jepa'},
        {'params': encoder_params, 'lr': base_lr, 'name': 'encoders'},
        {'params': comm_params, 'lr': base_lr * 1.0, 'name': 'comm_module'},
        {'params': proj_params, 'lr': base_lr * 0.5, 'name': 'proj'}
    ]
    
    optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)
    

    return optimizer


In [None]:
#| hide
from mawm.models import init_models
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

model = init_models(cfg, "cpu", distributed= False)

INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Total Parameters: 2462245


In [None]:
#| hide
optimizer = init_opt(cfg, model)
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    name: jepa
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00015000000000000001
    maximize: False
    name: encoders
    weight_decay: 0.01

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0002
    maximize: False
    name: comm_module
    weight_decay: 0.01

Parameter Group 3
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: 

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()