# Utilities for Optimization

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp optimizers.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from functools import partial
from torch.optim import Optimizer

In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
def init_opt(cfg, model):
    optimizer_cls = get_cls("torch.optim", cfg.optimizer.name)
    optimizer = optimizer_cls(model.parameters(), lr=cfg.optimizer.lr)
    return optimizer

In [None]:
#| export
# Source - https://stackoverflow.com/a
# Posted by isle_of_gods, modified by community. See post 'Timeline' for change history
# Retrieved 2025-11-15, License - CC BY-SA 4.0

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


In [None]:
# #| export
# def init_models_dist(
#     device,
#     patch_size=16,
#     max_num_frames=16,
#     tubelet_size=2,
#     model_name="vit_base",
#     crop_size=224,
#     pred_depth=6,
#     pred_num_heads=None,
#     pred_embed_dim=384,
#     uniform_power=False,
#     use_sdpa=False,
#     use_rope=False,
#     use_silu=False,
#     use_pred_silu=False,
#     wide_silu=False,
#     pred_is_frame_causal=True,
#     use_activation_checkpointing=False,
#     return_all_tokens=False,
#     action_embed_dim=7,
#     use_extrinsics=False,
#     old_pred=False,
# ):
#     encoder = video_vit.__dict__[model_name](
#         img_size=crop_size,
#         patch_size=patch_size,
#         num_frames=max_num_frames,
#         tubelet_size=tubelet_size,
#         uniform_power=uniform_power,
#         use_sdpa=use_sdpa,
#         use_silu=use_silu,
#         wide_silu=wide_silu,
#         use_activation_checkpointing=use_activation_checkpointing,
#         use_rope=use_rope,
#     )

#     predictor = vit_ac_pred.__dict__["vit_ac_predictor"](
#         img_size=crop_size,
#         patch_size=patch_size,
#         num_frames=max_num_frames,
#         tubelet_size=tubelet_size,
#         embed_dim=encoder.embed_dim,
#         predictor_embed_dim=pred_embed_dim,
#         action_embed_dim=action_embed_dim,
#         depth=pred_depth,
#         is_frame_causal=pred_is_frame_causal,
#         num_heads=encoder.num_heads if pred_num_heads is None else pred_num_heads,
#         uniform_power=uniform_power,
#         use_rope=use_rope,
#         use_sdpa=use_sdpa,
#         use_silu=use_pred_silu,
#         wide_silu=wide_silu,
#         use_extrinsics=use_extrinsics,
#         use_activation_checkpointing=use_activation_checkpointing,
#     )

#     encoder.to(device)
#     predictor.to(device)
#     logger.info(encoder)
#     logger.info(predictor)

#     def count_parameters(model):
#         return sum(p.numel() for p in model.parameters() if p.requires_grad)

#     logger.info(f"Encoder number of parameters: {count_parameters(encoder)}")
#     logger.info(f"Predictor number of parameters: {count_parameters(predictor)}")

#     return encoder, predictor


In [None]:
#| export
# from mawm.optimizers.schedulers import WSDSchedule, CosineWDSchedule
import torch
def init_opt_dis(
    cfg,
    jepa,
    msg_encoder,
    msg_pred,
    obs_pred,
    betas=(0.9, 0.999),
    eps=1e-8,
):
    all_params = (
        list(jepa.parameters()) + 
        list(msg_encoder.parameters()) + 
        list(msg_pred.parameters()) +
        list(obs_pred.parameters())
    )

    optimizer = torch.optim.AdamW(all_params, betas=betas, eps=eps)
    # scheduler = WSDSchedule(
    #     optimizer,
    #     warmup_steps=int(warmup * iterations_per_epoch),
    #     anneal_steps=int(anneal * iterations_per_epoch),
    #     start_lr=start_lr,
    #     ref_lr=ref_lr,
    #     final_lr=final_lr,
    #     T_max=int(num_epochs * iterations_per_epoch),
    # )
    # wd_scheduler = CosineWDSchedule(
    #     optimizer,
    #     ref_wd=wd,
    #     final_wd=final_wd,
    #     T_max=int(num_epochs * iterations_per_epoch),
    # )
    # scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
    return optimizer#, scaler, scheduler, wd_scheduler


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()