# 03 - Pruning Estructural y Cuantización

En este notebook aplicaremos pruning a las cabezas de atención y cuantización post-entrenamiento.

In [1]:
!pip install transformers datasets torch

Collecting transformers
  Downloading transformers-4.52.4-py3-none-any.whl.metadata (38 kB)
Collecting datasets
  Using cached datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting torch
  Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting filelock (from transformers)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Downloading huggingface_hub-0.33.0-py3-none-any.whl.metadata (14 kB)
Collecting numpy>=1.17 (from transformers)
  Using cached numpy-2.3.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (62 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting requests (from transformers)
  Downloading reques

In [2]:
import os
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import torch.nn.utils.prune as prune

# Cargar modelo fine-tuneado
model = AutoModelForSequenceClassification.from_pretrained('../outputs/models/finetuned')
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm


HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '../outputs/models/finetuned'. Use `repo_type` argument if needed.

In [None]:
# Pruning estructural de cabezas de atención
PRUNE_HEADS = {  # Ejemplo: podar cabezas en capas 1 y 2
    1: [0, 1],
    2: [2, 3]
}
model.distilbert.prune_heads(PRUNE_HEADS)
torch.save(model.state_dict(), '../outputs/models/pruned/model_pruned.pt')

In [None]:
# Evaluar modelo podado
dataset = load_dataset('glue', 'mrpc')
def preprocess(batch):
    return tokenizer(batch['sentence1'], batch['sentence2'], truncation=True, padding='max_length')
dataset = dataset.map(preprocess, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
model.eval()
acc = 0
total = 0
for batch in dataset['validation']:
    inputs = {k: v.unsqueeze(0) for k, v in batch.items() if k in ['input_ids','attention_mask']}
    labels = batch['label'].unsqueeze(0)
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.argmax(dim=-1)
    acc += (preds == labels).sum().item()
    total += labels.size(0)
print(f'Accuracy modelo podado: {acc/total:.4f}')

In [None]:
# Cuantización dinámica post-entrenamiento
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
torch.save(quantized_model.state_dict(), '../outputs/models/quantized/model_quantized.pt')

In [None]:
# Evaluar modelo cuantizado
quantized_model.eval()
acc_q = 0
total_q = 0
for batch in dataset['validation']:
    inputs = {k: v.unsqueeze(0) for k, v in batch.items() if k in ['input_ids','attention_mask']}
    labels = batch['label'].unsqueeze(0)
    with torch.no_grad():
        outputs = quantized_model(**inputs)
    preds = outputs.logits.argmax(dim=-1)
    acc_q += (preds == labels).sum().item()
    total_q += labels.size(0)
print(f'Accuracy modelo cuantizado: {acc_q/total_q:.4f}')

In [None]:
# Comparativa de tamaños
def size_mb(path): return os.path.getsize(path)/1e6
original = size_mb('../outputs/models/finetuned/pytorch_model.bin')
pruned = size_mb('../outputs/models/pruned/model_pruned.pt')
quant = size_mb('../outputs/models/quantized/model_quantized.pt')
print(f'Original: {original:.2f} MB')
print(f'Pruned: {pruned:.2f} MB')
print(f'Quantized: {quant:.2f} MB')

In [None]:
# Export a ONNX
from transformers import pipeline
onnx_path = '../outputs/models/quantized/model_quantized.onnx'
dummy_input = ('This is a test sentence.', 'Another sentence.')
nlp = pipeline('text-classification', model=quantized_model, tokenizer=tokenizer)
nlp.model.eval()
torch.onnx.export(
    nlp.model,
    (torch.tensor([tokenizer.encode(dummy_input[0], dummy_input[1])]),),
    onnx_path,
    opset_version=11
)