### Applying Model Optimization Techniques (Pruning & Quantization) to a Pre-trained Generative Model

In [1]:
import torch
import torch.nn.utils.prune as prune
import time
import psutil
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
def get_memory_usage():
    process = psutil.Process()
    return process.memory_info().rss / 1024 / 1024  # MB

def prune_model(model):
    # Prune 20% of weights in linear layers
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.2)
    return model

def quantize_model(model):
    # Dynamic quantization for better compatibility
    model.eval()
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

In [3]:
def benchmark_inference(model, tokenizer, prompt="Hello world", num_runs=5):
    inputs = tokenizer(prompt, return_tensors="pt")
    start_time = time.time()
    memory_before = get_memory_usage()
    
    for _ in range(num_runs):
        with torch.no_grad():
            outputs = model(**inputs)
    
    end_time = time.time()
    memory_after = get_memory_usage()
    
    avg_time = (end_time - start_time) / num_runs
    memory_used = memory_after - memory_before
    return avg_time, memory_used

In [4]:
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
original_model = AutoModelForCausalLM.from_pretrained(model_name)
    
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
print("Original model:")
orig_time, orig_mem = benchmark_inference(original_model, tokenizer)
print(f"  Inference time: {orig_time:.4f} s")
print(f"  Memory usage: {orig_mem:.2f} MB")
    
# Apply pruning to a copy
pruned_model = AutoModelForCausalLM.from_pretrained(model_name)
pruned_model = prune_model(pruned_model)
print("\nAfter pruning:")
prune_time, prune_mem = benchmark_inference(pruned_model, tokenizer)
print(f"  Inference time: {prune_time:.4f} s")
print(f"  Memory usage: {prune_mem:.2f} MB")
    
# Apply quantization to original
quantized_model = quantize_model(original_model)
print("\nAfter quantization:")
quant_time, quant_mem = benchmark_inference(quantized_model, tokenizer)
print(f"  Inference time: {quant_time:.4f} s")
print(f"  Memory usage: {quant_mem:.2f} MB")

Original model:
  Inference time: 0.0945 s
  Memory usage: 311.01 MB

After pruning:
  Inference time: 0.0764 s
  Memory usage: 220.63 MB


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.quantize_dynamic(



After quantization:
  Inference time: 0.0356 s
  Memory usage: 3.64 MB
