In [1]:
import pickle
import spacy
from typing import List, Tuple
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:

# Question answer dataset
class QWPDataset(Dataset):
    def __init__(self, qa_pairs, tokenizer, max_length=128):
        self.qa_pairs = qa_pairs
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.qa_pairs)

    def __getitem__(self, idx):
        masked_question, answer, question_word = self.qa_pairs[idx]

        input_text = f"{masked_question} [SEP] {answer}" #SEP token to separate q and a

        # input_ids is embedding, attention_mask is which ones is not pad
        encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        label = ALL_QW[question_word]

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": torch.tensor(label, dtype=torch.long)
        }


In [None]:

# Model
class QWPModel(nn.Module):
    def __init__(self, num_classes: int, tokenizer):
        super(QWPModel, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.bert.resize_token_embeddings(len(tokenizer))  #resize for <qw>
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.pooler_output  #[CLS] token
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits


In [None]:

# Training function with early stopping
def train_model(model, train_loader, val_loader, device,  epochs=20, lr=5e-5):
    """
    Trains model using the AdamW optimizer and cross-entropy loss.
    It evaluates performance on a validation set after each epoch and saves the model with the lowest validation loss.

    Args:
        model: The PyTorch model to train (e.g., a BERT-based classifier).
        train_loader: DataLoader for the training dataset,
            batches tuples with keys 'input_ids', 'attention_mask', and 'label'.
        val_loader : DataLoader for the validation dataset, with the same
            structure as train_loader.
        device : The device (e.g., 'cuda' or 'cpu') to perform computations on.
        epochs : Number of training epochs.
        lr : Initial learning rate.

    Returns:
        torch.nn.Module: The trained model
    
    """

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.to(device)

    best_val_loss = float("inf")

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)

                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
                val_loss += loss.item()

                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = correct / total

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # keep best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_qwp_model.pth")
            print(f"Saved best model with Val Loss: {best_val_loss:.4f}")
    return model


In [None]:
def compute_metrics(pred, labels):
    """
    returns accuracy and classification report for given prediction and true label

    Args:
        pred: Predicted labels
        labels: True labels

    Returns:
        accuracy and classification report
    
    """
    accuracy = accuracy_score(labels, pred)
    report = classification_report(labels, pred, output_dict=True)

    return accuracy, report

def plot_confusion_matrix(y_true, y_pred, labels, output_path):
    """
    plots confusion matrix for true and predicted labels

    Args:
        y_true: True labels
        y_pred: Predicted labels
        labels: name of labels
        output_path: path to save img

    Returns:
        None
    
    
    """
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.show()

In [None]:
def test_model(model, test_loader, device):
    """
    Test the model on the test set.

    Args:
        model: The PyTorch model to train (e.g., a BERT-based classifier).
        test_loader: DataLoader for the test dataset,
            batches tuples with keys 'input_ids', 'attention_mask', and 'label'.
        device : The device (e.g., 'cuda' or 'cpu') to perform computations on.

    Returns:
        None
    
    """

    # Load best model
    model.load_state_dict(torch.load("/content/best_qwp_model.pth"))
    criterion = nn.CrossEntropyLoss()

    # Test evaluation
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    pred_array = []
    label_array = []
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            logits = model(input_ids, attention_mask)

            preds = torch.argmax(logits, dim=1)
            pred_array.append(preds.cpu())
            label_array.append(labels.cpu())
    test_accuracy, test_report= compute_metrics(pred_array, label_array)
    print("\nTest Set Results:")
    print(f"Accuracy: {test_accuracy:.4f}")
    print("\nClassification Report:")
    for emotion, metrics in test_report.items():
        if isinstance(metrics, dict):
            print(f"{emotion}:")
            print(f"  Precision: {metrics['precision']:.4f}")
            print(f"  Recall: {metrics['recall']:.4f}")
            print(f"  F1-score: {metrics['f1-score']:.4f}")
    plot_confusion_matrix(label_array, pred_array, ALL_QW, 'confusion_matrix.png')




In [None]:

# Main
if __name__ == "__main__":
    # setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tokenizer.add_tokens(["<qw>"])

    # file
    with open("data/train_processed.pkl", "rb") as f:
        qa_pairs = pickle.load(f)
    qw_to_idx = dict()
    for idx, qw in enumerate(set([qa[2] for qa in qa_pairs])):
        qw_to_idx[qw] = idx
    print(qw_to_idx)
    

    #Split train val test
    qa_trainval, qa_test = train_test_split(
        qa_pairs, test_size=0.15, random_state=42
    )
    qa_train, qa_val = train_test_split(
        qa_trainval, test_size=0.1765, random_state=42  # 0.15/(1-0.15) to get 15% of original
    )

    print(f"Train size: {len(qa_train)}, Val size: {len(qa_val)}, Test size: {len(qa_test)}")

    # load data
    train_dataset = QWPDataset(qa_train, tokenizer)
    val_dataset = QWPDataset(qa_val, tokenizer)
    test_dataset = QWPDataset(qa_test, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    test_loader = DataLoader(test_dataset, batch_size=1)

    
    # Train
    model = QWPModel(num_classes=len(qw_to_idx), tokenizer=tokenizer)

    model = train_model(model, train_loader, val_loader, device, epochs=10)



In [None]:
    # TEST
    test_model(model, test_loader, device)