In [None]:
from load_data import df
from sklearn.model_selection import train_test_split
from mlmodel import SBertModel, device, sbert
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
# Best model so far, SBertModel2 lr=0.001, val_acc for scheduler
# SbertModel3 same parameters but 1000 units also good

In [2]:
# Convert the text and label columns to list
X = df['text'].tolist()
y = df['labels'].tolist()

num_of_labels = len(y[0]) # = 14

# Train test split
X_train_text, X_temp_text, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val_text, X_test_text, y_val, y_test = train_test_split(X_temp_text, y_temp, test_size=0.5, random_state=42)

In [4]:
# Encode the sets with the pre-trained sentence transformer
x_train = sbert.encode(X_train_text, device=device)
x_val = sbert.encode(X_val_text, device=device)
x_test = sbert.encode(X_test_text, device=device)

input_dimension = x_train.shape[1]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [5]:
# Convert the encoded data into tensors
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
x_val_tensor = torch.tensor(x_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).to(device)

# Create DataLoader for train and dev sets
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

  y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)


In [8]:
# Initialize the model, loss function, optimizer and scheduler
model = SBertModel(input_dimension, num_of_labels).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.1)

In [15]:
# Start the training
best_val_acc = 0
num_epochs = 200
early_stopping = 15

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')
    
    for inputs, labels in train_loader_tqdm:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        total_train += labels.numel()
        correct_train += (predicted == labels).sum().item()
        train_loader_tqdm.set_postfix(training_loss=running_loss / len(train_loader))
    
    # Calculate loss and accuracy
    train_loss = running_loss / len(train_loader)
    train_acc = correct_train / total_train
    
    val_loss = 0
    correct = 0
    total = 0

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss}, Training Accuracy: {train_acc:.4f}')
    
    model.eval()
    val_loader_tqdm = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')
    
    with torch.no_grad():
        for inputs, labels in val_loader_tqdm:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.numel()
            val_loader_tqdm.set_postfix(accuracy=correct / total)
    
    # Calculate validation loss and validation accuracy and update the scheduler
    val_loss /= len(val_loader)
    val_acc = correct / total
    scheduler.step(val_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss}, Val Accuracy: {val_acc:.4f}')
    print("Current learning rate: ", scheduler.get_last_lr())

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = model.state_dict().copy()
        torch.save(best_model_wts, 'best_model_latest.pth')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    
    if early_stopping_counter >= early_stopping:
        print('Early stopping')
        model.load_state_dict(best_model_wts)
        break        

Epoch 1/200: 100%|██████████| 625/625 [00:01<00:00, 398.09batch/s, training_loss=0.55] 


Epoch 1/200, Training Loss: 0.550006534576416, Training Accuracy: 0.8690


Epoch 1/200: 100%|██████████| 157/157 [00:00<00:00, 747.62batch/s, accuracy=0.864]


Epoch 1/200, Val Loss: 0.5531505463988917, Val Accuracy: 0.8645
Current learning rate:  [0.001]


Epoch 2/200: 100%|██████████| 625/625 [00:01<00:00, 484.73batch/s, training_loss=0.55] 


Epoch 2/200, Training Loss: 0.5497682942390442, Training Accuracy: 0.8694


Epoch 2/200: 100%|██████████| 157/157 [00:00<00:00, 659.66batch/s, accuracy=0.865]


Epoch 2/200, Val Loss: 0.553176891272235, Val Accuracy: 0.8647
Current learning rate:  [0.001]


Epoch 3/200: 100%|██████████| 625/625 [00:01<00:00, 450.06batch/s, training_loss=0.55] 


Epoch 3/200, Training Loss: 0.5495505290985108, Training Accuracy: 0.8695


Epoch 3/200: 100%|██████████| 157/157 [00:00<00:00, 673.82batch/s, accuracy=0.865]


Epoch 3/200, Val Loss: 0.5527119520743182, Val Accuracy: 0.8652
Current learning rate:  [0.001]


Epoch 4/200: 100%|██████████| 625/625 [00:01<00:00, 449.42batch/s, training_loss=0.549]


Epoch 4/200, Training Loss: 0.5493534889221191, Training Accuracy: 0.8698


Epoch 4/200: 100%|██████████| 157/157 [00:00<00:00, 673.83batch/s, accuracy=0.866]


