In [9]:
# Simplified imports for single GPU
from logging import getLogger
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from memory_layers import HashingMemory, MemoryLayerMonitorAndCheckpoint, load_and_process_dataset, ModelEvaluator

logger = getLogger()

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Qwen0.5 Instruct
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16)

# Qwen0.5 specs: 896 hidden_dim, 24 layers
hidden_dim = 896
layers_to_replace = [6, 12, 18]  # Which FFN layers to replace

# Replace FFNs with Memory Layers
for layer_idx in layers_to_replace:
    layer = model.model.layers[layer_idx]
    
    # Create memory layer
    memory_layer = HashingMemory(
        input_dim=hidden_dim,
        output_dim=hidden_dim,
        mem_n_keys=128,          # Memory size = 512¬≤ = 262k entries
        mem_heads=4,
        mem_knn=16,
        mem_k_dim=256,
        mem_v_dim=-1,            # Auto: uses output_dim
        swilu_projection=True,
        value_fixed_lr=0.001,
        mem_share_values=False,  # Don't share across layers for fine-tuning
    )
    
    # Initialize the memory layer
    memory_layer.reset_parameters()
    # Ensure memory layer matches model dtype (float16)
    memory_layer.to(device, dtype=model.dtype)
    
    # Replace the FFN (MLP) with memory layer
    original_mlp = layer.mlp
    layer.mlp = memory_layer
    
    print(f"Replaced layer {layer_idx} FFN with memory layer")

# FREEZE EVERYTHING EXCEPT MEMORY LAYERS
for name, param in model.named_parameters():
    if 'mlp' in name and any(f'layers.{idx}.' in name for idx in layers_to_replace):
        # This is a memory layer parameter - keep trainable
        param.requires_grad = True
        print(f"‚úì Trainable: {name}")
    else:
        # Freeze all other parameters
        param.requires_grad = False

# Verify what's trainable
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

Replaced layer 6 FFN with memory layer
Replaced layer 12 FFN with memory layer
Replaced layer 18 FFN with memory layer
‚úì Trainable: model.layers.6.mlp.keys
‚úì Trainable: model.layers.6.mlp.values.weight
‚úì Trainable: model.layers.6.mlp.value_proj.weight
‚úì Trainable: model.layers.6.mlp.value_proj.bias
‚úì Trainable: model.layers.6.mlp.swilu_projection.weight
‚úì Trainable: model.layers.6.mlp.swilu_projection.bias
‚úì Trainable: model.layers.6.mlp.query_proj.query_mlps.0.weight
‚úì Trainable: model.layers.6.mlp.query_proj.query_mlps.0.bias
‚úì Trainable: model.layers.12.mlp.keys
‚úì Trainable: model.layers.12.mlp.values.weight
‚úì Trainable: model.layers.12.mlp.value_proj.weight
‚úì Trainable: model.layers.12.mlp.value_proj.bias
‚úì Trainable: model.layers.12.mlp.swilu_projection.weight
‚úì Trainable: model.layers.12.mlp.swilu_projection.bias
‚úì Trainable: model.layers.12.mlp.query_proj.query_mlps.0.weight
‚úì Trainable: model.layers.12.mlp.query_proj.query_mlps.0.bias
‚úì Trainab

In [11]:
# Load and process dataset
# First we will finetune on background data to populate memory

tokenized = load_and_process_dataset(tokenizer, sample_size=20000)


Filtered dataset size: 22596
Tokenized dataset: Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 20000
})


In [12]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

# Training arguments optimized for memory layers only
training_args = TrainingArguments(
    output_dir="./qwen_memory_finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    learning_rate=5e-4,  # Higher LR since only training memory
    warmup_steps=100,
    lr_scheduler_type="cosine",
    logging_steps=10,
    logging_first_step=True,  # Log immediately
    logging_dir="./logs",
    save_steps=500,
    eval_strategy="steps",
    eval_steps=250,
    # Performance
    fp16=False,
    gradient_checkpointing=False,  # Not needed with frozen base
    dataloader_num_workers=2,

    # Monitoring
    report_to="tensorboard",  # or "wandb" if you have it
    # load_best_model_at_end=True,
    metric_for_best_model="loss",
    save_strategy="no",

    # Memory optimization
    optim="adamw_torch_fused",  # Faster optimizer
    max_grad_norm=1.0,
)

