<a href="https://colab.research.google.com/github/Leotzu/transformer-arxiv-classification/blob/main/finetuned_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Note:** This notebook assumes that you have already trained a base model by running arxiv_transformer.ipynb

**Step 1)** Setup environment and give your notebook access to google drive (which is where your saved model from before should be)

In [None]:
!pip install torch pandas numpy tqdm

In [None]:
# mount drive to access json file and save/load models and vocab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import json
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt


**Step 2)** Define your model and functions for data loading and preprocessing

- In Config, make sure to change the *project_dir*, *data_path*, *vocab_path*, and *pretrained_model_file_path* to where you have this project in your drive, where you loaded the *finetune_train.jsonl* data, where you saved the vocab file when training the base model, and where the base model you're wishing to finetune is.

- In Config, change *prefix* to differentiate this training run from any others you do (it will be added to the beginning of every model and checkpoint saved during preprocessing and training)

- Also be sure to have the same model hyperparameters in Config as you had for the original model you're finetuning.

In [None]:
class Config:
    # project directories
    project_dir = '/content/drive/MyDrive/your_project_path'
    data_path = project_dir + '/data/finetune_train.jsonl'
    models_path = project_dir + '/models'
    vocab_file_path = project_dir + '/vocab/50k_vocab.pkl'
    pretrained_model_file_path = models_path + '50k_model_epoch_30.pth'
    prefix = 'finetune'

    # hyperparameters (must be same as original model)
    d_model = 256
    nhead = 8
    num_encoder_layers = 3
    num_decoder_layers = 3
    dim_feedforward = 1024
    max_seq_length = 256
    dropout_rate = 0.3

    # hyperparameters for finetuning (these you can experiment with)
    num_epochs = 10
    learning_rate = 0.001
    batch_size = 32

**Note:** Be sure to use exact Vocabulary class from arXiv_transformer.ipynb

In [None]:
def load_vocab(vocab_path):
    with open(Config.vocab_file_path, 'rb') as f:
        vocab = pickle.load(f)
    return vocab

# Use EXACT Vocabulary class from original model notebook:
class Vocabulary:
    def __init__(self):
        self.stoi = {"<pad>": 0, "<unk>": 1, "<eos>": 2}
        self.itos = {0: "<pad>", 1: "<unk>", 2: "<eos>"}

    def build_vocab(self, texts, min_freq=2):
        counter = {}
        for text in texts:
            for word in text.split():
                if word not in counter:
                    counter[word] = 0
                counter[word] += 1
        idx = len(self.stoi)
        for word, count in counter.items():
            if count >= min_freq:
                self.stoi[word] = idx
                self.itos[idx] = word
                idx += 1

class ArxivDataset(Dataset):
    def __init__(self, data, vocab):
        self.vocab = vocab
        # note: in this jsonl, the abstracts are called 'text' and labels are called 'label'
        self.data = [self.vectorize(text['text']) for text in data]
        # Convert boolean strings to integers; True becomes 1, False becomes 0
        self.labels = torch.tensor([1 if text['label'] == 'True' else 0 for text in data])

    def vectorize(self, text):
        tokens = [self.vocab.stoi.get(word, self.vocab.stoi['<unk>']) for word in text.split()]
        tokens.append(self.vocab.stoi['<eos>'])
        if len(tokens) > Config.max_seq_length:
            tokens = tokens[:Config.max_seq_length]
        else:
            tokens += [self.vocab.stoi['<pad>']] * (Config.max_seq_length - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def __len__(self):
        return len(self.data)

def load_data(file_path, split_ratio=0.8):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = [json.loads(line) for line in file]
    np.random.shuffle(data)
    split_idx = int(len(data) * split_ratio)
    return data[:split_idx], data[split_idx:]

def get_data():
    train_data, test_data = load_data(Config.data_path)
    vocab = load_vocab(Config.vocab_file_path)  # Load saved vocab from original model training
    train_dataset = ArxivDataset(train_data, vocab)
    test_dataset = ArxivDataset(test_data, vocab)
    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False)
    return train_loader, test_loader, vocab

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, Config.d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=Config.d_model, nhead=Config.nhead,
                dim_feedforward=Config.dim_feedforward, dropout=Config.dropout_rate
            ), num_layers=Config.num_encoder_layers
        )
        self.fc = nn.Linear(Config.d_model, 2)

    def forward(self, src):
        src = self.embedding(src) * np.sqrt(Config.d_model)
        src = src.permute(1, 0, 2)  # Transformer expects [seq_len, batch_size, d_model]
        output = self.transformer(src)
        output = output.mean(dim=0)
        return self.fc(output)