Epoch 4/200, Val Loss: 0.5528647736379295, Val Accuracy: 0.8655
Current learning rate:  [0.001]


Epoch 5/200: 100%|██████████| 625/625 [00:01<00:00, 441.70batch/s, training_loss=0.549]


Epoch 5/200, Training Loss: 0.5491789751052857, Training Accuracy: 0.8701


Epoch 5/200: 100%|██████████| 157/157 [00:00<00:00, 603.81batch/s, accuracy=0.866]


Epoch 5/200, Val Loss: 0.5526774164977347, Val Accuracy: 0.8655
Current learning rate:  [0.001]


Epoch 6/200: 100%|██████████| 625/625 [00:01<00:00, 459.22batch/s, training_loss=0.549]


Epoch 6/200, Training Loss: 0.5490019469261169, Training Accuracy: 0.8703


Epoch 6/200: 100%|██████████| 157/157 [00:00<00:00, 691.63batch/s, accuracy=0.866]


Epoch 6/200, Val Loss: 0.5523427280650777, Val Accuracy: 0.8656
Current learning rate:  [0.001]


Epoch 7/200: 100%|██████████| 625/625 [00:01<00:00, 453.53batch/s, training_loss=0.549]


Epoch 7/200, Training Loss: 0.5488450147628784, Training Accuracy: 0.8707


Epoch 7/200: 100%|██████████| 157/157 [00:00<00:00, 665.25batch/s, accuracy=0.866]


Epoch 7/200, Val Loss: 0.5524684661512922, Val Accuracy: 0.8657
Current learning rate:  [0.001]


Epoch 8/200: 100%|██████████| 625/625 [00:01<00:00, 443.79batch/s, training_loss=0.549]


Epoch 8/200, Training Loss: 0.5486966629981994, Training Accuracy: 0.8704


Epoch 8/200: 100%|██████████| 157/157 [00:00<00:00, 615.68batch/s, accuracy=0.866]


Epoch 8/200, Val Loss: 0.5530143702865407, Val Accuracy: 0.8663
Current learning rate:  [0.001]


Epoch 9/200: 100%|██████████| 625/625 [00:01<00:00, 447.33batch/s, training_loss=0.549]


Epoch 9/200, Training Loss: 0.5485496481895447, Training Accuracy: 0.8708


Epoch 9/200: 100%|██████████| 157/157 [00:00<00:00, 682.61batch/s, accuracy=0.866]


Epoch 9/200, Val Loss: 0.5524069738995497, Val Accuracy: 0.8662
Current learning rate:  [0.001]


Epoch 10/200: 100%|██████████| 625/625 [00:01<00:00, 444.84batch/s, training_loss=0.548]


Epoch 10/200, Training Loss: 0.5484203019142151, Training Accuracy: 0.8709


Epoch 10/200: 100%|██████████| 157/157 [00:00<00:00, 685.60batch/s, accuracy=0.867]


Epoch 10/200, Val Loss: 0.55282061115192, Val Accuracy: 0.8666
Current learning rate:  [0.001]


Epoch 11/200: 100%|██████████| 625/625 [00:01<00:00, 453.77batch/s, training_loss=0.548]


Epoch 11/200, Training Loss: 0.5482916872024536, Training Accuracy: 0.8711


Epoch 11/200: 100%|██████████| 157/157 [00:00<00:00, 682.61batch/s, accuracy=0.866]


Epoch 11/200, Val Loss: 0.5523122757863087, Val Accuracy: 0.8664
Current learning rate:  [0.001]


Epoch 12/200: 100%|██████████| 625/625 [00:01<00:00, 431.16batch/s, training_loss=0.548]


Epoch 12/200, Training Loss: 0.5481707600593567, Training Accuracy: 0.8711


Epoch 12/200: 100%|██████████| 157/157 [00:00<00:00, 656.91batch/s, accuracy=0.867]


Epoch 12/200, Val Loss: 0.5519221309263995, Val Accuracy: 0.8666
Current learning rate:  [0.001]


Epoch 13/200: 100%|██████████| 625/625 [00:01<00:00, 446.16batch/s, training_loss=0.548]


Epoch 13/200, Training Loss: 0.5480640691757203, Training Accuracy: 0.8713


Epoch 13/200: 100%|██████████| 157/157 [00:00<00:00, 633.06batch/s, accuracy=0.867]


