In [1]:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import einsum, pack, rearrange, reduce, repeat, unpack
from einops.layers.torch import Rearrange
from tensordict import TensorDict
import triton
import triton.language as tl

from atlas import ResLinear, SlidingWindowAttention, LinearProjection, AdaptiveWeight

In [2]:
dataset = load_dataset("tiny_shakespeare")["train"]["text"][0]
chars = sorted(list(set(dataset)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}


class TSDataset(Dataset):
    def __init__(self, text, seq_len=32):
        self.data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        input_ids = self.data[idx : idx + self.seq_len]
        labels = self.data[idx + 1 : idx + self.seq_len + 1]
        return input_ids, labels


toy_dataset = TSDataset(dataset)
dataloader = DataLoader(toy_dataset, batch_size=16, shuffle=True)

In [None]:
# class NeuralMemory(nn.Module):

#     def __init__(
#         self,
#         layer_size: int,
#         input_dim: int,
#         n_hidden_layers: int,
#         learning_rate: float,
#         weight_decay: float,
#         momentum: float,
#         max_adaptive_lr: float,
#         meta_memory_dim: int,
#         num_attention_heads: int,
#         attention_window_size: int,
#         chunk_size: int,
#     ) -> None:
#         # TODO: add momentum, past surprises, associative scan

#         # DONE: add SWA
#         # DONE: add persistent memory
#         # DONE: add adaptive learning rate
#         # DONE: add chunking
#         # DONE: pad inputs and replace n_chunks by chunk_size
#         # DONE: vectorize the loss

#         super(NeuralMemory, self).__init__()
#         self.input_dim = input_dim
#         self.layer_size = layer_size
#         self.learning_rate = learning_rate
#         self.weight_decay = weight_decay
#         self.momentum = momentum
#         self.max_adaptive_lr = max_adaptive_lr
#         self.meta_memory_dim = meta_memory_dim
#         self.num_attention_heads = num_attention_heads
#         self.attention_window_size = attention_window_size
#         self.chunk_size = chunk_size
#         self.added_padding = 0

#         self.lmm = ResLinear(input_dim, n_hidden_layers)
#         self.key_projection = LinearProjection(input_dim, layer_size, chunk_size)
#         self.query_projection = LinearProjection(input_dim, layer_size, 1)
#         self.value_projection = LinearProjection(input_dim, layer_size, chunk_size)
#         self.adaptive_lr_projection = AdaptiveLR(
#             input_dim, 1, chunk_size, max_adaptive_lr
#         )
#         self.meta_memory = nn.Parameter(torch.randn(meta_memory_dim, input_dim))

#         self.optimizer = torch.optim.AdamW(
#             self.lmm.parameters(), learning_rate, weight_decay=weight_decay
#         )
#         self.swa = SlidingWindowAttention(
#             input_dim, num_attention_heads, attention_window_size
#         )
#         self.register_buffer("surprises", torch.zeros((input_dim, input_dim)))

#     def _pad_to_chunk_size(self, x: torch.Tensor) -> torch.Tensor:
#         seq_len = x.size(1)
#         if not seq_len % self.chunk_size == 0:
#             pad = (
#                 (seq_len // self.chunk_size) * self.chunk_size
#                 + self.chunk_size
#                 - seq_len
#             )
#             self.added_padding = pad
#             x = F.pad(x, (0, 0, 0, pad))
#         return x

#     def _associative_loss(self, params, inputs, targets, weights) -> float:
#         preds = functional_call(self.lmm, params, inputs)
#         loss = torch.pow(preds - targets, 2).mean(dim=-1)
#         weighted_loss = loss * weights.squeeze()
#         print("LOSS:", preds.shape, weights.shape, targets.shape, loss.shape)
#         return weighted_loss.sum(), loss

#     def _inject_meta_memory(self, x: torch.Tensor) -> torch.Tensor:
#         batch_size = x.size(0)
#         meta_memory = self.meta_memory.expand(batch_size, -1, -1)
#         meta_x = torch.concat([meta_memory, x], dim=1)
#         return meta_x

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         self.optimizer.zero_grad()
#         params = self.lmm.named_parameters()

#         x = self._inject_meta_memory(x)
#         x = self._pad_to_chunk_size(x)

#         queries = self.query_projection(x)
#         keys = self.key_projection(x)
#         values = self.value_projection(x)
#         adaptive_lr = self.adaptive_lr_projection(x)

#         grad_fn = grad(self._associative_loss, has_aux=True)
#         per_chunk_grad_fn = vmap(grad_fn, in_dims=(None, 2, 2, 2))
#         print("PRE LOSS: ", keys.shape, adaptive_lr.shape, values.shape)
#         grads, _ = per_chunk_grad_fn(dict(params), keys, values, adaptive_lr)

#         grads = TensorDict(grads).apply(lambda g: g.mean(0) if g.ndim == 3 else g)
#         print(grads)
#         surprises = grads.mul(-1)  # TODO: store past surprises

#         for name, param in self.lmm.named_parameters():
#             if grads.get(name) is not None:
#                 param.grad = grads.get(name)
#         self.optimizer.step()

#         retrieved = self.lmm(queries)
#         retrieved = self.swa(retrieved)

#         output = retrieved[
#             :, self.meta_memory_dim : -self.added_padding, :
#         ]  # discard meta-memory and padding

#         return output

In [5]:
class NeuralMemory(nn.Module):

    def __init__(
        self,
        layer_size: int,
        input_dim: int,
        n_layers: int,
        learning_rate: float,
        weight_decay: float,
        max_adaptive_lr: float,
        max_momentum: float,
        meta_memory_dim: int,
        num_attention_heads: int,
        attention_window_size: int,
        n_chunks: int,
    ) -> None:
        # TODO: implement associative scan kernel
        # TODO: add learned gating

        # DONE: add SWA
        # DONE: add persistent memory
        # DONE: add adaptive learning rate
        # DONE: add chunking
        # DONE: pad inputs and replace n_chunks by chunk_size
        # DONE: vectorize the loss
        # DONE: add momentum, past surprises

        super(NeuralMemory, self).__init__()
        self.input_dim = input_dim
        self.layer_size = layer_size
        self.n_layers = n_layers
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_adaptive_lr = max_adaptive_lr
        self.max_momentum = max_momentum
        self.meta_memory_dim = meta_memory_dim
        self.num_attention_heads = num_attention_heads
        self.attention_window_size = attention_window_size
        self.n_chunks = n_chunks
        self.added_padding = None

        self.lmm = ResLinear(input_dim, n_layers)
        self.key_projection = LinearProjection(input_dim, layer_size, n_chunks)
        self.query_projection = LinearProjection(input_dim, layer_size, 1)
        self.value_projection = LinearProjection(input_dim, layer_size, n_chunks)
        self.adaptive_lr_projection = AdaptiveWeight(
            input_dim, 1, n_chunks, max_adaptive_lr
        )
        self.adaptive_momentum_projection = AdaptiveWeight(
            input_dim, 1, n_chunks, max_weight=1.0
        )
        self.meta_memory = nn.Parameter(torch.randn(meta_memory_dim, input_dim))

        self.optimizer = torch.optim.AdamW(
            self.lmm.parameters(), learning_rate, weight_decay=weight_decay
        )
        self.swa = SlidingWindowAttention(
            input_dim, num_attention_heads, attention_window_size
        )

        self.register_buffer(
            "surprises",
            torch.zeros((self.n_layers, self.n_chunks, self.input_dim, self.input_dim)),
        )

    def _pad_to_chunk_size(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        print(seq_len, self.n_chunks)
        if not seq_len % self.n_chunks == 0:
            pad = (seq_len // self.n_chunks) * self.n_chunks + self.n_chunks - seq_len
            self.added_padding = -pad
            x = F.pad(x, (0, 0, 0, pad))
        return x

    def _inject_meta_memory(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        meta_memory = self.meta_memory.expand(batch_size, -1, -1)
        meta_x = torch.concat([meta_memory, x], dim=1)
        return meta_x

    def _associative_memory_loss(self, params, inputs, targets, weights) -> float:
        preds = functional_call(self.lmm, params, inputs)
        loss = torch.pow(preds - targets, 2).mean(dim=-1)
        weighted_loss = loss * weights.squeeze()
        return weighted_loss.sum(), loss

    @torch.no_grad
    def _compute_surprises(
        self, theta_t: torch.Tensor, eta_t, per_sample_grads: TensorDict
    ) -> None:
        eta_prod = torch.cumprod(eta_t, dim=1).flip(1)
        print(eta_prod)
        for layer_idx in range(self.n_layers):
            layer_id = f"weights.{layer_idx}"
            if layer_id in per_sample_grads:
                u_t = per_sample_grads[layer_id]  # (n_chunks, input_dim, input_dim)
                # Vectorized S_t = -sum(theta_t * u_t * eta_prod)
                theta_u = (
                    theta_t.unsqueeze(-1).unsqueeze(-1) * u_t
                )  # (batch_size, n_chunks, input_dim, input_dim)
                weighted_theta_u = theta_u * eta_prod.unsqueeze(-1).unsqueeze(-1)
                S_t = -torch.cumsum(
                    weighted_theta_u.mean(dim=0), dim=0
                )  # (n_chunks, input_dim, input_dim)
                self.surprises[layer_idx] = S_t

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.optimizer.zero_grad()
        params = self.lmm.named_parameters()

        x = self._inject_meta_memory(x)
        x = self._pad_to_chunk_size(x)

        queries = self.query_projection(x)
        keys = self.key_projection(x)
        values = self.value_projection(x)
        adaptive_lr = self.adaptive_lr_projection(x)
        adaptive_momentum = self.adaptive_momentum_projection(x)

        grad_fn = grad(self._associative_memory_loss, has_aux=True)
        per_chunk_grad_fn = vmap(grad_fn, in_dims=(None, 2, 2, 2))
        per_sample_grads, _ = per_chunk_grad_fn(dict(params), keys, values, adaptive_lr)
        per_sample_grads = TensorDict(per_sample_grads)

        theta_t = adaptive_lr.mean(dim=1).squeeze(-1)
        eta_t = adaptive_momentum.mean(dim=1).squeeze(-1)
        self._compute_surprises(theta_t, eta_t, per_sample_grads)

        for idx, (name, param) in enumerate(self.lmm.named_parameters()):
            if per_sample_grads.get(name) is not None:
                param.grad = (per_sample_grads.get(name) + self.surprises[idx]).mean(0)
        self.optimizer.step()

        retrieved = self.lmm(queries.squeeze())
        retrieved = self.swa(retrieved)

        # discard meta-memory and padding
        output = retrieved[:, self.meta_memory_dim : self.added_padding, :]

        return output

In [6]:
class LayerWithNeuralMemory(nn.Module):
    def __init__(self, original_layer, mal):
        super().__init__()
        self.original_layer = original_layer
        self.mal = mal

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        output = self.original_layer(
            hidden_states, attention_mask=attention_mask, **kwargs
        )
        attn_output = output[0]
        print(attn_output.shape)
        mal_output = self.mal(attn_output)
        return (mal_output,) + output[1:]


config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

mal_params = {
    "layer_size": 128,
    "n_layers": 3,
    "meta_memory_dim": 16,
    "input_dim": 128,
    "learning_rate": 4e-4,
    "weight_decay": 0.1,
    "max_momentum": 0.9,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

model = LlamaForCausalLM(config)
mal = NeuralMemory(**mal_params)


original_layer = model.model.layers[-2]
model.model.layers[-2] = LayerWithNeuralMemory(original_layer, mal)

batch_size = 4
seq_len = 64

input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

torch.Size([4, 64, 128])
80 10
tensor([[0.0014, 0.0027, 0.0053, 0.0097, 0.0197, 0.0390, 0.0728, 0.1453, 0.2627,
         0.5014],
        [0.0013, 0.0026, 0.0051, 0.0094, 0.0193, 0.0383, 0.0716, 0.1443, 0.2614,
         0.5021],
        [0.0011, 0.0022, 0.0044, 0.0083, 0.0173, 0.0351, 0.0668, 0.1371, 0.2526,
         0.4925],
        [0.0012, 0.0024, 0.0047, 0.0087, 0.0180, 0.0361, 0.0683, 0.1388, 0.2543,
         0.4946]])
Loss with MAL: 10.41256332397461


In [None]:
@triton.jit
def associative_scan_kernel(
    eta_pointer,
    theta_u_pointer,
    output_pointer,
    batch_size,
    n_chunks,
    input_dim,
    stride_eta_b,
    stride_eta_c,
    stride_theta_b,
    stride_theta_c,
    stride_theta_h,
    stride_out_b,
    stride_out_c,
    stride_out_h,
    stride_out_w,
    BLOCK_SIZE: tl.constexpr,
):
    # Program ID: One block per batch and input_dim pair
    pid = tl.program_id(0)
    batch_idx = pid // input_dim
    h_idx = pid % input_dim  # Row of weight matrix

    if batch_idx >= batch_size or h_idx >= input_dim:
        return

    # Base pointers
    eta_base = eta_pointer + batch_idx * stride_eta_b
    theta_u_base = theta_u_pointer + batch_idx * stride_theta_b + h_idx * stride_theta_h
    output_base = output_pointer + batch_idx * stride_out_b + h_idx * stride_out_h

    # Load eta and theta_u for all chunks
    chunk_range = tl.arange(0, BLOCK_SIZE)
    mask = chunk_range < n_chunks
    eta = tl.load(eta_base + chunk_range * stride_eta_c, mask=mask, other=1.0)
    theta_u = tl.load(theta_u_base + chunk_range * stride_theta_c, mask=mask, other=0.0)

    # Initialize S_t
    S_t = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

    # Associative scan: Sequential within block, parallel across blocks
    for t in range(n_chunks):
        S_t_prev = S_t
        S_t = eta[t] * S_t_prev - theta_u[t]
        # Store S_t for this chunk
        tl.store(
            output_base + t * stride_out_c + chunk_range * stride_out_w, S_t, mask=mask
        )

In [None]:
associative_scan_kernel[grid](
    momentum,
    adaptive_grads,
    output,
    batch_size,
    mal.n_chunks,
    mal.input_dim,
    momentum_strides[0],
    momentum_strides[1],
    ag_strides[0],
    ag_strides[1],
    ag_strides[2],
    output_strides[0],
    output_strides[1],
    output_strides[2],
    output_strides[3],
    BLOCK_SIZE=mal.n_chunks,  # Match n_chunks for simplicity
)

TypeError: unsupported operand type(s) for %: 'TensorDict' and 'int'