In [1]:
# Simplified imports for single GPU
from logging import getLogger
import math
from typing import Optional
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

# Import the embedding bag
# from memory_layers.xformer_embeddingbag import xFormerEmbeddingBag

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from memory_layers.memory import HashingMemory, ProductKeyArgs
device = "cuda"
# Load Qwen0.5 Instruct
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16)

# Qwen0.5 specs: 896 hidden_dim, 24 layers
hidden_dim = 896
layers_to_replace = [6, 12, 18]  # Which FFN layers to replace

# Replace FFNs with Memory Layers
for layer_idx in layers_to_replace:
    layer = model.model.layers[layer_idx]
    
    # Create memory layer
    memory_layer = HashingMemory(
        input_dim=hidden_dim,
        output_dim=hidden_dim,
        mem_n_keys=128,          # Memory size = 512¬≤ = 262k entries
        mem_heads=4,
        mem_knn=16,
        mem_k_dim=256,
        mem_v_dim=-1,            # Auto: uses output_dim
        swilu_projection=True,
        value_fixed_lr=0.001,
        mem_share_values=False,  # Don't share across layers for fine-tuning
    )
    
    # Initialize the memory layer
    memory_layer.reset_parameters()
    memory_layer.to(device)
    
    # Replace the FFN (MLP) with memory layer
    original_mlp = layer.mlp
    layer.mlp = memory_layer
    
    print(f"Replaced layer {layer_idx} FFN with memory layer")

# FREEZE EVERYTHING EXCEPT MEMORY LAYERS
for name, param in model.named_parameters():
    if 'mlp' in name and any(f'layers.{idx}.' in name for idx in layers_to_replace):
        # This is a memory layer parameter - keep trainable
        param.requires_grad = True
        print(f"‚úì Trainable: {name}")
    else:
        # Freeze all other parameters
        param.requires_grad = False

# Verify what's trainable
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

  from .autonotebook import tqdm as notebook_tqdm


Replaced layer 6 FFN with memory layer
Replaced layer 12 FFN with memory layer
Replaced layer 18 FFN with memory layer
‚úì Trainable: model.layers.6.mlp.keys
‚úì Trainable: model.layers.6.mlp.values.weight
‚úì Trainable: model.layers.6.mlp.value_proj.weight
‚úì Trainable: model.layers.6.mlp.value_proj.bias
‚úì Trainable: model.layers.6.mlp.swilu_projection.weight
‚úì Trainable: model.layers.6.mlp.swilu_projection.bias
‚úì Trainable: model.layers.6.mlp.query_proj.query_mlps.0.weight
‚úì Trainable: model.layers.6.mlp.query_proj.query_mlps.0.bias
‚úì Trainable: model.layers.12.mlp.keys
‚úì Trainable: model.layers.12.mlp.values.weight
‚úì Trainable: model.layers.12.mlp.value_proj.weight
‚úì Trainable: model.layers.12.mlp.value_proj.bias
‚úì Trainable: model.layers.12.mlp.swilu_projection.weight
‚úì Trainable: model.layers.12.mlp.swilu_projection.bias
‚úì Trainable: model.layers.12.mlp.query_proj.query_mlps.0.weight
‚úì Trainable: model.layers.12.mlp.query_proj.query_mlps.0.bias
‚úì Trainab

In [3]:
from transformers import Trainer, TrainingArguments, TrainerCallback
import torch
import wandb  # Optional but highly recommended
from transformers import TrainerCallback
import torch
import os
import shutil

