In [None]:
from transformers import (
    AutoModelForCausalLM,
    QuantoConfig,
    LlamaConfig,
    LlamaForCausalLM,
)
import torch
import torch.nn as nn

from hercules import LlamaMemoryAsLayer, NeuralMemory

In [2]:
config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

mal_params = {
    "input_dim": 128,
    "hidden_dim": 256,
    "output_dim": 128,
    "n_hidden_layers": 3,
    "meta_memory_dim": 16,
    "learning_rate": 4e-4,
    "weight_decay": 0.1,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

model = LlamaForCausalLM(config)
lmm = NeuralMemory(**mal_params)
device = torch.device("mps")

original_layer = model.model.layers[-2]
model.model.layers[-2] = LlamaMemoryAsLayer(original_layer, lmm)
model.lmm = lmm

batch_size = 4
seq_len = 256

input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

Loss with MAL: 10.402226448059082


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass, asdict
from datasets import load_dataset

In [7]:
class BabiDataset(Dataset):
    """Loads and processes the bAbI dataset for training."""

    def __init__(self, split="train", task="en-10k-qa1"):
        # Load the specified task from the bAbI dataset
        dataset = load_dataset("facebook/babi_qa", type=task)[split]

        self.stories = []
        self.questions = []
        self.answers = []

        self.word_to_idx = {"<pad>": 0, "<unk>": 1}
        self._build_vocab(dataset)

        for sample in dataset:
            story_text = " ".join(sample["story"]["text"])
            question_text = sample["question"]
            answer_text = sample["answer"]

            self.stories.append(story_text)
            self.questions.append(question_text)
            self.answers.append(answer_text)

    def _build_vocab(self, dataset):
        vocab = set()
        for example in dataset:
            for word in example["story"]["text"]:
                vocab.update(word.split())
            vocab.update(example["question"].split())
            vocab.update(example["answer"].split())

        for word in sorted(vocab):
            if word not in self.word_to_idx:
                self.word_to_idx[word] = len(self.word_to_idx)

    def get_vocab_size(self):
        return len(self.word_to_idx)

    def __len__(self):
        return len(self.stories)

    def __getitem__(self, idx):
        story = self.stories[idx]
        question = self.questions[idx]
        answer = self.answers[idx]

        # Combine story and question to form the input context
        context = f"{story} {question}"

        context_tokens = [self.word_to_idx.get(w, 1) for w in context.split()]
        answer_tokens = [self.word_to_idx.get(w, 1) for w in answer.split()]

        return torch.LongTensor(context_tokens), torch.LongTensor(answer_tokens)

In [8]:
def collate_fn(batch):
    """Pads sequences in a batch to the same length."""
    contexts, answers = zip(*batch)
    contexts_padded = pad_sequence(contexts, batch_first=True, padding_value=0)
    answers_padded = pad_sequence(answers, batch_first=True, padding_value=0)
    return contexts_padded, answers_padded

In [None]:
def train_one_epoch(
    model,
    dataloader,
    outer_optimizer,
    loss_fn,
    device,
    target_embedding_layer,
):
    model.train()
    total_loss = 0.0
    for sequences, targets in dataloader:
        sequences, targets = sequences.to(device), targets.to(device)

        outer_optimizer.zero_grad()

        # Get the embedding for the single retrieved value
        retrieved_embedding = model(sequences)

        # Get the ground truth embedding for the target token
        with torch.no_grad():
            target_embedding = target_embedding_layer(targets)

        # Calculate the supervising loss
        loss = loss_fn(retrieved_embedding, target_embedding)

        loss.backward()
        outer_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(
    model,
    dataloader,
    loss_fn,
    device,
    target_embedding_layer,
):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for sequences, targets in dataloader:
            sequences, targets = sequences.to(device), targets.to(device)
            retrieved_embedding = model(sequences)
            target_embedding = target_embedding_layer(targets)
            loss = loss_fn(retrieved_embedding, target_embedding)
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [None]:
if __name__ == '__main__':
        
    # Setup dataset and dataloader
    train_dataset = BabiDataset(num_samples=1000, seq_len=64, vocab_size=config.proxy_vocab_size)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    
    # TODO: Create a validation dataset and loader
    val_dataset = BabiDataset(...)
    val_loader = DataLoader(...)

    # Setup model
    model = Stage1Model(config).to(config.device)
    
    # Setup the main "outer loop" optimizer.
    # IMPORTANT: This optimizer trains the ProxyLLM and the projection layers inside
    # NeuralMemory, but NOT the LMM parameters themselves (self.lmm), which are
    # updated by their own internal optimizer.
    outer_params = [
        p for n, p in model.named_parameters() if not n.startswith('memory_module.lmm.')
    ]
    outer_optimizer = optim.AdamW(outer_params, lr=config.outer_learning_rate)
    
    # Loss function for the outer loop (supervising loss)
    supervising_loss_fn = nn.MSELoss()
    
    print("Starting Stage 1 Training...")
    for epoch in range(config.epochs):
        train_loss = train_one_epoch(
            model, 
            train_loader, 
            outer_optimizer, 
            supervising_loss_fn, 
            config.device,
            model.proxy_model.embedding # Pass the embedding layer to get target vectors
        )
        
        val_loss = evaluate(model, val_loader, supervising_loss_fn, config.device, model.proxy_model.embedding)
        
        print(f"Epoch {epoch+1}/{config.epochs} | Train Loss: {train_loss:.4f}")
        print(f"Epoch {epoch+1}/{config.epochs} | Val Loss: {val_loss:.4f} ")

In [None]:
assert

In [3]:
class MemoryLlama(nn.Module):
    def __init__(
        self,
        memory_architecture: str,
        llama_hf_path: str,
        freeze_llama_layers: bool,
        neural_memory_config: dict,
        quantize: bool,
    ):
        super(MemoryLlama, self).__init__()
        self.MEMORY_ARCHITECTURES = ["layer", "gate", "context"]

        assert (
            memory_architecture in self.MEMORY_ARCHITECTURES
        ), f"Memory architecture must be one of {self.MEMORY_ARCHITECTURES}, got {memory_architecture}"

        self.memory_architecture = memory_architecture

        if quantize:
            quantization_config = QuantoConfig(weights="int4")
        else:
            quantization_config = None

        self.quantization_config = quantization_config
        self.llama = AutoModelForCausalLM.from_pretrained(
            llama_hf_path, quantization_config=quantization_config, device_map="auto"
        )
        self.tokenizer = AutoModelForCausalLM.from_pretrained(llama_hf_path)
        self.config = self.llama.config

        if freeze_llama_layers:
            for param in self.llama.parameters():
                param.requires_grad = False

        self.neural_memory_config = neural_memory_config
        self.neural_memory_config["input_dim"] = self.config.hidden_size
        self.neural_memory = NeuralMemory(**neural_memory_config)

        if memory_architecture == "layer":
            original_layer = self.llama.model.layers[-2]
            self.llama.model.layers[-2] = LlamaMemoryAsLayer(
                original_layer, self.neural_memory
            )

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        return self.llama(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
        )


In [4]:
neural_memory_config = {
    "n_hidden_layers": 1,
    "meta_memory_dim": 32,
    "input_dim": 2048,
    "hidden_dim": 256,
    "output_dim": 2048,
    "learning_rate": 4e-4,
    "weight_decay": 0.1,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

In [5]:
model = MemoryLlama(
    "layer",
    "meta-llama/Llama-3.2-1B",
    True,
    neural_memory_config,
    quantize=True,
)

layer_sizes=[(np.int64(2048), np.int64(256)), (np.int64(256), np.int64(2048))]


In [6]:
device = torch.device("mps")
model.to(device)
input_ids = input_ids.to(device)
output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")

0
torch.Size([4, 29, 2048]) torch.Size([2048, 256])
torch.Size([4, 29, 2048]) torch.Size([2048, 256])
1
torch.Size([4, 29, 256]) torch.Size([256, 2048])
torch.Size([4, 29, 256]) torch.Size([256, 2048])
0
torch.Size([4, 290, 2048]) torch.Size([2048, 256])
torch.Size([4, 290, 2048]) torch.Size([2048, 256])
1
torch.Size([4, 290, 256]) torch.Size([256, 2048])
torch.Size([4, 290, 256]) torch.Size([256, 2048])
Loss with MAL: 12.350139617919922


```python
@triton.jit
def associative_scan_kernel(
    eta_pointer,
    theta_u_pointer,
    output_pointer,
    batch_size,
    n_chunks,
    input_dim,
    stride_eta_b,
    stride_eta_c,
    stride_theta_b,
    stride_theta_c,
    stride_theta_h,
    stride_out_b,
    stride_out_c,
    stride_out_h,
    stride_out_w,
    BLOCK_SIZE: tl.constexpr,
):
    # Program ID: One block per batch and input_dim pair
    pid = tl.program_id(0)
    batch_idx = pid // input_dim
    h_idx = pid % input_dim  # Row of weight matrix

    if batch_idx >= batch_size or h_idx >= input_dim:
        return

    # Base pointers
    eta_base = eta_pointer + batch_idx * stride_eta_b
    theta_u_base = theta_u_pointer + batch_idx * stride_theta_b + h_idx * stride_theta_h
    output_base = output_pointer + batch_idx * stride_out_b + h_idx * stride_out_h

    # Load eta and theta_u for all chunks
    chunk_range = tl.arange(0, BLOCK_SIZE)
    mask = chunk_range < n_chunks
    eta = tl.load(eta_base + chunk_range * stride_eta_c, mask=mask, other=1.0)
    theta_u = tl.load(theta_u_base + chunk_range * stride_theta_c, mask=mask, other=0.0)

    # Initialize S_t
    S_t = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

    # Associative scan: Sequential within block, parallel across blocks
    for t in range(n_chunks):
        S_t_prev = S_t
        S_t = eta[t] * S_t_prev - theta_u[t]
        # Store S_t for this chunk
        tl.store(
            output_base + t * stride_out_c + chunk_range * stride_out_w, S_t, mask=mask
        )
```