Epoch 13/200, Val Loss: 0.5521229664990857, Val Accuracy: 0.8666
Current learning rate:  [0.001]


Epoch 14/200: 100%|██████████| 625/625 [00:01<00:00, 450.61batch/s, training_loss=0.548]


Epoch 14/200, Training Loss: 0.5479498562812806, Training Accuracy: 0.8715


Epoch 14/200: 100%|██████████| 157/157 [00:00<00:00, 659.05batch/s, accuracy=0.866]


Epoch 14/200, Val Loss: 0.5518660756053438, Val Accuracy: 0.8664
Current learning rate:  [0.001]


Epoch 15/200: 100%|██████████| 625/625 [00:01<00:00, 473.92batch/s, training_loss=0.548]


Epoch 15/200, Training Loss: 0.5478456285476685, Training Accuracy: 0.8715


Epoch 15/200: 100%|██████████| 157/157 [00:00<00:00, 646.09batch/s, accuracy=0.867]


Epoch 15/200, Val Loss: 0.551986876946346, Val Accuracy: 0.8667
Current learning rate:  [0.001]


Epoch 16/200: 100%|██████████| 625/625 [00:01<00:00, 437.47batch/s, training_loss=0.548]


Epoch 16/200, Training Loss: 0.5477482711791992, Training Accuracy: 0.8716


Epoch 16/200: 100%|██████████| 157/157 [00:00<00:00, 688.60batch/s, accuracy=0.867]


Epoch 16/200, Val Loss: 0.552221625853496, Val Accuracy: 0.8667
Current learning rate:  [0.001]


Epoch 17/200: 100%|██████████| 625/625 [00:01<00:00, 460.46batch/s, training_loss=0.548]


Epoch 17/200, Training Loss: 0.5476601887702942, Training Accuracy: 0.8716


Epoch 17/200: 100%|██████████| 157/157 [00:00<00:00, 682.61batch/s, accuracy=0.867]


Epoch 17/200, Val Loss: 0.5517365211134504, Val Accuracy: 0.8668
Current learning rate:  [0.001]


Epoch 18/200: 100%|██████████| 625/625 [00:01<00:00, 451.59batch/s, training_loss=0.548]


Epoch 18/200, Training Loss: 0.5475606620788575, Training Accuracy: 0.8716


Epoch 18/200: 100%|██████████| 157/157 [00:00<00:00, 665.27batch/s, accuracy=0.867]


Epoch 18/200, Val Loss: 0.5520315944768821, Val Accuracy: 0.8668
Current learning rate:  [0.001]


Epoch 19/200: 100%|██████████| 625/625 [00:01<00:00, 457.21batch/s, training_loss=0.547]


Epoch 19/200, Training Loss: 0.5474781869888306, Training Accuracy: 0.8719


Epoch 19/200: 100%|██████████| 157/157 [00:00<00:00, 682.60batch/s, accuracy=0.867]


Epoch 19/200, Val Loss: 0.5518826307005184, Val Accuracy: 0.8668
Current learning rate:  [0.001]


Epoch 20/200: 100%|██████████| 625/625 [00:01<00:00, 447.29batch/s, training_loss=0.547]


Epoch 20/200, Training Loss: 0.547394228553772, Training Accuracy: 0.8720


Epoch 20/200: 100%|██████████| 157/157 [00:00<00:00, 694.69batch/s, accuracy=0.867]


Epoch 20/200, Val Loss: 0.5522751589869238, Val Accuracy: 0.8667
Current learning rate:  [0.001]


Epoch 21/200: 100%|██████████| 625/625 [00:01<00:00, 454.45batch/s, training_loss=0.547]


Epoch 21/200, Training Loss: 0.5473162422180176, Training Accuracy: 0.8720


Epoch 21/200: 100%|██████████| 157/157 [00:00<00:00, 691.63batch/s, accuracy=0.867]


Epoch 21/200, Val Loss: 0.5516012665952087, Val Accuracy: 0.8670
Current learning rate:  [0.001]


Epoch 22/200: 100%|██████████| 625/625 [00:01<00:00, 462.70batch/s, training_loss=0.547]


Epoch 22/200, Training Loss: 0.5472303025245666, Training Accuracy: 0.8722


Epoch 22/200: 100%|██████████| 157/157 [00:00<00:00, 639.94batch/s, accuracy=0.867]


