In [10]:
import torch
import torch.nn as nn
import math
from typing import Iterable, Optional, Dict

Memory Modules

In [11]:
class GDMemory:
    def __init__(self, lr: float = 1e-2):
        self.lr = float(lr)

    def step(self, param: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        return param - self.lr * grad

In [12]:

class MomentumMemory:
    """Momentum viewed as an associative memory (per-parameter states)."""
    def __init__(self, lr: float = 1e-2, beta: float = 0.9, device=None):
        self.lr = float(lr)
        self.beta = float(beta)
        self.state: Dict[int, torch.Tensor] = {}
        self.device = device

    def step(self, param: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        pid = id(param)
        if pid not in self.state:
            self.state[pid] = torch.zeros_like(grad, device=grad.device)
        m = self.state[pid]
        m = self.beta * m + (1.0 - self.beta) * grad
        self.state[pid] = m
        return param - self.lr * m

In [14]:
class AdamMemory:
    """Adam reinterpreted as multi-level memory (m, v) per param."""
    def __init__(self, lr: float = 1e-3, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8):
        self.lr = float(lr)
        self.b1 = float(b1)
        self.b2 = float(b2)
        self.eps = float(eps)
        self.state: Dict[int, Dict[str, torch.Tensor]] = {}
        self.t: Dict[int, int] = {}

    def step(self, param: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        pid = id(param)
        if pid not in self.state:
            self.state[pid] = {
                'm': torch.zeros_like(grad, device=grad.device),
                'v': torch.zeros_like(grad, device=grad.device)
            }
            self.t[pid] = 0
        s = self.state[pid]
        self.t[pid] += 1
        s['m'] = self.b1 * s['m'] + (1 - self.b1) * grad
        s['v'] = self.b2 * s['v'] + (1 - self.b2) * (grad * grad)
        t = self.t[pid]
        m_hat = s['m'] / (1 - self.b1 ** t)
        v_hat = s['v'] / (1 - self.b2 ** t)
        update = self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
        return param - update

Two NL-inspired extensions from the paper

In [17]:
class DeepMomentumMemory(nn.Module):
    """
    Deep Momentum: the momentum memory is an MLP that maps the incoming gradient
    (or gradient features) to a momentum-like update. This adds representational power.
    This class is a Module so it can be trained jointly if you want (requires hooking into
    a higher-level training loop).
    """
    def __init__(self, dim, hidden=64, lr=1e-2, beta=0.9):
        super().__init__()
        self.lr = float(lr)
        self.beta = float(beta)
        # small MLP to "compress" gradients into a learned momentum signal
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, dim)
        )
        # per-parameter running memory will be stored in python dict keyed by id(param)
        self.state: Dict[int, torch.Tensor] = {}

    def step(self, param: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        """
        grad shape must be (N,) flattened or param-shaped; for simplicity we flatten param & grad,
        pass through mlp (so mlp dim must match flat size) OR you can apply mlp per-element (1D conv).
        Here we'll operate on param.view(-1) for demonstration (small params only).
        """
        pid = id(param)
        flat_grad = grad.view(-1)
        # if mlp input/output dims don't match, we fallback to elementwise transform
        if flat_grad.shape[0] != self.mlp[0].in_features:
            # safe fallback: elementwise 1-layer learned scaler (vector of same shape)
            # create or fetch per-param linear scaling vector stored as part of state (not ideal but simple)
            if pid not in self.state:
                self.state[pid] = torch.zeros_like(flat_grad, device=flat_grad.device)
            # map gradient to momentum via a simple nonlinear transform (elementwise tanh)
            learned = torch.tanh(flat_grad)  # placeholder compress
            m = self.beta * self.state[pid] + (1.0 - self.beta) * learned
            self.state[pid] = m
            new_flat = flat_grad - self.lr * m
            return new_flat.view_as(param)
        else:
            # MLP path: produce learned momentum vector
            g_in = flat_grad.unsqueeze(0)  # [1, N]
            m_pred = self.mlp(g_in).squeeze(0)  # [N]
            if pid not in self.state:
                self.state[pid] = torch.zeros_like(m_pred, device=m_pred.device)
            m = self.beta * self.state[pid] + (1.0 - self.beta) * m_pred
            self.state[pid] = m
            new_flat = flat_grad - self.lr * m
            return new_flat.view_as(param)

In [18]:
class PreconditionedMomentumMemory:
    """
    Momentum memory that applies a learnable or provided preconditioner P to gradients before storing.
    For simplicity we show diagonal preconditioner (vector), but could be matrix or low-rank.
    """
    def __init__(self, lr=1e-2, beta=0.9, preconditioner: Optional[Iterable[float]]=None):
        self.lr = float(lr)
        self.beta = float(beta)
        self.state = {}
        # preconditioner can be a scalar or per-parameter function; here we store a callable or scalar
        self.pre = preconditioner

    def _apply_pre(self, grad: torch.Tensor) -> torch.Tensor:
        if self.pre is None:
            return grad
        # If pre is callable, call with grad; if scalar, multiply
        if callable(self.pre):
            return self.pre(grad)
        else:
            return grad * float(self.pre)

    def step(self, param: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        grad_pre = self._apply_pre(grad)
        pid = id(param)
        if pid not in self.state:
            self.state[pid] = torch.zeros_like(grad_pre)
        m = self.state[pid]
        m = self.beta * m + (1.0 - self.beta) * grad_pre
        self.state[pid] = m
        return param - self.lr * m