# Fine-tuning and quantization

In this example, you will fine-tune a small language model (GPT-2 in this case) and then quantizing it from FP32 to INT8. 



## Imports



In [None]:
%pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
%pip install transformers==4.41.2
%pip install datasets==2.20.2
%pip install numpy==1.26.3
%pip install pandas==2.0.3

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset
import time

## Prepare dataset

Next, you will check whether there are any NVIDIA GPUs configured in the environment

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Next, you will load and prepare the dataset

In [None]:
# Load a small dataset (e.g., a subset of WikiText-2)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")

In [None]:
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

Next, we will create a custom TextDataset class  (which is a custom implementation of the PyTorch's Dataset class).  

The DataLoader is a crucial part of the PyTorch training pipeline. It:

* Batches the data, which allows for more efficient processing.
* Shuffles the data, which helps in reducing overfitting.
* Handles the conversion of your data into PyTorch tensors.
* Can distribute the data across multiple CPU cores for faster loading (though in this CPU-only version, we're not using multiple cores).

When we use this train_loader in our training loop, it will yield batches of data, each containing 4 samples (except possibly the last batch if the dataset size isn't divisible by 4). Each batch will be a dictionary with keys 'input_ids' and 'attention_mask', where each value is a tensor of shape (4, ...).

This setup allows for efficient, batched processing of our dataset during training, which is crucial for handling larger datasets and speeding up the training process.

In [None]:
# Create a custom dataset
class TextDataset(Dataset):

    #The constructor takes the tokenized dataset as an argument and stores it
    def __init__(self, tokenized_dataset):
        self.tokenized_dataset = tokenized_dataset  

    #Returns the length of the dataset
    def __len__(self):
        return len(self.tokenized_dataset) 

    # Fetches a single item from the dataset
    # Takes an index (idx)
    # Returns a dictionary containing:
    #      input_ids: the tokenzied and encoded text
    #      attention_mask: a mask indicating which tokens are padding and which are actual input
    def __getitem__(self, idx):
        item = self.tokenized_dataset[idx]
        return torch.tensor(item['input_ids']), torch.tensor(item['attention_mask'])

# Creates an intance of the custom dataset class
train_dataset = TextDataset(tokenized_dataset)

# Creates a PyTorch utility for loading data in batches of 4 items and sheffle the data before each epoch (to prevent model from learning the order of the data)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

## Fine-tuning

In [None]:
# Load pre-trained model
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)

In [None]:
# Set up optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

NOTE:  This next step can take up to 2 hours to complete the training.

In [None]:
# Fine-tuning loop
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    start_time = time.time()
    for batch in train_loader:
        input_ids, attention_mask = batch
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        end_time = time.time()
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Time: {end_time - start_time:.2f} seconds")
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Fine-tuning complete!")

## Quantization

In [None]:
# Static quantization example
# def calibrate(model, loader):
#     model.eval()
#     with torch.no_grad():
#         for batch in loader:
#             input_ids, attention_mask = batch
#             input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
#             _ = model(input_ids, attention_mask=attention_mask)

# # Prepare the model for quantization
# model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(model, inplace=True)

# # Calibrate the model
# calibrate(model, train_loader)

# # Convert the model to quantized version
# quantized_model = torch.quantization.convert(model, inplace=False)

# Dynamic quantization example
# quantized_model = torch.quantization.quantize_dynamic(
#     model, {torch.nn.Linear}, dtype=torch.quint8
# )

# Custom quantization example
# This approach avoids quantizing the embedding layers, where were causing errors
def quantize_model(model):
    # Quantize only the transformer blocks
    for name, module in model.named_children():
        if "h" in name:  # This is the transformer block in GPT-2
            for sub_name, sub_module in module.named_children():
                if isinstance(sub_module, nn.Linear):  #Quantizes only the linear layers within these blocks
                    module._modules[sub_name] = torch.quantization.quantize_dynamic(
                        sub_module, {torch.nn.Linear}, dtype=torch.qint8
                    )
    return model

quantized_model = quantize_model(model)

print("Quantization complete!")

## Evaluation

In [None]:
print(f"GPT2 model's total parameters: {model.num_parameters()}")
model_param = next(iter(model.state_dict().values()))
print(f"GPT2 model precision (weight data type): {model_param.dtype}")
print(f"Quantized GPT2 model's total parameters: {quantized_model.num_parameters()}")
quantized_mparam = next(iter(quantized_model.state_dict().values()))
print(f"GPT2 model precision (weight data type): {quantized_mparam.dtype}")

In [None]:
# Function to calculate model size
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)

print(f"Original model size: {original_size:.2f} MB")
print(f"Quantized model size: {quantized_size:.2f} MB")
print(f"Size reduction: {(1 - quantized_size/original_size)*100:.2f}%")

# Inference comparison (example)
input_text = "The quick brown fox"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

with torch.no_grad():
    original_output = model.generate(input_ids, max_length=50, num_return_sequences=1)
    quantized_output = quantized_model.generate(input_ids, max_length=50, num_return_sequences=1)

print("Original model output:")
print(tokenizer.decode(original_output[0], skip_special_tokens=True))
print("\nQuantized model output:")
print(tokenizer.decode(quantized_output[0], skip_special_tokens=True))

In [None]:
# Inference comparison (example)
def generate_text(model, prompt, max_length=50):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test prompts
test_prompts = [
    "The quick brown fox",
    "In a world where technology",
    "Climate change is a pressing issue because",
    "The future of artificial intelligence"
]


In [None]:
print("\nInference Comparison:")
for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    
    # Original model
    start_time = time.time()
    original_output = generate_text(model, prompt)
    original_time = time.time() - start_time
    
    print(f"Original model output: {original_output}")
    print(f"Original model inference time: {original_time:.4f} seconds")
    
    # Quantized model
    start_time = time.time()
    quantized_output = generate_text(quantized_model, prompt)
    quantized_time = time.time() - start_time
    
    print(f"Quantized model output: {quantized_output}")
    print(f"Quantized model inference time: {quantized_time:.4f} seconds")
    
    print(f"Speedup: {original_time/quantized_time:.2f}x")

NOTE: The perplexity comparison test takes about 30 minutes to complete.

In [None]:
# Perplexity comparison
def calculate_perplexity(model, data_loader):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids, attention_mask = batch
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            
            total_loss += loss.item() * input_ids.size(0)
            total_tokens += torch.sum(attention_mask).item()
    
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()

print("\nPerplexity Comparison:")
original_perplexity = calculate_perplexity(model, train_loader)
quantized_perplexity = calculate_perplexity(quantized_model, train_loader)

print(f"Original model perplexity: {original_perplexity:.2f}")
print(f"Quantized model perplexity: {quantized_perplexity:.2f}")
print(f"Perplexity increase: {(quantized_perplexity/original_perplexity - 1)*100:.2f}%")