In [None]:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import einsum, pack, rearrange, reduce, repeat, unpack
from einops.layers.torch import Rearrange
from tensordict import TensorDict

from atlas import ResLinear, SlidingWindowAttention, LinearProjection, AdaptiveLR

In [None]:
dataset = load_dataset("tiny_shakespeare")["train"]["text"][0]
chars = sorted(list(set(dataset)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}


class TSDataset(Dataset):
    def __init__(self, text, seq_len=32):
        self.data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        input_ids = self.data[idx : idx + self.seq_len]
        labels = self.data[idx + 1 : idx + self.seq_len + 1]
        return input_ids, labels


toy_dataset = TSDataset(dataset)
dataloader = DataLoader(toy_dataset, batch_size=16, shuffle=True)

In [None]:
class NeuralMemory(nn.Module):

    def __init__(
        self,
        layer_size: int,
        input_dim: int,
        n_hidden_layers: int,
        learning_rate: float,
        weight_decay: float,
        max_adaptive_lr: float,
        meta_memory_dim: int,
        num_attention_heads: int,
        attention_window_size: int,
        n_chunks: int,
    ) -> None:
        # TODO: add momentum & past surprises

        # DONE: add SWA
        # DONE: add persistent memory
        # DONE: add adaptive learning rate
        # DONE: add chunking
        # DONE: vectorize the loss

        super(NeuralMemory, self).__init__()
        self.input_dim = input_dim
        self.layer_size = layer_size
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_adaptive_lr = max_adaptive_lr
        self.meta_memory_dim = meta_memory_dim
        self.num_attention_heads = num_attention_heads
        self.attention_window_size = attention_window_size
        self.n_chunks = n_chunks

        self.lmm = ResLinear(input_dim, n_hidden_layers)
        self.key_projection = LinearProjection(input_dim, layer_size, n_chunks)
        self.query_projection = LinearProjection(input_dim, layer_size, 1)
        self.value_projection = LinearProjection(input_dim, layer_size, n_chunks)
        self.adaptive_lr_projection = AdaptiveLR(
            input_dim, 1, n_chunks, max_adaptive_lr
        )
        self.meta_memory = nn.Parameter(torch.randn(meta_memory_dim, input_dim))

        self.optimizer = torch.optim.AdamW(
            self.lmm.parameters(), learning_rate, weight_decay=weight_decay
        )
        self.swa = SlidingWindowAttention(
            input_dim, num_attention_heads, attention_window_size
        )

    def _associative_loss(self, params, inputs, targets, weights) -> float:
        preds = functional_call(self.lmm, params, inputs)
        loss = torch.pow(preds - targets, 2).mean(dim=-1)
        print(inputs.shape, weights.shape, targets.shape, loss.shape)
        weighted_loss = loss * weights.squeeze()
        print(weighted_loss.shape)
        return weighted_loss.sum(), loss

    def _inject_meta_memory(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        meta_memory = self.meta_memory.expand(batch_size, -1, -1)
        meta_x = torch.concat([meta_memory, x], dim=1)
        return meta_x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.optimizer.zero_grad()
        params = self.lmm.named_parameters()

        x = self._inject_meta_memory(x)

        queries = self.query_projection(x)
        keys = self.key_projection(x)
        values = self.value_projection(x)
        adaptive_lr = self.adaptive_lr_projection(x)

        grad_fn = grad(self._associative_loss, has_aux=True)
        per_chunk_grad_fn = vmap(grad_fn, in_dims=(None, 2, 2, 2))
        grads, _ = per_chunk_grad_fn(dict(params), keys, values, adaptive_lr)
        grads = TensorDict(grads).apply(lambda g: g.mean(0) if g.ndim == 3 else g)
        surprises = grads.mul(-1)  # TODO: store surprises

        for name, param in self.lmm.named_parameters():
            if grads.get(name) is not None:
                param.grad = grads.get(name)
        self.optimizer.step()

        retrieved = self.lmm(queries)
        retrieved = self.swa(retrieved)

        output = retrieved[:, self.meta_memory_dim :, :]  # discard meta-memory

        return output

In [163]:
class LayerWithNeuralMemory(nn.Module):
    def __init__(self, original_layer, mal):
        super().__init__()
        self.original_layer = original_layer
        self.mal = mal

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        output = self.original_layer(
            hidden_states, attention_mask=attention_mask, **kwargs
        )
        attn_output = output[0]
        print(attn_output.shape)
        mal_output = self.mal(attn_output)
        return (mal_output,) + output[1:]


config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

mal_params = {
    "layer_size": 128,
    "n_hidden_layers": 2,
    "meta_memory_dim": 16,
    "input_dim": 128,
    "learning_rate": 4e-4,
    "weight_decay": 0.1,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 8
}

model = LlamaForCausalLM(config)
mal = NeuralMemory(**mal_params)


original_layer = model.model.layers[-2]
model.model.layers[-2] = LayerWithNeuralMemory(original_layer, mal)

batch_size = 4
seq_len = 64

input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

torch.Size([4, 64, 128])
torch.Size([4, 10, 128]) torch.Size([4, 10, 1]) torch.Size([4, 10, 128]) torch.Size([4, 10])
torch.Size([4, 10])
Loss with MAL: 10.41355037689209
