In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath(".."))

In [2]:
from dataclasses import dataclass
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder
from utils import get_loader, EarlyStopper
from typing import Optional
import torch.nn.functional as F
from torch.utils.data import Subset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super(LearnedPositionalEncoding, self).__init__()
        self.positional_encoding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(0, seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
        
        return x + self.positional_encoding(positions)


In [4]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, encoder_layer, num_layers, layerdrop=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.layerdrop = layerdrop

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        for layer in self.layers:
            if not self.training or torch.rand(1).item() > self.layerdrop:
                src = layer(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
                
        return self.norm(src)

In [5]:
class TransformerModel(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_layers, output_size, dropout=0.3, max_len=5000, layerdrop=0.1):
        super(TransformerModel, self).__init__()
        self.input_linear = nn.Linear(input_size, d_model)
        self.positional_encoding = LearnedPositionalEncoding(d_model=d_model, max_len=max_len)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_encoder = TransformerEncoder(d_model, encoder_layer, num_layers=num_layers, layerdrop=layerdrop)
        self.output_linear = nn.Linear(d_model, output_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.unsqueeze(1)
        padding_mask = (x.sum(dim=-1) == 0)
        x = self.input_linear(x)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
        x = x.mean(dim=1)
        x = self.output_linear(x)
        
        return x

In [6]:
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam


In [7]:
def train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs, device, scheduler, stopper_args: Optional[dict]=None, mixup_alpha=0.2):
    if stopper_args:
        stopper = EarlyStopper(**stopper_args)
    num_batches = len(train_loader)
    num_items = len(train_loader.dataset)

    for epoch in range(num_epochs):
        correct_predictions_train = 0
        total_loss_train = 0.0
        model.train()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            X_batch, y_a, y_b, lam = mixup_data(X_batch, y_batch, alpha=mixup_alpha)
            
            optimizer.zero_grad()
            y_pred = model(X_batch)
            
            loss = lam * criterion(y_pred, y_a) + (1 - lam) * criterion(y_pred, y_b)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            _, predicted = torch.max(y_pred, 1)
            correct_predictions_train += (predicted == y_batch).sum().item()
            total_loss_train += loss.item()


        train_loss = total_loss_train / num_batches
        train_accuracy = correct_predictions_train / num_items
        valid_loss, valid_accuracy, miss_indices, _ = test(model, valid_loader, criterion, device, verbose=0)
        
        if miss_indices: 
            neg_loader = negative_loader(valid_loader.dataset, miss_indices, batch_size=32)
            for x_batch, t_batch in neg_loader: 
                x_batch, t_batch = x_batch.to(device), t_batch.to(device)
                
                optimizer.zero_grad()
                y_pred = model(x_batch)
                loss = criterion(y_pred, t_batch)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
        scheduler.step()
        
        if not (epoch + 1) % 10: 
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}")
        if stopper and stopper.early_stop(valid_loss): 
            print("Early stopping triggered. ")
            break

def test(model, test_loader, criterion, device, verbose):
    model.eval()
    num_batches = len(test_loader)
    num_items = len(test_loader.dataset)
    total_loss = 0.0
    total_correct = 0
    miss_indices = []
    
    all_preds = []
    all_targets = []
    all_indices = []
    
    with torch.no_grad():
        for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            total_loss += loss.item()
                
            _, predicted = torch.max(y_pred, 1)
            total_correct += (predicted == y_batch).sum().item()
            
            indices = torch.arange(batch_idx * test_loader.batch_size, batch_idx * test_loader.batch_size + y_batch.size(0))
            misclassified = indices[predicted.cpu() != y_batch.cpu()]
            miss_indices.extend(misclassified.tolist())
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(y_batch.cpu().tolist())
            all_indices.extend(indices.tolist())
                
    test_loss = total_loss / num_batches
    test_accuracy = total_correct / num_items
    if verbose: 
        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
        
    return test_loss, test_accuracy, miss_indices, all_preds

def negative_loader(dataset, miss_indices, batch_size):
    negative_data = Subset(dataset, miss_indices)
    negative_loader = DataLoader(negative_data, batch_size=batch_size, shuffle=True)
    
    return negative_loader


In [8]:
class FocalLoss(nn.Module):
    """
    Focal Loss with Label Smoothing.
    Args:
        alpha: Weighting factor for classes.
        gamma: Focusing parameter for Focal Loss.
        smoothing: Label smoothing factor.
    """
    def __init__(self, alpha=1.0, gamma=2.0, smoothing=0.1, weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.smoothing = smoothing
        self.weight = weight

    def forward(self, inputs, targets):
        # Apply label smoothing
        num_classes = inputs.size(1)
        smoothed_labels = torch.zeros_like(inputs).scatter_(1, targets.unsqueeze(1), 1)  # One-hot encoding
        smoothed_labels = (1 - self.smoothing) * smoothed_labels + self.smoothing / num_classes

        # Compute cross entropy
        log_probs = F.log_softmax(inputs, dim=1)
        ce_loss = -(smoothed_labels * log_probs).sum(dim=1)  # Smoothed cross-entropy loss

        # Compute Focal Loss
        pt = torch.exp(-ce_loss)  # Probability of the true class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        # Return weighted loss
        if self.weight is not None:
            focal_loss = focal_loss * self.weight[targets]

        return focal_loss.mean()

In [30]:
train_feature = "../features/feature_aug_train.npy"
valid_feature = "../features/feature_aug_validation.npy"
test_feature = "../features/feature_aug_test.npy"
train_label = "../features/label_train.csv"
valid_label = "../features/label_validation.csv"
test_label = "../features/label_test.csv"

batch_size = 256
# valid_size = 0.2

train_loader, valid_loader, test_loader, encoder = get_loader(train_feature, train_label, valid_feature, valid_label, test_feature, test_label, batch_size)

input_size = train_loader.dataset[0][0].shape[0]
d_model = 512
nhead = 8
num_layers = 4
output_size = 4
dropout = 0.2
max_len = 512
layerdrop = 0.1

model = TransformerModel(input_size, d_model, nhead, num_layers, output_size, dropout, max_len, layerdrop).to(device)

epochs = 500

def class_weights(label_path):
    labels = pd.read_csv(label_path)["Stance"]
    _, counts = np.unique(labels, return_counts = True)
    return labels.shape[0] / counts

weights = torch.tensor(class_weights(train_label), dtype = torch.float32).to(device)
# class_weights = torch.tensor([13, 59, 10, 2]).to(device)
criterion = FocalLoss(alpha=1.0, gamma=2.0, smoothing=0.2, weight=weights)
# criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.2)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

train_model(model, train_loader, valid_loader, optimizer, criterion, num_epochs=500, device=device, scheduler=scheduler, stopper_args={'threshold': 20, 'epsilon': 1e-4}, mixup_alpha=0.2)

# test_loss, test_accuracy, pred, _ = test(model, test_loader, criterion, device, verbose=1)
# pred_labels = encoder.inverse_transform(pred)

Epoch 10/500, Train Loss: 2.4675, Train Accuracy: 0.4455, Val Loss: 2.0202, Valid Accuracy: 0.6521
Epoch 20/500, Train Loss: 1.9580, Train Accuracy: 0.6039, Val Loss: 1.4953, Valid Accuracy: 0.7031
Epoch 30/500, Train Loss: 1.8695, Train Accuracy: 0.6643, Val Loss: 1.2537, Valid Accuracy: 0.8188
Epoch 40/500, Train Loss: 1.5716, Train Accuracy: 0.6626, Val Loss: 1.0378, Valid Accuracy: 0.8802
Epoch 50/500, Train Loss: 1.4222, Train Accuracy: 0.7051, Val Loss: 0.8903, Valid Accuracy: 0.9112
Epoch 60/500, Train Loss: 1.3285, Train Accuracy: 0.7233, Val Loss: 0.8663, Valid Accuracy: 0.9095
Epoch 70/500, Train Loss: 1.3084, Train Accuracy: 0.7091, Val Loss: 0.8302, Valid Accuracy: 0.9220
Epoch 80/500, Train Loss: 1.2591, Train Accuracy: 0.7073, Val Loss: 0.7756, Valid Accuracy: 0.9394
Epoch 90/500, Train Loss: 1.2207, Train Accuracy: 0.7069, Val Loss: 0.7315, Valid Accuracy: 0.9420
Epoch 100/500, Train Loss: 1.2276, Train Accuracy: 0.7468, Val Loss: 0.6652, Valid Accuracy: 0.9707
Epoch 110

In [31]:
test_loss, test_accuracy, _, pred = test(model, test_loader, criterion, device, verbose=1)
pred_labels = encoder.inverse_transform(pred)
pd.DataFrame(pred_labels, columns=['Stance']).to_csv('../output/preds2_trans.csv', index=False)

Test Loss: 4.8314, Test Accuracy: 0.7870
