In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.utils.prune as prune
import gc
import time
import os
import json
from tqdm import tqdm
import matplotlib.pyplot as plt

# Memory management
def free_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_model(model_name="gpt2-large"):
    """Load the model with memory optimization"""
    print(f"Loading {model_name}...")
    start_time = time.time()
    
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    model = GPT2LMHeadModel.from_pretrained(
        model_name, 
        torch_dtype=dtype
    )
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    
    print(f"Model loaded in {time.time() - start_time:.2f} seconds")
    return model, tokenizer

def count_parameters(model):
    """Count total and non-zero parameters in the model"""
    total_params = 0
    nonzero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:  # Focus on weights for pruning analysis
            total = param.numel()
            nonzero = torch.count_nonzero(param).item()
            total_params += total
            nonzero_params += nonzero
            print(f"{name}: {nonzero}/{total} ({nonzero/total*100:.2f}% non-zero)")
    
    print(f"Total: {nonzero_params}/{total_params} " 
         f"({nonzero_params/total_params*100:.2f}% non-zero, "
         f"{100-nonzero_params/total_params*100:.2f}% sparsity)")
    
    return total_params, nonzero_params

def prune_model(model, amount=0.5, method="l1_unstructured"):

    # Prune the model using the specified method and amount
    print(f"Pruning model with {method} at {amount*100}% sparsity...")
    start_time = time.time()
    
    # Count parameters before pruning
    print("Before pruning:")
    total_before, nonzero_before = count_parameters(model)
    
    # Apply pruning
    # for name, module in model.named_modules():
    for name, module in tqdm(list(model.named_modules()), desc="Pruning modules"):
        # Prune the linear layers in the MLP (feed-forward network) blocks
        if 'mlp.c_fc' in name or 'mlp.c_proj' in name:
            print("MODULE NAME ->-->--->---->", name)
        # if isinstance(module, torch.nn.Linear):
        # if isinstance(module, Conv1D):
            if method == "l1_unstructured":
                prune.l1_unstructured(module, name='weight', amount=amount)
            elif method == "random_unstructured":
                prune.random_unstructured(module, name='weight', amount=amount)
            elif method == "magnitude_unstructured":  # Equivalent to l1 for our purposes
                prune.l1_unstructured(module, name='weight', amount=amount)


    # Count parameters after pruning
    print("\nAfter pruning before mask removal:")
    total_after, nonzero_after = count_parameters(model)

    # Make pruning permanent to save memory
    model = make_pruning_permanent(model)
    
    # Count parameters after pruning
    print("\nAfter pruning:")
    total_after, nonzero_after = count_parameters(model)
    
    print(f"Pruning completed in {time.time() - start_time:.2f} seconds")
    
    # Return pruning statistics
    return {
        "total_params": total_after,
        "nonzero_params": nonzero_after,
        "sparsity": 1 - (nonzero_after / total_after),
        "pruning_method": method,
        "target_sparsity": amount
    }

def make_pruning_permanent(model):
    """Convert pruning masks to permanent pruning"""
    print("Making pruning permanent...")
    for name, module in tqdm(list(model.named_modules()), desc="Removing pruning masks"):
        if 'mlp.c_fc' in name or 'mlp.c_proj' in name:
            try:
                prune.remove(module, "weight")
            except:
                # If the module has no pruning mask, skip it
                pass
    return model
def run_pruning_experiment():
    """Run the pruning experiment with evaluation"""
        # Define pruning parameters
    pruning_method = "random_unstructured"  # Options: l1_unstructured, random_unstructured
    pruning_amount = 0.5  # 50% sparsity
    # Create output directory
    output_dir = ("pruned_gpt2_large_" + pruning_method + "_" + str(pruning_amount))
    os.makedirs(output_dir, exist_ok=True)
    
    # Clear GPU memory
    free_memory()
    
    # Load model
    model, tokenizer = load_model()
    model.to(device)

    # Prune the model
    try:
        pruning_stats = prune_model(model, amount=pruning_amount, method=pruning_method)
    except Exception as e:
        print(f"Pruning failed: {e}")
        return
   
    # Save pruned model
    print("Saving pruned model...")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    print(f"Pruned model saved to {output_dir}")
    print(f"Model pruned to {pruning_stats['sparsity']*100:.2f}% sparsity")
    
    return model, tokenizer


# Run the experiment
if __name__ == "__main__":
    run_pruning_experiment()


2025-03-23 02:30:07.453025: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-23 02:30:07.453090: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-23 02:30:07.454442: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-23 02:30:07.462417: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
Loading gpt2-large...
Model loaded in 3.98 seconds
Pruning model with random_unstructured at 50.0% sparsity...
Before pruning:
transformer.wte.weight: 64328938/64328960 (100.00% non-zero)
transformer.wpe.weight: 1310717/1310720 (100.00% non-zero)
transformer.h.0.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.0.attn.c_proj.weight: 1638398/1638400 (100.00% non-zero)
transformer.h.0.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.mlp.c_fc.weight: 6553598/6553600 (100.00% non-zero)
transformer.h.0.mlp.c_proj.weight: 6553596/6553600 (100.00% non-zero)
transformer.h.1.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.1.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.1.attn.c_proj.weight: 1638400/1638400 (100.00% non-zero)
transformer.h.1.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.1.mlp.c_fc.weight: 6553595/6553600 (100.00% non-zero)
transformer.

Pruning modules:   0%|          | 0/476 [00:00<?, ?it/s]