class MemoryLayerMonitorAndCheckpoint(TrainerCallback):
    """
    Combined callback for:
    1.  Monitoring memory layer training health
    2. Safe checkpoint saving with safetensors
    """
    
    def __init__(self, model, layers_to_check=[6, 12, 18], 
                 save_every=500, keep_last=2, monitor_every=50):
        # Monitoring
        self.model = model
        self.layers_to_check = layers_to_check
        self.monitor_every = monitor_every
        self.initial_params = {}
        
        # Checkpointing
        self.save_every = save_every
        self.keep_last = keep_last
        self.checkpoints = []
        
        # Store initial parameter values for monitoring
        for idx in layers_to_check:
            layer = model.model.layers[idx].mlp
            self.initial_params[f"layer_{idx}_keys"] = layer.keys.data.clone()
            self.initial_params[f"layer_{idx}_values"] = layer.values.weight.data.clone()
    
    def on_step_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
        step = state.global_step
        
        # ================================================================
        # MONITORING (every N steps)
        # ================================================================
        if step % self.monitor_every == 0 and step > 0:
            self._monitor_health(step)
        
        # ================================================================
        # CHECKPOINTING (every M steps)
        # ================================================================
        if step % self.save_every == 0 and step > 0:
            self._save_checkpoint(step, state, model, tokenizer)
    
    def _monitor_health(self, step):
        """Monitor memory layer training health"""
        print(f"\n{'='*80}")
        print(f"üîç MEMORY LAYER HEALTH CHECK - Step {step}")
        print(f"{'='*80}")
        
        all_healthy = True
        
        for idx in self.layers_to_check:
            layer = self.model.model.layers[idx].mlp
            
            # Check parameter changes
            keys_diff = (
                layer.keys.data - self.initial_params[f"layer_{idx}_keys"]
            ).abs().mean().item()
            values_diff = (
                layer.values.weight.data - self.initial_params[f"layer_{idx}_values"]
            ).abs().mean().item()
            
            # Check gradients
            keys_grad = layer.keys.grad.norm().item() if layer.keys.grad is not None else 0.0
            values_grad = (
                layer.values.weight.grad.norm().item() 
                if layer.values.weight.grad is not None else 0.0
            )
            
            # Parameter statistics
            keys_mean = layer.keys.data.mean().item()
            keys_std = layer.keys.data.std().item()
            values_mean = layer.values.weight.data.mean().item()
            values_std = layer.values.weight.data.std().item()
            
            print(f"\nüìä Layer {idx} Memory:")
            print(f"  Parameters:")
            print(f"    Keys:   mean={keys_mean:+.4f}, std={keys_std:.4f}")
            print(f"    Values: mean={values_mean:+.4f}, std={values_std:.4f}")
            print(f"  Changes since start:")
            print(f"    Keys:   {keys_diff:.6f} {'‚úÖ' if keys_diff > 1e-6 else '‚ùå FROZEN'}")
            print(f"    Values: {values_diff:.6f} {'‚úÖ' if values_diff > 1e-6 else '‚ùå FROZEN'}")
            print(f"  Gradient norms:")
            print(f"    Keys:   {keys_grad:.4f} {'‚úÖ' if keys_grad > 0 else '‚ùå NO GRAD'}")
            print(f"    Values: {values_grad:.4f} {'‚úÖ' if values_grad > 0 else '‚ùå NO GRAD'}")
            
            # Health checks
            if keys_diff < 1e-8 and step > 100:
                print(f"  ‚ö†Ô∏è  WARNING: Keys not updating!")
                all_healthy = False
            if values_diff < 1e-8 and step > 100:
                print(f"  ‚ö†Ô∏è  WARNING: Values not updating!")
                all_healthy = False
            if keys_grad == 0.0:
                print(f"  ‚ö†Ô∏è  WARNING: No gradient flow to keys!")
                all_healthy = False
            if values_grad == 0.0:
                print(f"  ‚ö†Ô∏è  WARNING: No gradient flow to values!")
                all_healthy = False
        
        if all_healthy:
            print(f"\n‚úÖ All memory layers healthy!")
        else:
            print(f"\n‚ö†Ô∏è  Some memory layers need attention!")
        
        print(f"{'='*80}\n")
    
    def _save_checkpoint(self, step, state, model, tokenizer):
        """Save checkpoint safely with safetensors"""
        checkpoint_dir = f"./checkpoints/step-{step}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        print(f"\nüíæ Saving checkpoint at step {step}...")
        
        try:
            # Save model with safetensors (no JSON serialization issues)
            model.save_pretrained(
                checkpoint_dir, 
                safe_serialization=True
            )
            
            # Save tokenizer
            if tokenizer:
                tokenizer.save_pretrained(checkpoint_dir)
            
            # Save minimal training state (safe to serialize)
            training_state = {
                'step': step,
                'epoch': state.epoch,
                'global_step': state.global_step,
            }
            
            # Add last loss if available
            if state.log_history:
                last_log = state.log_history[-1]
                if 'loss' in last_log:
                    training_state['loss'] = last_log['loss']
            
            torch.save(
                training_state, 
                os.path.join(checkpoint_dir, 'training_state.pt')
            )
            
            # Track checkpoints
            self.checkpoints.append(checkpoint_dir)
            
            # Remove old checkpoints (keep only last N)
            if len(self.checkpoints) > self.keep_last:
                old_checkpoint = self.checkpoints.pop(0)
                if os.path.exists(old_checkpoint):
                    shutil.rmtree(old_checkpoint)
                    print(f"  üóëÔ∏è  Removed old checkpoint: {os.path.basename(old_checkpoint)}")
            
            print(f"  ‚úÖ Checkpoint saved: {checkpoint_dir}")
            
        except Exception as e:
            print(f"  ‚ùå Failed to save checkpoint: {e}")
            # Continue training even if checkpoint fails
    
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
        """Save final model at end of training"""
        print(f"\n{'='*80}")
        print("üèÅ TRAINING COMPLETE - Saving final model")
        print(f"{'='*80}\n")
        
        final_dir = "./qwen_memory_final"
        os.makedirs(final_dir, exist_ok=True)
        
        model.save_pretrained(final_dir, safe_serialization=True)
        if tokenizer:
            tokenizer.save_pretrained(final_dir)
        
        # Save final statistics
        final_stats = {
            'total_steps': state.global_step,
            'total_epochs': state.epoch,
        }
        
        if state.log_history:
            losses = [log['loss'] for log in state.log_history if 'loss' in log]
            if losses:
                final_stats['final_loss'] = losses[-1]
                final_stats['initial_loss'] = losses[0]
                final_stats['loss_improvement'] = losses[0] - losses[-1]
        
        torch.save(final_stats, os.path.join(final_dir, 'final_stats.pt'))
        
        print(f"‚úÖ Final model saved to: {final_dir}")
        print(f"   Total steps: {state.global_step}")
        print(f"   Total epochs: {state.epoch:.2f}")
        if 'loss_improvement' in final_stats:
            print(f"   Loss improvement: {final_stats['loss_improvement']:.4f}")
        print(f"\n{'='*80}\n")

