In [None]:
# Simplified imports for single GPU
from logging import getLogger
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from memory_layers import HashingMemory, MemoryLayerMonitorAndCheckpoint, load_and_process_dataset

logger = getLogger()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Qwen0.5 Instruct
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16)

# Qwen0.5 specs: 896 hidden_dim, 24 layers
hidden_dim = 896
layers_to_replace = [6, 12, 18]  # Which FFN layers to replace

# Replace FFNs with Memory Layers
for layer_idx in layers_to_replace:
    layer = model.model.layers[layer_idx]
    
    # Create memory layer
    memory_layer = HashingMemory(
        input_dim=hidden_dim,
        output_dim=hidden_dim,
        mem_n_keys=128,          # Memory size = 512Â² = 262k entries
        mem_heads=4,
        mem_knn=16,
        mem_k_dim=256,
        mem_v_dim=-1,            # Auto: uses output_dim
        swilu_projection=True,
        value_fixed_lr=0.001,
        mem_share_values=False,  # Don't share across layers for fine-tuning
    )
    
    # Initialize the memory layer
    memory_layer.reset_parameters()
    memory_layer.to(device)
    
    # Replace the FFN (MLP) with memory layer
    original_mlp = layer.mlp
    layer.mlp = memory_layer
    
    print(f"Replaced layer {layer_idx} FFN with memory layer")

# FREEZE EVERYTHING EXCEPT MEMORY LAYERS
for name, param in model.named_parameters():
    if 'mlp' in name and any(f'layers.{idx}.' in name for idx in layers_to_replace):
        # This is a memory layer parameter - keep trainable
        param.requires_grad = True
        print(f"âœ“ Trainable: {name}")
    else:
        # Freeze all other parameters
        param.requires_grad = False

# Verify what's trainable
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")

  from .autonotebook import tqdm as notebook_tqdm


Replaced layer 6 FFN with memory layer
Replaced layer 12 FFN with memory layer
Replaced layer 18 FFN with memory layer
âœ“ Trainable: model.layers.6.mlp.keys
âœ“ Trainable: model.layers.6.mlp.values.weight
âœ“ Trainable: model.layers.6.mlp.value_proj.weight
âœ“ Trainable: model.layers.6.mlp.value_proj.bias
âœ“ Trainable: model.layers.6.mlp.swilu_projection.weight
âœ“ Trainable: model.layers.6.mlp.swilu_projection.bias
âœ“ Trainable: model.layers.6.mlp.query_proj.query_mlps.0.weight
âœ“ Trainable: model.layers.6.mlp.query_proj.query_mlps.0.bias
âœ“ Trainable: model.layers.12.mlp.keys
âœ“ Trainable: model.layers.12.mlp.values.weight
âœ“ Trainable: model.layers.12.mlp.value_proj.weight
âœ“ Trainable: model.layers.12.mlp.value_proj.bias
âœ“ Trainable: model.layers.12.mlp.swilu_projection.weight
âœ“ Trainable: model.layers.12.mlp.swilu_projection.bias
âœ“ Trainable: model.layers.12.mlp.query_proj.query_mlps.0.weight
âœ“ Trainable: model.layers.12.mlp.query_proj.query_mlps.0.bias
âœ“ Trainab

In [None]:
# Load and process dataset
tokenized = load_and_process_dataset(tokenizer, sample_size=20000)

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

# Training arguments optimized for memory layers only
training_args = TrainingArguments(
    output_dir="./qwen_memory_finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=5e-4,  # Higher LR since only training memory
    warmup_steps=100,
    lr_scheduler_type="cosine",
    logging_steps=10,
    logging_first_step=True,  # Log immediately
    logging_dir="./logs",
    save_steps=500,
    eval_strategy="steps",
    eval_steps=250,   
    # Performance
    fp16=True,
    gradient_checkpointing=False,  # Not needed with frozen base
    dataloader_num_workers=2,
    
    # Monitoring
    report_to="tensorboard",  # or "wandb" if you have it
    # load_best_model_at_end=True,
    metric_for_best_model="loss",
    save_strategy="no",
    
    # Memory optimization
    optim="adamw_torch_fused",  # Faster optimizer
    max_grad_norm=1.0,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Initialize callback
memory_monitor = MemoryLayerMonitorAndCheckpoint(
    model=model,
    layers_to_check=layers_to_replace,
    save_every=500,
    keep_last=2,
    monitor_every=50,
)

# Create trainer with callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    eval_dataset=tokenized.select(range(1000)),  # Use 1k for validation
    data_collator=data_collator,
    callbacks=[memory_monitor],  # Add our custom monitor
)

Filtered dataset size: 7669
Tokenized dataset: Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 7669
})


In [None]:
print("\nðŸš€ Starting training...")
print(f"Total steps: {len(tokenized) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

# Train! 
trainer.train()

print("\nâœ… Training complete!")

In [None]:
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_dim = 896
layers_to_replace = [6, 12, 18]

# Reload model for testing
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    dtype=torch.float16,
).to(device)

# Add memory layers
for idx in layers_to_replace:
    model.model.layers[idx].mlp = HashingMemory(
        input_dim=hidden_dim, output_dim=hidden_dim, mem_n_keys=128, mem_heads=4,
        mem_knn=16, mem_k_dim=256, mem_v_dim=-1, swilu_projection=True,
        value_fixed_lr=0.001, mem_share_values=False
    ).to(device)

# Load weights
try:
    state_dict = load_file("./qwen_memory_final/model.safetensors")
except:
    state_dict = torch.load("./qwen_memory_final/pytorch_model.bin", 
                           weights_only=False)

model.load_state_dict(state_dict, strict=False)
print("\nâœ… Model loaded successfully!")

# Test generation
def test_model(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Try some prompts
test_prompts = [
    "Explain quantum computing in simple terms:",
    "Write a Python function to sort a list:",
    "What are the health benefits of exercise?",
]

for prompt in test_prompts:
    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    response = test_model(prompt)
    print(response)


âœ… Model loaded successfully!

Prompt: Explain quantum computing in simple terms:


TypeError: PreTrainedTokenizerFast._batch_encode_plus() got an unexpected keyword argument 'dtype'

In [None]:
# Load original Qwen model for comparison
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    torch_dtype=torch.float16,
)
base_model.to(device)

def compare_models(prompt):
    # Your fine-tuned model
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        # Fine-tuned
        ft_outputs = model.generate(**inputs, max_new_tokens=100)
        ft_response = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
        
        # Base
        base_outputs = base_model.generate(**inputs, max_new_tokens=100)
        base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    print(f"\n{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"{'='*80}")
    print(f"\nðŸ”· BASE MODEL:")
    print(base_response)
    print(f"\nðŸ”¶ FINE-TUNED (with memory layers):")
    print(ft_response)
    print(f"{'='*80}\n")

# Test
compare_models("Explain machine learning:")