In [2]:
import sys
import os
sys.path.insert(0, os.path.abspath(".."))

In [3]:
from dataclasses import dataclass
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder
from utils import get_loader, EarlyStopper
from typing import Optional
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
class TransformerModel(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_layers, output_size, dropout=0.3):
        super(TransformerModel, self).__init__()
        self.input_linear = nn.Linear(input_size, d_model)
        self.positional_encoding = nn.Parameter(torch.randn(1, d_model))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.output_linear = nn.Linear(d_model, output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.input_linear(x) + self.positional_encoding
        x = self.norm(x)
        x = self.dropout(x)
        x = self.transformer_encoder(x)
        x = x.flatten(start_dim=1)
        x = self.output_linear(x)
        return x

In [10]:
def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, device, scheduler, stopper_args: Optional[dict]=None):
    if stopper_args:
        stopper = EarlyStopper(**stopper_args)
    num_batches = len(train_loader)
    num_items = len(train_loader.dataset)

    for epoch in range(num_epochs):
        correct_predictions_train = 0
        total_loss_train = 0.0
        model.train()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            _, predicted = torch.max(y_pred, 1)
            correct_predictions_train += (predicted == y_batch).sum().item()
            total_loss_train += loss.item()


        train_loss = total_loss_train / num_batches
        train_accuracy = correct_predictions_train / num_items
        valid_loss, valid_accuracy, _ = test(model, valid_loader, criterion, device, verbose=0)
        scheduler.step()
        
        if not (epoch + 1) % 10: 
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}")
        if stopper and stopper.early_stop(valid_loss): 
            print("Early stopping triggered. ")
            break

def test(model, test_loader, criterion, device, verbose):
    model.eval()
    num_batches = len(test_loader)
    num_items = len(test_loader.dataset)
    total_loss = 0.0
    total_correct = 0
    
    all_preds = []
    
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            total_loss += loss.item()
                
            _, predicted = torch.max(y_pred, 1)
            total_correct += (predicted == y_batch).sum().item()
                
            all_preds.extend(predicted.cpu().numpy())
                
    test_loss = total_loss / num_batches
    test_accuracy = total_correct / num_items
    if verbose: 
        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
    return test_loss, test_accuracy, all_preds

In [13]:
train_feature = "../features/feature_aug_train.npy"
valid_feature = "../features/feature_aug_validation.npy"
test_feature = "../features/feature_aug_test.npy"
train_label = "../features/label_train.csv"
valid_label = "../features/label_validation.csv"
test_label = "../features/label_test.csv"

batch_size = 256
# valid_size = 0.2

train_loader, valid_loader, test_loader, encoder = get_loader(train_feature, train_label, valid_feature, valid_label, test_feature, test_label, batch_size)

input_size = train_loader.dataset[0][0].shape[0]
d_model = 768
nhead = 16
num_layers = 4
output_size = 4
dropout = 0.1

model = TransformerModel(input_size, d_model, nhead, num_layers, output_size, dropout).to(device)

epochs = 500

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0001, steps_per_epoch=len(train_loader), epochs=epochs)

train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs=500, device=device, scheduler=scheduler, stopper_args={'threshold': 20, 'epsilon': 1e-4})

test_loss, test_accuracy, pred = test(model, test_loader, criterion, device, verbose=1)
pred_labels = encoder.inverse_transform(pred)

Epoch 10/500, Train Loss: 0.5271, Train Accuracy: 0.8051, Val Loss: 0.4801, Valid Accuracy: 0.8270
Epoch 20/500, Train Loss: 0.3657, Train Accuracy: 0.8642, Val Loss: 0.3632, Valid Accuracy: 0.8688
Epoch 30/500, Train Loss: 0.2748, Train Accuracy: 0.8987, Val Loss: 0.2921, Valid Accuracy: 0.8971
Epoch 40/500, Train Loss: 0.2090, Train Accuracy: 0.9234, Val Loss: 0.2457, Valid Accuracy: 0.9114
Epoch 50/500, Train Loss: 0.1485, Train Accuracy: 0.9449, Val Loss: 0.2029, Valid Accuracy: 0.9280
Epoch 60/500, Train Loss: 0.1080, Train Accuracy: 0.9599, Val Loss: 0.1604, Valid Accuracy: 0.9460
Epoch 70/500, Train Loss: 0.0779, Train Accuracy: 0.9708, Val Loss: 0.1369, Valid Accuracy: 0.9537
Epoch 80/500, Train Loss: 0.0587, Train Accuracy: 0.9785, Val Loss: 0.1255, Valid Accuracy: 0.9581
Epoch 90/500, Train Loss: 0.0458, Train Accuracy: 0.9833, Val Loss: 0.1239, Valid Accuracy: 0.9631
Epoch 100/500, Train Loss: 0.0341, Train Accuracy: 0.9875, Val Loss: 0.1213, Valid Accuracy: 0.9667
Epoch 110

In [14]:
pd.DataFrame(pred_labels, columns=['Stance']).to_csv('../output/preds_trans.csv', index=False)