# Weighted Mutual Learning (WML), an alternative to distillation ? 

### 1. Pruning function : 
This function does the following:

1. Calculates filter importance using the SNIP method.
2. Computes layer importance based on average filter importance.
3. Normalizes layer importance.
4. Determines the number of parameters to prune per layer.
5. Prunes filters in each layer based on their importance.

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm


In [2]:
def extended_snip_pruning_gpt2(model, X, Y, pruning_ratio, n_head, layer_bound=0.9):
    # Enable gradients for all parameters
    for param in model.parameters():
        param.requires_grad = True
    
    # Count non-zero parameters before pruning
    total_params_before = sum(p.numel() for p in model.parameters())
    non_zero_params_before = sum((p != 0).sum().item() for p in model.parameters())
    print(f"Total parameters before pruning: {total_params_before:,}")
    print(f"Non-zero parameters before pruning: {non_zero_params_before:,}")

    # Forward pass
    loss = model(X, Y)

    # Backward pass
    loss[1].backward()

    # Calculate importance for attention heads and feed-forward layers
    attention_importance = {}
    ffn_importance = {}
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if "attn.c_attn" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=1)
                attention_importance[name] = importance.view(3, n_head, -1).sum(dim=(0, 2))
            elif "attn.c_proj" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=0)
                attention_importance[name] = importance.view(n_head, -1).sum(dim=1)
            elif "mlp.c_fc" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=1)
                ffn_importance[name] = importance
            elif "mlp.c_proj" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=0)
                ffn_importance[name] = importance

    # Calculate layer importance
    layer_importance = {}
    for name, importance in attention_importance.items():
        layer_name = name.rsplit(".", 2)[0]
        layer_importance[layer_name] = torch.mean(importance)
    
    for name, importance in ffn_importance.items():
        layer_name = name.rsplit(".", 2)[0]
        layer_importance[layer_name] = (layer_importance.get(layer_name, 0) + torch.mean(importance)) / 2

    # Normalize layer importance
    total_importance = sum(layer_importance.values())
    normalized_importance = {name: imp / total_importance for name, imp in layer_importance.items()}

    # Calculate number of elements to prune per layer
    total_elements = sum(p.numel() for p in model.parameters() if p.requires_grad)
    elements_to_prune = int(pruning_ratio * total_elements)
    layer_prune_elements = {name: int(imp * elements_to_prune) for name, imp in normalized_importance.items()}

    # Prune attention heads and feed-forward neurons
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            layer_name = name.rsplit(".", 2)[0]
            if "attn.c_attn" in name:
                head_importance = attention_importance[name]
                heads_to_keep = max(1, int((1 - layer_bound) * n_head))
                _, indices = torch.topk(head_importance, k=heads_to_keep, largest=True)
                mask = torch.zeros(n_head, device=module.weight.device)
                mask[indices] = 1
                mask = mask.repeat_interleave(module.weight.size(0) // n_head).unsqueeze(1).expand_as(module.weight)
                module.weight.data *= mask
                if module.bias is not None:
                    module.bias.data *= mask.squeeze()
            elif "attn.c_proj" in name:
                head_importance = attention_importance[name]
                heads_to_keep = max(1, int((1 - layer_bound) * n_head))
                _, indices = torch.topk(head_importance, k=heads_to_keep, largest=True)
                mask = torch.zeros(n_head, device=module.weight.device)
                mask[indices] = 1
                mask = mask.repeat_interleave(module.weight.size(1) // n_head).unsqueeze(0).expand_as(module.weight)
                module.weight.data *= mask
            elif "mlp.c_fc" in name:
                neuron_importance = ffn_importance[name]
                neurons_to_keep = max(1, int((1 - layer_bound) * module.out_features))
                _, indices = torch.topk(neuron_importance, k=neurons_to_keep, largest=True)
                mask = torch.zeros(module.out_features, device=module.weight.device)
                mask[indices] = 1
                module.weight.data *= mask.unsqueeze(1).expand_as(module.weight)
                if module.bias is not None:
                    module.bias.data *= mask

            elif "mlp.c_proj" in name:
                neuron_importance = ffn_importance[name]
                neurons_to_keep = max(1, int((1 - layer_bound) * module.in_features))
                _, indices = torch.topk(neuron_importance, k=neurons_to_keep, largest=True)
                mask = torch.zeros(module.in_features, device=module.weight.device)
                mask[indices] = 1
                module.weight.data *= mask.unsqueeze(0).expand_as(module.weight)

    # Reset gradients
    model.zero_grad()

    # Count non-zero parameters after pruning
    total_params_after = sum(p.numel() for p in model.parameters())
    non_zero_params_after = sum((p != 0).sum().item() for p in model.parameters())
    print(f"Total parameters after pruning: {total_params_after:,}")
    print(f"Non-zero parameters after pruning: {non_zero_params_after:,}")
    print(f"Effectively pruned {non_zero_params_before - non_zero_params_after:,} parameters")
    print(f"Effective pruning ratio: {(non_zero_params_before - non_zero_params_after) / non_zero_params_before:.2%}")


    return model

In [3]:
import torch
import torch.nn.functional as F

def wml_loss(outputs, labels, peer_outputs, weights, alpha):
    """
    Compute the Weighted Mutual Learning loss for GPT-2.
    
    :param outputs: Logits from the current peer model
    :param labels: True labels (input_ids for GPT-2)
    :param peer_outputs: List of logits from other peer models
    :param weights: Weights for each peer model
    :param alpha: Balancing factor between CE loss and KL divergence
    """
    # Cross-entropy loss
    ce_loss = F.cross_entropy(outputs.squeeze(1), labels)
    
    # KL divergence loss
    kl_loss = 0
    for i, peer_output in enumerate(peer_outputs):
        kl_loss += weights[i] * F.kl_div(
            F.log_softmax(outputs, dim=-1),
            F.softmax(peer_output, dim=-1),
            reduction='batchmean'
        )
    
    # Combine losses
    loss = (1 - alpha) * ce_loss + alpha * kl_loss
    return loss

In [4]:
def update_peer_weights(model, peer_models, val_loader, current_weights, learning_rate, device):
    model.eval()
    for peer in peer_models:
        peer.eval()
    
    gradients = torch.zeros_like(current_weights)
    
    for batch in tqdm(val_loader):
        inputs = batch[0][0].squeeze(1).to(device) #need to select logits have highest dimension (when MRL is enabled)
        labels = batch[1][0].squeeze(1).to(device)
        
        with torch.cuda.amp.autocast():
            ensemble_output = sum(w * peer(inputs)[0][0].detach() for w, peer in zip(current_weights, peer_models))
            loss = torch.nn.functional.cross_entropy(ensemble_output.view(-1, ensemble_output.size(-1)), labels.squeeze().reshape(labels.shape[1], -1))
            
            for i, peer in enumerate(peer_models):
                peer_output = peer(inputs)[0][0]
                peer_loss = torch.nn.functional.cross_entropy(peer_output.view(-1, peer_output.size(-1)), labels.squeeze().reshape(labels.shape[1], -1))
                grad = torch.autograd.grad(peer_loss, peer.parameters(), allow_unused=True)
                gradients[i] += sum(g.norm() if g is not None else 0 for g in grad)
    
    # Mirror descent update
    new_weights = current_weights * torch.exp(-learning_rate * gradients)
    new_weights /= new_weights.sum()  # Normalize
    
    return new_weights

In [5]:
import os
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
import math

class BinaryFileDataset(IterableDataset):
    def __init__(self, file_path, block_size, model, device):
        self.file_path = file_path
        self.block_size = block_size
        self.device = device
        self.data = None
        self.pretrained_model = model

    def load_adjusted_memmap(self):
        with open(self.file_path, 'rb') as f:
            data_bytes = f.read()
        trimmed_length = (len(data_bytes) // 2) * 2
        trimmed_data = data_bytes[:trimmed_length]
        data = np.frombuffer(trimmed_data, dtype=np.uint16)
        return data

    def __iter__(self):
        self.data = self.load_adjusted_memmap()
        while True:
            ix = torch.randint(len(self.data) - self.block_size, (1,))
            x = torch.from_numpy(self.data[ix:ix+self.block_size].astype(np.int64))
            with torch.no_grad():
                self.pretrained_model.eval()
                y, _ = self.pretrained_model(x.unsqueeze(0).to(self.device))
            yield x, y
            
    def estimate_length(self):
        return len(self.load_adjusted_memmap()) - self.block_size + 1

def get_dataloader(split, model, device, batch_size=256, block_size=64):
    if split == 'train':
        file_path = os.path.join('/Users/krishnaiyer/generative-ai-research-babylm/data/processed/train_10M/processed_encoded_train.bin')
    else:
        file_path = os.path.join('/Users/krishnaiyer/generative-ai-research-babylm/data/processed/train_10M/processed_encoded_val.bin')
    dataset = BinaryFileDataset(file_path, block_size, model, device)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    # Estimate the number of batches
    estimated_samples = dataset.estimate_length()
    estimated_batches = math.ceil(estimated_samples / batch_size)

    return dataloader,estimated_batches

In [6]:
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
import sys
from pathlib import Path
base_path = Path('.').resolve().parent
sys.path.append(str(base_path))
import babylm as blm
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf

def create_peer_model(base_model, args, prune_ratio):
    peer = base_model
    X, Y = blm.gpt_2.utils.get_batch(split='train',args=args)
    return extended_snip_pruning_gpt2(peer,X,Y,prune_ratio,args.train.n_head)

def main(args):
    # Hyperparameters
    num_peers = 2
    prune_ratios = [0.2, 0.4]
    alpha = 0.5
    num_epochs = 10
    learning_rate = 1e-4
    weight_update_frequency = 100  # Update weights every 100 steps
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    batch_size = 256  # Reduced batch size
    accumulation_steps = 4  # Gradient accumulation steps
    max_sequence_length = 16  # Reduced sequence length

    # Load base model and tokenizer
    checkpoint_path = "MRL_mean_loss_2000_ckpt.pt"
    vocab_size = blm.gpt_2.utils.get_vocab_size(args)
    base_model = blm.eval.utils.load_checkpoint(args,checkpoint_path,vocab_size)

    # Create peer models
    peer_models = [create_peer_model(base_model, args, ratio).to(device) for ratio in prune_ratios]

    # Initialize peer weights
    peer_weights = torch.ones(num_peers, device=device) / num_peers

    # You should replace this with your actual dataset
    train_dataloader,num_batches_train =  get_dataloader('train', base_model, device)
    val_dataloader,num_batches_val = get_dataloader('val', base_model, device)

    # Optimizer and scheduler
    optimizer = torch.optim.Adam(sum([list(model.parameters()) for model in peer_models], []), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
    scaler = GradScaler()  # For mixed precision training

    # Training loop
    for epoch in range(num_epochs):
        # Training
        for model in peer_models:
            model.train()
        
        train_loss = 0
        optimizer.zero_grad()
        for step, batch in tqdm(enumerate(train_dataloader),desc=f"Running batches"):
            if args.MRL.enable:
                inputs = batch[0].to(device) 
                labels = batch[1][0].squeeze().reshape(batch_size, -1).to(device) #need to select logits have highest dimension (when MRL is enabled)
            else:
                inputs = batch[0].squeeze(1).to(device) 
                labels = batch[1].squeeze().reshape(batch_size, -1).to(device) 
            
            with autocast():
                # Forward pass for all peers
                if args.MRL.enable:
                    peer_outputs = [model(inputs)[0][0] for model in peer_models]
                else:
                    peer_outputs = [model(inputs)[0] for model in peer_models]
                
                # Compute loss for each peer
                losses = [wml_loss(outputs, labels, peer_outputs[:i] + peer_outputs[i+1:], 
                                   torch.cat([peer_weights[:i], peer_weights[i+1:]]), alpha) 
                          for i, outputs in enumerate(peer_outputs)]
                
                total_loss = sum(losses) / accumulation_steps

            # Backward pass
            for i, loss in enumerate(losses):
                if i == len(losses) - 1 and (step + 1) % accumulation_steps == 0:
                    scaler.scale(loss / accumulation_steps).backward()
                else:
                    scaler.scale(loss / accumulation_steps).backward(retain_graph=True)
            
            if (step + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += total_loss.item()

            # Update peer weights
            if (step+1) % weight_update_frequency == 0:
                peer_weights = update_peer_weights(base_model, peer_models, val_dataloader, peer_weights, learning_rate, device)

            train_loss /= num_batches_train * num_peers

            # Update peer weights
            if (step+1) % weight_update_frequency == 0:
                peer_weights = update_peer_weights(base_model, peer_models, val_dataloader, peer_weights, learning_rate, device)

        train_loss /= num_batches_train * num_peers

        # Validation
        for model in peer_models:
            model.eval()
        
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                inputs = batch[0].to(device)
                labels = inputs.clone()
                if args.MRL.enable:
                    outputs = sum(w * model(inputs)[0][0] for w, model in zip(peer_weights, peer_models))
                else:
                    outputs = sum(w * model(inputs)[0] for w, model in zip(peer_weights, peer_models))
                val_loss += torch.nn.functional.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1)).item()
        
        val_loss /= num_batches_val

        print(f"Epoch {epoch+1}/{num_epochs} completed. Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
        
        # Step the learning rate scheduler
        scheduler.step()

    # Output final models and weights
    for i, (model, weight) in enumerate(zip(peer_models, peer_weights)):
        print(f"Peer {i+1} weight: {weight.item():.4f}")
        model.save_pretrained(f"/Users/krishnaiyer/generative-ai-research-babylm/models/WML/peer_model_{i+1}")

if __name__ == "__main__":
    # Initialize Hydra
    initialize(version_base=None, config_path="../conf")

    # Compose the configuration
    cfg = compose(config_name="blm-main.yaml")

    main(cfg)

  from .autonotebook import tqdm as notebook_tqdm


Total parameters before pruning: 196,121,096
Non-zero parameters before pruning: 196,121,096
Total parameters after pruning: 196,121,096
Non-zero parameters after pruning: 172,055,048
Effectively pruned 24,066,048 parameters
Effective pruning ratio: 12.27%
Total parameters before pruning: 196,121,096
Non-zero parameters before pruning: 172,055,048


  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


Total parameters after pruning: 196,121,096
Non-zero parameters after pruning: 169,502,984
Effectively pruned 2,552,064 parameters
Effective pruning ratio: 1.48%


Running batches: 3it [00:21,  7.02s/it]


KeyboardInterrupt: 