# SELM Pruning Experiment

In this notebook, we will experiment with model pruning and quantization for SELM. We will first prune the model to reduce its size and then quantize it to optimize performance for deployment.

## 1. Setup
Import necessary libraries and load the trained model.

In [1]:
import torch
from transformers import AutoModelForSequenceClassification
import torch.nn.utils.prune as prune
import torch.quantization as quant

# Load the trained model
model_path = 'model_output/trained_model/'
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()


## 2. Prune the Model
Apply pruning to the model's parameters to reduce its size.

In [2]:
# Apply pruning
def prune_model(model, amount=0.2):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)

prune_model(model, amount=0.2)

# Check the pruning status
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        print(f'{name} pruned weights: {module.weight_mask.sum().item() / module.weight_mask.numel()}')


## 3. Quantize the Model
Quantize the pruned model to reduce its size further and improve performance for deployment.

In [3]:
# Prepare model for quantization
model = torch.quantization.convert(model, inplace=False)

# Test the quantized model
def test_model(model, dataloader):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch['input_ids'], batch['labels']
            outputs = model(inputs)
            preds = torch.argmax(outputs.logits, dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    return correct / total

# Create dataloader for evaluation
from torch.utils.data import DataLoader
test_dataloader = DataLoader(tokenized_dataset['test'], batch_size=16)

accuracy = test_model(model, test_dataloader)
print(f'Quantized Model Accuracy: {accuracy:.4f}')


## 4. Save the Pruned and Quantized Model
Save the pruned and quantized model for future use.

In [4]:
quantized_model_path = 'model_output/quantized_model/'
model.save_pretrained(quantized_model_path)
print(f'Quantized model saved to {quantized_model_path}')
