In [None]:
!pip install adam-mini

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import warnings

try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("Triton not installed. Triton optimizer unavailable.")

# Map string to Triton math and PyTorch dtypes
_DTYPE_MAP = {}
if HAS_TRITON:
    _DTYPE_MAP = {
        "bf16": (tl.bfloat16, torch.bfloat16),
        "fp16": (tl.float16, torch.float16),
        "fp32": (tl.float32, torch.float32),
        "fp8e4": (tl.float8e4nv, getattr(torch, "float8_e4m3fn", None))
    }

_warn_once_cache = set()
def warn_once(msg: str):
    if msg not in _warn_once_cache:
        warnings.warn(msg)
        _warn_once_cache.add(msg)

# Triton kernels
if HAS_TRITON:
    @triton.jit
    def _reduce_g2_kernel_parallel(g_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
        # This kernel sums g^2 over a tile of size BLOCK_SIZE and atomically adds to output_ptr.
        pid = tl.program_id(axis=0)
        offset = pid * BLOCK_SIZE
        idx = offset + tl.arange(0, BLOCK_SIZE)
        mask = idx < n_elements
        g = tl.load(g_ptr + idx, mask=mask, other=0.0).to(tl.float32)
        sum_g2 = tl.sum(g * g, axis=0)
        tl.atomic_add(output_ptr, sum_g2)

    @triton.jit
    def _update_kernel(
        p_ptr, g_ptr, m_ptr, amax_m_ptr,
        lr_p, beta1_p, weight_decay_p, bias_correction1_p,
        step_scale_ptr, n_elements,
        BLOCK_SIZE: tl.constexpr,
        T_MATH: tl.constexpr, T_STATE: tl.constexpr, STATE_IS_FP8: tl.constexpr
    ):
        """
        Triton kernel to fuse one optimization step:
          - Loads parameter p, gradient g, momentum m (optionally in FP8), each as vectors.
          - Applies AdamW-style weight decay: p = p * (1 - lr*wd).
          - Updates momentum m_new = beta1*m + (1-beta1)*g.
          - Computes bias-corrected m_hat = m_new * bias_correction1.
          - Updates parameter p_new = p - step_scale * m_hat.
          - Writes back p_new (original dtype) and m_new (to state dtype).
          - Updates amax_m_ptr with max(|m_new|) for FP8 saturation check.
        """
        pid = tl.program_id(axis=0)
        offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements

        # Load parameter and gradient
        p = tl.load(p_ptr + offsets, mask=mask).to(T_MATH)
        g = tl.load(g_ptr + offsets, mask=mask).to(T_MATH)
        # Load previous momentum (FP8 or not)
        if STATE_IS_FP8:
            m = tl.load(m_ptr + offsets, mask=mask).to(T_STATE).to(T_MATH)
        else:
            m = tl.load(m_ptr + offsets, mask=mask).to(T_MATH)

        # Broadcast scalars
        lr = tl.full((), lr_p, T_MATH)
        beta1 = tl.full((), beta1_p, T_MATH)
        wd = tl.full((), weight_decay_p, T_MATH)
        bc1 = tl.full((), bias_correction1_p, T_MATH)
        step_scale = tl.load(step_scale_ptr).to(T_MATH)

        # Decoupled weight decay (p = p * (1 - lr*wd))
        if wd != 0:
            p = p * (1 - lr * wd)

        # Momentum update (no extra grad scaling)
        m_new = beta1 * m + (1 - beta1) * g
        m_hat = m_new * bc1
        p_new = p - step_scale * m_hat

        # Track max(|m_new|) for FP8 overflow check
        max_m_val = tl.max(tl.abs(m_new.to(tl.float32)))
        tl.atomic_max(amax_m_ptr, max_m_val)

        # Write updated parameter and momentum
        tl.store(p_ptr + offsets, p_new.to(p_ptr.dtype.element_ty), mask=mask)
        tl.store(m_ptr + offsets, m_new.to(T_STATE), mask=mask)

# PyTorch (reference) implementation of Adam-mini
class AdamMiniPyTorch(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.block_states = {}

    def _init_groups_and_states(self):
        for group in self.param_groups:
            for p in group["params"]:
                # First moment (exp_avg) per parameter
                self.state[p]["exp_avg"] = torch.zeros_like(p)
                # v_bar per "block" (here each parameter is its own block)
                self.block_states[id(p)] = {
                    "v_bar": torch.zeros(1, device=p.device, dtype=torch.float32),
                    "step": 0
                }
        self.state["initialized"] = True

    @torch.no_grad()
    def step(self, closure=None):
        if not self.state.get("initialized", False):
            self._init_groups_and_states()
        for group in self.param_groups:
            beta1, beta2 = group["betas"]
            lr, eps, wd = group["lr"], group["eps"], group["weight_decay"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]
                block = self.block_states[id(p)]

                # Increment step for this block
                block["step"] += 1
                step = block["step"]

                # Block-wise second moment: moving average of mean(g^2)
                g2 = (grad.to(torch.float32) ** 2).mean()
                v_bar = block["v_bar"]
                v_bar.mul_(beta2).add_(g2, alpha=1 - beta2)

                # Bias corrections
                bc1 = 1.0 - beta1 ** step
                bc2 = 1.0 - beta2 ** step

                # Compute step size = lr / (sqrt(v_bar/bc2) + eps)
                denom = (v_bar / bc2).sqrt() + eps
                step_scale = lr / denom

                # Decoupled weight decay (AdamW)
                if wd != 0:
                    p.mul_(1 - lr * wd)

                # Update first moment
                exp_avg = state["exp_avg"]
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                m_hat = exp_avg / bc1
                # Parameter update
                p.add_(-step_scale * m_hat)

                # Telemetry for debugging
                state.setdefault("telemetry", {
                    "steps": 0,
                    "max_update": [],
                    "max_m": [],
                    "dtype": str(exp_avg.dtype)
                })
                state["telemetry"]["steps"] += 1
                # Record max parameter update magnitude (|step_scale * m_hat|)
                max_update_val = (step_scale * m_hat).abs().max().item()
                state["telemetry"]["max_update"].append(max_update_val)
                # Record max momentum magnitude
                state["telemetry"]["max_m"].append(exp_avg.abs().max().item())

# Triton-optimized Adam-mini
class AdamMiniTriton(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0,
                 math_dtype="bf16", state_dtype="bf16"):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.math_dtype_str = math_dtype
        self.state_dtype_str = state_dtype
        self.block_states = {}

    def _init_groups_and_states(self):
        for group in self.param_groups:
            for p in group["params"]:
                # Determine storage dtype for momentum
                resolved_state = self.state_dtype_str
                if resolved_state == "fp8e4" and _DTYPE_MAP["fp8e4"][1] is None:
                    resolved_state = "bf16"
                self.state[p]["exp_avg"] = torch.zeros_like(p, dtype=_DTYPE_MAP[resolved_state][1])
                self.state[p]["amax_m"] = torch.zeros(1, device=p.device, dtype=torch.float32)
                self.block_states[id(p)] = {
                    "v_bar": torch.zeros(1, device=p.device, dtype=torch.float32),
                    "step": 0,
                    "state_is_fp8": (resolved_state == "fp8e4")
                }
        self.state["initialized"] = True

    @torch.no_grad()
    def step(self, closure=None):
        if not self.state.get("initialized", False):
            self._init_groups_and_states()
        for group in self.param_groups:
            beta1, beta2 = group["betas"]
            lr, eps, wd = group["lr"], group["eps"], group["weight_decay"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]
                block = self.block_states[id(p)]

                # Update step and v_bar
                block["step"] += 1
                step = block["step"]
                g2 = (grad.to(torch.float32) ** 2).mean()
                v_bar = block["v_bar"]
                v_bar.mul_(beta2).add_(g2, alpha=1 - beta2)

                # Bias corrections
                bc1 = 1.0 - beta1 ** step
                bc2 = 1.0 - beta2 ** step

                # Step scale for this block
                denom = (v_bar / bc2).sqrt() + eps
                step_scale = (lr / denom).to(torch.bfloat16)
                step_scale_val = float(step_scale)

                m_buffer = state["exp_avg"]
                amax_m = state["amax_m"].zero_()

                # Launch Triton kernel (one tile per block of size 1024)
                _update_kernel[(triton.cdiv(p.numel(), 1024),)](
                    p, grad, m_buffer, amax_m,
                    lr, beta1, wd, 1.0 / bc1,
                    step_scale, p.numel(),
                    BLOCK_SIZE=1024,
                    T_MATH=_DTYPE_MAP[self.math_dtype_str][0],
                    T_STATE=_DTYPE_MAP[self.state_dtype_str][0],
                    STATE_IS_FP8=block["state_is_fp8"]
                )

                # FP8 overflow check: switch to BF16 if needed
                if amax_m.item() > 400.0 and block["state_is_fp8"]:
                    warn_once("FP8 momentum saturated; switching block to BF16.")
                    block["state_is_fp8"] = False
                    state["exp_avg"] = state["exp_avg"].to(torch.bfloat16)

                # Telemetry
                state.setdefault("telemetry", {
                    "steps": 0,
                    "max_update": [],
                    "max_m": [],
                    "dtype": str(m_buffer.dtype)
                })
                state["telemetry"]["steps"] += 1
                # Compute actual max parameter update magnitude, analogous to PyTorch version
                if bc1 != 0:
                    bc1_inv = 1.0 / bc1
                else:
                    bc1_inv = 0.0
                m_fp32 = m_buffer.to(torch.float32)
                max_update_val = (step_scale_val * m_fp32 * bc1_inv).abs().max().item()
                max_m_val = m_fp32.abs().max().item()
                state["telemetry"]["max_update"].append(max_update_val)
                state["telemetry"]["max_m"].append(max_m_val)

# ---------------------------------------------------------------------
# Toy model & parity test
# ---------------------------------------------------------------------
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(16, 16)

    def forward(self, x):
        return self.linear(x)

def parity_check(optA_cls, optB_cls, steps=5):
    torch.manual_seed(0)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    modelA = SimpleModel().to(device)
    modelB = SimpleModel().to(device)
    modelB.load_state_dict(modelA.state_dict())  # sync weights

    optA = optA_cls(modelA.parameters(), lr=1e-3)
    optB = optB_cls(modelB.parameters(), lr=1e-3)

    loss_fn = nn.MSELoss()

    for step in range(steps):
        data = torch.randn(8, 16, device=device)
        target = torch.randn(8, 16, device=device)

        # Run model A
        outA = modelA(data)
        lossA = loss_fn(outA, target)
        optA.zero_grad()
        lossA.backward()
        optA.step()

        # Run model B
        outB = modelB(data)
        lossB = loss_fn(outB, target)
        optB.zero_grad()
        lossB.backward()
        optB.step()

    # Compare weights and telemetry
    for (nA, pA), (nB, pB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        diff = (pA - pB).abs().max().item()
        print(f"Param {nA}: max diff {diff:.6e}", "\n")
        # Retrieve telemetry from each optimizer
        telA = optA.state[pA]["telemetry"]
        telB = optB.state[pB]["telemetry"]
        print(f"  [Telemetry] Steps={telA['steps']}, max_update_A={telA['max_update']}, max_m_A={telA['max_m']}, dtype_A={telA['dtype']}")
        print(f"              Steps={telB['steps']}, max_update_B={telB['max_update']}, max_m_B={telB['max_m']}, dtype_B={telB['dtype']}")
        print("\n\n")

# ---------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------
if __name__ == "__main__":
    print("Running parity test between AdamMiniPyTorch and AdamMiniTriton...\n")
    parity_check(AdamMiniPyTorch, AdamMiniTriton)


Running parity test between AdamMiniPyTorch and AdamMiniTriton...

Param linear.weight: max diff 2.246886e-03 

  [Telemetry] Steps=5, max_update_A=[0.004441387485712767, 0.0023785014636814594, 0.00216480134986341, 0.0015761416871100664, 0.0015192623250186443], max_m_A=[0.02251160517334938, 0.026076965034008026, 0.03351600840687752, 0.029425164684653282, 0.03513770550489426], dtype_A=torch.float32
              Steps=5, max_update_B=[0.004538297653198242, 0.002416654722765088, 0.002185906982049346, 0.001596447778865695, 0.0015406586462631822], max_m_B=[0.02294921875, 0.0264892578125, 0.033935546875, 0.02978515625, 0.03564453125], dtype_B=torch.bfloat16



Param linear.bias: max diff 1.433402e-03 

  [Telemetry] Steps=5, max_update_A=[0.002155275084078312, 0.001265847124159336, 0.001437773578800261, 0.0010939586209133267, 0.0008378218044526875], max_m_A=[0.009643001481890678, 0.015522426925599575, 0.023487450554966927, 0.022799750789999962, 0.021312423050403595], dtype_A=torch.float32
 

In [None]:
!pip install adam-mini

# === START OF FILE ===
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import torch.optim as optim
import time
from typing import Callable, List, Tuple, Dict, Optional, Iterable
import warnings
from collections import defaultdict
import math

# ======================================================================
# Setup: Triton import (optional) and dtype map
# ======================================================================
try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print('Triton is not installed. Fused kernel will not be available.')

_DTYPE_MAP = {}
if HAS_TRITON:
    _DTYPE_MAP = {
        "bf16": (tl.bfloat16, torch.bfloat16), "fp16": (tl.float16, torch.float16),
        "fp32": (tl.float32, torch.float32), "fp8e4": (tl.float8e4nv, getattr(torch, 'float8e4m3fn', None))
    }

_warn_once_cache = set()
def warn_once(msg: str):
    if msg not in _warn_once_cache:
        warnings.warn(msg); _warn_once_cache.add(msg)

# ======================================================================
# Feature: Hessian-Aware Partitioner for Transformers
# ======================================================================
def adam_mini_transformer_partition(model: nn.Module) -> List[Tuple[str, List[Tuple[str, torch.Tensor]], Dict]]:
    # --- *** THE FIX IS HERE: Initialize the 'partitions' dictionary *** ---
    partitions = defaultdict(list)
    param_to_group_map = {}

    for group in model.param_groups:
        for p in group['params']:
            if p.requires_grad: param_to_group_map[p] = group

    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        low_name = name.lower()
        if 'bias' in low_name or 'norm' in low_name or 'ln' in low_name: partitions['bias_and_norm'].append((name, p))
        elif any(key in low_name for key in {"embed", "wte", "wpe", "lm_head", "output"}): partitions[f'embd_output_{name}'] = [(name, p)]
        elif any(key in low_name for key in {"k_proj", "q_proj", "wk", "wq"}): partitions[f'q_and_k_{name}'] = [(name, p)]
        elif any(key in low_name for key in {"v_proj", "wv", "o_proj", "wo"}): partitions[f'v_and_proj_{name}'] = [(name, p)]
        elif 'mlp' in low_name: partitions[f'mlp_{name}'] = [(name, p)]
        else: partitions[f'other_{name}'] = [(name, p)]

    final_blocks = []
    for block_key, params_list in partitions.items():
        if not params_list: continue
        first_param_group = param_to_group_map[params_list[0][1]]
        final_blocks.append((block_key, params_list, first_param_group))

    print(f"Hessian-aware partitioner created {len(final_blocks)} blocks.")
    return final_blocks

# ======================================================================
# Triton Kernels (Unchanged)
# ======================================================================
if HAS_TRITON:
    @triton.jit
    def _reduce_g2_kernel_parallel(g_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
        pid = tl.program_id(axis=0); offset = pid * BLOCK_SIZE
        idx = offset + tl.arange(0, BLOCK_SIZE); mask = idx < n_elements
        g = tl.load(g_ptr + idx, mask=mask, other=0.0).to(tl.float32)
        sum_g2 = tl.sum(g * g, axis=0)
        tl.atomic_add(output_ptr, sum_g2)

    @triton.autotune(
        configs=[triton.Config({'BLOCK_SIZE': 1024}, num_warps=8), triton.Config({'BLOCK_SIZE': 2048}, num_warps=8)],
        key=['n_elements'],
    )
    @triton.jit
    def _update_kernel(
        p_ptr, g_ptr, m_ptr, amax_m_ptr, lr_p, beta1_p, weight_decay_p, bias_correction1_p,
        step_scale_ptr, state_scale_p, n_elements,
        BLOCK_SIZE: tl.constexpr, T_MATH: tl.constexpr, T_STATE: tl.constexpr, STATE_IS_FP8: tl.constexpr,
    ):
        pid = tl.program_id(axis=0); offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE); mask = offsets < n_elements
        p = tl.load(p_ptr + offsets, mask=mask).to(T_MATH); g = tl.load(g_ptr + offsets, mask=mask).to(T_MATH)
        if STATE_IS_FP8: m_scaled = tl.load(m_ptr + offsets, mask=mask).to(T_STATE).to(T_MATH)
        else: m_scaled = tl.load(m_ptr + offsets, mask=mask).to(T_MATH)
        lr = tl.full((), lr_p, T_MATH); beta1 = tl.full((), beta1_p, T_MATH); weight_decay = tl.full((), weight_decay_p, T_MATH)
        bias_correction1 = tl.full((), bias_correction1_p, T_MATH); state_scale = tl.full((), state_scale_p, T_MATH)
        step_scale = tl.load(step_scale_ptr).to(T_MATH)
        if weight_decay != 0: p = p * (1 - lr * weight_decay)
        m_unscaled = m_scaled / state_scale; m_new_unscaled = beta1 * m_unscaled + (1 - beta1) * g
        m_hat = m_new_unscaled * bias_correction1; p_new = p - step_scale * m_hat
        max_m_val = tl.max(tl.abs(m_new_unscaled.to(tl.float32))); tl.atomic_max(amax_m_ptr, max_m_val)
        tl.store(p_ptr + offsets, p_new.to(p_ptr.dtype.element_ty), mask=mask)
        if STATE_IS_FP8: tl.store(m_ptr + offsets, (m_new_unscaled * state_scale).to(T_STATE), mask=mask)
        else: tl.store(m_ptr + offsets, m_new_unscaled.to(m_ptr.dtype.element_ty), mask=mask)

# ======================================================================
# Final Optimizer Class
# ======================================================================
class AdamMiniTriton(Optimizer):
    def __init__(self, params, lr: float = 1e-3, betas: tuple = (0.9, 0.999), eps: float = 1e-6,
                 weight_decay: float = 0.0, math_dtype: str = "bf16", state_dtype: str = "bf16",
                 partition_fn: Optional[Callable] = None, model_for_partitioning: Optional[nn.Module] = None):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.math_dtype_str = math_dtype; self.state_dtype_str = state_dtype
        self.partition_fn = partition_fn; self.model_for_partitioning = model_for_partitioning
        self.block_states = {}; self.telemetry = defaultdict(list); self._blocks_built = False

    def get_telemetry(self): return dict(self.telemetry)

    def _resolve_state_dtype(self, requested_dtype, device):
        if "fp8" in requested_dtype:
            if _DTYPE_MAP[requested_dtype][1] is None: warn_once(f"PyTorch version lacks {requested_dtype}. Falling back to bf16."); return "bf16"
            major, _ = torch.cuda.get_device_capability(device)
            if major < 9: warn_once("FP8 requires SM90+; falling back to bf16 state."); return "bf16"
        return requested_dtype

    def _init_groups_and_states(self):
        if self.partition_fn:
            if self.model_for_partitioning is None: raise ValueError("model_for_partitioning must be provided for partition_fn.")
            # Ensure the model has access to the optimizer's param_groups
            self.model_for_partitioning.param_groups = self.param_groups
            self.param_blocks = self.partition_fn(self.model_for_partitioning)
        else:
            self.param_blocks = []
            for group in self.param_groups:
                params_with_grad = [p for p in group['params'] if p.grad is not None]
                if params_with_grad: self.param_blocks.append((f"group_{id(group)}", [(p.name if hasattr(p, 'name') else f'param_{i}', p) for i, p in enumerate(params_with_grad)], group))

        for block_name, block_params, group in self.param_blocks:
            first_param = block_params[0][1]; device = first_param.device
            resolved_state_dtype_str = self._resolve_state_dtype(self.state_dtype_str, device)
            _, torch_state = _DTYPE_MAP[resolved_state_dtype_str]; is_fp8 = "fp8" in resolved_state_dtype_str
            self.block_states[block_name] = {'v_bar_fp32': torch.zeros(1, device=device, dtype=torch.float32), 'step': 0,
                                            'state_is_fp8': is_fp8, 'consecutive_soft_breaches': 0,
                                            'state_scale': torch.ones(1, device=device, dtype=torch.float32)}
            for _, p in block_params:
                self.state[p]['exp_avg'] = torch.zeros_like(p, dtype=torch_state)
                self.state[p]['amax_m'] = torch.zeros(1, device=device, dtype=torch.float32)
        self._blocks_built = True

    @torch.no_grad()
    def step(self, closure=None):
        if not self._blocks_built: self._init_groups_and_states()
        self.telemetry.clear()

        for block_name, block_params, group in self.param_blocks:
            block_state = self.block_states[block_name]
            beta1, beta2 = group['betas']; lr, eps, wd = group['lr'], group['eps'], group['weight_decay']
            block_state['step'] += 1; step = block_state['step']
            first_param = block_params[0][1]
            sum_g2 = torch.zeros(1, device=first_param.device, dtype=torch.float32); total_params_in_block = 0
            for _, p in block_params:
                if p.grad is None: continue
                if not p.grad.is_leaf: warn_once("A gradient is not a leaf tensor. If using GradScaler, call scaler.unscale_(optimizer) before step().")
                total_params_in_block += p.numel()
                grid = (triton.cdiv(p.numel(), 1024),); _reduce_g2_kernel_parallel[grid](p.grad, sum_g2, p.numel(), BLOCK_SIZE=1024)
            if total_params_in_block == 0: continue

            mean_g2 = sum_g2 / total_params_in_block; v_bar = block_state['v_bar_fp32']; v_bar.mul_(beta2).add_(mean_g2, alpha=1.0 - beta2)
            bias_correction1 = 1.0 - beta1 ** step; bias_correction2 = 1.0 - beta2 ** step
            v_bar_hat = v_bar / bias_correction2; denom = torch.sqrt(v_bar_hat) + eps
            step_scale = (lr / denom).to(torch.bfloat16)

            block_amax_m = 0.0; state_scale = block_state['state_scale']
            for _, p in block_params:
                if p.grad is None: continue
                param_state = self.state[p]; m_buffer = param_state['exp_avg']; amax_m_buffer = param_state['amax_m'].zero_();
                current_state_is_fp8 = block_state['state_is_fp8']
                if current_state_is_fp8:
                    current_state_dtype = 'fp8e4'
                    if m_buffer.dtype != _DTYPE_MAP['fp8e4'][1]: m_buffer = param_state['exp_avg'] = m_buffer.to(_DTYPE_MAP['fp8e4'][1])
                else: current_state_dtype = 'bf16';
                if m_buffer.dtype != torch.bfloat16: m_buffer = param_state['exp_avg'] = m_buffer.to(torch.bfloat16)

                _update_kernel[(triton.cdiv(p.numel(), 1024),)](
                    p, p.grad, m_buffer, amax_m_buffer, lr, beta1, wd, 1.0 / bias_correction1, step_scale, state_scale.item(), p.numel(),
                    T_MATH=_DTYPE_MAP[self.math_dtype_str][0], T_STATE=_DTYPE_MAP[current_state_dtype][0], STATE_IS_FP8=current_state_is_fp8)

                current_amax = amax_m_buffer.item();
                if current_amax > block_amax_m: block_amax_m = current_amax

            if block_state['state_is_fp8']:
                target_range = 120.0
                new_scale = target_range / max(block_amax_m, 1e-6)
                state_scale.mul_(0.9).add_(new_scale, alpha=0.1)

            self.telemetry[block_name].append({'amax_m': block_amax_m, 'state_dtype': 'fp8' if block_state['state_is_fp8'] else 'bf16', 'scale': state_scale.item()})

# ======================================================================
# Main Execution Block
# ======================================================================
if __name__ == '__main__':
    if not torch.cuda.is_available(): print("This benchmark requires a CUDA-enabled environment."); exit()
    try:
        from adam_mini import Adam_mini as AdamMiniOfficial
    except ImportError:
        print("Official Adam-mini not found. Please `pip install adam-mini` to run the full benchmark.")
        AdamMiniOfficial = None

    device = "cuda"; dtype = torch.bfloat16

    class SimpleTransformer(nn.Module):
        def __init__(self, dim=1024, n_heads=16, mlp_dim=4096):
            super().__init__(); self.dim = dim; self.n_heads = n_heads
            self.embed = nn.Linear(512, dim); self.q_proj = nn.Linear(dim, dim); self.k_proj = nn.Linear(dim, dim); self.v_proj = nn.Linear(dim, dim)
            self.o_proj = nn.Linear(dim, dim); self.mlp_1 = nn.Linear(dim, mlp_dim); self.mlp_2 = nn.Linear(mlp_dim, dim)
            self.norm1 = nn.LayerNorm(dim); self.norm2 = nn.LayerNorm(dim)
            self.lm_head = nn.Linear(dim, 512)
        def forward(self, x):
            x = self.embed(x)
            x = x + self.o_proj(self.q_proj(self.norm1(x)) + self.k_proj(self.norm1(x)) + self.v_proj(self.norm1(x)))
            x = x + self.mlp_2(torch.nn.functional.relu(self.mlp_1(self.norm2(x))))
            return self.lm_head(x)

    model_definition = lambda: SimpleTransformer().to(device).to(dtype)

    def benchmark_optimizer(optimizer_class, name, steps=5000, **kwargs):
        model = model_definition()

        if name == "Adam-mini (Official)":
            official_kwargs = {**kwargs, 'dim': model.dim, 'n_heads': model.n_heads}
            optimizer = optimizer_class(model.named_parameters(), **official_kwargs)
        elif name.startswith("Adam-mini (Triton"):
            optimizer = optimizer_class(model.parameters(), model_for_partitioning=model, **kwargs)
        else:
            optimizer = optimizer_class(model.parameters(), **kwargs)

        print(f"\n--- Benchmarking {name} ---")
        for _ in range(10):
            inp = torch.randn(64, 512, device=device, dtype=dtype)
            optimizer.zero_grad(set_to_none=True); loss = model(inp).sum(); loss.backward(); optimizer.step()

        torch.cuda.reset_peak_memory_stats(device); torch.cuda.synchronize(device)
        start_time = time.time()
        for i in range(steps):
            inp = torch.randn(64, 512, device=device, dtype=dtype)
            optimizer.zero_grad(set_to_none=True); loss = model(inp).sum(); loss.backward(); optimizer.step()
        torch.cuda.synchronize(device); end_time = time.time()

        peak_memory = torch.cuda.max_memory_allocated(device) / 1e6; total_time = end_time - start_time
        print(f"Total Time ({steps} steps): {total_time:.3f} s | Peak Memory: {peak_memory:.2f} MB")

    hyperparams = {'lr': 1e-3, 'eps': 1e-6, 'betas': (0.9, 0.999)}
    print("\n\n" + "="*50); print(" " * 10 + "Running Full Optimizer Benchmark")
    print("="*50)

    benchmark_optimizer(torch.optim.AdamW, "AdamW (PyTorch)", **hyperparams)

    if AdamMiniOfficial:
        benchmark_optimizer(AdamMiniOfficial, "Adam-mini (Official)", **hyperparams)

    benchmark_optimizer(AdamMiniTriton, "Adam-mini (Triton, BF16 State)", state_dtype="bf16", partition_fn=adam_mini_transformer_partition, **hyperparams)
    benchmark_optimizer(AdamMiniTriton, "Adam-mini (Triton, FP8 State)", state_dtype="fp8e4", partition_fn=adam_mini_transformer_partition, **hyperparams)

Collecting adam-mini
  Downloading adam_mini-1.1.1-py3-none-any.whl.metadata (2.9 kB)
Downloading adam_mini-1.1.1-py3-none-any.whl (13 kB)
Installing collected packages: adam-mini
Successfully installed adam-mini-1.1.1


          Running Full Optimizer Benchmark

--- Benchmarking AdamW (PyTorch) ---
Total Time (5000 steps): 31.564 s | Peak Memory: 153.57 MB
Adam-mini found the param block with name: embed.weight torch.Size([1024, 512])
Adam-mini found the param block with name: embed.bias torch.Size([1024])
Adam-mini found the param block with name: q_proj.weight torch.Size([1024, 1024])
Adam-mini found the param block with name: q_proj.bias torch.Size([1024])
Adam-mini found the param block with name: k_proj.weight torch.Size([1024, 1024])
Adam-mini found the param block with name: k_proj.bias torch.Size([1024])
Adam-mini found the param block with name: v_proj.weight torch.Size([1024, 1024])
Adam-mini found the param block with name: v_proj.bias torch.Size([1024])
Adam-mini found th



Total Time (5000 steps): 46.086 s | Peak Memory: 99.13 MB
