In [None]:
# Simplified imports for single GPU
from logging import getLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
from memory_layers import HashingMemory, MemoryLayerMonitorAndCheckpoint, load_and_process_dataset, load_wikitext_dataset, ModelEvaluator

logger = getLogger()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Qwen0.5 Instruct
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", dtype=torch.float16)

# Qwen0.5 specs: 896 hidden_dim, 24 layers
hidden_dim = 896
layers_to_replace = [6, 12, 18]  # Which FFN layers to replace

In [None]:
# Load and process dataset
# Using WikiText for distillation as requested
full_dataset = load_wikitext_dataset(tokenizer, sample_size=20000)

# Split into train and eval for IterableDataset
# Since load_wikitext_dataset returns an IterableDataset (streaming=True)
eval_dataset = full_dataset.take(1000)
train_dataset = full_dataset.skip(1000)

print("Datasets loaded (Iterable).")

Filtered dataset size: 24406
Tokenized dataset: Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 20000
})
Training set size: 18000
Evaluation set size: 2000


In [4]:
# from transformers import Trainer, DataCollatorForLanguageModeling

# print("ðŸ“Š Evaluating Base Instruct Model Baseline...")
# # Create a temporary trainer just for evaluation
# base_trainer = Trainer(
#     model=model,
#     eval_dataset=eval_dataset,
#     data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
#     # report_to="tensorboard",
# )
# base_metrics = base_trainer.evaluate()
# print(f"Base Instruct Model Eval Loss: {base_metrics['eval_loss']:.4f}")
# print(f"Base Instruct Model Perplexity: {torch.exp(torch.tensor(base_metrics['eval_loss'])):.4f}")

In [None]:
# Configuration
LAYERS_TO_REPLACE = [6, 12, 18]
HIDDEN_DIM = 896
BATCH_SIZE = 4
STEPS = 500 # Reduced for demo/speed, increase for real training
LEARNING_RATE = 1e-3

# Helper to capture input/output of the specific FFN layer
class IOCollector:
    def __init__(self):
        self.input = None
        self.target = None

    def hook(self, module, input, output):
        self.input = input[0].detach() # The hidden state entering FFN
        self.target = output.detach()  # The output of the FFN

# Distillation Loop Function
def distill_layer(layer_idx, model, tokenizer, dataset, device):
    print(f"\n{'='*40}")
    print(f"Starting Distillation for Layer {layer_idx}")
    print(f"{'='*40}")

    # 1. Initialize Student Memory Layer
    student_memory = HashingMemory(
        input_dim=HIDDEN_DIM,
        output_dim=HIDDEN_DIM,
        mem_n_keys=128,
        mem_heads=4,
        mem_knn=16,
        mem_k_dim=256,
        mem_v_dim=-1,
        swilu_projection=True,
        value_fixed_lr=0.001,
        mem_share_values=False,
    ).to(device)
    # Important: Cast to model's dtype
    student_memory = student_memory.to(dtype=model.dtype)
    student_memory.train()

    # 2. Setup Optimizer
    optimizer = torch.optim.AdamW(student_memory.parameters(), lr=LEARNING_RATE)

    # 3. Attach hook to the target layer's MLP
    target_layer = model.model.layers[layer_idx].mlp
    collector = IOCollector()
    hook_handle = target_layer.register_forward_hook(collector.hook)

    # 4. Data Loader
    # Create a fresh dataloader for each layer to ensure we iterate through data
    # Since dataset is Iterable, we just iterate it directly
    
    step = 0
    total_loss = 0
    
    # Iterate through dataset
    for batch in dataset:
        if step >= STEPS: break
        
        # A. Run Teacher (Forward pass up to the layer)
        # We only need to run the model; the hook captures the data automatically.
        inputs = torch.tensor(batch['input_ids']).unsqueeze(0).to(device)
        
        with torch.no_grad():
            # We can stop early to save compute if we knew how, but for now run full forward
            # Ideally we'd use a custom forward that stops at layer_idx
            model(inputs)
        
        # Retrieve captured data
        x_input = collector.input  # [Batch, Seq, Dim]
        y_target = collector.target # [Batch, Seq, Dim]
        
        # B. Run Student (Memory Layer)
        # We pass the EXACT same input the FFN saw
        # HashingMemory returns just the output tensor in forward()
        y_pred = student_memory(x_input)
        
        # C. Calculate Loss
        # Main Task: Reconstruction (MSE)
        mse_loss = F.mse_loss(y_pred, y_target)
        
        # Aux Task: Load Balancing (if available in your implementation)
        # Assuming standard HashingMemory implementation might not return aux loss in forward
        # If it does, adjust here. For now, just MSE.
        
        loss = mse_loss
        
        # D. Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if step % 50 == 0:
            print(f"Step {step}: MSE={loss.item():.6f}")
        
        step += 1

    # Cleanup
    hook_handle.remove()
    avg_loss = total_loss / step
    print(f"Layer {layer_idx} Distillation Complete. Avg MSE: {avg_loss:.6f}")
    
    return student_memory

