In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
import numpy as np

# Dummy Dataset Class (Replace with actual Wikimedia dataset)
class WikimediaDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.inputs = []
        for text in texts:
            tokenized = tokenizer(text, truncation=True, max_length=max_length, return_tensors='pt')
            self.inputs.append(tokenized)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx]

# Function to compute per-token loss
def compute_per_token_loss(logits, labels, ignore_index=-100):
    """
    Compute the negative log-likelihood loss per token.
    """
    loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
    # Shift logits and labels for causal LM
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    loss = loss.view(shift_labels.size())
    return loss

def main():
    # Configuration

    # Load the dataset
    # dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    batch_size = 4
    epochs = 3
    learning_rate = 5e-5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

    # Initialize Models
    reference_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B").to(device)
    reference_model.eval()  # Reference model is not trained further

    small_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B").to(device)
    small_model.train()

    # Freeze reference model parameters
    for param in reference_model.parameters():
        param.requires_grad = False

    # Prepare Dataset and DataLoader (Replace with actual data)
    dummy_texts = [
        "This is a sample sentence for training.",
        "Another example of Wikimedia text data.",
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is transforming the world."
    ]
    dataset = WikimediaDataset(dummy_texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer
    optimizer = AdamW(small_model.parameters(), lr=learning_rate)

    # Training Loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for batch_idx, batch in enumerate(dataloader):
            # Move inputs to device
            input_ids = batch['input_ids'].squeeze(1).to(device)
            attention_mask = batch['attention_mask'].squeeze(1).to(device)
            labels = input_ids.clone()

            # Forward pass with small model
            outputs_small = small_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss_small = outputs_small.loss  # This is the averaged loss
            logits_small = outputs_small.logits

            # Compute per-token training loss for small model
            per_token_train_loss = compute_per_token_loss(logits_small, labels)  # Shape: [batch_size, seq_len -1]

            # Forward pass with reference model (inference mode)
            with torch.no_grad():
                outputs_ref = reference_model(input_ids=input_ids, attention_mask=attention_mask)
                logits_ref = outputs_ref.logits
                per_token_ref_loss = compute_per_token_loss(logits_ref, labels)  # Shape: [batch_size, seq_len -1]

            # Compute excess loss
            excess_loss = per_token_train_loss - per_token_ref_loss  # Shape: [batch_size, seq_len -1]

            # Determine threshold for top 30% excess loss
            batch_size_current, seq_len = excess_loss.size()
            top_k = int(0.3 * seq_len)
            # If top_k is 0 due to small seq_len, default to at least 1
            top_k = max(top_k, 1)

            # For each sequence in the batch, create a mask for top 30% excess loss
            masks = torch.zeros_like(excess_loss)
            for i in range(batch_size_current):
                # Get the excess loss for the current sequence
                excess_loss_seq = excess_loss[i]
                # Get the top_k indices
                if top_k >= seq_len:
                    top_indices = torch.arange(seq_len)
                else:
                    top_indices = torch.topk(excess_loss_seq, top_k).indices
                # Set mask to 1 for top_k tokens
                masks[i, top_indices] = 1

            # Apply mask to per-token training loss
            masked_train_loss = per_token_train_loss * masks  # Shape: [batch_size, seq_len -1]

            # Compute the final loss as the average of masked training loss
            # To avoid division by zero, add a small epsilon
            epsilon = 1e-8
            loss = masked_train_loss.sum() / (masks.sum() + epsilon)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(dataloader):
                print(f"Batch {batch_idx + 1}/{len(dataloader)} - Loss: {loss.item():.4f}")

    # Save the trained small model
    small_model.save_pretrained('trained_qwen2.5-0.5B')
    tokenizer.save_pretrained('trained_qwen2.5-0.5B')

if __name__ == "__main__":
    main()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Epoch 1/3


RuntimeError: stack expects each tensor to be equal size, but got [1, 8] at entry 0 and [1, 10] at entry 1