In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

def load_calibration_data(tokenizer, max_length=512, batch_size=32):
    # Load a small subset of C4 dataset
    dataset = load_dataset("c4", "en", split="train[:1%]")
    texts = dataset['text']
    dataset = TextDataset(texts, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def load_train_data(tokenizer, max_length=512, batch_size=32):
    # Load the full C4 dataset
    dataset = load_dataset("c4", "en", split="train")
    texts = dataset['text']
    dataset = TextDataset(texts, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def load_val_data(tokenizer, max_length=512, batch_size=32):
    # Load the validation split of C4 dataset
    dataset = load_dataset("c4", "en", split="validation")
    texts = dataset['text']
    dataset = TextDataset(texts, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define a simple transformer model (you'd use a full LLM implementation in practice)
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead), num_layers
        )
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        return self.fc(x)

# Importance estimation function for width pruning
def estimate_importance(model, calibration_data):
    importances = {
        'heads': [],
        'neurons': [],
        'emb_channels': []
    }
    
    model.eval()
    with torch.no_grad():
        for batch in calibration_data:
            # Compute importances based on activations
            # This is a simplified version - you'd need to implement the full logic for each layer
            for layer in model.transformer.layers:
                # Compute head importance
                q, k, v = layer.self_attn.in_proj_weight.chunk(3, dim=0)
                qkv = torch.cat([q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)], dim=0)
                head_importance = torch.norm(qkv, p=2, dim=-1).sum(dim=-1)
                importances['heads'].append(head_importance)
                
                # Compute neuron importance
                neuron_importance = torch.norm(layer.linear1.weight, p=2, dim=1)
                importances['neurons'].append(neuron_importance)
            
            # Compute embedding channel importance
            emb_importance = torch.norm(model.embedding.weight, p=2, dim=0)
            importances['emb_channels'].append(emb_importance)

    # Aggregate importances
    for key in importances:
        importances[key] = torch.stack(importances[key]).sum(dim=0)

    return importances

# Pruning function
def prune_model(model, importances, prune_ratio):
    for name, param in model.named_parameters():
        if 'weight' in name:
            if 'self_attn.in_proj_weight' in name:
                # Prune attention heads
                num_heads = model.transformer.layers[0].self_attn.num_heads
                head_importance = importances['heads']
                k = int(num_heads * (1 - prune_ratio))
                top_heads = torch.topk(head_importance, k).indices
                mask = torch.zeros_like(param)
                for head in top_heads:
                    start = head * param.size(0) // num_heads
                    end = (head + 1) * param.size(0) // num_heads
                    mask[start:end] = 1
                param.data *= mask
            elif 'linear1.weight' in name:
                # Prune neurons
                neuron_importance = importances['neurons']
                k = int(param.size(0) * (1 - prune_ratio))
                top_neurons = torch.topk(neuron_importance, k).indices
                mask = torch.zeros_like(param)
                mask[top_neurons] = 1
                param.data *= mask
            elif 'embedding.weight' in name:
                # Prune embedding channels
                emb_importance = importances['emb_channels']
                k = int(param.size(1) * (1 - prune_ratio))
                top_channels = torch.topk(emb_importance, k).indices
                mask = torch.zeros_like(param)
                mask[:, top_channels] = 1
                param.data *= mask

# Knowledge distillation loss
def distillation_loss(student_logits, teacher_logits, temperature=1.0):
    return nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1)
    ) * (temperature ** 2)

# Main training loop
def train(teacher_model, student_model, train_data, val_data, epochs):
    optimizer = optim.Adam(student_model.parameters())
    teacher_model.eval()
    for epoch in range(epochs):
        student_model.train()
        for batch in train_data:
            optimizer.zero_grad()
            student_output = student_model(batch)
            with torch.no_grad():
                teacher_output = teacher_model(batch)
            # Compute distillation loss
            loss = distillation_loss(student_output, teacher_output)
            loss.backward()
            optimizer.step()
        
        # Evaluate on validation data
        validate(student_model, val_data)

# Validation function
def validate(model, val_data):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in val_data:
            output = model(batch)
            # Compute validation loss (e.g., cross-entropy)
            loss = F.cross_entropy(output, batch['targets'])
            total_loss += loss.item()
    print(f"Validation Loss: {total_loss / len(val_data)}")

# Main process
def main():
    # Initialize teacher model (pretrained LLM)
    teacher_model = SimpleTransformer(vocab_size=30000, d_model=512, nhead=8, num_layers=6)
    
    # Create student model (copy of teacher)
    student_model = SimpleTransformer(vocab_size=30000, d_model=512, nhead=8, num_layers=6)
    student_model.load_state_dict(teacher_model.state_dict())
    
    # Load calibration data
    calibration_data = load_calibration_data()
    
    # Estimate importance
    importances = estimate_importance(teacher_model, calibration_data)
    
    # Prune student model
    prune_model(student_model, importances, prune_ratio=0.5)
    
    # Load training and validation data
    train_data = load_train_data()
    val_data = load_val_data()
    
    # Train student model
    train(teacher_model, student_model, train_data, val_data, epochs=10)

if __name__ == "__main__":
    main()



TypeError: load_calibration_data() missing 1 required positional argument: 'tokenizer'