### E-LlaMA-13B Expert Creation
### Binary Mask

In [None]:
# Dependencies
#!pip install --upgrade pip
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121  # For CUDA 12.1
#!pip install transformers accelerate safetensors bitsandbytes xformers
#!pip install scipy sentencepiece
#!pip install ipython rich matplotlib pandas tqdm
#sudo apt-get install gcsfuse
#sudo apt-get update
#sudo apt-get install fuse
#sudo modprobe fuse

In [1]:
# load dependencies
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import json
import torch
import time
from tqdm.notebook import tqdm
import torch.nn as nn
import numpy as np
from safetensors.torch import save_file

In [2]:
# Set model path
model_path = "/mnt/models/MeLLaMA-13B"

In [None]:
#Get cuda ver - should be 12.4
!nvcc --version

In [None]:
# Activate MoeME env

# Mount SSD to VM
!sudo ln -s /mnt/models ~/models

# Ensure models folder is visible in explorer
!sudo ln -s /mnt/models ~/models

In [None]:
# Get GPU info
!nvidia-smi -L

In [None]:
# Install Torch for CUDA 12.4
#!pip3 install torch torchvision torchaudio

In [None]:
# Load Baseline MeLLaMA-13B Model to prune to CPU

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float32)

model.to("cpu")

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
# View all params - prints them out
for name, param in model.named_parameters():
    print(name, param.shape)

In [None]:
def prune_ffn_layers(model, sparsity_percentage):
    """
    Apply random binary mask to sparsify FFN layers of the model
    
    Args:
        model: The pre-trained transformer model
        sparsity_percentage: Float between 0 and 1 indicating percentage of nodes to remove
    
    Returns:
        model: The pruned model
    """
    print(f"Starting pruning with sparsity level: {sparsity_percentage}")
    
    # Count parameters before pruning
    orig_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Original parameter count: {orig_params:,}")
    
    pruned_count = 0
    total_ffn_params = 0
    
    # Iterate through all modules in the model
    for name, module in model.named_modules():
        # Target FFN layers in transformer blocks
        # This pattern needs to be adjusted based on your specific model architecture
        if "mlp" in name.lower() or "ffn" in name.lower():
            for subname, param in module.named_parameters():
                if "weight" in subname:  # Focus on weight matrices
                    total_ffn_params += param.numel()
                    
                    # Create binary mask (1s for keep, 0s for prune)
                    mask = torch.rand_like(param, dtype=torch.float) > sparsity_percentage
                    
                    # Apply mask (hard pruning)
                    param.data = param.data * mask.float()
                    
                    # Count pruned parameters
                    pruned_count += param.numel() - mask.sum().item()
    
    print(f"FFN parameters before pruning: {total_ffn_params:,}")
    print(f"Parameters pruned: {pruned_count:,} ({pruned_count/orig_params:.2%} of total)")
    
    # Count parameters after pruning (note: this doesn't change since we're just zeroing values)
    remaining_params = orig_params - pruned_count
    print(f"Effective parameter count after pruning: {remaining_params:,}")
    
    return model

def remove_pruned_parameters(model):
    """
    Convert the pruned model (with zeroed weights) to a physically smaller model
    This is a placeholder - actual implementation depends on model architecture
    """
    # This is more complex and would require rebuilding the model architecture
    # to physically remove the pruned nodes
    print("Note: Converting masked model to physically smaller model would require")
    print("rebuilding the model architecture based on the specific transformer implementation.")
    
    return model

def save_pruned_model(model, output_dir, tokenizer):
    """
    Save the pruned model and tokenizer
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Save model using safetensors format
    state_dict = model.state_dict()
    save_file(state_dict, f"{output_dir}/model.safetensors")
    
    # Save tokenizer
    tokenizer.save_pretrained(output_dir)
    
    print(f"Pruned model saved to {output_dir}")


In [None]:
# Parameters
output_dir = "pruned_model"
sparsity = 0.5  # 50% of FFN parameters will be pruned
    
# Apply pruning
pruned_model = prune_ffn_layers(model, sparsity)