# Weighted Mutual Learning (WML), an alternative to distillation ? 

### 1. Pruning function : 
This function does the following:

1. Calculates filter importance using the SNIP method.
2. Computes layer importance based on average filter importance.
3. Normalizes layer importance.
4. Determines the number of parameters to prune per layer.
5. Prunes filters in each layer based on their importance.

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm


In [2]:
def extended_snip_pruning_gpt2(model, X, Y, pruning_ratio, n_head, layer_bound=0.9):
    # Enable gradients for all parameters
    for param in model.parameters():
        param.requires_grad = True
    
    # Count non-zero parameters before pruning
    total_params_before = sum(p.numel() for p in model.parameters())
    non_zero_params_before = sum((p != 0).sum().item() for p in model.parameters())
    print(f"Total parameters before pruning: {total_params_before:,}")
    print(f"Non-zero parameters before pruning: {non_zero_params_before:,}")

    # Forward pass
    loss = model(X, Y)

    # Backward pass
    loss[1].backward()

    # Calculate importance for attention heads and feed-forward layers
    attention_importance = {}
    ffn_importance = {}
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if "attn.c_attn" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=1)
                attention_importance[name] = importance.view(3, n_head, -1).sum(dim=(0, 2))
            elif "attn.c_proj" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=0)
                attention_importance[name] = importance.view(n_head, -1).sum(dim=1)
            elif "mlp.c_fc" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=1)
                ffn_importance[name] = importance
            elif "mlp.c_proj" in name:
                importance = torch.sum(torch.abs(module.weight.grad * module.weight), dim=0)
                ffn_importance[name] = importance

    # Calculate layer importance
    layer_importance = {}
    for name, importance in attention_importance.items():
        layer_name = name.rsplit(".", 2)[0]
        layer_importance[layer_name] = torch.mean(importance)
    
    for name, importance in ffn_importance.items():
        layer_name = name.rsplit(".", 2)[0]
        layer_importance[layer_name] = (layer_importance.get(layer_name, 0) + torch.mean(importance)) / 2

    # Normalize layer importance
    total_importance = sum(layer_importance.values())
    normalized_importance = {name: imp / total_importance for name, imp in layer_importance.items()}

    # Calculate number of elements to prune per layer
    total_elements = sum(p.numel() for p in model.parameters() if p.requires_grad)
    elements_to_prune = int(pruning_ratio * total_elements)
    layer_prune_elements = {name: int(imp * elements_to_prune) for name, imp in normalized_importance.items()}

    # Prune attention heads and feed-forward neurons
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            layer_name = name.rsplit(".", 2)[0]
            if "attn.c_attn" in name:
                head_importance = attention_importance[name]
                heads_to_keep = max(1, int((1 - layer_bound) * n_head))
                _, indices = torch.topk(head_importance, k=heads_to_keep, largest=True)
                mask = torch.zeros(n_head, device=module.weight.device)
                mask[indices] = 1
                mask = mask.repeat_interleave(module.weight.size(0) // n_head).unsqueeze(1).expand_as(module.weight)
                module.weight.data *= mask
                if module.bias is not None:
                    module.bias.data *= mask.squeeze()
            elif "attn.c_proj" in name:
                head_importance = attention_importance[name]
                heads_to_keep = max(1, int((1 - layer_bound) * n_head))
                _, indices = torch.topk(head_importance, k=heads_to_keep, largest=True)
                mask = torch.zeros(n_head, device=module.weight.device)
                mask[indices] = 1
                mask = mask.repeat_interleave(module.weight.size(1) // n_head).unsqueeze(0).expand_as(module.weight)
                module.weight.data *= mask
            elif "mlp.c_fc" in name:
                neuron_importance = ffn_importance[name]
                neurons_to_keep = max(1, int((1 - layer_bound) * module.out_features))
                _, indices = torch.topk(neuron_importance, k=neurons_to_keep, largest=True)
                mask = torch.zeros(module.out_features, device=module.weight.device)
                mask[indices] = 1
                module.weight.data *= mask.unsqueeze(1).expand_as(module.weight)
                if module.bias is not None:
                    module.bias.data *= mask

            elif "mlp.c_proj" in name:
                neuron_importance = ffn_importance[name]
                neurons_to_keep = max(1, int((1 - layer_bound) * module.in_features))
                _, indices = torch.topk(neuron_importance, k=neurons_to_keep, largest=True)
                mask = torch.zeros(module.in_features, device=module.weight.device)
                mask[indices] = 1
                module.weight.data *= mask.unsqueeze(0).expand_as(module.weight)

    # Reset gradients
    model.zero_grad()

    # Count non-zero parameters after pruning
    total_params_after = sum(p.numel() for p in model.parameters())
    non_zero_params_after = sum((p != 0).sum().item() for p in model.parameters())
    print(f"Total parameters after pruning: {total_params_after:,}")
    print(f"Non-zero parameters after pruning: {non_zero_params_after:,}")
    print(f"Effectively pruned {non_zero_params_before - non_zero_params_after:,} parameters")
    print(f"Effective pruning ratio: {(non_zero_params_before - non_zero_params_after) / non_zero_params_before:.2%}")


    return model

In [3]:
import torch
import torch.nn.functional as F

def wml_loss(outputs, labels, peer_outputs, weights, alpha):
    """
    Compute the Weighted Mutual Learning loss for GPT-2.
    
    :param outputs: Logits from the current peer model
    :param labels: True labels (input_ids for GPT-2)
    :param peer_outputs: List of logits from other peer models
    :param weights: Weights for each peer model
    :param alpha: Balancing factor between CE loss and KL divergence
    """
    # Cross-entropy loss
    ce_loss = F.cross_entropy(outputs.squeeze(1), labels)
    
    # KL divergence loss
    kl_loss = 0
    for i, peer_output in enumerate(peer_outputs):
        kl_loss += weights[i] * F.kl_div(
            F.log_softmax(outputs, dim=-1),
            F.softmax(peer_output, dim=-1),
            reduction='batchmean'
        )
    
    # Combine losses
    loss = (1 - alpha) * ce_loss + alpha * kl_loss
    return loss

In [4]:
def update_peer_weights(model, peer_models, val_loader, current_weights, learning_rate, device):
    model.eval()
    for peer in peer_models:
        peer.eval()
    
    gradients = torch.zeros_like(current_weights)
    
    for batch in tqdm(val_loader):
        inputs = batch[0][0].squeeze(1).to(device) #need to select logits have highest dimension (when MRL is enabled)
        labels = batch[1][0].squeeze(1).to(device)
        
        with torch.cuda.amp.autocast():
            ensemble_output = sum(w * peer(inputs)[0][0].detach() for w, peer in zip(current_weights, peer_models))
            loss = torch.nn.functional.cross_entropy(ensemble_output.view(-1, ensemble_output.size(-1)), labels.squeeze().reshape(labels.shape[1], -1))
            
            for i, peer in enumerate(peer_models):
                peer_output = peer(inputs)[0][0]
                peer_loss = torch.nn.functional.cross_entropy(peer_output.view(-1, peer_output.size(-1)), labels.squeeze().reshape(labels.shape[1], -1))
                grad = torch.autograd.grad(peer_loss, peer.parameters(), allow_unused=True)
                gradients[i] += sum(g.norm() if g is not None else 0 for g in grad)
    
    # Mirror descent update
    new_weights = current_weights * torch.exp(-learning_rate * gradients)
    new_weights /= new_weights.sum()  # Normalize
    
    return new_weights

In [5]:
import os
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
import math

class BinaryFileDataset(IterableDataset):
    def __init__(self, file_path, block_size, model, device):
        self.file_path = file_path
        self.block_size = block_size
        self.device = device
        self.data = None
        self.pretrained_model = model

    def load_adjusted_memmap(self):
        with open(self.file_path, 'rb') as f:
            data_bytes = f.read()
        trimmed_length = (len(data_bytes) // 2) * 2
        trimmed_data = data_bytes[:trimmed_length]
        data = np.frombuffer(trimmed_data, dtype=np.uint16)
        return data

    def __iter__(self):
        self.data = self.load_adjusted_memmap()
        while True:
            ix = torch.randint(len(self.data) - self.block_size, (1,))
            x = torch.from_numpy(self.data[ix:ix+self.block_size].astype(np.int64))
            with torch.no_grad():
                self.pretrained_model.eval()
                y, _ = self.pretrained_model(x.unsqueeze(0).to(self.device))
            yield x, y
            
    def estimate_length(self):
        return len(self.load_adjusted_memmap()) - self.block_size + 1

def get_dataloader(split, model, device, batch_size=256, block_size=64):
    if split == 'train':
        file_path = os.path.join('/Users/krishnaiyer/generative-ai-research-babylm/data/processed/train_10M/processed_encoded_train.bin')
    else:
        file_path = os.path.join('/Users/krishnaiyer/generative-ai-research-babylm/data/processed/train_10M/processed_encoded_val.bin')
    dataset = BinaryFileDataset(file_path, block_size, model, device)
    dataloader = DataLoader(dataset, batch_size=batch_size)

    # Estimate the number of batches
    estimated_samples = dataset.estimate_length()
    estimated_batches = math.ceil(estimated_samples / batch_size)

    return dataloader,estimated_batches

In [6]:
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
import sys
from pathlib import Path
base_path = Path('.').resolve().parent
sys.path.append(str(base_path))
import babylm as blm
import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf

def create_peer_model(base_model, args, prune_ratio):
    peer = base_model
    X, Y = blm.gpt_2.utils.get_batch(split='train',args=args)
    return extended_snip_pruning_gpt2(peer,X,Y,prune_ratio,args.train.n_head)

def main(args):
    # Hyperparameters
    num_peers = 2
    prune_ratios = [0.2, 0.4]
    alpha = 0.5
    num_epochs = 10
    learning_rate = 1e-4
    weight_update_frequency = 100  # Update weights every 100 steps
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    batch_size = 256  # Reduced batch size
    accumulation_steps = 4  # Gradient accumulation steps
    max_sequence_length = 16  # Reduced sequence length

    # Load base model and tokenizer
    checkpoint_path = "MRL_mean_loss_2000_ckpt.pt"
    vocab_size = blm.gpt_2.utils.get_vocab_size(args)
    base_model = blm.eval.utils.load_checkpoint(args,checkpoint_path,vocab_size)

    # Create peer models
    peer_models = [create_peer_model(base_model, args, ratio).to(device) for ratio in prune_ratios]

    # Initialize peer weights
    peer_weights = torch.ones(num_peers, device=device) / num_peers

    # You should replace this with your actual dataset
    train_dataloader,num_batches_train =  get_dataloader('train', base_model, device)
    val_dataloader,num_batches_val = get_dataloader('val', base_model, device)

    # Optimizer and scheduler
    optimizer = torch.optim.Adam(sum([list(model.parameters()) for model in peer_models], []), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
    scaler = GradScaler()  # For mixed precision training

    # Training loop
    for epoch in range(num_epochs):
        # Training
        for model in peer_models:
            model.train()
        
        train_loss = 0
        optimizer.zero_grad()
        for step, batch in tqdm(enumerate(train_dataloader),desc=f"Running batches"):
            if args.MRL.enable:
                inputs = batch[0].to(device) 
                labels = batch[1][0].squeeze().reshape(batch_size, -1).to(device) #need to select logits have highest dimension (when MRL is enabled)
            else:
                inputs = batch[0].squeeze(1).to(device) 
                labels = batch[1].squeeze().reshape(batch_size, -1).to(device) 
            
            with autocast():
                # Forward pass for all peers
                if args.MRL.enable:
                    peer_outputs = [model(inputs)[0][0] for model in peer_models]
                else:
                    peer_outputs = [model(inputs)[0] for model in peer_models]
                
                # Compute loss for each peer
                losses = [wml_loss(outputs, labels, peer_outputs[:i] + peer_outputs[i+1:], 
                                   torch.cat([peer_weights[:i], peer_weights[i+1:]]), alpha) 
                          for i, outputs in enumerate(peer_outputs)]
                
                total_loss = sum(losses) / accumulation_steps

            # Backward pass
            for i, loss in enumerate(losses):
                if i == len(losses) - 1 and (step + 1) % accumulation_steps == 0:
                    scaler.scale(loss / accumulation_steps).backward()
                else:
                    scaler.scale(loss / accumulation_steps).backward(retain_graph=True)
            
            if (step + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += total_loss.item()

            # Update peer weights
            if (step+1) % weight_update_frequency == 0:
                peer_weights = update_peer_weights(base_model, peer_models, val_dataloader, peer_weights, learning_rate, device)

            train_loss /= num_batches_train * num_peers

            # Update peer weights
            if (step+1) % weight_update_frequency == 0:
                peer_weights = update_peer_weights(base_model, peer_models, val_dataloader, peer_weights, learning_rate, device)

        train_loss /= num_batches_train * num_peers

        # Validation
        for model in peer_models:
            model.eval()
        
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                inputs = batch[0].to(device)
                labels = inputs.clone()
                if args.MRL.enable:
                    outputs = sum(w * model(inputs)[0][0] for w, model in zip(peer_weights, peer_models))
                else:
                    outputs = sum(w * model(inputs)[0] for w, model in zip(peer_weights, peer_models))
                val_loss += torch.nn.functional.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1)).item()
        
        val_loss /= num_batches_val

        print(f"Epoch {epoch+1}/{num_epochs} completed. Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
        
        # Step the learning rate scheduler
        scheduler.step()

    # Output final models and weights
    for i, (model, weight) in enumerate(zip(peer_models, peer_weights)):
        print(f"Peer {i+1} weight: {weight.item():.4f}")
        model.save_pretrained(f"/Users/krishnaiyer/generative-ai-research-babylm/models/WML/peer_model_{i+1}")

if __name__ == "__main__":
    # Initialize Hydra
    initialize(version_base=None, config_path="../conf")

    # Compose the configuration
    cfg = compose(config_name="blm-main.yaml")

    main(cfg)

  from .autonotebook import tqdm as notebook_tqdm


Total parameters before pruning: 196,121,096
Non-zero parameters before pruning: 196,121,096
Total parameters after pruning: 196,121,096
Non-zero parameters after pruning: 172,055,048
Effectively pruned 24,066,048 parameters
Effective pruning ratio: 12.27%
Total parameters before pruning: 196,121,096
Non-zero parameters before pruning: 172,055,048


  return torch._dynamo.disable(fn, recursive)(*args, **kwargs)


Total parameters after pruning: 196,121,096
Non-zero parameters after pruning: 169,502,984
Effectively pruned 2,552,064 parameters
Effective pruning ratio: 1.48%


Running batches: 3it [00:21,  7.02s/it]


KeyboardInterrupt: 

In [22]:
import torch
import torch.nn as nn
from torch.nn.utils import prune
import copy
import logging
logger = logging.getLogger(__name__)
import sys
from pathlib import Path
base_path = Path('.').resolve().parent
sys.path.append(str(base_path))
import babylm as blm


def prune_gpt_model(base_model, amount=0.3, importance='l1'):
    model = copy.deepcopy(base_model)

    for module in model.modules():
        print(module)

    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, blm.gpt_2.attention.CausalSelfAttention):
            parameters_to_prune.extend([
                (module.c_attn, 'weight'),
                (module.c_proj, 'weight')
            ])
        elif isinstance(module, blm.gpt_2.elements.MLP):
            parameters_to_prune.extend([
                (module.c_fc, 'weight'),
                (module.c_proj, 'weight')
            ])
    
    print(parameters_to_prune)
    # Count total parameters to be pruned
    total_params_to_prune = sum(p.numel() for module, _ in parameters_to_prune for p in [getattr(module, _)])
    print(f"total params to prune {total_params_to_prune}")
    
    # Count total model parameters
    total_model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"total_model_params {total_model_params}")

    # Adjust pruning amount to achieve desired overall sparsity
    adjusted_amount = (amount * total_model_params) / total_params_to_prune
    print(f"adjusted_amount {adjusted_amount}")
    
    # Select pruning method
    if importance == 'l1':
        prune_method = prune.L1Unstructured
    elif importance == 'random':
        prune_method = prune.RandomUnstructured
    else:
        raise ValueError("Unsupported importance method. Choose 'l1' or 'random'.")

    # Apply global unstructured pruning
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune_method,
        amount=amount
    )

    # Make the pruning permanent
    for module, _ in parameters_to_prune:
        prune.remove(module, 'weight')

    return model

def print_gpt_sparsity(model):
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        if isinstance(module, (blm.gpt_2.attention.CausalSelfAttention, blm.gpt_2.elements.MLP)):
            for param_name, param in module.named_parameters():
                if 'weight' in param_name:
                    layer_total = param.nelement()
                    layer_zero = torch.sum(param == 0).item()
                    layer_sparsity = 100.0 * layer_zero / layer_total
                    print(f"{name}.{param_name}: {layer_sparsity:.2f}% sparsity")
                    total_params += layer_total
                    zero_params += layer_zero
    overall_sparsity = 100.0 * zero_params / total_params
    print(f"Overall model sparsity: {overall_sparsity:.2f}%")


In [2]:

from hydra import initialize, compose

def main(args):
    # Load base model and tokenizer
    checkpoint_path = "GPT2_MRL_500_ckpt.pt"
    vocab_size = blm.gpt_2.utils.get_vocab_size(args)
    base_model = blm.eval.utils.load_checkpoint(args,checkpoint_path,vocab_size)
    
    #pruned_gpt_model = prune_gpt_model(base_model, amount=0.5, importance='l1')
    #print_gpt_sparsity(pruned_gpt_model)    
    return base_model

if __name__ == "__main__":
    # Initialize Hydra
    initialize(version_base=None, config_path="../conf")

    # Compose the configuration
    cfg = compose(config_name="blm-main.yaml")

    base_model = main(cfg)



In [34]:
layer = base_model.transformer.h[0]        

In [44]:
layer.attn.n_embd

768

### New method for pruning 

In [33]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy

def prune_gpt_model(base_model, layers, num_heads, prune_ratio, importance='l1'):
    """
    Prune a GPT-style PyTorch model, focusing on attention heads and MLP layers.
    
    Args:
    - base_model (nn.Module): The GPT model to be pruned.
    - layers (list): List of layer indices to prune.
    - num_heads (int): Number of attention heads to select for pruning.
    - prune_ratio (float): The ratio of parameters to prune (0.0 to 1.0).
    - importance (str): The importance measure for pruning ('l1' or 'random').
    
    Returns:
    - pruned_model (nn.Module): The pruned model.
    """
    model = copy.deepcopy(base_model)
    pruned_modules = set()
    
    if importance == 'l1':
        prune_method = prune.L1Unstructured
    elif importance == 'random':
        prune_method = prune.RandomUnstructured
    else:
        raise ValueError("Unsupported importance method. Choose 'l1' or 'random'.")
    
    for layer_idx in layers:
        layer = model.transformer.h[layer_idx]
        
        # Prune attention heads
        attn = layer.attn
        
        # Prune query, key, value projections
        prune.ln_structured(attn.c_attn, name='weight', amount=prune_ratio, 
                            n=1, dim=0)  # Prune across Q, K, V
        pruned_modules.add(attn.c_attn)
        
        # Prune output projection
        prune.ln_structured(attn.c_proj, name='weight', amount=prune_ratio, 
                            n=1, dim=1)
        pruned_modules.add(attn.c_proj)
        
        # Prune MLP
        mlp = layer.mlp
        prune.l1_unstructured(mlp.c_fc, name='weight', amount=prune_ratio)
        prune.l1_unstructured(mlp.c_proj, name='weight', amount=prune_ratio)
        pruned_modules.add(mlp.c_fc)
        pruned_modules.add(mlp.c_proj)
    
    # Make pruning permanent
    for module in pruned_modules:
        prune.remove(module, 'weight')
    
    return model

# Usage example:
# model = YourGPTModel()
# pruned_model = prune_gpt_model(model, layers=[0, 1, 2], num_heads=4, prune_ratio=0.3)

In [31]:
pruned_model = prune_gpt_model(base_model, layers=[0, 1, 2, 3], num_heads=4, prune_ratio=0.3)

### pruning of specific heads using custom mask

In [96]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy

def prune_gpt_model(base_model, config, importance='l1'):
    """
    Prune a GPT-style PyTorch model, focusing on top important attention heads and MLP layers.
    
    Args:
    - base_model (nn.Module): The GPT model to be pruned.
    - layers (list): List of layer indices to prune.
    - num_heads (int): Number of top important attention heads to prune.
    - prune_ratio (float): The ratio of parameters to prune (0.0 to 1.0).
    - importance (str): The importance measure for pruning ('l1' or 'l2').
    
    Returns:
    - pruned_model (nn.Module): The pruned model.
    """
    layers = [item[0] for item in config]       
    num_heads = [item[1] for item in config]    
    pruning_ratios = [item[2] for item in config]
    
    model = copy.deepcopy(base_model)
    pruned_modules = set()

    for layer_idx, num_head, prune_ratio in zip(layers,num_heads,pruning_ratios):
        layer = model.transformer.h[layer_idx]
        attn = layer.attn
        
        # Calculate head importance
        weight = attn.c_attn.weight
        head_size = int(attn.n_embd / attn.n_head)
        num_heads_total = attn.n_head
        head_importance = []
        for i in range(num_heads_total):
            start = i * head_size
            end = (i + 1) * head_size
            head_weights = weight[start:end, :]
            if importance == 'l1':
                head_imp = torch.norm(head_weights, p=1)
            elif importance == 'l2':
                head_imp = torch.norm(head_weights, p=2)
            else:
                raise ValueError("Unsupported importance method. Choose 'l1' or 'l2'.")
            head_importance.append((i, head_imp.item()))
        
        # Sort heads by importance (descending) and select top num_heads
        head_importance.sort(key=lambda x: x[1], reverse=True)
        heads_to_prune = head_importance[:num_head]
        
        # Create pruning mask for attention
        attn_mask = torch.ones_like(weight)
        for head_idx, _ in heads_to_prune:
            start = head_idx * head_size
            end = (head_idx + 1) * head_size
            attn_mask[start:end, :] = 0
        
        # Apply pruning to attention
        prune.custom_from_mask(attn.c_attn, name='weight', mask=attn_mask)
        pruned_modules.add(attn.c_attn)
        
        # Prune corresponding parts in the output projection
        proj_mask = torch.ones_like(attn.c_proj.weight)
        for head_idx, _ in heads_to_prune:
            start = head_idx * head_size
            end = (head_idx + 1) * head_size
            proj_mask[:, start:end] = 0
        prune.custom_from_mask(attn.c_proj, name='weight', mask=proj_mask)
        pruned_modules.add(attn.c_proj)
        
        # Prune MLP
        mlp = layer.mlp
        prune.l1_unstructured(mlp.c_fc, name='weight', amount=prune_ratio)
        prune.l1_unstructured(mlp.c_proj, name='weight', amount=prune_ratio)
        pruned_modules.add(mlp.c_fc)
        pruned_modules.add(mlp.c_proj)
    
    # Make pruning permanent
    for module in pruned_modules:
        prune.remove(module, 'weight')
    
    return model



### peer model generator using bayesian opt

In [66]:
!pip install bayesian-optimization

Collecting bayesian-optimization
  Downloading bayesian_optimization-1.5.1-py3-none-any.whl.metadata (16 kB)
Downloading bayesian_optimization-1.5.1-py3-none-any.whl (28 kB)
Installing collected packages: bayesian-optimization
Successfully installed bayesian-optimization-1.5.1


In [72]:
import random
import math
from typing import List, Tuple
import torch
import torch.nn as nn
from bayes_opt import BayesianOptimization
import itertools

In [110]:
class TreeNode:
    def __init__(self, layer_idx: int, num_heads: int):
        self.layer_idx = layer_idx
        self.num_heads = num_heads
        self.left = None
        self.right = None

In [116]:
#create a binary tree
def create_pruning_tree(num_layers: int, base_heads: int, max_depth: int = 6) -> TreeNode:
    def create_node(layer_idx: int, num_heads: int, depth: int) -> TreeNode:
        if layer_idx > num_layers or depth > max_depth:
            return None
        
        node = TreeNode(layer_idx, num_heads)
        
        if layer_idx + 1 < num_layers:
            node.left = create_node(layer_idx + 1, num_heads, depth + 1)
            node.right = create_node(layer_idx + 1, max(1, num_heads // 2), depth + 1)
        
        return node
    
    return create_node(0, base_heads, 0)

In [135]:
!pip install graphviz

Collecting graphviz
  Downloading graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading graphviz-0.20.3-py3-none-any.whl (47 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.1/47.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: graphviz
Successfully installed graphviz-0.20.3


In [136]:
from graphviz import Digraph

In [140]:
def visualize_pruning_tree(root: TreeNode, filename: str = "pruning_tree.png"):
    dot = Digraph(comment='Pruning Tree')
    dot.attr(rankdir='TB', size='8,8')
    
    def add_nodes_edges(node: TreeNode):
        if node:
            node_id = f"{node.layer_idx}_{node.num_heads}"
            dot.node(node_id, f"Layer {node.layer_idx}\nHeads {node.num_heads}")
            if node.left:
                left_id = f"{node.left.layer_idx}_{node.left.num_heads}"
                dot.edge(node_id, left_id, 'left')
                add_nodes_edges(node.left)
            if node.right:
                right_id = f"{node.right.layer_idx}_{node.right.num_heads}"
                dot.edge(node_id, right_id, 'right')
                add_nodes_edges(node.right)
    
    add_nodes_edges(root)
    dot.render(filename, view=True, format='png')

In [141]:
pruning_tree = create_pruning_tree(4, 4)
visualize_pruning_tree(pruning_tree)


NotADirectoryError: [Errno 20] Not a directory: PosixPath('dot')

In [122]:
#traverse the binary tree and extract the config
def generate_peer_model_config(root: TreeNode, prune_ratio: List) -> List[Tuple[int, int, float]]:
    config = []
    node = root
    layer_idx = 0
    while node and layer_idx < len(prune_ratio):
        config.append((node.layer_idx, node.num_heads, prune_ratio[layer_idx]))
        node = node.left if random.random() < 0.5 else node.right
        layer_idx += 1
    return config

### Bayesian optimisation on normalised max difference of sparsities between peer models 

In [112]:
def compute_sparsity(pruned_model):
    total_params = 0
    zero_params = 0
    for name, module in pruned_model.named_modules():
        if isinstance(module, (blm.gpt_2.attention.CausalSelfAttention, blm.gpt_2.elements.MLP)):
            for param_name, param in module.named_parameters():
                if 'weight' in param_name:
                    layer_total = param.nelement()
                    layer_zero = torch.sum(param == 0).item()
                    total_params += layer_total
                    zero_params += layer_zero
    return zero_params / total_params

In [131]:
def optimize_peer_models(base_model: nn.Module, num_peers: int, num_layers: int, base_heads: int) -> List[List[Tuple[int, int, float]]]:
    pruning_tree = create_pruning_tree(num_layers, base_heads)
    
    def objective(**kwargs):
        all_prune_ratios = [[kwargs[f'peer_{p}_layer_{l}'] for l in range(num_layers)] for p in range(num_peers)]
        configs = [generate_peer_model_config(pruning_tree, prune_ratios) for prune_ratios in all_prune_ratios]
        sparsities = []
        for config in configs:
            pruned_model = prune_gpt_model(base_model,config)
            sparsities.append(compute_sparsity(pruned_model))
        
        # Calculate the sum of absolute differences between all pairs
        diff_sum = sum(abs(s1 - s2) for s1, s2 in itertools.combinations(sparsities, 2))
        
        # Normalize by the number of pairs
        normalized_diff = (2 / (num_peers * (num_peers - 1))) * diff_sum if num_peers > 1 else 0
        
        # Include average sparsity in the objective
        avg_sparsity = sum(sparsities) / num_peers
        
        # Combine normalized difference and average sparsity
        return normalized_diff
    
    pbounds = {f'peer_{p}_layer_{l}': (0.1, 0.5) for p in range(num_peers) for l in range(num_layers)}
    
    optimizer = BayesianOptimization(
        f=objective,
        pbounds=pbounds,  
        random_state=1,
    )
    
    optimizer.maximize(init_points=10, n_iter=100)
    
    best_prune_ratios = [[optimizer.max['params'][f'peer_{p}_layer_{l}'] for l in range(num_layers)] for p in range(num_peers)]
    best_configs = [generate_peer_model_config(pruning_tree, prune_ratios) for prune_ratios in best_prune_ratios]
    return best_configs

In [132]:
best_configs = optimize_peer_models(base_model, num_peers=3, num_layers=4, base_heads=4)

|   iter    |  target   | peer_0... | peer_0... | peer_0... | peer_0... | peer_1... | peer_1... | peer_1... | peer_1... | peer_2... | peer_2... | peer_2... | peer_2... |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| [39m1        [39m | [39m0.07029  [39m | [39m0.2668   [39m | [39m0.3881   [39m | [39m0.1      [39m | [39m0.2209   [39m | [39m0.1587   [39m | [39m0.1369   [39m | [39m0.1745   [39m | [39m0.2382   [39m | [39m0.2587   [39m | [39m0.3155   [39m | [39m0.2677   [39m | [39m0.3741   [39m |
| [39m2        [39m | [39m0.0649   [39m | [39m0.1818   [39m | [39m0.4512   [39m | [39m0.111    [39m | [39m0.3682   [39m | [39m0.2669   [39m | [39m0.3235   [39m | [39m0.1562   [39m | [39m0.1792   [39m | [39m0.4203   [39m | [39m0.4873   [39m | [39m0.2254   [39m | [39m0.3769   [39m |
| [39m3        [39m | [39m0.0271   [

In [129]:
def print_gpt_sparsity(model):
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        if isinstance(module, (blm.gpt_2.attention.CausalSelfAttention, blm.gpt_2.elements.MLP)):
            for param_name, param in module.named_parameters():
                if 'weight' in param_name:
                    layer_total = param.nelement()
                    layer_zero = torch.sum(param == 0).item()
                    layer_sparsity = 100.0 * layer_zero / layer_total
                    print(f"{name}.{param_name}: {layer_sparsity:.2f}% sparsity")
                    total_params += layer_total
                    zero_params += layer_zero
    overall_sparsity = 100.0 * zero_params / total_params
    print(f"Overall model sparsity: {overall_sparsity:.2f}%")

In [133]:
for i, config in enumerate(best_configs):
    print(f"pruned model {i+1}")
    pruned_model = prune_gpt_model(base_model,config)
    print_gpt_sparsity(pruned_model)

pruned model 1
transformer.h.0.attn.c_attn.weight: 33.33% sparsity
transformer.h.0.attn.c_proj.weight: 100.00% sparsity
transformer.h.0.mlp.c_fc.weight: 49.40% sparsity
transformer.h.0.mlp.c_proj.weight: 49.40% sparsity
transformer.h.1.attn.c_attn.weight: 33.33% sparsity
transformer.h.1.attn.c_proj.weight: 100.00% sparsity
transformer.h.1.mlp.c_fc.weight: 19.67% sparsity
transformer.h.1.mlp.c_proj.weight: 19.67% sparsity
transformer.h.2.attn.c_attn.weight: 16.67% sparsity
transformer.h.2.attn.c_proj.weight: 50.00% sparsity
transformer.h.2.mlp.c_fc.weight: 38.89% sparsity
transformer.h.2.mlp.c_proj.weight: 38.89% sparsity
transformer.h.3.attn.c_attn.weight: 8.33% sparsity
transformer.h.3.attn.c_proj.weight: 25.00% sparsity
transformer.h.3.mlp.c_fc.weight: 44.60% sparsity
transformer.h.3.mlp.c_proj.weight: 44.60% sparsity
Overall model sparsity: 36.88%
pruned model 2
transformer.h.0.attn.c_attn.weight: 33.33% sparsity
transformer.h.0.attn.c_proj.weight: 100.00% sparsity
transformer.h.0.m