Epoch 22/200, Val Loss: 0.55241490055801, Val Accuracy: 0.8669
Current learning rate:  [0.001]


Epoch 23/200: 100%|██████████| 625/625 [00:01<00:00, 447.71batch/s, training_loss=0.547]


Epoch 23/200, Training Loss: 0.547164420413971, Training Accuracy: 0.8721


Epoch 23/200: 100%|██████████| 157/157 [00:00<00:00, 592.46batch/s, accuracy=0.867]


Epoch 23/200, Val Loss: 0.5517693123999675, Val Accuracy: 0.8671
Current learning rate:  [0.001]


Epoch 24/200: 100%|██████████| 625/625 [00:01<00:00, 470.18batch/s, training_loss=0.547] 


Epoch 24/200, Training Loss: 0.5470810894966125, Training Accuracy: 0.8722


Epoch 24/200: 100%|██████████| 157/157 [00:00<00:00, 805.14batch/s, accuracy=0.867]


Epoch 24/200, Val Loss: 0.5519174060244469, Val Accuracy: 0.8675
Current learning rate:  [0.001]


Epoch 25/200: 100%|██████████| 625/625 [00:01<00:00, 475.52batch/s, training_loss=0.547]


Epoch 25/200, Training Loss: 0.5470233777046204, Training Accuracy: 0.8723


Epoch 25/200: 100%|██████████| 157/157 [00:00<00:00, 788.94batch/s, accuracy=0.867]


Epoch 25/200, Val Loss: 0.5516564986508363, Val Accuracy: 0.8674
Current learning rate:  [0.001]


Epoch 26/200: 100%|██████████| 625/625 [00:01<00:00, 512.72batch/s, training_loss=0.547]


Epoch 26/200, Training Loss: 0.5469589196205139, Training Accuracy: 0.8724


Epoch 26/200: 100%|██████████| 157/157 [00:00<00:00, 694.90batch/s, accuracy=0.868]


Epoch 26/200, Val Loss: 0.5516441931390459, Val Accuracy: 0.8676
Current learning rate:  [0.001]


Epoch 27/200: 100%|██████████| 625/625 [00:01<00:00, 474.20batch/s, training_loss=0.547]


Epoch 27/200, Training Loss: 0.546894122505188, Training Accuracy: 0.8724


Epoch 27/200: 100%|██████████| 157/157 [00:00<00:00, 700.89batch/s, accuracy=0.868]


Epoch 27/200, Val Loss: 0.5514391170945138, Val Accuracy: 0.8678
Current learning rate:  [0.001]


Epoch 28/200: 100%|██████████| 625/625 [00:01<00:00, 496.12batch/s, training_loss=0.547]


Epoch 28/200, Training Loss: 0.5468292862892151, Training Accuracy: 0.8726


Epoch 28/200: 100%|██████████| 157/157 [00:00<00:00, 740.57batch/s, accuracy=0.867]


Epoch 28/200, Val Loss: 0.5515968727458055, Val Accuracy: 0.8675
Current learning rate:  [0.001]


Epoch 29/200: 100%|██████████| 625/625 [00:01<00:00, 518.24batch/s, training_loss=0.547]


Epoch 29/200, Training Loss: 0.5467661448478699, Training Accuracy: 0.8725


Epoch 29/200: 100%|██████████| 157/157 [00:00<00:00, 773.41batch/s, accuracy=0.868]


Epoch 29/200, Val Loss: 0.5514579350781289, Val Accuracy: 0.8675
Current learning rate:  [0.001]


Epoch 30/200: 100%|██████████| 625/625 [00:01<00:00, 532.37batch/s, training_loss=0.547]


Epoch 30/200, Training Loss: 0.5467084618568421, Training Accuracy: 0.8726


Epoch 30/200: 100%|██████████| 157/157 [00:00<00:00, 720.18batch/s, accuracy=0.868]


Epoch 30/200, Val Loss: 0.551491889414514, Val Accuracy: 0.8680
Current learning rate:  [0.001]


Epoch 31/200: 100%|██████████| 625/625 [00:01<00:00, 479.43batch/s, training_loss=0.547]


Epoch 31/200, Training Loss: 0.5466545295715332, Training Accuracy: 0.8727


Epoch 31/200: 100%|██████████| 157/157 [00:00<00:00, 665.26batch/s, accuracy=0.868]


