In [3]:
from load_data import df
from sklearn.model_selection import train_test_split
from mlmodel import SBertModel, device, sbert
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
# Best model so far, SBertModel2 lr=0.001, val_acc for scheduler
# SbertModel3 same parameters but 1000 units also good

In [4]:
# Convert the text and label columns to list
X = df['text'].tolist()
y = df['labels'].tolist()

num_of_labels = len(y[0]) # = 14

# Train test split
X_train_text, X_temp_text, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val_text, X_test_text, y_val, y_test = train_test_split(X_temp_text, y_temp, test_size=0.5, random_state=42)

In [None]:
# Encode the sets with the pre-trained sentence transformer
x_train = sbert.encode(X_train_text, device=device)
x_val = sbert.encode(X_val_text, device=device)
x_test = sbert.encode(X_test_text, device=device)

input_dimension = x_train.shape[1]

In [None]:
# Convert the encoded data into tensors
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
x_val_tensor = torch.tensor(x_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).to(device)

# Create DataLoader for train and dev sets
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

In [15]:
# Initialize the model, loss function, optimizer and scheduler
model = SBertModel(input_dimension, num_of_labels).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.1)

In [None]:
# Perform the training process
best_val_acc = 0
num_epochs = 200
early_stopping = 15

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')
    
    for inputs, labels in train_loader_tqdm:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        total_train += labels.numel()
        correct_train += (predicted == labels).sum().item()
        train_loader_tqdm.set_postfix(training_loss=running_loss / len(train_loader))
    
    # Calculate loss and accuracy
    train_loss = running_loss / len(train_loader)
    train_acc = correct_train / total_train
    
    val_loss = 0
    correct = 0
    total = 0

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss}, Training Accuracy: {train_acc:.4f}')
    
    model.eval()
    val_loader_tqdm = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')
    
    with torch.no_grad():
        for inputs, labels in val_loader_tqdm:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.numel()
            val_loader_tqdm.set_postfix(accuracy=correct / total)
    
    # Calculate validation loss and validation accuracy and update the scheduler
    val_loss /= len(val_loader)
    val_acc = correct / total
    scheduler.step(val_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss}, Val Accuracy: {val_acc:.4f}')
    print("Current learning rate: ", scheduler.get_last_lr())

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = model.state_dict().copy()
        torch.save(best_model_wts, 'best_model.pth')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    
    if early_stopping_counter >= early_stopping:
        print('Early stopping')
        model.load_state_dict(best_model_wts)
        break        

In [11]:
# Convert the test data to tensors
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

# Create DataLoader for test dataset
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [22]:
from sklearn.metrics import classification_report, f1_score
import numpy as np

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Evaluating', unit='batch'):
        outputs = model(inputs)
        predicted = torch.sigmoid(outputs) > 0.5
        predicted = predicted.float()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

print(classification_report(all_labels, all_preds, target_names=[str(i) for i in range(num_of_labels)], zero_division=0))

f1 = f1_score(all_labels, all_preds, average='micro')
print(f'Test F1 Score (micro average): {f1:.4f}')

Evaluating: 100%|██████████| 313/313 [00:00<00:00, 1021.99batch/s]

Test Accuracy: 0.1388
              precision    recall  f1-score   support

           0       0.77      0.75      0.76      2333
           1       0.96      0.99      0.98      4670
           2       0.84      0.86      0.85      2715
           3       0.90      0.92      0.91      3107
           4       0.81      0.94      0.87      3917
           5       0.79      0.70      0.74       859
           6       0.82      0.89      0.85      3312
           7       0.32      0.03      0.05       572
           8       0.67      0.55      0.60       564
           9       0.65      0.43      0.51       554
          10       0.62      0.41      0.50       738
          11       0.84      0.91      0.87      2112
          12       0.80      0.76      0.78      2251
          13       0.72      0.73      0.72       797

   micro avg       0.84      0.84      0.84     28501
   macro avg       0.75      0.70      0.71     28501
weighted avg       0.82      0.84      0.83     28501
 sam


