In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SelfModifyingMLP(nn.Module):
    """
    Implements W' = W + u(h) @ v(h)^T (rank-1 update)
    x:        [B, D]
    fast_h:   [B, D]   (fast context from CMS or attention)
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        # base slow weight
        self.base = nn.Linear(dim, dim)

        # fast update projections
        self.to_u = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)

    def forward(self, x, fast_h):
        """
        x:        [B, D]
        fast_h:   [B, D]   (the fast context signal)
        """
        # compute low-rank update vectors
        u = self.to_u(fast_h)      # [B, D]
        v = self.to_v(fast_h)      # [B, D]

        # average updates across batch (simple, stable)
        u_mean = u.mean(dim=0)     # [D]
        v_mean = v.mean(dim=0)     # [D]

        # outer product gives ΔW  (rank-1)
        delta = torch.ger(u_mean, v_mean)   # [D, D]

        # apply the modified weight: W' = W + ΔW
        W_mod = self.base.weight + delta

        out = F.linear(x, W_mod, self.base.bias)
        return out, delta.detach()   # return ΔW for inspection


Paper-faithful “Rank-k” version (faster & more expressive). 
The real model uses rank-k updates, typically k = 4 or 8.

In [3]:
class SelfModifyingMLP_RankK(nn.Module):
    """
    Rank-k self modifying MLP:
    ΔW = sum_i u_i ⊗ v_i
    """
    def __init__(self, dim, rank=4):
        super().__init__()
        self.dim = dim
        self.rank = rank

        self.base = nn.Linear(dim, dim)

        # project to rank*k vector → reshape into (k, dim)
        self.to_u = nn.Linear(dim, dim * rank, bias=False)
        self.to_v = nn.Linear(dim, dim * rank, bias=False)

    def forward(self, x, fast_h):
        B = x.size(0)

        u = self.to_u(fast_h).view(B, self.rank, self.dim)
        v = self.to_v(fast_h).view(B, self.rank, self.dim)

        # batch average
        u_mean = u.mean(dim=0)    # [k, D]
        v_mean = v.mean(dim=0)    # [k, D]

        # compute ΔW
        delta = torch.einsum("kd,ke->de", u_mean, v_mean)

        W_mod = self.base.weight + delta
        return F.linear(x, W_mod, self.base.bias), delta.detach()