# Run distillation for each layer
distilled_layers = {}

# We need a fresh iterator for each layer or reset it
# Since it's an iterable dataset, we can't easily reset. 
# We'll re-instantiate the dataset iterator or just continue if it's large enough.
# Given sample_size=20000 and STEPS=500, we have plenty of data.
# We can just continue iterating the same dataset object if we are careful.
# Or better, reload it to be safe and consistent.

for idx in LAYERS_TO_REPLACE:
    # Reload dataset to ensure fresh stream for each layer
    # (Optional but good for consistency)
    current_dataset = load_wikitext_dataset(tokenizer, sample_size=5000) 
    
    distilled_memory = distill_layer(idx, model, tokenizer, current_dataset, device)
    distilled_layers[idx] = distilled_memory
    
    # Save individual layer
    torch.save(distilled_memory.state_dict(), f"distilled_memory_layer_{idx}.pt")

print("\nAll layers distilled and saved.")

Replaced layer 6 FFN with memory layer
Replaced layer 12 FFN with memory layer
Replaced layer 18 FFN with memory layer

Trainable: 506,820,736 / 506,820,736 (100.00%)


In [None]:
# Integration: Swap in the distilled layers
print("\nIntegrating distilled layers into the model...")

for layer_idx, memory_layer in distilled_layers.items():
    # Ensure it's on the right device/dtype
    memory_layer = memory_layer.to(device, dtype=model.dtype)
    
    # Replace
    model.model.layers[layer_idx].mlp = memory_layer
    print(f"Replaced layer {layer_idx} with distilled memory layer")

# Save the final retrofitted model
output_path = "./qwen_memory_retrofitted"
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print(f"Saved retrofitted model to {output_path}")

In [11]:
del model
torch.cuda.empty_cache()

In [None]:
from safetensors.torch import load_file

# Reload model for verification
print("\nVerifying retrofitted model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_dim = 896
layers_to_replace = [6, 12, 18]

# Load base model again
model_verify = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    dtype=torch.float16,
).to(device)

# Add memory layers structure
for idx in layers_to_replace:
    mem_layer = HashingMemory(
        input_dim=hidden_dim, output_dim=hidden_dim, mem_n_keys=128, mem_heads=4,
        mem_knn=16, mem_k_dim=256, mem_v_dim=-1, swilu_projection=True,
        value_fixed_lr=0.001, mem_share_values=False
    )
    model_verify.model.layers[idx].mlp = mem_layer.to(device, dtype=model_verify.dtype)

# Load the saved weights
try:
    state_dict = load_file("./qwen_memory_retrofitted/model.safetensors")
    print("Loaded from safetensors")
except:
    state_dict = torch.load("./qwen_memory_retrofitted/pytorch_model.bin", weights_only=False)
    print("Loaded from pytorch_model.bin")

model_verify.load_state_dict(state_dict, strict=False)
print("âœ… Retrofitted model loaded successfully!")

# Test generation
def test_model(model, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=0.7,
            do_sample=True,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print("\nTest Generation:")
print(test_model(model_verify, "Explain quantum computing in one sentence:"))

Loaded from safetensors

âœ… Model loaded successfully!


In [13]:
print("ðŸ“Š Evaluating Base Model With Init Memory Values...")
# Create a temporary trainer just for evaluation
base_trainer = Trainer(
    model=model,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
base_metrics = base_trainer.evaluate()
print(f"Base Model With Init Memory Values Eval Loss: {base_metrics['eval_loss']:.4f}")
print(f"Base Model With Init Memory Values Perplexity: {torch.exp(torch.tensor(base_metrics['eval_loss'])):.4f}")

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


ðŸ“Š Evaluating Base Model With Init Memory Values...


Base Model With Init Memory Values Eval Loss: 2.0576
Base Model With Init Memory Values Perplexity: 7.8269
