In [1]:
%%capture
%pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading xx

In [2]:
from transformers import LlamaConfig, LlamaForCausalLM
from datasets  import load_dataset
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import einsum, pack, rearrange, reduce, repeat, unpack
from einops.layers.torch import Rearrange

In [3]:
config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

model = LlamaForCausalLM(config)
print(model.num_parameters())

9241728


In [4]:
dataset = load_dataset("tiny_shakespeare")["train"]["text"][0]
chars = sorted(list(set(dataset)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}


class TSDataset(Dataset):
    def __init__(self, text, seq_len=32):
        self.data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        input_ids = self.data[idx : idx + self.seq_len]
        labels = self.data[idx + 1 : idx + self.seq_len + 1]
        return input_ids, labels


toy_dataset = TSDataset(dataset)
dataloader = DataLoader(toy_dataset, batch_size=16, shuffle=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

tiny_shakespeare.py:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

The repository for tiny_shakespeare contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/tiny_shakespeare.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

In [5]:
for input_ids, labels in dataloader:
    output = model(input_ids, labels=labels)
    print(f"Loss: {output['loss'].item()}")
    break

Loss: 10.36716365814209


In [15]:
class ResLinear(nn.Module):
    """Residual MLP with SiLU activation."""

    def __init__(self, layer_size: int, num_layers: int):
        super(ResLinear, self).__init__()
        dims = np.tile([layer_size], num_layers)
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(d, d)) for d in dims])
        for w in self.weights:
            nn.init.xavier_uniform_(w)

    def forward(self, x: torch.Tensor):
        for idx, w in enumerate(self.weights):
            first_layer = idx == 0
            if not first_layer:
                x = F.silu(x)
            residual = x
            x = x @ w + residual

        return x


class LinearProjection(nn.Module):
    """Linear Layer with no bias."""

    def __init__(self, in_dim: int, out_dim: int) -> None:
        super(LinearProjection, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, x: torch.Tensor):
        return F.silu(self.linear(x))


class MemoryAsLayer(nn.Module):
    def __init__(
        self,
        layer_size: int,
        input_dim: int,
        n_hidden_layers: int,
        learning_rate: float,
        weight_decay: float,
    ) -> None:
        # TODO: add chunking
        # TODO: add multihead processing
        # TODO: add adaptive learning rate, momentum
        # TODO: add persistent memory
        super(MemoryAsLayer, self).__init__()
        self.input_dim = input_dim
        self.layer_size = layer_size
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.lmm = ResLinear(input_dim, n_hidden_layers)
        self.key_projection = LinearProjection(input_dim, layer_size)
        self.query_projection = LinearProjection(input_dim, layer_size)
        self.value_projection = LinearProjection(input_dim, layer_size)

        self.optimizer = torch.optim.AdamW(
            self.lmm.parameters(), self.learning_rate, weight_decay=self.weight_decay
        )

    def _associative_loss(self, params, inputs, targets, weights) -> float:
        pred = torch.func.functional_call(self.lmm, params, inputs)
        loss = torch.pow(pred - targets, 2).mean(dim=-1)

        if weights == None:
            weights = torch.ones_like(loss)  # TODO: pass actual weights

        weighted_loss = loss * weights
        return weighted_loss.sum(), loss

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        params = self.lmm.named_parameters()
        keys = self.key_projection(x)
        queries = self.query_projection(x)
        values = self.value_projection(x)

        keys_flat = keys.view(-1, self.layer_size)
        values_flat = values.view(-1, self.layer_size)

        grad_fn = grad(self._associative_loss, has_aux=True)
        grads, unweighted_loss = self._associative_loss(
            dict(params), keys_flat, values_flat, None
        )
        self.optimizer.zero_grad()

        with torch.no_grad():
          for name, param in params:
            if grads[name] is not None:
              param.grad = grads[name]

        self.optimizer.step()

        surprises = grads.mul(-1)
        retrieved = self.lmm(queries)

        return retrieved, surprises

In [16]:
mal_params = {
    "layer_size":128,
    "n_hidden_layers":2,
    "input_dim":128,
    "learning_rate":4e-4,
    "weight_decay":0.1
}
mal = MemoryAsLayer(**mal_params)

target_layer = model.model.layers[-2]
def mal_forward_hook(module, input, output):
    attn_output = output[0]  # (batch_size, seq_len, hidden_size)
    return mal(attn_output)

handle = target_layer.register_forward_hook(mal_forward_hook)

# Test forward pass
input_ids = torch.randint(0, config.vocab_size, (2, 32))
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")
handle.remove()

Loss with MAL: 10.41081714630127
