In [None]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

tiny_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
tiny_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
# Load the dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
reference_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
reference_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
# Load the reference and tiny LLaMA models


In [None]:


# Calculate reference loss for each token in the dataset
def calculate_reference_loss(dataset):
    reference_loss_dict = {}
    for example in dataset:
        text = example['text']
        inputs = reference_tokenizer(text, return_tensors='pt')
        with torch.no_grad():
            outputs = reference_model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss.item()
            for token_id in inputs['input_ids'][0]:
                reference_loss_dict[token_id.item()] = -loss  # Store negative log likelihood
    return reference_loss_dict

reference_loss_dict = calculate_reference_loss(dataset)

In [None]:
# Training function for tiny LLaMA
def train_tiny_llama(tiny_model, dataset, reference_loss_dict):
    optimizer = torch.optim.Adam(tiny_model.parameters(), lr=1e-4)
    num_epochs = 3

    for epoch in range(num_epochs):
        for example in dataset:
            text = example['text']
            inputs = tiny_tokenizer(text, return_tensors='pt')
            outputs = tiny_model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss

            # Calculate excess loss
            excess_losses = []
            for token_id, token_loss in zip(inputs['input_ids'][0], loss):
                reference_loss = reference_loss_dict.get(token_id.item(), 0)
                excess_loss = token_loss.item() - reference_loss
                excess_losses.append(excess_loss)

            # Find top 30% excess losses
            threshold = np.percentile(excess_losses, 70)
            modified_losses = [loss if excess_loss >= threshold else 0 for loss, excess_loss in zip(loss, excess_losses)]

            # Backpropagation with modified losses
            modified_loss = torch.tensor(modified_losses, requires_grad=True).mean()
            optimizer.zero_grad()
            modified_loss.backward()
            optimizer.step()

train_tiny_llama(tiny_model, dataset, reference_loss_dict)