In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    QuantoConfig,
    LlamaConfig,
    LlamaForCausalLM,
)
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad, vmap
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import einsum, pack, rearrange, reduce, repeat, unpack
from einops.layers.torch import Rearrange
from tensordict import TensorDict

from atlas import LlamaMemoryAsLayer, NeuralMemory

In [2]:
config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

mal_params = {
    "input_dim": 128,       
    "hidden_dim": 256,
    "output_dim":128,
    "n_hidden_layers": 3,
    "meta_memory_dim": 16,
    "learning_rate": 4e-4,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

model = LlamaForCausalLM(config)
lmm = NeuralMemory(**mal_params)
device = torch.device("mps")

original_layer = model.model.layers[-2]
model.model.layers[-2] = LlamaMemoryAsLayer(original_layer, lmm)
model.lmm = lmm

batch_size = 4
seq_len = 256

input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

layer_sizes=[(np.int64(128), np.int64(256)), (np.int64(256), np.int64(256)), (np.int64(256), np.int64(256)), (np.int64(256), np.int64(128))]
0
torch.Size([4, 28, 128]) torch.Size([128, 256])
torch.Size([4, 28, 128]) torch.Size([128, 256])
1
torch.Size([4, 28, 256]) torch.Size([256, 256])
torch.Size([4, 28, 256]) torch.Size([256, 256])
2
torch.Size([4, 28, 256]) torch.Size([256, 256])
torch.Size([4, 28, 256]) torch.Size([256, 256])
3
torch.Size([4, 28, 256]) torch.Size([256, 128])
torch.Size([4, 28, 256]) torch.Size([256, 128])
0
torch.Size([4, 280, 128]) torch.Size([128, 256])
torch.Size([4, 280, 128]) torch.Size([128, 256])
1
torch.Size([4, 280, 256]) torch.Size([256, 256])
torch.Size([4, 280, 256]) torch.Size([256, 256])
2
torch.Size([4, 280, 256]) torch.Size([256, 256])
torch.Size([4, 280, 256]) torch.Size([256, 256])
3
torch.Size([4, 280, 256]) torch.Size([256, 128])
torch.Size([4, 280, 256]) torch.Size([256, 128])
Loss with MAL: 10.402139663696289


In [3]:
class MemoryLlama(nn.Module):
    def __init__(
        self,
        memory_architecture: str,
        llama_hf_path: str,
        freeze_llama_layers: bool,
        neural_memory_config: dict,
        quantize: bool,
    ):
        super(MemoryLlama, self).__init__()
        self.MEMORY_ARCHITECTURES = ["layer", "gate", "context"]

        assert (
            memory_architecture in self.MEMORY_ARCHITECTURES
        ), f"Memory architecture must be one of {self.MEMORY_ARCHITECTURES}, got {memory_architecture}"

        self.memory_architecture = memory_architecture

        if quantize:
            quantization_config = QuantoConfig(weights="int4")
        else:
            quantization_config = None

        self.quantization_config = quantization_config
        self.llama = AutoModelForCausalLM.from_pretrained(
            llama_hf_path, quantization_config=quantization_config, device_map="auto"
        )
        self.tokenizer = AutoModelForCausalLM.from_pretrained(llama_hf_path)
        self.config = self.llama.config

        if freeze_llama_layers:
            for param in self.llama.parameters():
                param.requires_grad = False

        self.neural_memory_config = neural_memory_config
        self.neural_memory_config["input_dim"] = self.config.hidden_size
        self.neural_memory = NeuralMemory(**neural_memory_config)

        if memory_architecture == "layer":
            original_layer = self.llama.model.layers[-2]
            self.llama.model.layers[-2] = LlamaMemoryAsLayer(
                original_layer, self.neural_memory
            )

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        return self.llama(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
        )

In [4]:
neural_memory_config = {
    "n_hidden_layers": 1,
    "meta_memory_dim": 32,
    "input_dim": 2048,
    "hidden_dim": 256,
    "output_dim": 2048,
    "learning_rate": 4e-4,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

In [5]:
model = MemoryLlama(
    "layer",
    "meta-llama/Llama-3.2-1B",
    True,
    neural_memory_config,
    quantize=True,
)

layer_sizes=[(np.int64(2048), np.int64(256)), (np.int64(256), np.int64(2048))]


In [6]:
device = torch.device("mps")
model.to(device)
input_ids = input_ids.to(device)
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

0
torch.Size([4, 29, 2048]) torch.Size([2048, 256])
torch.Size([4, 29, 2048]) torch.Size([2048, 256])
1
torch.Size([4, 29, 256]) torch.Size([256, 2048])
torch.Size([4, 29, 256]) torch.Size([256, 2048])
0
torch.Size([4, 290, 2048]) torch.Size([2048, 256])
torch.Size([4, 290, 2048]) torch.Size([2048, 256])
1
torch.Size([4, 290, 256]) torch.Size([256, 2048])
torch.Size([4, 290, 256]) torch.Size([256, 2048])
Loss with MAL: 14.177559852600098


```python
@triton.jit
def associative_scan_kernel(
    eta_pointer,
    theta_u_pointer,
    output_pointer,
    batch_size,
    n_chunks,
    input_dim,
    stride_eta_b,
    stride_eta_c,
    stride_theta_b,
    stride_theta_c,
    stride_theta_h,
    stride_out_b,
    stride_out_c,
    stride_out_h,
    stride_out_w,
    BLOCK_SIZE: tl.constexpr,
):
    # Program ID: One block per batch and input_dim pair
    pid = tl.program_id(0)
    batch_idx = pid // input_dim
    h_idx = pid % input_dim  # Row of weight matrix

    if batch_idx >= batch_size or h_idx >= input_dim:
        return

    # Base pointers
    eta_base = eta_pointer + batch_idx * stride_eta_b
    theta_u_base = theta_u_pointer + batch_idx * stride_theta_b + h_idx * stride_theta_h
    output_base = output_pointer + batch_idx * stride_out_b + h_idx * stride_out_h

    # Load eta and theta_u for all chunks
    chunk_range = tl.arange(0, BLOCK_SIZE)
    mask = chunk_range < n_chunks
    eta = tl.load(eta_base + chunk_range * stride_eta_c, mask=mask, other=1.0)
    theta_u = tl.load(theta_u_base + chunk_range * stride_theta_c, mask=mask, other=0.0)

    # Initialize S_t
    S_t = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

    # Associative scan: Sequential within block, parallel across blocks
    for t in range(n_chunks):
        S_t_prev = S_t
        S_t = eta[t] * S_t_prev - theta_u[t]
        # Store S_t for this chunk
        tl.store(
            output_base + t * stride_out_c + chunk_range * stride_out_w, S_t, mask=mask
        )
```