# Этап 12: Wanda (Dual-Track) — nanoGPT vs Llama

Метод **Wanda** (arXiv:2401.18079) использует и веса, и активации для оценки важности. 
Здесь мы применим его параллельно к обеим моделям.

### Формула важности:
$$Score_{i,j} = |W_{i,j}| \cdot \|X_j\|_2$$

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.model import GPTLanguageModel, device, get_batch

# 1. Загружаем модели
nanogpt = GPTLanguageModel().to(device)
nanogpt.load_state_dict(torch.load('model_ckpt.pt', map_location=device))

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llama = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def get_activation_norms(model, layer, input_data, is_llama=False):
    norms = []
    def hook(module, input, output):
        x = input[0].detach().float()
        # Для Llama input может иметь другую форму или тип
        norm = torch.norm(x, p=2, dim=(0, 1))
        norms.append(norm)
    
    handle = layer.register_forward_hook(hook)
    with torch.no_grad():
        if is_llama:
            model(input_data)
        else:
            model(input_data)
    handle.remove()
    return norms[0]

def apply_wanda(layer, norms, sparsity=0.5):
    with torch.no_grad():
        W = layer.weight.data.float()
        score = W.abs() * norms.view(1, -1)
        k = int(W.size(1) * sparsity)
        thresholds, _ = torch.kthvalue(score, k, dim=1)
        mask = score > thresholds.view(-1, 1)
        layer.weight.data *= mask.to(layer.weight.dtype)
        return mask

# 2. Применяем Wanda к nanoGPT
xb, _ = get_batch('val')
nano_layer = nanogpt.blocks[0].ffwd.net[0]
nano_norms = get_activation_norms(nanogpt, nano_layer, xb)
apply_wanda(nano_layer, nano_norms)
print("nanoGPT: Wanda applied to blocks[0].ffwd")

# 3. Применяем Wanda к Llama
prompt = "The concept of model compression is"
inputs = tokenizer(prompt, return_tensors="pt").to(device).input_ids
llama_layer = llama.model.layers[0].mlp.gate_proj
llama_norms = get_activation_norms(llama, llama_layer, inputs, is_llama=True)
apply_wanda(llama_layer, llama_norms)
print("Llama: Wanda applied to layers[0].mlp.gate_proj")