# Initialize callback
memory_monitor = MemoryLayerMonitorAndCheckpoint(model=model,
    layers_to_check=[6, 12, 18],    # Your memory layer indices
    save_every=500,                  # Save checkpoint every 500 steps
    keep_last=2,                     # Keep only 2 checkpoints
    monitor_every=50,)

In [4]:
from datasets import load_dataset

# Load and filter OpenAssistant
dataset = load_dataset("OpenAssistant/oasst1", split="train")

# Keep only high-quality English assistant responses
filtered = dataset.filter(
    lambda x: (
        x['lang'] == 'en' and 
        x['role'] == 'assistant' and 
        x['rank'] == 0.0 and
        len(x['text']) > 50  # Filter out very short responses
    )
)

print(f"Filtered dataset size: {len(filtered)}")

# Take subset
dataset = filtered.select(range(min(20000, len(filtered))))

# Tokenize
def tokenize(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=2048,
        padding=False,
    )

tokenized = dataset.map(
    tokenize, 
    batched=True, 
    remove_columns=dataset.column_names,
    num_proc=4  # Speed up with multiprocessing
)

print(f"Tokenized dataset: {tokenized}")

Filtered dataset size: 7669
Tokenized dataset: Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 7669
})


In [5]:
from transformers import Trainer, TrainingArguments

# Training arguments optimized for memory layers only
training_args = TrainingArguments(
    output_dir="./qwen_memory_finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=5e-4,  # Higher LR since only training memory
    warmup_steps=100,
    lr_scheduler_type="cosine",
    logging_steps=10,
    logging_first_step=True,  # Log immediately
    logging_dir="./logs",
    save_steps=500,
    eval_strategy="steps",
    eval_steps=250,   
    # Performance
    fp16=True,
    gradient_checkpointing=False,  # Not needed with frozen base
    dataloader_num_workers=2,
    
    
    # Monitoring
    report_to="tensorboard",  # or "wandb" if you have it
    # load_best_model_at_end=True,
    metric_for_best_model="loss",
    save_strategy="no",
    
    # Memory optimization
    optim="adamw_torch_fused",  # Faster optimizer
    max_grad_norm=1.0,
)


