In [1]:
from transformers import (
    AutoModelForCausalLM,
    QuantoConfig,
)
import torch
import torch.nn as nn

from hercules import LlamaMemoryAsLayer, NeuralMemory, MemoryLlama
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [2]:
neural_memory_config = {
    "n_hidden_layers": 1,
    "meta_memory_dim": 32,
    "input_dim": 2048,
    "hidden_dim": 256,
    "output_dim": 2048,
    "learning_rate": 4e-4,
    "max_adaptive_lr": 1e-2,
    "num_attention_heads": 4,
    "attention_window_size": 7,
    "n_chunks": 10,
}

In [73]:
model = MemoryLlama(
    llama_hf_path="meta-llama/Llama-3.2-1B",
    freeze_llama_layers=True,
    neural_memory_config=neural_memory_config,
    memory_layer_ids=-1,
    quantize=True,
)

In [6]:
device = torch.device("cpu")
model.to(device)

batch_size = 4
seq_len = 256

input_ids = torch.randint(0, model.llama.config.vocab_size, (batch_size, seq_len))
input_ids = input_ids.to(device)

output = model(input_ids, labels=input_ids)
print(f"Loss with MAL: {output.loss.item()}")
loss = output.loss.item()

Llama inputs: torch.Size([4, 256, 2048])
Llama outputs: torch.Size([4, 256, 2048])
lmm output: torch.Size([4, 256, 2048])
Loss with MAL: 11.761780738830566


In [70]:
def log_memory_model(model: AutoModelForCausalLM):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(
        f"""Trainable parameters: {trainable:.3e}
Frozen parameters: {frozen:.3e}"""
    )

In [None]:
for name, param in model.named_parameters():
     if name.startswith("llama"):
          param.requires_grad = False
log_memory_model(model.llama)   

Trainable parameters: 3.154e+07
Frozen parameters: 1.236e+09