Epoch 31/200, Val Loss: 0.5513984488833482, Val Accuracy: 0.8677
Current learning rate:  [0.001]


Epoch 32/200: 100%|██████████| 625/625 [00:01<00:00, 442.34batch/s, training_loss=0.547]


Epoch 32/200, Training Loss: 0.5465954046249389, Training Accuracy: 0.8728


Epoch 32/200: 100%|██████████| 157/157 [00:00<00:00, 643.44batch/s, accuracy=0.868]


Epoch 32/200, Val Loss: 0.5516403127627768, Val Accuracy: 0.8676
Current learning rate:  [0.001]


Epoch 33/200: 100%|██████████| 625/625 [00:01<00:00, 513.56batch/s, training_loss=0.547]


Epoch 33/200, Training Loss: 0.546543493270874, Training Accuracy: 0.8728


Epoch 33/200: 100%|██████████| 157/157 [00:00<00:00, 633.82batch/s, accuracy=0.868]


Epoch 33/200, Val Loss: 0.551550802341692, Val Accuracy: 0.8676
Current learning rate:  [0.001]


Epoch 34/200: 100%|██████████| 625/625 [00:01<00:00, 473.13batch/s, training_loss=0.546]


Epoch 34/200, Training Loss: 0.5464864980697632, Training Accuracy: 0.8727


Epoch 34/200: 100%|██████████| 157/157 [00:00<00:00, 659.66batch/s, accuracy=0.868]


Epoch 34/200, Val Loss: 0.5511754540501127, Val Accuracy: 0.8678
Current learning rate:  [0.001]


Epoch 35/200: 100%|██████████| 625/625 [00:01<00:00, 448.89batch/s, training_loss=0.546]


Epoch 35/200, Training Loss: 0.5464460982322693, Training Accuracy: 0.8729


Epoch 35/200: 100%|██████████| 157/157 [00:00<00:00, 694.69batch/s, accuracy=0.868]


Epoch 35/200, Val Loss: 0.5517355594665382, Val Accuracy: 0.8678
Current learning rate:  [0.001]


Epoch 36/200: 100%|██████████| 625/625 [00:01<00:00, 439.36batch/s, training_loss=0.546]


Epoch 36/200, Training Loss: 0.5463971545219422, Training Accuracy: 0.8730


Epoch 36/200: 100%|██████████| 157/157 [00:00<00:00, 673.82batch/s, accuracy=0.868]


Epoch 36/200, Val Loss: 0.5514509651311643, Val Accuracy: 0.8677
Current learning rate:  [0.0001]


Epoch 37/200: 100%|██████████| 625/625 [00:01<00:00, 473.13batch/s, training_loss=0.546]


Epoch 37/200, Training Loss: 0.5461846581459046, Training Accuracy: 0.8734


Epoch 37/200: 100%|██████████| 157/157 [00:00<00:00, 692.65batch/s, accuracy=0.868]


Epoch 37/200, Val Loss: 0.551461317546808, Val Accuracy: 0.8676
Current learning rate:  [0.0001]


Epoch 38/200: 100%|██████████| 625/625 [00:01<00:00, 437.68batch/s, training_loss=0.546]


Epoch 38/200, Training Loss: 0.5461733509063721, Training Accuracy: 0.8733


Epoch 38/200: 100%|██████████| 157/157 [00:00<00:00, 682.61batch/s, accuracy=0.868]


Epoch 38/200, Val Loss: 0.5513446737246909, Val Accuracy: 0.8676
Current learning rate:  [0.0001]


Epoch 39/200: 100%|██████████| 625/625 [00:01<00:00, 464.32batch/s, training_loss=0.546]


Epoch 39/200, Training Loss: 0.546167018032074, Training Accuracy: 0.8731


Epoch 39/200: 100%|██████████| 157/157 [00:00<00:00, 676.73batch/s, accuracy=0.868]


Epoch 39/200, Val Loss: 0.5515421370791781, Val Accuracy: 0.8676
Current learning rate:  [0.0001]


Epoch 40/200: 100%|██████████| 625/625 [00:01<00:00, 459.68batch/s, training_loss=0.546]


Epoch 40/200, Training Loss: 0.5461609352111816, Training Accuracy: 0.8731


