### Fine-tuning for Classification

We will learn to fine-tune the LLM on a specific task - classifying text into "spam" or "not spam". The kind of fine-tuning requires less data and compute power than instruction fine-tuning, however it is confined to the specific classes on which the model has been trained. 

We start by preparing the dataset.

In [1]:
import urllib.request
import zipfile
import os
from pathlib import Path

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

def download_and_unzip_spam_data(
    url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download "
              f"and extraction.")
        return

    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")

download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)

File downloaded and saved as sms_spam_collection\SMSSpamCollection.tsv


In [2]:
import pandas as pd

df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
df.head()

Unnamed: 0,Label,Text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [3]:
# Class label distribution
print(df["Label"].value_counts())

Label
ham     4825
spam     747
Name: count, dtype: int64


For simplicity and because we want a smaller dataset (to speed up finetuning of the LLM), we choose to undersample the dataset to include 747 instances of each class. 

In [4]:
def create_balanced_dataset(df):
    num_spam = df[df["Label"] == "spam"].shape[0]
    ham_subset = df[df["Label"] == "ham"].sample(
        num_spam, random_state=42
    )
    balanced_df = pd.concat([
        ham_subset, df[df["Label"] == "spam"]
    ])
    return balanced_df

balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

Label
ham     747
spam    747
Name: count, dtype: int64


In [5]:
# Convert labels to integers
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})

In [7]:
# Split the dataset into three parts - 70% train, 10% validation, 20% testing
def random_split(df, train_frac, valid_frac):

    df = df.sample(
        frac=1, random_state=42
    ).reset_index(drop=True)
    train_end = int(len(df) * train_frac)
    valid_end = train_end + int(len(df) * valid_frac)

    train_df = df[:train_end]
    valid_df = df[train_end:valid_end]
    test_df = df[valid_end:]

    return train_df, valid_df, test_df

train_df, valid_df, test_df = random_split(balanced_df, 0.7, 0.1)

In [9]:
# Save the datasets to re-use later
train_df.to_parquet("train.parquet", index=None)
valid_df.to_parquet("valid.parquet", index=None)
test_df.to_parquet("test.parquet", index=None)

#### Creating data loaders

Previously, we utilised a sliding window technique to generate uniformly sized text chunks, which we then grouped into batches for more efficient model training. Each chunk functioned as an individual training instance. However, we are now working with a spam dataset that contains messages of varying lengths. To batch these messages, we have two primary options:
- Truncate all messages to the length of the smallest message in the dataset/batch.
- Pad all messages to the length of the longest message in the dataset/batch.

The first options is computationally cheaper but may result in significant information loss if shorter messages are much shorter than average, potentially reducing model performance. So, we opt for option two, which preserves the entire content of all messages. We will use "<|endoftext|>" as a padding token. Instead of appending this string to each of the text messages directly, we add the token ID corresponding to it to the encoded text messages. 

In [10]:
import tiktoken
tokeniser = tiktoken.get_encoding("gpt2")
print(tokeniser.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))

[50256]


We first need to implement a PyTorch Dataset, which specifies how the data is loaded and processed before we can instantiate the data loaders. This "SpamDataset" class will handle several key tasks: it identifies the longest sequences in the training dataset, encodes the text messages, and ensures that all other sequences are padded with a <i>padding token</i> to match the length of the longest sequence.

In [11]:
import torch
from torch.utils.data import Dataset

class SpamDataset(Dataset):
    def __init__(self, parquet_file, tokeniser, max_length=None,
                 pad_token_id=50_256):
        self.data = pd.read_parquet(parquet_file)

        self.encoded_texts = [
            tokeniser.encode(text) for text in self.data["Text"]
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length

            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        self.encoded_texts = [
            encoded_text + [pad_token_id] * 
            (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )
    
    def __len__(self):
        return len(self.data)
    
    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

In [12]:
train_dataset = SpamDataset(
    parquet_file="Datasets/train.parquet",
    max_length=None,
    tokeniser=tokeniser
)
print(train_dataset.max_length)

109


The longest sequence in the training data is 109 tokens. The model can handle sequences of up to 1,024 tokens, given its context length limit. If the dataset includes longer text, we can pass max_length=1,024 to ensure the data doesn't exceed the model's supported input context length. 

Next, we pad the validation and test sets to match the length of the longest training sequence. Any samples longer will be truncated using <i>encoded_text[:self.max_length]</i>. This truncation is optional; you can set max_length = None as well. 

In [13]:
val_dataset = SpamDataset(
    parquet_file="Datasets/valid.parquet",
    max_length=train_dataset.max_length,
    tokeniser=tokeniser
)

test_dataset = SpamDataset(
    parquet_file="Datasets/test.parquet",
    max_length=train_dataset.max_length,
    tokeniser=tokeniser
)

We can instantiate data loaders similarly to as before. However, in this case the targets represent class labels rather than the next token in the text. If we choose a batch of 8, each batch will consist of eight training examples of length 109 and the corresponding class label of each example. 

In [14]:
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8
torch.manual_seed(42)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False
)

In [15]:
# Ensure data loaders are working
for input_batch, target_batch in train_loader:
    pass
print("Input batch dimensions:", input_batch.shape)
print("Target batch dimensions:", target_batch.shape)

Input batch dimensions: torch.Size([8, 109])
Target batch dimensions: torch.Size([8])


In [None]:
# Get an idea of the dataset size
print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches")