In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from memory_layers import HashingMemory, ModelEvaluator
from safetensors.torch import load_file
import pandas as pd
import matplotlib.pyplot as plt

# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
hidden_dim = 896
layers_to_replace = [6, 12, 18]
num_samples = 50  # Small sample for quick testing, increase for full eval

In [None]:
# 1. Evaluate Base Model
print("Loading Base Model...")
base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, dtype=torch.float16)

evaluator = ModelEvaluator(base_model, tokenizer, device=device)
base_results = evaluator.evaluate_all(num_samples=num_samples)

print("\nBase Model Results:")
print(base_results)

# Free up memory
del base_model
del evaluator
torch.cuda.empty_cache()

In [None]:
# 2. Evaluate Fine-tuned Memory Model
print("Loading Fine-tuned Memory Model...")
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16).to(device)

# Add memory layers
for idx in layers_to_replace:
    mem_layer = HashingMemory(
        input_dim=hidden_dim, output_dim=hidden_dim, mem_n_keys=128, mem_heads=4,
        mem_knn=16, mem_k_dim=256, mem_v_dim=-1, swilu_projection=True,
        value_fixed_lr=0.001, mem_share_values=False
    )
    model.model.layers[idx].mlp = mem_layer.to(device, dtype=model.dtype)

# Load weights
try:
    state_dict = load_file("./qwen_memory_final/model.safetensors")
except:
    state_dict = torch.load("./qwen_memory_final/pytorch_model.bin", weights_only=False)

model.load_state_dict(state_dict, strict=False)

evaluator = ModelEvaluator(model, tokenizer, device=device)
ft_results = evaluator.evaluate_all(num_samples=num_samples)

print("\nFine-tuned Model Results:")
print(ft_results)

In [None]:
# 3. Compare Results
df = pd.DataFrame([base_results, ft_results], index=['Base', 'Memory-Augmented'])
print("\nComparison Table:")
print(df)

# Plot
df.T.plot(kind='bar', figsize=(10, 6))
plt.title("Model Performance Comparison")
plt.ylabel("Accuracy")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()