Epoch 40/200: 100%|██████████| 157/157 [00:00<00:00, 707.21batch/s, accuracy=0.868]


Epoch 40/200, Val Loss: 0.551142135243507, Val Accuracy: 0.8676
Current learning rate:  [0.0001]


Epoch 41/200: 100%|██████████| 625/625 [00:01<00:00, 465.38batch/s, training_loss=0.546]


Epoch 41/200, Training Loss: 0.546156844329834, Training Accuracy: 0.8731


Epoch 41/200: 100%|██████████| 157/157 [00:00<00:00, 626.31batch/s, accuracy=0.868]


Epoch 41/200, Val Loss: 0.5512921692459447, Val Accuracy: 0.8676
Current learning rate:  [0.0001]


Epoch 42/200: 100%|██████████| 625/625 [00:01<00:00, 478.56batch/s, training_loss=0.546]


Epoch 42/200, Training Loss: 0.5461500387191772, Training Accuracy: 0.8731


Epoch 42/200: 100%|██████████| 157/157 [00:00<00:00, 668.08batch/s, accuracy=0.868]


Epoch 42/200, Val Loss: 0.5517851045936536, Val Accuracy: 0.8675
Current learning rate:  [1e-05]


Epoch 43/200: 100%|██████████| 625/625 [00:01<00:00, 459.78batch/s, training_loss=0.546]


Epoch 43/200, Training Loss: 0.5461254837989807, Training Accuracy: 0.8732


Epoch 43/200: 100%|██████████| 157/157 [00:00<00:00, 594.70batch/s, accuracy=0.868]


Epoch 43/200, Val Loss: 0.5514047791244118, Val Accuracy: 0.8676
Current learning rate:  [1e-05]


Epoch 44/200: 100%|██████████| 625/625 [00:01<00:00, 468.87batch/s, training_loss=0.546]


Epoch 44/200, Training Loss: 0.5461260559082032, Training Accuracy: 0.8732


Epoch 44/200: 100%|██████████| 157/157 [00:00<00:00, 700.89batch/s, accuracy=0.868]


Epoch 44/200, Val Loss: 0.5513327611479789, Val Accuracy: 0.8675
Current learning rate:  [1e-05]


Epoch 45/200: 100%|██████████| 625/625 [00:01<00:00, 442.08batch/s, training_loss=0.546]


Epoch 45/200, Training Loss: 0.5461251915931702, Training Accuracy: 0.8732


Epoch 45/200: 100%|██████████| 157/157 [00:00<00:00, 677.89batch/s, accuracy=0.868]

Epoch 45/200, Val Loss: 0.5514262796966893, Val Accuracy: 0.8675
Current learning rate:  [1e-05]
Early stopping





In [6]:
# Convert the test data to tensors
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

# Create DataLoader for test dataset
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [16]:
model.load_state_dict(torch.load('best_model_latest.pth'))
model.eval()

from sklearn.metrics import classification_report, f1_score
import numpy as np

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Evaluating', unit='batch'):
        outputs = model(inputs)
        predicted = torch.sigmoid(outputs) > 0.5
        predicted = predicted.float()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

print(classification_report(all_labels, all_preds, target_names=[str(i) for i in range(num_of_labels)], zero_division=0))

f1 = f1_score(all_labels, all_preds, average='micro')
print(f'Test F1 Score (micro average): {f1:.4f}')

  model.load_state_dict(torch.load('best_model_latest.pth'))
Evaluating: 100%|██████████| 313/313 [00:00<00:00, 969.04batch/s] 

              precision    recall  f1-score   support

           0       0.77      0.75      0.76      2333
           1       0.96      0.99      0.98      4670
           2       0.85      0.86      0.85      2715
           3       0.90      0.92      0.91      3107
           4       0.81      0.94      0.87      3917
           5       0.79      0.70      0.74       859
           6       0.82      0.89      0.85      3312
           7       0.52      0.15      0.24       572
           8       0.67      0.55      0.60       564
           9       0.65      0.43      0.52       554
          10       0.60      0.41      0.49       738
          11       0.83      0.91      0.87      2112
          12       0.80      0.76      0.78      2251
          13       0.72      0.72      0.72       797

   micro avg       0.84      0.84      0.84     28501
   macro avg       0.76      0.71      0.73     28501
weighted avg       0.83      0.84      0.83     28501
 samples avg       0.84   