In [6]:
from transformers import Trainer, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)


# Create trainer with callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    eval_dataset=tokenized.select(range(1000)),  # Use 1k for validation
    data_collator=data_collator,
    callbacks=[memory_monitor],  # Add our custom monitor
)

print("\nüöÄ Starting training...")
print(f"Total steps: {len(tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

# Train! 
trainer.train()

print("\n‚úÖ Training complete!")

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.



üöÄ Starting training...
Total steps: 1437


Step,Training Loss,Validation Loss
250,1.9712,1.873475
500,1.7182,1.789377
750,1.8301,1.702344
1000,1.6194,1.635654
1250,1.6419,1.603854



üîç MEMORY LAYER HEALTH CHECK - Step 50

üìä Layer 6 Memory:
  Parameters:
    Keys:   mean=+0.0001, std=0.0360
    Values: mean=+0.0000, std=0.0334
  Changes since start:
    Keys:   0.000846 ‚úÖ
    Values: 0.000810 ‚úÖ
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 12 Memory:
  Parameters:
    Keys:   mean=+0.0000, std=0.0361
    Values: mean=-0.0000, std=0.0334
  Changes since start:
    Keys:   0.000830 ‚úÖ
    Values: 0.000776 ‚úÖ
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

üìä Layer 18 Memory:
  Parameters:
    Keys:   mean=-0.0000, std=0.0361
    Values: mean=+0.0000, std=0.0334
  Changes since start:
    Keys:   0.000801 ‚úÖ
    Values: 0.000749 ‚úÖ
  Gradient norms:
    Keys:   0.0000 ‚ùå NO GRAD
    Values: 0.0000 ‚ùå NO GRAD

‚ö†Ô∏è  Some memory layers need attention!


üîç MEMORY LAYER HEALTH CHECK - Step 100

üìä Layer 6 Memory:
  Parameters:
    Keys:   mean=+0.0000, std=0.0360
    Va

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from memory_layers.memory import HashingMemory
from safetensors.torch import load_file

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    dtype=torch.float16,
)
model.to("cuda")

# Add memory layers
for idx in [6, 12, 18]:
    model.model.layers[idx].mlp = HashingMemory(
        input_dim=896, output_dim=896, mem_n_keys=128, mem_heads=4,
        mem_knn=16, mem_k_dim=256, mem_v_dim=-1, swilu_projection=True,
        value_fixed_lr=0.001, mem_share_values=False
    ).to("cuda")

# Load weights (use safetensors if available, otherwise pytorch with weights_only=False)
try:
    state_dict = load_file("./qwen_memory_final/model.safetensors")
except:
    state_dict = torch.load("./qwen_memory_final/pytorch_model.bin", 
                           weights_only=False)  # ‚Üê Fix here

model.load_state_dict(state_dict, strict=False)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16)

print("\n‚úÖ Model loaded successfully!")

# Test generation
def test_model(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device = "cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Try some prompts
test_prompts = [
    "Explain quantum computing in simple terms:",
    "Write a Python function to sort a list:",
    "What are the health benefits of exercise?",
]

for prompt in test_prompts:
    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    response = test_model(prompt)
    print(response)


‚úÖ Model loaded successfully!

Prompt: Explain quantum computing in simple terms:


TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'dtype'

In [None]:
# Load original Qwen model for comparison
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2. 5-0.5B-Instruct",
    torch_dtype=torch.float16,
)
base_model.to("cuda")

def compare_models(prompt):
    # Your fine-tuned model
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    with torch. no_grad():
        # Fine-tuned
        ft_outputs = model.generate(**inputs, max_new_tokens=100)
        ft_response = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
        
        # Base
        base_outputs = base_model.generate(**inputs, max_new_tokens=100)
        base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    print(f"\nüî∑ BASE MODEL:")
    print(base_response)
    print(f"\nüî∂ FINE-TUNED (with memory layers):")
    print(ft_response)
    print(f"{'='*80}\n")

# Test
compare_models("Explain machine learning:")