**Step 3)** Train the model

- This function will save your finetuned models to model_path after each epoch.

In [None]:
def train():
    train_loader, test_loader, vocab = get_data()
    model = TransformerClassifier(len(vocab.stoi)).to('cuda')
    optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)
    criterion = nn.CrossEntropyLoss()

    print('finetuning started...')
    for epoch in range(Config.num_epochs):
        model.train()
        total_loss = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1} Training Loss: {avg_train_loss:.4f}')

        model.eval()
        total_loss = 0
        correct_predictions = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                predictions = outputs.argmax(dim=1)
                correct_predictions += (predictions == labels).sum().item()

        avg_test_loss = total_loss / len(test_loader)
        accuracy = correct_predictions / len(test_loader.dataset)
        print(f'Epoch {epoch + 1} Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy * 100:.2f}%')

        # Optional: Save the model
        torch.save(model.state_dict(), f'{Config.models_path}/{Config.prefix}_model_epoch_{epoch + 1}.pth')

train()

**Step 4)** Evaluate your finetuned model and perform inference in predict_ai_relevance() to see determine whether a custom text prompt is AI-relevant or not.

In [None]:
def evaluate():
    # load the vocab and test data
    train_loader, test_loader, vocab = get_data()

    # load the model and evaluate on test data
    model = TransformerClassifier(len(vocab.stoi)).to('cuda')
    model.load_state_dict(torch.load(f'{Config.models_path}/{Config.prefix}_model_epoch_{Config.num_epochs}.pth'))
    model.eval()

    # use cross entropy loss criterion
    criterion = nn.CrossEntropyLoss()

    total_loss = 0
    correct_predictions = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            correct_predictions += (predictions == labels).sum().item()

    avg_test_loss = total_loss / len(test_loader)
    accuracy = correct_predictions / len(test_loader.dataset)
    print(f'Final Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy * 100:.2f}%')

evaluate()

In [None]:
def predict_ai_relevance(prompt):
    # Load the vocabulary
    vocab_path = Config.vocab_file_path
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # load the model
    model = TransformerClassifier(len(vocab.stoi)).to('cuda')
    model_path = f'{Config.models_path}/{Config.prefix}_model_epoch_{Config.num_epochs}.pth'
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # process the prompt
    tokens = [vocab.stoi.get(word, vocab.stoi['<unk>']) for word in prompt.split()]
    tokens.append(vocab.stoi['<eos>'])
    if len(tokens) > Config.max_seq_length:
        tokens = tokens[:Config.max_seq_length]
    else:
        tokens += [vocab.stoi['<pad>']] * (Config.max_seq_length - len(tokens))
    input_tensor = torch.tensor([tokens], dtype=torch.long).to('cuda')

    # make prediction
    with torch.no_grad():
        outputs = model(input_tensor)
        predictions = torch.softmax(outputs, dim=1)
        predicted_class = predictions.argmax(dim=1).item()

    # interpret prediction
    ai_related = "Yes" if predicted_class == 1 else "No"
    confidence = predictions[0, predicted_class].item()
    return f"AI-related: {ai_related}\nPrediction Confidence: {confidence:.4f}"


prompt = "In this study we study the various morphology of mitochondria in axons, using shape descriptors as a" # Change this to any prompt of your choosing
result = predict_ai_relevance(prompt)
print(result)