# Utilities for Optimization

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp optimizers.schedulers

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from functools import partial
from torch.optim import Optimizer

### Cosine Learning Rate Scheduler Adjustment

In [None]:
#| export
import torch
import math
class Scheduler:
    def __init__(
        self,
        schedule: str,
        base_lr: float,
        data_loader,
        epochs: int,
        optimizer,
        batch_steps=None,
        batch_size=None,
    ):
        self.schedule = schedule
        self.base_lr = base_lr
        self.data_loader = data_loader
        self.epochs = epochs
        self.optimizer = optimizer

        if batch_size is None:
            self.batch_size = data_loader.config.batch_size
        else:
            self.batch_size = batch_size

        if batch_steps is None:
            self.batch_steps = len(data_loader)
        else:
            self.batch_steps = batch_steps

    # def adjust_learning_rate(self, step: int):
    #     if self.schedule == "constant":
    #         return self.base_lr
    #     else:
    #         max_steps = self.epochs * self.batch_steps
    #         warmup_steps = int(0.10 * max_steps)
    #         for param_group in self.optimizer.param_groups:
    #             base_lr = (
    #                 param_group["base_lr"] if "base_lr" in param_group else self.base_lr
    #             )
    #             base_lr = base_lr * self.batch_size / 256
    #             if step < warmup_steps:
    #                 lr = base_lr * step / warmup_steps
    #             else:
    #                 step -= warmup_steps
    #                 max_steps -= warmup_steps
    #                 q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
    #                 end_lr = base_lr * 0.001
    #                 lr = base_lr * q + end_lr * (1 - q)
    #             param_group["lr"] = lr
    #         return lr

    # def adjust_learning_rate(self, step: int):
    #     if self.schedule == "constant":
    #         return self.base_lr
        
    #     # 1. Calculate totals outside the loop
    #     total_max_steps = self.epochs * self.batch_steps
    #     warmup_steps = int(0.10 * total_max_steps)
        
    #     # 2. Prevent division by zero if total_max_steps is very small
    #     if total_max_steps == warmup_steps:
    #         warmup_steps = max(1, warmup_steps - 1)

    #     for param_group in self.optimizer.param_groups:
    #         base_lr = param_group.get("base_lr", self.base_lr)
    #         # Scaling LR by batch size (Linear Scaling Rule)
    #         # base_lr = base_lr * self.batch_size / 256
            
    #         if step < warmup_steps:
    #             lr = base_lr * step / warmup_steps
    #         else:
    #             # Use local variables for the decay calculation 
    #             # to avoid modifying the outer scope variables
    #             current_decay_step = step - warmup_steps
    #             decay_period = total_max_steps - warmup_steps
                
    #             # Ensure we don't divide by zero if epochs are too low
    #             if decay_period <= 0:
    #                 lr = base_lr
    #             else:
    #                 q = 0.5 * (1 + math.cos(math.pi * current_decay_step / decay_period))
    #                 end_lr = base_lr * 0.001
    #                 lr = base_lr * q + end_lr * (1 - q)
            
    #         param_group["lr"] = lr
    #     return lr
    
    def adjust_learning_rate(self, step: int):
        if self.schedule == "constant":
            return self.base_lr
        
        total_max_steps = self.epochs * self.batch_steps
        warmup_steps = int(0.10 * total_max_steps)
        
        if total_max_steps == warmup_steps:
            warmup_steps = max(1, warmup_steps - 1)
        
        for param_group in self.optimizer.param_groups:
            # Get the ORIGINAL base_lr for this group
            if 'base_lr' not in param_group:
                param_group['base_lr'] = param_group['lr']  # Store initial LR
            
            group_base_lr = param_group['base_lr']
            
            # Apply warmup/cosine schedule to this group's base LR
            if step < warmup_steps:
                lr = group_base_lr * step / warmup_steps
            else:
                current_decay_step = step - warmup_steps
                decay_period = total_max_steps - warmup_steps
                
                if decay_period <= 0:
                    lr = group_base_lr
                else:
                    q = 0.5 * (1 + math.cos(math.pi * current_decay_step / decay_period))
                    end_lr = group_base_lr * 0.001
                    lr = group_base_lr * q + end_lr * (1 - q)
            
            param_group['lr'] = lr
            
            # Log LR for different groups
            if step % 100 == 0:
                group_name = param_group.get('name', 'unnamed')
                print(f"Step {step}, Group {group_name}: LR = {lr:.6e}")
        
        # Return the base LR (not group-specific)
        return self.base_lr