# Initialize evaluator
evaluator = ModelEvaluator(model, tokenizer, device=device)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Initialize callback
memory_monitor = MemoryLayerMonitorAndCheckpoint(
    model=model,
    layers_to_check=layers_to_replace,
    save_every=500,
    keep_last=2,
    monitor_every=50,
    evaluator=evaluator,
    eval_every=100,     # Run evaluation every 100 steps
    eval_samples=50     # Small sample size for speed during training
)

# Create trainer with callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    eval_dataset=tokenized.select(range(1000)),  # Use 1k for validation
    data_collator=data_collator,
    callbacks=[memory_monitor],  # Add our custom monitor
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [13]:
print("\nüöÄ Starting training...")
print(f"Total steps: {len(tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

# Train! 
trainer.train()

print("\n‚úÖ Training complete!")


üöÄ Starting training...
Total steps: 2500


Step,Training Loss,Validation Loss



üîç MEMORY LAYER HEALTH CHECK - Step 50

üìä Layer 6 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 12 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 18 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

‚ö†Ô∏è  Some memory layers need attention!



`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'trivia_qa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.



üîç MEMORY LAYER HEALTH CHECK - Step 100

üìä Layer 6 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 12 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 18 Memory:
  Parameters:
    Keys:   mean=+nan, std=nan
    Values: mean=+nan, std=nan
  Changes since start:
    Keys:   nan ‚ùå FROZEN
    Values: nan ‚ùå FROZEN
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

‚ö†Ô∏è  Some memory layers need attention!


üìä RUNNING BENCHMARK EVALUATION - Step 100
Evaluating on TriviaQA (validation, 50 samples)...


Generating train split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 138384/138384 [00:01<00:00, 136853.56 examples/s]
Generating validation split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17944/17944 [00:00<00:00, 139900.05 examples/s]
Generating test split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17210/17210 [00:00<00:00, 420358.56 examples/s]
  0%|          | 0/50 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:41<00:00,  1.21it/s]


TriviaQA Accuracy: 0.00%
Evaluating on GSM8K (test, 50 samples)...


Generating train split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7473/7473 [00:00<00:00, 189244.71 examples/s]
Generating test split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1319/1319 [00:00<00:00, 131627.10 examples/s]





KeyboardInterrupt: 

In [3]:
SPARSE_TOP_T = hidden_dim // 2  # top-t memory slots updated per layer, per batch
BACKGROUND_MAX_BATCHES = 200 # how many background batches to use for DF

class MemoryAccessLogger:
    """
    Forward hook on each memory.values table that counts how often
    each memory index is used in a batch.
    """
    def __init__(self, values_module):
        self.module = values_module
        self.mem_size = values_module.weight.shape[0]
        self.device = values_module.weight.device
        self.batch_counts = torch.zeros(self.mem_size, dtype=torch.int64, device=self.device)
        self.handle = self.module.register_forward_hook(self.hook)

    def reset(self):
        self.batch_counts.zero_()

    def hook(self, module, inputs, output):
        # inputs[0] is 'indices' passed to xFormerEmbeddingBag.forward
        indices = inputs[0]  # shape: [N, bag_size]
        with torch.no_grad():
            flat = indices.reshape(-1).detach().to("cpu")
            counts_cpu = torch.bincount(flat, minlength=self.mem_size)
            counts = counts_cpu.to(self.batch_counts.device)
            self.batch_counts += counts

    def close(self):
        self.handle.remove()

# Create one logger per memory layer
mem_loggers = {}
mem_sizes = {}

for idx in layers_to_replace:  # e.g. [6, 12, 18]
    mem_layer = model.model.layers[idx].mlp  # this is HashingMemory
    mem_loggers[idx] = MemoryAccessLogger(mem_layer.values)
    mem_sizes[idx] = mem_layer.values.weight.shape[0]

print("\n‚úÖ MemoryAccessLogger attached for layers:", layers_to_replace)


‚úÖ MemoryAccessLogger attached for layers: [6, 12, 18]


In [None]:
from torch.utils.data import DataLoader
from transformers import Trainer

# 1) Compute background document frequency (DF) over memory slots
background_dataset = tokenized

bg_loader = DataLoader(
    background_dataset,
    batch_size=1, # small batch to avoid OOM
    shuffle=False,
    collate_fn=data_collator,
)

bg_df = {
    idx: torch.zeros(mem_sizes[idx], dtype=torch.int64, device=device)
    for idx in layers_to_replace
}
bg_num_batches = {idx: 0 for idx in layers_to_replace}

model.eval()
torch.cuda.empty_cache()
print("Computing background DF stats for sparse memory finetuning...")

with torch.no_grad():
    for step, batch in enumerate(bg_loader):
        if step >= BACKGROUND_MAX_BATCHES:
            break

        # Reset per-batch counts for each memory layer
        for logger in mem_loggers.values():
            logger.reset()

        # Move to device and DROP labels to avoid computing loss
        batch = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}

        _ = model(**batch)  # forward only

        # For each layer, mark which slots were touched in THIS batch
        for idx, logger in mem_loggers.items():
            used = (logger.batch_counts > 0).to(bg_df[idx].dtype)
            bg_df[idx] += used
            bg_num_batches[idx] += 1

        torch.cuda.empty_cache()

for idx in layers_to_replace:
    print(f"Layer {idx}: background batches = {bg_num_batches[idx]}")

model.train()
torch.cuda.empty_cache()

# 2) Per-slot trainability masks and gradient hooks:
#    only the top-t (TF-IDF) slots will get gradient each batch.
slot_train_masks = {
    idx: torch.zeros(mem_sizes[idx], dtype=torch.bool, device=device)
    for idx in layers_to_replace
}

def make_grad_hook(layer_idx):
    def hook(grad):
        """
        grad: [num_slots, value_dim] for this layer's values.weight
        We zero-out rows whose mask is False, so only top-t slots get gradient.
        """
        mask = slot_train_masks[layer_idx].to(grad.device)  # [num_slots]
        return grad * mask.unsqueeze(-1)
    return hook

# Attach hooks to each memory values table
for idx in layers_to_replace:
    values_param = model.model.layers[idx].mlp.values.weight
    values_param.register_hook(make_grad_hook(idx))

print("\nTrainable parameters after adding sparse hooks:")
for n, p in model.named_parameters():
    if p.requires_grad:
        print("  ", n)

# 3) SparseMemoryTrainer: override compute_loss with the new signature
#    NOTE: num_items_in_batch is added to match HF Trainer API.

class SparseMemoryTrainer(Trainer):
    def compute_loss(
        self,
        model,
        inputs,
        return_outputs: bool = False,
        num_items_in_batch: int = None,
    ):
        # Reset per-batch counts for loggers
        for logger in mem_loggers.values():
            logger.reset()

        # Standard forward pass (with labels, so we get loss)
        outputs = model(**inputs)
        loss = outputs.loss

        # Compute TF-IDF score for each memory slot, per layer, and choose top-t
        with torch.no_grad():
            for idx in layers_to_replace:
                logger = mem_loggers[idx]
                counts = logger.batch_counts # c(i) on this batch
                total = counts.sum()

                if total == 0:
                    slot_train_masks[idx].fill_(False)
                    continue

                # Term frequency: c(i) / sum_j c(j)
                tf = counts.float() / total

                # Document frequency from background: df(i) = #background batches where slot i was used
                df = bg_df[idx].float()
                N = float(bg_num_batches[idx])  # |B|

                # Paper‚Äôs TF-IDF:
                # c(i)/sum_j c(j) * log( (|B| + 1) / (df(i) + 1) )
                idf = torch.log((N + 1.0) / (df + 1.0))
                tfidf = tf * idf.to(tf.device)

                # Don't pick slots that weren't used in this batch
                tfidf = tfidf.masked_fill(counts == 0, float("-inf"))

                num_active = int((counts > 0).sum().item())
                k = min(SPARSE_TOP_T, num_active) if num_active > 0 else 0

                if k == 0:
                    slot_train_masks[idx].fill_(False)
                    continue

                # Top-t indices by TF-IDF
                _, topk_idx = torch.topk(tfidf, k=k)
                mask = torch.zeros_like(slot_train_masks[idx])
                mask[topk_idx] = True
                slot_train_masks[idx] = mask

        if return_outputs:
            return loss, outputs
        return loss

# Initialize Evaluator
evaluator = ModelEvaluator(model, tokenizer, device=device)

# Initialize callback
memory_monitor = MemoryLayerMonitorAndCheckpoint(
    model=model,
    layers_to_check=layers_to_replace,
    save_every=500,
    keep_last=2,
    monitor_every=50,
    evaluator=evaluator,
    eval_every=100,     # Run evaluation every 100 steps
    eval_samples=20     # Small sample size for speed during training
)

# 4) Recreate 'trainer' using SparseMemoryTrainer (overwrites the earlier Trainer)
trainer = SparseMemoryTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    eval_dataset=tokenized.select(range(1000)), # same eval subset as before
    data_collator=data_collator,
    callbacks=[memory_monitor], # keep your monitor/checkpoints
)

print("\n‚úÖ SparseMemoryTrainer is set up. trainer.train() cell will now do sparse memory finetuning.")

Computing background DF stats for sparse memory finetuning...
Layer 6: background batches = 200
Layer 12: background batches = 200
Layer 18: background batches = 200

Trainable parameters after adding sparse hooks:
   model.layers.6.mlp.keys
   model.layers.6.mlp.values.weight
   model.layers.6.mlp.value_proj.weight
   model.layers.6.mlp.value_proj.bias
   model.layers.6.mlp.swilu_projection.weight
   model.layers.6.mlp.swilu_projection.bias
   model.layers.6.mlp.query_proj.query_mlps.0.weight
   model.layers.6.mlp.query_proj.query_mlps.0.bias
   model.layers.12.mlp.keys
   model.layers.12.mlp.values.weight
   model.layers.12.mlp.value_proj.weight
   model.layers.12.mlp.value_proj.bias
   model.layers.12.mlp.swilu_projection.weight
   model.layers.12.mlp.swilu_projection.bias
   model.layers.12.mlp.query_proj.query_mlps.0.weight
   model.layers.12.mlp.query_proj.query_mlps.0.bias
   model.layers.18.mlp.keys
   model.layers.18.mlp.values.weight
   model.layers.18.mlp.value_proj.weight
 

In [51]:
print("\nüöÄ Starting training...")
print(f"Total steps: {len(tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

# Train!
trainer.train()

print("\n‚úÖ Training complete!")


üöÄ Starting training...
Total steps: 1437


Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 536.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 152.12 MiB is free. Process 8929 has 14.59 GiB memory in use. Of the allocated memory 13.94 GiB is allocated by PyTorch, and 526.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_dim = 896
layers_to_replace = [6, 12, 18]

# Reload model for testing
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    dtype=torch.float16,
).to(device)

# Add memory layers
for idx in layers_to_replace:
    # Initialize and cast to correct device/dtype
    mem_layer = HashingMemory(
        input_dim=hidden_dim, output_dim=hidden_dim, mem_n_keys=128, mem_heads=4,
        mem_knn=16, mem_k_dim=256, mem_v_dim=-1, swilu_projection=True,
        value_fixed_lr=0.001, mem_share_values=False
    )
    # Important: Cast to model's dtype (float16) to avoid "Half and Float" errors
    model.model.layers[idx].mlp = mem_layer.to(device, dtype=model.dtype)

# Load weights
try:
    state_dict = load_file("./qwen_memory_final/model.safetensors")
except:
    state_dict = torch.load("./qwen_memory_final/pytorch_model.bin",
                           weights_only=False)

model.load_state_dict(state_dict, strict=False)
print("\n‚úÖ Model loaded successfully!")

# Test generation
def test_model(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Try some prompts
test_prompts = [
    "Explain quantum computing in simple terms:",
    "Write a Python function to sort a list:",
    "What are the health benefits of exercise?",
]

for prompt in test_prompts:
    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    response = test_model(prompt)
    print(response)

In [None]:
# Load original Qwen model for comparison
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    torch_dtype=torch.float16,
)
base_model.to(device)

def compare_models(prompt):
    # Your fine-tuned model
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        # Fine-tuned
        ft_outputs = model.generate(**inputs, max_new_tokens=100)
        ft_response = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)

        # Base
        base_outputs = base_model.generate(**inputs, max_new_tokens=100)
        base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)

    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    print(f"\nüî∑ BASE MODEL:")
    print(base_response)
    print(f"\nüî∂ FINE-TUNED (with memory layers):")
    print(ft_response)
    print(f"{'='*80}\n")

# Test
compare_models("Explain machine learning:")