MODULE NAME ->-->--->----> transformer.h.0.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.0.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.1.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.1.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.2.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.2.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.3.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.3.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.4.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.4.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.5.mlp.c_fc


Pruning modules:  17%|█▋        | 82/476 [00:00<00:00, 782.98it/s]

MODULE NAME ->-->--->----> transformer.h.5.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.6.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.6.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.7.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.7.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.8.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.8.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.9.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.9.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.10.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.10.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.11.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.11.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.12.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.12.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.13.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.13.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.14.mlp.c_fc
MODULE NAME ->-->--->----> transforme

Pruning modules:  44%|████▍     | 211/476 [00:00<00:00, 1042.74it/s]

MODULE NAME ->-->--->----> transformer.h.15.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.15.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.16.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.16.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.17.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.17.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.18.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.18.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.19.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.19.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.20.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.20.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.21.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.21.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.22.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.22.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.23.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.23.mlp.c_proj


Pruning modules:  69%|██████▉   | 329/476 [00:00<00:00, 1085.57it/s]

MODULE NAME ->-->--->----> transformer.h.24.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.24.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.25.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.25.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.26.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.26.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.27.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.27.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.28.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.28.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.29.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.29.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.30.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.30.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.31.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.31.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.32.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.32.mlp.c_proj
MODULE NAME ->-->--->----> t

Pruning modules: 100%|██████████| 476/476 [00:00<00:00, 1096.14it/s]

MODULE NAME ->-->--->----> transformer.h.33.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.34.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.34.mlp.c_proj
MODULE NAME ->-->--->----> transformer.h.35.mlp.c_fc
MODULE NAME ->-->--->----> transformer.h.35.mlp.c_proj

After pruning before mask removal:
transformer.wte.weight: 64328938/64328960 (100.00% non-zero)
transformer.wpe.weight: 1310717/1310720 (100.00% non-zero)
transformer.h.0.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.0.attn.c_proj.weight: 1638398/1638400 (100.00% non-zero)
transformer.h.0.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.mlp.c_fc.weight_orig: 6553598/6553600 (100.00% non-zero)
transformer.h.0.mlp.c_proj.weight_orig: 6553596/6553600 (100.00% non-zero)
transformer.h.1.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.1.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.1.attn.c_proj.weight: 16




transformer.h.31.mlp.c_fc.weight_orig: 6553600/6553600 (100.00% non-zero)
transformer.h.31.mlp.c_proj.weight_orig: 6553597/6553600 (100.00% non-zero)
transformer.h.32.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.32.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.32.attn.c_proj.weight: 1638400/1638400 (100.00% non-zero)
transformer.h.32.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.32.mlp.c_fc.weight_orig: 6553596/6553600 (100.00% non-zero)
transformer.h.32.mlp.c_proj.weight_orig: 6553598/6553600 (100.00% non-zero)
transformer.h.33.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.33.attn.c_attn.weight: 4915193/4915200 (100.00% non-zero)
transformer.h.33.attn.c_proj.weight: 1638399/1638400 (100.00% non-zero)
transformer.h.33.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.33.mlp.c_fc.weight_orig: 6553599/6553600 (100.00% non-zero)
transformer.h.33.mlp.c_proj.weight_orig: 6553598/6553600 (100.00% non-zero)
transformer.h.34.ln_1.weig

Removing pruning masks: 100%|██████████| 476/476 [00:00<00:00, 85940.71it/s]


After pruning:
transformer.wte.weight: 64328938/64328960 (100.00% non-zero)
transformer.wpe.weight: 1310717/1310720 (100.00% non-zero)
transformer.h.0.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.0.attn.c_proj.weight: 1638398/1638400 (100.00% non-zero)
transformer.h.0.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.0.mlp.c_fc.weight: 3276800/6553600 (50.00% non-zero)
transformer.h.0.mlp.c_proj.weight: 3276797/6553600 (50.00% non-zero)
transformer.h.1.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.1.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.1.attn.c_proj.weight: 1638400/1638400 (100.00% non-zero)
transformer.h.1.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.1.mlp.c_fc.weight: 3276797/6553600 (50.00% non-zero)
transformer.h.1.mlp.c_proj.weight: 3276800/6553600 (50.00% non-zero)
transformer.h.2.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.2.at




transformer.h.34.attn.c_attn.weight: 4915199/4915200 (100.00% non-zero)
transformer.h.34.attn.c_proj.weight: 1638400/1638400 (100.00% non-zero)
transformer.h.34.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.34.mlp.c_fc.weight: 3276797/6553600 (50.00% non-zero)
transformer.h.34.mlp.c_proj.weight: 3276799/6553600 (50.00% non-zero)
transformer.h.35.ln_1.weight: 1280/1280 (100.00% non-zero)
transformer.h.35.attn.c_attn.weight: 4915197/4915200 (100.00% non-zero)
transformer.h.35.attn.c_proj.weight: 1638400/1638400 (100.00% non-zero)
transformer.h.35.ln_2.weight: 1280/1280 (100.00% non-zero)
transformer.h.35.mlp.c_fc.weight: 3276797/6553600 (50.00% non-zero)
transformer.h.35.mlp.c_proj.weight: 3276799/6553600 (50.00% non-zero)
transformer.ln_f.weight: 1280/1280 (100.00% non-zero)
Total: 537592050/773521920 (69.50% non-zero, 30.50% sparsity)
Pruning completed in 0.72 seconds
Saving pruned model...
Pruned model saved to pruned_gpt2_large_random_unstructured_0.5
Model pruned to 30.50%