In [None]:
#| hide
from omegaconf import OmegaConf

In [None]:
#| hide
cfg = OmegaConf.load("../cfgs/findgoal/mawm/main/mawm-seq-40.yaml")

In [None]:
#| hide
from mawm.data.utils import init_data
dl, _ = init_data(cfg)

  from .autonotebook import tqdm as notebook_tqdm


Data path found for hostname: local
Using all 10 rollouts in dataset.
Using all 10 rollouts in dataset.


In [None]:
#| hide
from mawm.models import init_models
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

model = init_models(cfg, "cpu", distributed= False)

INFO:root:JEPA Parameters: 98560
INFO:root:CommModule Parameters: 56005
INFO:root:MSgEncoder Parameters: 32608
INFO:root:Projector Parameters: 2241536
INFO:root:--------------------------------------------------
INFO:root:Total Parameters: 2462245


In [None]:
#| hide
from mawm.optimizers.utils import init_opt
optimizer = init_opt(cfg, model)

In [None]:
#| hide
scheduler = Scheduler(
        schedule=cfg.optimizer.scheduler.name,
        base_lr=cfg.optimizer.lr,
        data_loader=dl,
        epochs=cfg.epochs,
        optimizer=optimizer,
        batch_size=cfg.data.batch_size,
)

In [None]:
#| hide
epoch = 300
global_step = epoch * len(dl) + 0
print(optimizer.param_groups[0]["lr"])
lr = scheduler.adjust_learning_rate(global_step)
print(optimizer.param_groups[0]["lr"])

1.0000000000000001e-07
Step 300, Group jepa: LR = 1.000000e-07
Step 300, Group encoders: LR = 1.500000e-07
Step 300, Group comm_module: LR = 2.000000e-07
Step 300, Group proj: LR = 1.500000e-07
1.0000000000000001e-07


### V-JEPA schedulers

In [None]:
#| export
import math


class WSDSchedule(object):

    def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.anneal_steps = anneal_steps
        self.warmup_steps = warmup_steps
        self.T_max = T_max - warmup_steps - anneal_steps
        self._step = 0.0

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            progress = float(self._step) / float(max(1, self.warmup_steps))
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        elif self._step < self.T_max + self.warmup_steps:
            new_lr = self.ref_lr
        else:
            _step = self._step - (self.T_max + self.warmup_steps)
            progress = float(_step) / float(max(1, self.anneal_steps))
            new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)

        for group in self.optimizer.param_groups:
            group["lr"] = new_lr
            if "lr_scale" in group:
                group["lr"] *= group["lr_scale"]

        return new_lr


class WarmupCosineSchedule(object):

    def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.T_max = T_max - warmup_steps
        self._step = 0.0

    def step(self):
        self._step += 1
        if self._step < self.warmup_steps:
            progress = float(self._step) / float(max(1, self.warmup_steps))
            new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
        else:
            # -- progress after warmup
            progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
            new_lr = max(
                self.final_lr,
                self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)),
            )

        for group in self.optimizer.param_groups:
            group["lr"] = new_lr

        return new_lr


class CosineWDSchedule(object):

    def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0):
        self.optimizer = optimizer
        self.ref_wd = ref_wd
        self.final_wd = final_wd
        self.T_max = T_max
        self._step = 0.0

    def step(self):
        self._step += 1
        progress = self._step / self.T_max
        new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress))

        if self.final_wd <= self.ref_wd:
            new_wd = max(self.final_wd, new_wd)
        else:
            new_wd = min(self.final_wd, new_wd)

        for group in self.optimizer.param_groups:
            if ("WD_exclude" not in group) or not group["WD_exclude"]:
                group["weight_decay"] = new_wd
        return new_wd


class LinearDecaySchedule(object):

    def __init__(self, optimizer, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
        self.optimizer = optimizer
        self.ref_lr = ref_lr
        self.final_lr = final_lr
        self.T_max = T_max
        self._step = 0.0

    def step(self):
        self._step += 1
        progress = float(self._step) / float(max(1, self.T_max))
        new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
        for group in self.optimizer.param_groups:
            group["lr"] = new_lr

        return new_lr


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()