In [3]:
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os

import torch
from torch.utils.data import Dataset
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms

from sklearn.metrics import classification_report, fbeta_score, matthews_corrcoef
from tqdm import tqdm
import torch.nn.functional as F

In [4]:
# Define hyperparameters
BATCH_SIZE = 86
LEARNING_RATE = 0.0001

In [5]:
# Set the device
device = "cpu"
if torch.cuda.is_available():
    device =  "cuda"
elif torch.backends.mps.is_available():
    device = "mps" # Use M1 Mac GPU
    device="cpu"
print(device)

cpu


In [6]:
class ExtractorClassifierDataset(Dataset):
    def __init__(self, data, resize_shape=(256, 256)):
        self.data = data
        self.to_tensor = transforms.ToTensor()
        self.resize_image = transforms.Resize(resize_shape, antialias=True)
        self.image_tensors = []
        self.label_tensors = []

        for image, label in tqdm(self.data):
            # Apply transformations to the image and convert to tensor
            image_tensor = self.to_tensor(image)
            image_tensor = self.resize_image(image_tensor)
            image_tensor = image_tensor/255
            self.image_tensors.append(image_tensor.to(device))

            # Create one-hot encoded tensor for the label
            label_tensor = torch.zeros(2)
            label_tensor[label] = 1
            self.label_tensors.append(label_tensor.to(device))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.image_tensors[idx], self.label_tensors[idx]

data_dir = "../../../datasets/extractor_classifier"

train_list = []
validation_list = []
test_list = []

for class_name in ("relevant", "not_relevant"):
    if class_name == "relevant":
        label = 0
    else:
        label = 1

    for image in os.listdir(f"{data_dir}/train/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/train/{class_name}/{image}"
            train_list.append((Image.open(src_path), label))

# Uncomment this part to add augmented data
    for image in os.listdir(f"{data_dir}/train_data_augmentation/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/train_data_augmentation/{class_name}/{image}"
            train_list.append((Image.open(src_path), label))

    for image in os.listdir(f"{data_dir}/validation/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/validation/{class_name}/{image}"
            validation_list.append((Image.open(src_path), label))

    for image in os.listdir(f"{data_dir}/test/{class_name}"):
        if image != ".ipynb_checkpoints":
            src_path = f"{data_dir}/test/{class_name}/{image}"
            test_list.append((Image.open(src_path), label))

# Define the train/validation/test datasets
train_data = ExtractorClassifierDataset(train_list)
validation_data = ExtractorClassifierDataset(validation_list)
test_data = ExtractorClassifierDataset(test_list)

100%|██████████| 7308/7308 [00:51<00:00, 141.56it/s]
100%|██████████| 348/348 [00:02<00:00, 154.84it/s]
100%|██████████| 698/698 [00:04<00:00, 146.72it/s]


In [7]:
import gc
del train_list
del validation_list
del test_list
gc.collect()

433

In [7]:
# Define the dataloaders for each dataset
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False )
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)

In [7]:
from torchvision.models import resnet50

class Resnet50Model(nn.Module):
    def __init__(self, pretrained=False):
        super(Resnet50Model, self).__init__()
        self.resnet_model = resnet50(pretrained=pretrained)
        num_features = self.resnet_model.fc.in_features
        self.resnet_model.fc = nn.Linear(num_features, 2)
        
    def forward(self, x):
        x = self.resnet_model(x)
        return x

In [9]:
from torchvision.models import convnext_base

class ConvNeXtBaseModel(nn.Module):
    def __init__(self, pretrained=False):
        super(ConvNeXtBaseModel, self).__init__()
        self.convnext_model = convnext_base(pretrained=pretrained)
        num_features = self.convnext_model.classifier[2].in_features
        new_classifier_layer = torch.nn.Linear(num_features, 2)
        self.convnext_model.classifier[2] = new_classifier_layer
        
    def forward(self, x):
        x = self.convnext_model(x)
        return x

In [8]:
from torchvision.models import efficientnet_v2_m

class EfficientNetModel(nn.Module):
    def __init__(self, pretrained=False):
        super(EfficientNetModel, self).__init__()
        self.efficientnet_model = efficientnet_v2_m(pretrained=pretrained)
        num_features = self.efficientnet_model.classifier[1].in_features
        new_classifier_layer = torch.nn.Linear(num_features, 2)
        self.efficientnet_model.classifier[1] = new_classifier_layer
        
    def forward(self, x):
        x = self.efficientnet_model(x)
        return x

In [9]:
def train(model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, epochs):
    # Move the model to the device
    model.to(device)

    # Define variables to track the best validation accuracy and the corresponding model state
    best_valid_f_beta = 0.0
    train_loss_values = []
    validation_loss_values = []
    train_f_beta_values = []
    validation_f_beta_values = []

    train_true_labels = []
    train_predicted_labels = []

    for epoch in range(epochs):
        # Train the model for one epoch
        train_loss = 0.0
        model.train()
        for i, (images, labels) in enumerate(tqdm(train_loader)):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            train_loss += loss.item() * images.size(0)
            loss.backward()

            #if (i + 1) % 2 == 0 or (i + 1) == len(train_loader):
            optimizer.step()
            optimizer.zero_grad()

            outputs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)

            train_predicted_labels.extend(predicted.detach().cpu())
            train_true_labels.extend(labels.detach().cpu())

        train_loss = train_loss / len(train_loader.dataset)
        train_loss_values.append(train_loss)

        train_f_beta = fbeta_score(train_true_labels, train_predicted_labels, average="binary", beta=0.5)
        train_f_beta_values.append(train_f_beta)

        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f} - Train Fbeta: {train_f_beta:.4f}")

        # Evaluate the model on the validation set
        validation_loss, validation_f_beta = validate_model(model, validation_loader)
        validation_loss_values.append(validation_loss)
        validation_f_beta_values.append(validation_f_beta)

        # Save the model state if the current validation accuracy is better than the previous best
        if validation_f_beta > best_valid_f_beta:
            best_valid_f_beta = validation_f_beta
            print(f"Updated best model: epoch {epoch+1}")
            torch.save(model.state_dict(), model_path)

    # Load the best model state and evaluate on the test set
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Test model:")
    test_loss, test_f_beta = validate_model(model, test_loader)
    print(f"Test Loss: {test_loss:.4f} - Test Fbeta: {test_f_beta:.4f}")

    return model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values

def validate_model(model, data_loader):
    model.to(device)
    model.eval()

    valid_loss = 0.0
    valid_true_labels = []
    valid_predicted_labels = []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            valid_loss += loss.item() * images.size(0)

            outputs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)

            valid_predicted_labels.extend(predicted.detach().cpu())
            valid_true_labels.extend(labels.detach().cpu())

    f_beta_score = fbeta_score(valid_true_labels, valid_predicted_labels, average="binary", beta=0.5)
    mcc_score = matthews_corrcoef(valid_true_labels, valid_predicted_labels)
    valid_loss = valid_loss / len(data_loader.dataset)

    print(f"Valid Loss: {valid_loss:.4f} - Valid Fbeta: {f_beta_score:.4f} - Valid MCC: {mcc_score}")
    print(classification_report(valid_true_labels, valid_predicted_labels))
    return valid_loss, f_beta_score

In [25]:
import gc
del train_loader
del validation_loader
del test_loader
gc.collect()

25067

In [10]:
efficient_net_model_augmented_data = EfficientNetModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(efficient_net_model_augmented_data.parameters(), lr=LEARNING_RATE)
model_path = "model_results/efficient_net_augmented_model_50.pt"

if os.path.isfile(model_path):
    print("Evaluate model")
    efficient_net_model_augmented_data.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Validation data:")
    validate_model(efficient_net_model_augmented_data, validation_loader)
    print("Test data:")
    validate_model(efficient_net_model_augmented_data, test_loader)
else:
    print("Train model")
    model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values = train(efficient_net_model_augmented_data, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, 50)
    np.save("model_results/efficient_net_augmented_model_50_train_loss", train_loss_values)
    np.save("model_results/efficient_net_augmented_model_50_train_f_beta", train_f_beta_values)
    np.save("model_results/efficient_net_augmented_model_50_valid_loss", validation_loss_values)
    np.save("model_results/efficient_net_augmented_model_50_valid_f_beta", validation_f_beta_values)



Train model


  0%|          | 0/85 [03:45<?, ?it/s]


KeyboardInterrupt: 

In [11]:
resnet_model = Resnet50Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet_model.parameters(), lr=LEARNING_RATE)
model_path = "model_results/resnet50_model_50.pt"

if os.path.isfile(model_path):
    print("Evaluate model")
    resnet_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Validation data:")
    validate_model(resnet_model, validation_loader)
    print("Test data:")
    validate_model(resnet_model, test_loader)
else:
    print("Train model")
    model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values = train(resnet_model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, 50)
    np.save("model_results/resnet50_model_50_train_loss", train_loss_values)
    np.save("model_results/resnet50_model_50_train_f_beta", train_f_beta_values)
    np.save("model_results/resnet50_model_50_valid_loss", validation_loss_values)
    np.save("model_results/resnet50_model_50_valid_f_beta", validation_f_beta_values)



Train model


100%|██████████| 29/29 [00:13<00:00,  2.15it/s]


Epoch 1/50 - Train Loss: 0.5801 - Train Fbeta: 0.5963


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 2.7899 - Valid Fbeta: 0.4271 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       218
           1       0.37      1.00      0.54       130

    accuracy                           0.37       348
   macro avg       0.19      0.50      0.27       348
weighted avg       0.14      0.37      0.20       348

Updated best model: epoch 1


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 2/50 - Train Loss: 0.4604 - Train Fbeta: 0.6759


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 3.7730 - Valid Fbeta: 0.4271 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       218
           1       0.37      1.00      0.54       130

    accuracy                           0.37       348
   macro avg       0.19      0.50      0.27       348
weighted avg       0.14      0.37      0.20       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 3/50 - Train Loss: 0.4045 - Train Fbeta: 0.7127


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 2.2905 - Valid Fbeta: 0.4271 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       218
           1       0.37      1.00      0.54       130

    accuracy                           0.37       348
   macro avg       0.19      0.50      0.27       348
weighted avg       0.14      0.37      0.20       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 4/50 - Train Loss: 0.3529 - Train Fbeta: 0.7424
Valid Loss: 1.0820 - Valid Fbeta: 0.4955 - Valid MCC: 0.32058205832704334
              precision    recall  f1-score   support

           0       0.67      1.00      0.80       218
           1       0.96      0.17      0.29       130

    accuracy                           0.69       348
   macro avg       0.81      0.58      0.54       348
weighted avg       0.78      0.69      0.61       348

Updated best model: epoch 4


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 5/50 - Train Loss: 0.2965 - Train Fbeta: 0.7639
Valid Loss: 0.7904 - Valid Fbeta: 0.5645 - Valid MCC: 0.3220150842223363
              precision    recall  f1-score   support

           0       0.69      0.95      0.80       218
           1       0.78      0.27      0.40       130

    accuracy                           0.70       348
   macro avg       0.73      0.61      0.60       348
weighted avg       0.72      0.70      0.65       348

Updated best model: epoch 5


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 6/50 - Train Loss: 0.2246 - Train Fbeta: 0.7845
Valid Loss: 0.8429 - Valid Fbeta: 0.6470 - Valid MCC: 0.4301307851302884
              precision    recall  f1-score   support

           0       0.78      0.80      0.79       218
           1       0.65      0.62      0.64       130

    accuracy                           0.74       348
   macro avg       0.72      0.71      0.71       348
weighted avg       0.73      0.74      0.73       348

Updated best model: epoch 6


100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 7/50 - Train Loss: 0.2282 - Train Fbeta: 0.8005
Valid Loss: 1.0223 - Valid Fbeta: 0.5287 - Valid MCC: 0.29396868758630984
              precision    recall  f1-score   support

           0       0.80      0.50      0.62       218
           1       0.49      0.79      0.60       130

    accuracy                           0.61       348
   macro avg       0.65      0.65      0.61       348
weighted avg       0.69      0.61      0.61       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 8/50 - Train Loss: 0.2045 - Train Fbeta: 0.8140
Valid Loss: 1.8174 - Valid Fbeta: 0.3738 - Valid MCC: 0.20343653232468653
              precision    recall  f1-score   support

           0       0.65      0.98      0.78       218
           1       0.76      0.12      0.21       130

    accuracy                           0.66       348
   macro avg       0.71      0.55      0.50       348
weighted avg       0.69      0.66      0.57       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 9/50 - Train Loss: 0.1836 - Train Fbeta: 0.8244
Valid Loss: 0.7986 - Valid Fbeta: 0.7741 - Valid MCC: 0.5507131059890218
              precision    recall  f1-score   support

           0       0.76      0.98      0.85       218
           1       0.92      0.47      0.62       130

    accuracy                           0.79       348
   macro avg       0.84      0.72      0.74       348
weighted avg       0.82      0.79      0.77       348

Updated best model: epoch 9


100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 10/50 - Train Loss: 0.1077 - Train Fbeta: 0.8374
Valid Loss: 0.9664 - Valid Fbeta: 0.7667 - Valid MCC: 0.5522168152386769
              precision    recall  f1-score   support

           0       0.77      0.95      0.85       218
           1       0.86      0.53      0.66       130

    accuracy                           0.79       348
   macro avg       0.82      0.74      0.75       348
weighted avg       0.81      0.79      0.78       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 11/50 - Train Loss: 0.1826 - Train Fbeta: 0.8435
Valid Loss: 0.7622 - Valid Fbeta: 0.7326 - Valid MCC: 0.5053839194886307
              precision    recall  f1-score   support

           0       0.75      0.94      0.84       218
           1       0.84      0.48      0.61       130

    accuracy                           0.77       348
   macro avg       0.80      0.71      0.73       348
weighted avg       0.79      0.77      0.76       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 12/50 - Train Loss: 0.1277 - Train Fbeta: 0.8520
Valid Loss: 1.6353 - Valid Fbeta: 0.6522 - Valid MCC: 0.41462501445420574
              precision    recall  f1-score   support

           0       0.71      0.97      0.82       218
           1       0.88      0.32      0.47       130

    accuracy                           0.73       348
   macro avg       0.79      0.65      0.65       348
weighted avg       0.77      0.73      0.69       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 13/50 - Train Loss: 0.1083 - Train Fbeta: 0.8598
Valid Loss: 1.0422 - Valid Fbeta: 0.7277 - Valid MCC: 0.4987637047228007
              precision    recall  f1-score   support

           0       0.75      0.94      0.84       218
           1       0.84      0.48      0.61       130

    accuracy                           0.77       348
   macro avg       0.79      0.71      0.72       348
weighted avg       0.78      0.77      0.75       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 14/50 - Train Loss: 0.0850 - Train Fbeta: 0.8675
Valid Loss: 0.9383 - Valid Fbeta: 0.7759 - Valid MCC: 0.5862586932381382
              precision    recall  f1-score   support

           0       0.80      0.92      0.86       218
           1       0.83      0.62      0.71       130

    accuracy                           0.81       348
   macro avg       0.82      0.77      0.78       348
weighted avg       0.81      0.81      0.80       348

Updated best model: epoch 14


100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 15/50 - Train Loss: 0.0633 - Train Fbeta: 0.8743
Valid Loss: 0.7434 - Valid Fbeta: 0.7551 - Valid MCC: 0.5477811841001434
              precision    recall  f1-score   support

           0       0.78      0.93      0.85       218
           1       0.82      0.57      0.67       130

    accuracy                           0.79       348
   macro avg       0.80      0.75      0.76       348
weighted avg       0.80      0.79      0.78       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 16/50 - Train Loss: 0.0959 - Train Fbeta: 0.8800
Valid Loss: 0.8380 - Valid Fbeta: 0.6848 - Valid MCC: 0.5463171026846079
              precision    recall  f1-score   support

           0       0.87      0.75      0.80       218
           1       0.66      0.82      0.73       130

    accuracy                           0.77       348
   macro avg       0.77      0.78      0.77       348
weighted avg       0.79      0.77      0.78       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 17/50 - Train Loss: 0.0599 - Train Fbeta: 0.8856
Valid Loss: 1.7824 - Valid Fbeta: 0.6902 - Valid MCC: 0.45591318992380025
              precision    recall  f1-score   support

           0       0.72      0.98      0.83       218
           1       0.92      0.35      0.50       130

    accuracy                           0.74       348
   macro avg       0.82      0.66      0.67       348
weighted avg       0.79      0.74      0.71       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 18/50 - Train Loss: 0.0718 - Train Fbeta: 0.8907
Valid Loss: 0.7730 - Valid Fbeta: 0.7163 - Valid MCC: 0.5675110823746576
              precision    recall  f1-score   support

           0       0.85      0.81      0.83       218
           1       0.70      0.77      0.74       130

    accuracy                           0.79       348
   macro avg       0.78      0.79      0.78       348
weighted avg       0.80      0.79      0.79       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 19/50 - Train Loss: 0.0512 - Train Fbeta: 0.8955
Valid Loss: 0.7831 - Valid Fbeta: 0.7445 - Valid MCC: 0.5868134747891307
              precision    recall  f1-score   support

           0       0.84      0.85      0.85       218
           1       0.75      0.73      0.74       130

    accuracy                           0.81       348
   macro avg       0.79      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 20/50 - Train Loss: 0.0397 - Train Fbeta: 0.9002
Valid Loss: 0.9138 - Valid Fbeta: 0.7397 - Valid MCC: 0.5401393729408951
              precision    recall  f1-score   support

           0       0.79      0.90      0.84       218
           1       0.78      0.61      0.68       130

    accuracy                           0.79       348
   macro avg       0.79      0.75      0.76       348
weighted avg       0.79      0.79      0.78       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 21/50 - Train Loss: 0.0381 - Train Fbeta: 0.9041
Valid Loss: 0.8902 - Valid Fbeta: 0.7254 - Valid MCC: 0.5883878158124556
              precision    recall  f1-score   support

           0       0.87      0.81      0.84       218
           1       0.71      0.79      0.75       130

    accuracy                           0.80       348
   macro avg       0.79      0.80      0.79       348
weighted avg       0.81      0.80      0.80       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 22/50 - Train Loss: 0.0362 - Train Fbeta: 0.9078
Valid Loss: 1.1461 - Valid Fbeta: 0.6015 - Valid MCC: 0.43687051630926477
              precision    recall  f1-score   support

           0       0.86      0.61      0.72       218
           1       0.56      0.84      0.67       130

    accuracy                           0.70       348
   macro avg       0.71      0.72      0.69       348
weighted avg       0.75      0.70      0.70       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 23/50 - Train Loss: 0.0373 - Train Fbeta: 0.9112
Valid Loss: 0.8496 - Valid Fbeta: 0.7008 - Valid MCC: 0.5606320797422509
              precision    recall  f1-score   support

           0       0.87      0.78      0.82       218
           1       0.68      0.80      0.73       130

    accuracy                           0.78       348
   macro avg       0.77      0.79      0.78       348
weighted avg       0.80      0.78      0.79       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 24/50 - Train Loss: 0.0269 - Train Fbeta: 0.9144
Valid Loss: 1.0660 - Valid Fbeta: 0.7762 - Valid MCC: 0.5989053631302961
              precision    recall  f1-score   support

           0       0.82      0.91      0.86       218
           1       0.81      0.66      0.73       130

    accuracy                           0.82       348
   macro avg       0.81      0.78      0.79       348
weighted avg       0.82      0.82      0.81       348

Updated best model: epoch 24


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 25/50 - Train Loss: 0.0384 - Train Fbeta: 0.9173
Valid Loss: 1.4421 - Valid Fbeta: 0.7379 - Valid MCC: 0.5164599359208939
              precision    recall  f1-score   support

           0       0.76      0.94      0.84       218
           1       0.83      0.52      0.64       130

    accuracy                           0.78       348
   macro avg       0.80      0.73      0.74       348
weighted avg       0.79      0.78      0.76       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 26/50 - Train Loss: 0.0653 - Train Fbeta: 0.9194
Valid Loss: 1.0176 - Valid Fbeta: 0.6264 - Valid MCC: 0.47825854897255116
              precision    recall  f1-score   support

           0       0.88      0.65      0.74       218
           1       0.59      0.85      0.69       130

    accuracy                           0.72       348
   macro avg       0.73      0.75      0.72       348
weighted avg       0.77      0.72      0.73       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 27/50 - Train Loss: 0.0680 - Train Fbeta: 0.9213
Valid Loss: 0.8427 - Valid Fbeta: 0.7558 - Valid MCC: 0.5887462407398817
              precision    recall  f1-score   support

           0       0.83      0.88      0.85       218
           1       0.77      0.70      0.73       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 28/50 - Train Loss: 0.0381 - Train Fbeta: 0.9235
Valid Loss: 0.8450 - Valid Fbeta: 0.7378 - Valid MCC: 0.603770664381949
              precision    recall  f1-score   support

           0       0.87      0.82      0.84       218
           1       0.73      0.79      0.76       130

    accuracy                           0.81       348
   macro avg       0.80      0.81      0.80       348
weighted avg       0.82      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 29/50 - Train Loss: 0.0491 - Train Fbeta: 0.9254
Valid Loss: 0.9402 - Valid Fbeta: 0.7932 - Valid MCC: 0.6011925118410592
              precision    recall  f1-score   support

           0       0.80      0.94      0.86       218
           1       0.86      0.61      0.71       130

    accuracy                           0.82       348
   macro avg       0.83      0.77      0.79       348
weighted avg       0.82      0.82      0.81       348

Updated best model: epoch 29


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 30/50 - Train Loss: 0.0724 - Train Fbeta: 0.9267
Valid Loss: 0.9925 - Valid Fbeta: 0.7925 - Valid MCC: 0.6122746386337018
              precision    recall  f1-score   support

           0       0.81      0.93      0.87       218
           1       0.84      0.65      0.73       130

    accuracy                           0.82       348
   macro avg       0.83      0.79      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 31/50 - Train Loss: 0.0472 - Train Fbeta: 0.9284
Valid Loss: 1.0095 - Valid Fbeta: 0.7053 - Valid MCC: 0.5251182816693067
              precision    recall  f1-score   support

           0       0.82      0.83      0.82       218
           1       0.71      0.69      0.70       130

    accuracy                           0.78       348
   macro avg       0.76      0.76      0.76       348
weighted avg       0.78      0.78      0.78       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 32/50 - Train Loss: 0.0247 - Train Fbeta: 0.9302
Valid Loss: 0.9540 - Valid Fbeta: 0.6323 - Valid MCC: 0.4801782707134225
              precision    recall  f1-score   support

           0       0.87      0.67      0.75       218
           1       0.60      0.83      0.69       130

    accuracy                           0.73       348
   macro avg       0.73      0.75      0.72       348
weighted avg       0.77      0.73      0.73       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 33/50 - Train Loss: 0.0269 - Train Fbeta: 0.9320
Valid Loss: 1.1003 - Valid Fbeta: 0.7621 - Valid MCC: 0.5728515431209227
              precision    recall  f1-score   support

           0       0.80      0.91      0.85       218
           1       0.80      0.63      0.71       130

    accuracy                           0.80       348
   macro avg       0.80      0.77      0.78       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 34/50 - Train Loss: 0.0243 - Train Fbeta: 0.9337
Valid Loss: 0.8719 - Valid Fbeta: 0.7712 - Valid MCC: 0.606676981453716
              precision    recall  f1-score   support

           0       0.83      0.89      0.86       218
           1       0.79      0.70      0.74       130

    accuracy                           0.82       348
   macro avg       0.81      0.79      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 35/50 - Train Loss: 0.0153 - Train Fbeta: 0.9355
Valid Loss: 1.0003 - Valid Fbeta: 0.7692 - Valid MCC: 0.6316160903316866
              precision    recall  f1-score   support

           0       0.86      0.86      0.86       218
           1       0.77      0.77      0.77       130

    accuracy                           0.83       348
   macro avg       0.82      0.82      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 36/50 - Train Loss: 0.0091 - Train Fbeta: 0.9373
Valid Loss: 1.0235 - Valid Fbeta: 0.7478 - Valid MCC: 0.6125992470120936
              precision    recall  f1-score   support

           0       0.87      0.83      0.85       218
           1       0.74      0.78      0.76       130

    accuracy                           0.82       348
   macro avg       0.80      0.81      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 37/50 - Train Loss: 0.0080 - Train Fbeta: 0.9389
Valid Loss: 1.2156 - Valid Fbeta: 0.7795 - Valid MCC: 0.5927081900998996
              precision    recall  f1-score   support

           0       0.81      0.92      0.86       218
           1       0.83      0.63      0.72       130

    accuracy                           0.81       348
   macro avg       0.82      0.78      0.79       348
weighted avg       0.82      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 38/50 - Train Loss: 0.0165 - Train Fbeta: 0.9403
Valid Loss: 1.0233 - Valid Fbeta: 0.7492 - Valid MCC: 0.5924263454828674
              precision    recall  f1-score   support

           0       0.84      0.86      0.85       218
           1       0.75      0.73      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.80       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 39/50 - Train Loss: 0.0104 - Train Fbeta: 0.9418
Valid Loss: 1.4285 - Valid Fbeta: 0.8233 - Valid MCC: 0.6416023098298607
              precision    recall  f1-score   support

           0       0.81      0.95      0.88       218
           1       0.89      0.63      0.74       130

    accuracy                           0.83       348
   macro avg       0.85      0.79      0.81       348
weighted avg       0.84      0.83      0.83       348

Updated best model: epoch 39


100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 40/50 - Train Loss: 0.0150 - Train Fbeta: 0.9431
Valid Loss: 1.6215 - Valid Fbeta: 0.8260 - Valid MCC: 0.6289129674869637
              precision    recall  f1-score   support

           0       0.79      0.97      0.87       218
           1       0.93      0.58      0.71       130

    accuracy                           0.82       348
   macro avg       0.86      0.77      0.79       348
weighted avg       0.84      0.82      0.81       348

Updated best model: epoch 40


100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 41/50 - Train Loss: 0.0128 - Train Fbeta: 0.9443
Valid Loss: 1.0451 - Valid Fbeta: 0.7975 - Valid MCC: 0.6311456663201523
              precision    recall  f1-score   support

           0       0.83      0.92      0.87       218
           1       0.83      0.68      0.75       130

    accuracy                           0.83       348
   macro avg       0.83      0.80      0.81       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 42/50 - Train Loss: 0.0084 - Train Fbeta: 0.9456
Valid Loss: 1.4865 - Valid Fbeta: 0.8097 - Valid MCC: 0.6218765078818351
              precision    recall  f1-score   support

           0       0.81      0.95      0.87       218
           1       0.88      0.62      0.72       130

    accuracy                           0.82       348
   macro avg       0.84      0.78      0.80       348
weighted avg       0.83      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 43/50 - Train Loss: 0.0101 - Train Fbeta: 0.9468
Valid Loss: 1.0336 - Valid Fbeta: 0.7492 - Valid MCC: 0.5924263454828674
              precision    recall  f1-score   support

           0       0.84      0.86      0.85       218
           1       0.75      0.73      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.80       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 44/50 - Train Loss: 0.0081 - Train Fbeta: 0.9479
Valid Loss: 1.1391 - Valid Fbeta: 0.7571 - Valid MCC: 0.6047863195972583
              precision    recall  f1-score   support

           0       0.85      0.86      0.85       218
           1       0.76      0.74      0.75       130

    accuracy                           0.82       348
   macro avg       0.80      0.80      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 45/50 - Train Loss: 0.0222 - Train Fbeta: 0.9489
Valid Loss: 1.7613 - Valid Fbeta: 0.7442 - Valid MCC: 0.5198305766253191
              precision    recall  f1-score   support

           0       0.76      0.95      0.84       218
           1       0.85      0.49      0.62       130

    accuracy                           0.78       348
   macro avg       0.81      0.72      0.73       348
weighted avg       0.79      0.78      0.76       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 46/50 - Train Loss: 0.0077 - Train Fbeta: 0.9500
Valid Loss: 1.3056 - Valid Fbeta: 0.7915 - Valid MCC: 0.6062951031056054
              precision    recall  f1-score   support

           0       0.81      0.93      0.87       218
           1       0.85      0.63      0.72       130

    accuracy                           0.82       348
   macro avg       0.83      0.78      0.79       348
weighted avg       0.82      0.82      0.81       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 47/50 - Train Loss: 0.0092 - Train Fbeta: 0.9509
Valid Loss: 1.1107 - Valid Fbeta: 0.7655 - Valid MCC: 0.6086719662854201
              precision    recall  f1-score   support

           0       0.84      0.88      0.86       218
           1       0.78      0.72      0.75       130

    accuracy                           0.82       348
   macro avg       0.81      0.80      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 48/50 - Train Loss: 0.0168 - Train Fbeta: 0.9518
Valid Loss: 1.4202 - Valid Fbeta: 0.8039 - Valid MCC: 0.6201797046499385
              precision    recall  f1-score   support

           0       0.81      0.94      0.87       218
           1       0.86      0.63      0.73       130

    accuracy                           0.82       348
   macro avg       0.84      0.79      0.80       348
weighted avg       0.83      0.82      0.82       348



100%|██████████| 29/29 [00:10<00:00,  2.88it/s]


Epoch 49/50 - Train Loss: 0.0132 - Train Fbeta: 0.9528
Valid Loss: 1.1211 - Valid Fbeta: 0.7235 - Valid MCC: 0.5473919635868973
              precision    recall  f1-score   support

           0       0.82      0.85      0.84       218
           1       0.73      0.69      0.71       130

    accuracy                           0.79       348
   macro avg       0.78      0.77      0.77       348
weighted avg       0.79      0.79      0.79       348



100%|██████████| 29/29 [00:10<00:00,  2.89it/s]


Epoch 50/50 - Train Loss: 0.0333 - Train Fbeta: 0.9535
Valid Loss: 1.9468 - Valid Fbeta: 0.5648 - Valid MCC: 0.39637158366839337
              precision    recall  f1-score   support

           0       0.88      0.51      0.65       218
           1       0.52      0.88      0.65       130

    accuracy                           0.65       348
   macro avg       0.70      0.70      0.65       348
weighted avg       0.75      0.65      0.65       348

Test model:
Valid Loss: 1.5313 - Valid Fbeta: 0.8054 - Valid MCC: 0.6043441825539875
              precision    recall  f1-score   support

           0       0.79      0.96      0.87       436
           1       0.90      0.57      0.70       261

    accuracy                           0.81       697
   macro avg       0.84      0.77      0.78       697
weighted avg       0.83      0.81      0.80       697

Test Loss: 1.5313 - Test Fbeta: 0.8054


In [15]:
efficient_net_model = EfficientNetModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(efficient_net_model.parameters(), lr=LEARNING_RATE)
model_path = "model_results/efficient_net50_model_50.pt"

if os.path.isfile(model_path):
    print("Evaluate model")
    efficient_net_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Validation data:")
    validate_model(efficient_net_model, validation_loader)
    print("Test data:")
    validate_model(efficient_net_model, test_loader)
else:
    print("Train model")
    model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values = train(efficient_net_model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, 50)
    np.save("model_results/efficient_net50_model_50_train_loss", train_loss_values)
    np.save("model_results/efficient_net50_model_50_train_f_beta", train_f_beta_values)
    np.save("model_results/efficient_net50_model_50_valid_loss", validation_loss_values)
    np.save("model_results/efficient_net50_model_50_valid_f_beta", validation_f_beta_values)

Train model


100%|██████████| 29/29 [00:38<00:00,  1.32s/it]


Epoch 1/50 - Train Loss: 0.6851 - Train Fbeta: 0.3632


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 0.6808 - Valid Fbeta: 0.0000 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.63      1.00      0.77       218
           1       0.00      0.00      0.00       130

    accuracy                           0.63       348
   macro avg       0.31      0.50      0.39       348
weighted avg       0.39      0.63      0.48       348



100%|██████████| 29/29 [00:18<00:00,  1.57it/s]


Epoch 2/50 - Train Loss: 0.6712 - Train Fbeta: 0.4317


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 0.6678 - Valid Fbeta: 0.0000 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.63      1.00      0.77       218
           1       0.00      0.00      0.00       130

    accuracy                           0.63       348
   macro avg       0.31      0.50      0.39       348
weighted avg       0.39      0.63      0.48       348



100%|██████████| 29/29 [00:18<00:00,  1.57it/s]


Epoch 3/50 - Train Loss: 0.6369 - Train Fbeta: 0.4578
Valid Loss: 0.6106 - Valid Fbeta: 0.3960 - Valid MCC: 0.24879427386081487
              precision    recall  f1-score   support

           0       0.65      0.99      0.79       218
           1       0.89      0.12      0.22       130

    accuracy                           0.67       348
   macro avg       0.77      0.56      0.50       348
weighted avg       0.74      0.67      0.57       348

Updated best model: epoch 3


100%|██████████| 29/29 [00:18<00:00,  1.57it/s]


Epoch 4/50 - Train Loss: 0.6163 - Train Fbeta: 0.4707
Valid Loss: 0.6904 - Valid Fbeta: 0.5590 - Valid MCC: 0.31126577016332924
              precision    recall  f1-score   support

           0       0.69      0.94      0.80       218
           1       0.75      0.28      0.40       130

    accuracy                           0.70       348
   macro avg       0.72      0.61      0.60       348
weighted avg       0.71      0.70      0.65       348

Updated best model: epoch 4


100%|██████████| 29/29 [00:18<00:00,  1.57it/s]


Epoch 5/50 - Train Loss: 0.5979 - Train Fbeta: 0.4990
Valid Loss: 0.8916 - Valid Fbeta: 0.5994 - Valid MCC: 0.3505014227159677
              precision    recall  f1-score   support

           0       0.70      0.94      0.80       218
           1       0.77      0.32      0.45       130

    accuracy                           0.71       348
   macro avg       0.74      0.63      0.63       348
weighted avg       0.73      0.71      0.67       348

Updated best model: epoch 5


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 6/50 - Train Loss: 0.5334 - Train Fbeta: 0.5407
Valid Loss: 0.6631 - Valid Fbeta: 0.6143 - Valid MCC: 0.36564389703544564
              precision    recall  f1-score   support

           0       0.70      0.94      0.81       218
           1       0.78      0.33      0.46       130

    accuracy                           0.72       348
   macro avg       0.74      0.64      0.64       348
weighted avg       0.73      0.72      0.68       348

Updated best model: epoch 6


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 7/50 - Train Loss: 0.5210 - Train Fbeta: 0.5731
Valid Loss: 0.5098 - Valid Fbeta: 0.7296 - Valid MCC: 0.5083516451420639
              precision    recall  f1-score   support

           0       0.77      0.93      0.84       218
           1       0.81      0.52      0.64       130

    accuracy                           0.78       348
   macro avg       0.79      0.72      0.74       348
weighted avg       0.78      0.78      0.76       348

Updated best model: epoch 7


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 8/50 - Train Loss: 0.4474 - Train Fbeta: 0.6071
Valid Loss: 0.4878 - Valid Fbeta: 0.7451 - Valid MCC: 0.540176292855235
              precision    recall  f1-score   support

           0       0.79      0.91      0.85       218
           1       0.80      0.58      0.68       130

    accuracy                           0.79       348
   macro avg       0.79      0.75      0.76       348
weighted avg       0.79      0.79      0.78       348

Updated best model: epoch 8


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 9/50 - Train Loss: 0.4020 - Train Fbeta: 0.6356
Valid Loss: 0.4448 - Valid Fbeta: 0.7797 - Valid MCC: 0.6193055009863092
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       218
           1       0.80      0.71      0.75       130

    accuracy                           0.82       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348

Updated best model: epoch 9


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 10/50 - Train Loss: 0.3942 - Train Fbeta: 0.6573
Valid Loss: 0.4930 - Valid Fbeta: 0.7015 - Valid MCC: 0.5311288392334897
              precision    recall  f1-score   support

           0       0.83      0.81      0.82       218
           1       0.70      0.72      0.71       130

    accuracy                           0.78       348
   macro avg       0.76      0.77      0.77       348
weighted avg       0.78      0.78      0.78       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 11/50 - Train Loss: 0.3589 - Train Fbeta: 0.6776
Valid Loss: 0.4732 - Valid Fbeta: 0.7565 - Valid MCC: 0.5663649580671883
              precision    recall  f1-score   support

           0       0.80      0.90      0.85       218
           1       0.80      0.63      0.70       130

    accuracy                           0.80       348
   macro avg       0.80      0.77      0.78       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 12/50 - Train Loss: 0.3166 - Train Fbeta: 0.6971
Valid Loss: 0.4582 - Valid Fbeta: 0.7714 - Valid MCC: 0.5859015782745022
              precision    recall  f1-score   support

           0       0.81      0.91      0.86       218
           1       0.81      0.64      0.72       130

    accuracy                           0.81       348
   macro avg       0.81      0.78      0.79       348
weighted avg       0.81      0.81      0.80       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 13/50 - Train Loss: 0.2839 - Train Fbeta: 0.7136
Valid Loss: 0.4708 - Valid Fbeta: 0.7317 - Valid MCC: 0.5421067495270164
              precision    recall  f1-score   support

           0       0.81      0.88      0.84       218
           1       0.76      0.65      0.70       130

    accuracy                           0.79       348
   macro avg       0.78      0.76      0.77       348
weighted avg       0.79      0.79      0.79       348



100%|██████████| 29/29 [00:18<00:00,  1.55it/s]


Epoch 14/50 - Train Loss: 0.2466 - Train Fbeta: 0.7297
Valid Loss: 0.5669 - Valid Fbeta: 0.7288 - Valid MCC: 0.5621353975412011
              precision    recall  f1-score   support

           0       0.83      0.84      0.84       218
           1       0.73      0.72      0.72       130

    accuracy                           0.80       348
   macro avg       0.78      0.78      0.78       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 15/50 - Train Loss: 0.2175 - Train Fbeta: 0.7441
Valid Loss: 0.5394 - Valid Fbeta: 0.7576 - Valid MCC: 0.5880484571893909
              precision    recall  f1-score   support

           0       0.83      0.88      0.85       218
           1       0.78      0.69      0.73       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 16/50 - Train Loss: 0.1987 - Train Fbeta: 0.7572
Valid Loss: 0.6541 - Valid Fbeta: 0.7609 - Valid MCC: 0.594672991214502
              precision    recall  f1-score   support

           0       0.83      0.88      0.86       218
           1       0.78      0.70      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.80       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 17/50 - Train Loss: 0.1882 - Train Fbeta: 0.7686
Valid Loss: 0.5631 - Valid Fbeta: 0.7295 - Valid MCC: 0.5715920059035414
              precision    recall  f1-score   support

           0       0.84      0.83      0.84       218
           1       0.73      0.74      0.73       130

    accuracy                           0.80       348
   macro avg       0.78      0.79      0.79       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 18/50 - Train Loss: 0.1287 - Train Fbeta: 0.7802
Valid Loss: 0.6270 - Valid Fbeta: 0.7776 - Valid MCC: 0.6198213933220056
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       218
           1       0.79      0.72      0.75       130

    accuracy                           0.82       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 19/50 - Train Loss: 0.1514 - Train Fbeta: 0.7896
Valid Loss: 0.6145 - Valid Fbeta: 0.7951 - Valid MCC: 0.6520219324369835
              precision    recall  f1-score   support

           0       0.86      0.89      0.87       218
           1       0.81      0.75      0.78       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348

Updated best model: epoch 19


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 20/50 - Train Loss: 0.1169 - Train Fbeta: 0.7993
Valid Loss: 0.6980 - Valid Fbeta: 0.7632 - Valid MCC: 0.586520650811533
              precision    recall  f1-score   support

           0       0.82      0.89      0.86       218
           1       0.79      0.67      0.72       130

    accuracy                           0.81       348
   macro avg       0.81      0.78      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 21/50 - Train Loss: 0.1068 - Train Fbeta: 0.8083
Valid Loss: 0.5690 - Valid Fbeta: 0.7524 - Valid MCC: 0.5904119745090552
              precision    recall  f1-score   support

           0       0.84      0.87      0.85       218
           1       0.76      0.72      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 22/50 - Train Loss: 0.0946 - Train Fbeta: 0.8164
Valid Loss: 0.6021 - Valid Fbeta: 0.7680 - Valid MCC: 0.6238305906610252
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       218
           1       0.77      0.75      0.76       130

    accuracy                           0.82       348
   macro avg       0.81      0.81      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 23/50 - Train Loss: 0.0905 - Train Fbeta: 0.8239
Valid Loss: 0.7211 - Valid Fbeta: 0.7623 - Valid MCC: 0.6020321477954169
              precision    recall  f1-score   support

           0       0.84      0.88      0.86       218
           1       0.78      0.72      0.74       130

    accuracy                           0.82       348
   macro avg       0.81      0.80      0.80       348
weighted avg       0.81      0.82      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 24/50 - Train Loss: 0.0839 - Train Fbeta: 0.8310
Valid Loss: 0.6271 - Valid Fbeta: 0.7477 - Valid MCC: 0.5935605343269026
              precision    recall  f1-score   support

           0       0.85      0.85      0.85       218
           1       0.75      0.74      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.80      0.80       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 25/50 - Train Loss: 0.0622 - Train Fbeta: 0.8376
Valid Loss: 0.8093 - Valid Fbeta: 0.7538 - Valid MCC: 0.6070571630204658
              precision    recall  f1-score   support

           0       0.85      0.85      0.85       218
           1       0.75      0.75      0.75       130

    accuracy                           0.82       348
   macro avg       0.80      0.80      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 26/50 - Train Loss: 0.0652 - Train Fbeta: 0.8437
Valid Loss: 0.7449 - Valid Fbeta: 0.7807 - Valid MCC: 0.6120726270862468
              precision    recall  f1-score   support

           0       0.83      0.90      0.86       218
           1       0.81      0.68      0.74       130

    accuracy                           0.82       348
   macro avg       0.82      0.79      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 27/50 - Train Loss: 0.0756 - Train Fbeta: 0.8491
Valid Loss: 0.7047 - Valid Fbeta: 0.7462 - Valid MCC: 0.5947776993648554
              precision    recall  f1-score   support

           0       0.85      0.85      0.85       218
           1       0.75      0.75      0.75       130

    accuracy                           0.81       348
   macro avg       0.80      0.80      0.80       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 28/50 - Train Loss: 0.0663 - Train Fbeta: 0.8544
Valid Loss: 0.7202 - Valid Fbeta: 0.7289 - Valid MCC: 0.5830629892314716
              precision    recall  f1-score   support

           0       0.86      0.82      0.84       218
           1       0.72      0.77      0.74       130

    accuracy                           0.80       348
   macro avg       0.79      0.80      0.79       348
weighted avg       0.81      0.80      0.80       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 29/50 - Train Loss: 0.0655 - Train Fbeta: 0.8592
Valid Loss: 0.7032 - Valid Fbeta: 0.7807 - Valid MCC: 0.626390231144857
              precision    recall  f1-score   support

           0       0.84      0.89      0.87       218
           1       0.80      0.72      0.76       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.81       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 30/50 - Train Loss: 0.0542 - Train Fbeta: 0.8637
Valid Loss: 0.7522 - Valid Fbeta: 0.7907 - Valid MCC: 0.6533938973324865
              precision    recall  f1-score   support

           0       0.86      0.89      0.87       218
           1       0.80      0.76      0.78       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 31/50 - Train Loss: 0.0459 - Train Fbeta: 0.8680
Valid Loss: 0.7346 - Valid Fbeta: 0.7907 - Valid MCC: 0.6533938973324865
              precision    recall  f1-score   support

           0       0.86      0.89      0.87       218
           1       0.80      0.76      0.78       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 32/50 - Train Loss: 0.0500 - Train Fbeta: 0.8719
Valid Loss: 0.7218 - Valid Fbeta: 0.7872 - Valid MCC: 0.6250781322644678
              precision    recall  f1-score   support

           0       0.83      0.90      0.87       218
           1       0.81      0.70      0.75       130

    accuracy                           0.83       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.83      0.83      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 33/50 - Train Loss: 0.0458 - Train Fbeta: 0.8756
Valid Loss: 0.7071 - Valid Fbeta: 0.7776 - Valid MCC: 0.6198213933220056
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       218
           1       0.79      0.72      0.75       130

    accuracy                           0.82       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 34/50 - Train Loss: 0.0788 - Train Fbeta: 0.8788
Valid Loss: 0.8252 - Valid Fbeta: 0.7890 - Valid MCC: 0.6058744386735473
              precision    recall  f1-score   support

           0       0.81      0.93      0.87       218
           1       0.84      0.64      0.72       130

    accuracy                           0.82       348
   macro avg       0.82      0.78      0.79       348
weighted avg       0.82      0.82      0.81       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 35/50 - Train Loss: 0.0608 - Train Fbeta: 0.8820
Valid Loss: 0.8156 - Valid Fbeta: 0.8302 - Valid MCC: 0.6647811872400419
              precision    recall  f1-score   support

           0       0.83      0.94      0.88       218
           1       0.88      0.68      0.77       130

    accuracy                           0.84       348
   macro avg       0.86      0.81      0.82       348
weighted avg       0.85      0.84      0.84       348

Updated best model: epoch 35


100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 36/50 - Train Loss: 0.0502 - Train Fbeta: 0.8850
Valid Loss: 0.7428 - Valid Fbeta: 0.7797 - Valid MCC: 0.6343747961568701
              precision    recall  f1-score   support

           0       0.85      0.88      0.87       218
           1       0.79      0.75      0.77       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 37/50 - Train Loss: 0.0559 - Train Fbeta: 0.8877
Valid Loss: 0.6690 - Valid Fbeta: 0.7778 - Valid MCC: 0.6352258593817761
              precision    recall  f1-score   support

           0       0.86      0.88      0.87       218
           1       0.78      0.75      0.77       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 38/50 - Train Loss: 0.0471 - Train Fbeta: 0.8905
Valid Loss: 0.6277 - Valid Fbeta: 0.7778 - Valid MCC: 0.6352258593817761
              precision    recall  f1-score   support

           0       0.86      0.88      0.87       218
           1       0.78      0.75      0.77       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 39/50 - Train Loss: 0.0295 - Train Fbeta: 0.8933
Valid Loss: 0.7008 - Valid Fbeta: 0.7875 - Valid MCC: 0.6628506736389114
              precision    recall  f1-score   support

           0       0.88      0.87      0.87       218
           1       0.79      0.79      0.79       130

    accuracy                           0.84       348
   macro avg       0.83      0.83      0.83       348
weighted avg       0.84      0.84      0.84       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 40/50 - Train Loss: 0.0353 - Train Fbeta: 0.8957
Valid Loss: 0.7914 - Valid Fbeta: 0.7776 - Valid MCC: 0.6198213933220056
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       218
           1       0.79      0.72      0.75       130

    accuracy                           0.82       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 41/50 - Train Loss: 0.0494 - Train Fbeta: 0.8980
Valid Loss: 0.8156 - Valid Fbeta: 0.8065 - Valid MCC: 0.6440186408764419
              precision    recall  f1-score   support

           0       0.83      0.92      0.88       218
           1       0.84      0.69      0.76       130

    accuracy                           0.84       348
   macro avg       0.84      0.81      0.82       348
weighted avg       0.84      0.84      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 42/50 - Train Loss: 0.0336 - Train Fbeta: 0.9002
Valid Loss: 0.7186 - Valid Fbeta: 0.7899 - Valid MCC: 0.646091206233717
              precision    recall  f1-score   support

           0       0.85      0.89      0.87       218
           1       0.80      0.75      0.77       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.82       348
weighted avg       0.83      0.84      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 43/50 - Train Loss: 0.0403 - Train Fbeta: 0.9024
Valid Loss: 0.7618 - Valid Fbeta: 0.7907 - Valid MCC: 0.6533938973324865
              precision    recall  f1-score   support

           0       0.86      0.89      0.87       218
           1       0.80      0.76      0.78       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 44/50 - Train Loss: 0.0457 - Train Fbeta: 0.9043
Valid Loss: 0.6475 - Valid Fbeta: 0.7912 - Valid MCC: 0.6384526106627673
              precision    recall  f1-score   support

           0       0.84      0.90      0.87       218
           1       0.81      0.72      0.76       130

    accuracy                           0.83       348
   macro avg       0.83      0.81      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 45/50 - Train Loss: 0.0256 - Train Fbeta: 0.9064
Valid Loss: 0.8115 - Valid Fbeta: 0.8158 - Valid MCC: 0.6631765796356743
              precision    recall  f1-score   support

           0       0.84      0.92      0.88       218
           1       0.85      0.72      0.78       130

    accuracy                           0.84       348
   macro avg       0.84      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 46/50 - Train Loss: 0.0253 - Train Fbeta: 0.9083
Valid Loss: 0.7739 - Valid Fbeta: 0.7860 - Valid MCC: 0.6323955943757573
              precision    recall  f1-score   support

           0       0.84      0.89      0.87       218
           1       0.80      0.72      0.76       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.81       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 47/50 - Train Loss: 0.0209 - Train Fbeta: 0.9102
Valid Loss: 0.8711 - Valid Fbeta: 0.7787 - Valid MCC: 0.6270270401162003
              precision    recall  f1-score   support

           0       0.85      0.89      0.87       218
           1       0.79      0.73      0.76       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.81       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 48/50 - Train Loss: 0.0354 - Train Fbeta: 0.9119
Valid Loss: 0.7657 - Valid Fbeta: 0.7778 - Valid MCC: 0.6352258593817761
              precision    recall  f1-score   support

           0       0.86      0.88      0.87       218
           1       0.78      0.75      0.77       130

    accuracy                           0.83       348
   macro avg       0.82      0.81      0.82       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 49/50 - Train Loss: 0.0168 - Train Fbeta: 0.9137
Valid Loss: 0.9510 - Valid Fbeta: 0.7673 - Valid MCC: 0.6079128800539787
              precision    recall  f1-score   support

           0       0.84      0.88      0.86       218
           1       0.78      0.72      0.75       130

    accuracy                           0.82       348
   macro avg       0.81      0.80      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:18<00:00,  1.56it/s]


Epoch 50/50 - Train Loss: 0.0232 - Train Fbeta: 0.9154
Valid Loss: 0.7513 - Valid Fbeta: 0.7980 - Valid MCC: 0.6585642862164827
              precision    recall  f1-score   support

           0       0.86      0.89      0.88       218
           1       0.81      0.75      0.78       130

    accuracy                           0.84       348
   macro avg       0.83      0.82      0.83       348
weighted avg       0.84      0.84      0.84       348

Test model:
Valid Loss: 0.7006 - Valid Fbeta: 0.8163 - Valid MCC: 0.6376924519454447
              precision    recall  f1-score   support

           0       0.82      0.94      0.88       436
           1       0.88      0.64      0.74       261

    accuracy                           0.83       697
   macro avg       0.85      0.79      0.81       697
weighted avg       0.84      0.83      0.83       697

Test Loss: 0.7006 - Test Fbeta: 0.8163


In [8]:
convnext_model = ConvNeXtBaseModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(convnext_model.parameters(), lr=LEARNING_RATE)
model_path = "model_results/convnext_model_50.pt"

if os.path.isfile(model_path):
    print("Evaluate model")
    convnext_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    print("Validation data:")
    validate_model(convnext_model, validation_loader)
    print("Test data:")
    validate_model(convnext_model, test_loader)
else:
    print("Train model")
    model, train_loss_values, train_f_beta_values, validation_loss_values, validation_f_beta_values = train(convnext_model, model_path, train_loader, validation_loader, test_loader, criterion, optimizer, 50)
    np.save("model_results/convnext_model_50_train_loss", train_loss_values)
    np.save("model_results/convnext_model_50_train_f_beta", train_f_beta_values)
    np.save("model_results/convnext_model_50_valid_loss", validation_loss_values)
    np.save("model_results/convnext_model_50_valid_f_beta", validation_f_beta_values)



Train model


100%|██████████| 29/29 [00:56<00:00,  1.94s/it]


Epoch 1/50 - Train Loss: 0.8413 - Train Fbeta: 0.3498
Valid Loss: 0.6880 - Valid Fbeta: 0.5025 - Valid MCC: 0.22555228909385755
              precision    recall  f1-score   support

           0       0.68      0.88      0.77       218
           1       0.60      0.31      0.41       130

    accuracy                           0.66       348
   macro avg       0.64      0.59      0.59       348
weighted avg       0.65      0.66      0.63       348

Updated best model: epoch 1


100%|██████████| 29/29 [00:54<00:00,  1.89s/it]


Epoch 2/50 - Train Loss: 0.6696 - Train Fbeta: 0.3059


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 0.6558 - Valid Fbeta: 0.0000 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.63      1.00      0.77       218
           1       0.00      0.00      0.00       130

    accuracy                           0.63       348
   macro avg       0.31      0.50      0.39       348
weighted avg       0.39      0.63      0.48       348



100%|██████████| 29/29 [00:54<00:00,  1.89s/it]


Epoch 3/50 - Train Loss: 0.6594 - Train Fbeta: 0.2581


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Valid Loss: 0.6644 - Valid Fbeta: 0.0000 - Valid MCC: 0.0
              precision    recall  f1-score   support

           0       0.63      1.00      0.77       218
           1       0.00      0.00      0.00       130

    accuracy                           0.63       348
   macro avg       0.31      0.50      0.39       348
weighted avg       0.39      0.63      0.48       348



100%|██████████| 29/29 [00:54<00:00,  1.89s/it]


Epoch 4/50 - Train Loss: 0.6569 - Train Fbeta: 0.2563
Valid Loss: 0.6507 - Valid Fbeta: 0.5442 - Valid MCC: 0.3074085854004641
              precision    recall  f1-score   support

           0       0.68      0.96      0.80       218
           1       0.78      0.25      0.37       130

    accuracy                           0.69       348
   macro avg       0.73      0.60      0.59       348
weighted avg       0.72      0.69      0.64       348

Updated best model: epoch 4


100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 5/50 - Train Loss: 0.6508 - Train Fbeta: 0.2803
Valid Loss: 0.6424 - Valid Fbeta: 0.5164 - Valid MCC: 0.23745090392122628
              precision    recall  f1-score   support

           0       0.69      0.86      0.76       218
           1       0.59      0.34      0.43       130

    accuracy                           0.67       348
   macro avg       0.64      0.60      0.60       348
weighted avg       0.65      0.67      0.64       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 6/50 - Train Loss: 0.6364 - Train Fbeta: 0.3192
Valid Loss: 0.6182 - Valid Fbeta: 0.4972 - Valid MCC: 0.2284617641955976
              precision    recall  f1-score   support

           0       0.68      0.90      0.77       218
           1       0.62      0.28      0.38       130

    accuracy                           0.67       348
   macro avg       0.65      0.59      0.58       348
weighted avg       0.66      0.67      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 7/50 - Train Loss: 0.6298 - Train Fbeta: 0.3493
Valid Loss: 0.6497 - Valid Fbeta: 0.5150 - Valid MCC: 0.2260296359403133
              precision    recall  f1-score   support

           0       0.70      0.79      0.74       218
           1       0.54      0.42      0.48       130

    accuracy                           0.65       348
   macro avg       0.62      0.61      0.61       348
weighted avg       0.64      0.65      0.64       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 8/50 - Train Loss: 0.6248 - Train Fbeta: 0.3725
Valid Loss: 0.6119 - Valid Fbeta: 0.5000 - Valid MCC: 0.23537097544001684
              precision    recall  f1-score   support

           0       0.68      0.91      0.77       218
           1       0.64      0.27      0.38       130

    accuracy                           0.67       348
   macro avg       0.66      0.59      0.58       348
weighted avg       0.66      0.67      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.91s/it]


Epoch 9/50 - Train Loss: 0.6377 - Train Fbeta: 0.3864
Valid Loss: 0.6276 - Valid Fbeta: 0.5105 - Valid MCC: 0.2385893564471477
              precision    recall  f1-score   support

           0       0.68      0.89      0.77       218
           1       0.62      0.30      0.40       130

    accuracy                           0.67       348
   macro avg       0.65      0.59      0.59       348
weighted avg       0.66      0.67      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 10/50 - Train Loss: 0.6217 - Train Fbeta: 0.3995
Valid Loss: 0.6320 - Valid Fbeta: 0.4124 - Valid MCC: 0.28428079423731123
              precision    recall  f1-score   support

           0       0.66      1.00      0.79       218
           1       1.00      0.12      0.22       130

    accuracy                           0.67       348
   macro avg       0.83      0.56      0.51       348
weighted avg       0.78      0.67      0.58       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 11/50 - Train Loss: 0.6140 - Train Fbeta: 0.4093
Valid Loss: 0.5927 - Valid Fbeta: 0.5117 - Valid MCC: 0.2513082849237556
              precision    recall  f1-score   support

           0       0.68      0.92      0.78       218
           1       0.66      0.27      0.38       130

    accuracy                           0.68       348
   macro avg       0.67      0.59      0.58       348
weighted avg       0.67      0.68      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 12/50 - Train Loss: 0.6019 - Train Fbeta: 0.4183
Valid Loss: 0.5928 - Valid Fbeta: 0.5272 - Valid MCC: 0.2889831310912389
              precision    recall  f1-score   support

           0       0.68      0.95      0.79       218
           1       0.76      0.24      0.36       130

    accuracy                           0.69       348
   macro avg       0.72      0.60      0.58       348
weighted avg       0.71      0.69      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 13/50 - Train Loss: 0.5867 - Train Fbeta: 0.4308
Valid Loss: 0.7335 - Valid Fbeta: 0.4899 - Valid MCC: 0.18546346069880948
              precision    recall  f1-score   support

           0       0.70      0.67      0.68       218
           1       0.48      0.52      0.50       130

    accuracy                           0.61       348
   macro avg       0.59      0.59      0.59       348
weighted avg       0.62      0.61      0.61       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 14/50 - Train Loss: 0.6305 - Train Fbeta: 0.4354
Valid Loss: 0.6238 - Valid Fbeta: 0.5105 - Valid MCC: 0.2385893564471477
              precision    recall  f1-score   support

           0       0.68      0.89      0.77       218
           1       0.62      0.30      0.40       130

    accuracy                           0.67       348
   macro avg       0.65      0.59      0.59       348
weighted avg       0.66      0.67      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 15/50 - Train Loss: 0.6055 - Train Fbeta: 0.4427
Valid Loss: 0.5963 - Valid Fbeta: 0.5000 - Valid MCC: 0.22433123905083396
              precision    recall  f1-score   support

           0       0.68      0.88      0.77       218
           1       0.60      0.30      0.40       130

    accuracy                           0.66       348
   macro avg       0.64      0.59      0.58       348
weighted avg       0.65      0.66      0.63       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 16/50 - Train Loss: 0.5625 - Train Fbeta: 0.4568
Valid Loss: 0.5837 - Valid Fbeta: 0.5658 - Valid MCC: 0.29617682362458203
              precision    recall  f1-score   support

           0       0.71      0.84      0.77       218
           1       0.62      0.42      0.50       130

    accuracy                           0.69       348
   macro avg       0.66      0.63      0.64       348
weighted avg       0.68      0.69      0.67       348

Updated best model: epoch 16


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 17/50 - Train Loss: 0.5603 - Train Fbeta: 0.4687
Valid Loss: 0.5772 - Valid Fbeta: 0.5645 - Valid MCC: 0.2963119112665602
              precision    recall  f1-score   support

           0       0.70      0.88      0.78       218
           1       0.64      0.38      0.48       130

    accuracy                           0.69       348
   macro avg       0.67      0.63      0.63       348
weighted avg       0.68      0.69      0.67       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 18/50 - Train Loss: 0.5273 - Train Fbeta: 0.4852
Valid Loss: 0.5572 - Valid Fbeta: 0.6643 - Valid MCC: 0.41976350026354964
              precision    recall  f1-score   support

           0       0.73      0.93      0.82       218
           1       0.77      0.42      0.55       130

    accuracy                           0.74       348
   macro avg       0.75      0.67      0.68       348
weighted avg       0.75      0.74      0.72       348

Updated best model: epoch 18


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 19/50 - Train Loss: 0.4998 - Train Fbeta: 0.5029
Valid Loss: 0.5191 - Valid Fbeta: 0.6702 - Valid MCC: 0.4320182253790459
              precision    recall  f1-score   support

           0       0.75      0.90      0.81       218
           1       0.74      0.48      0.59       130

    accuracy                           0.74       348
   macro avg       0.74      0.69      0.70       348
weighted avg       0.74      0.74      0.73       348

Updated best model: epoch 19


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 20/50 - Train Loss: 0.4692 - Train Fbeta: 0.5214
Valid Loss: 0.5278 - Valid Fbeta: 0.6627 - Valid MCC: 0.4260833872229191
              precision    recall  f1-score   support

           0       0.75      0.88      0.81       218
           1       0.72      0.51      0.59       130

    accuracy                           0.74       348
   macro avg       0.73      0.69      0.70       348
weighted avg       0.74      0.74      0.73       348



100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 21/50 - Train Loss: 0.4713 - Train Fbeta: 0.5378
Valid Loss: 0.5337 - Valid Fbeta: 0.6268 - Valid MCC: 0.4112207574151256
              precision    recall  f1-score   support

           0       0.79      0.76      0.77       218
           1       0.62      0.65      0.64       130

    accuracy                           0.72       348
   macro avg       0.70      0.71      0.71       348
weighted avg       0.72      0.72      0.72       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 22/50 - Train Loss: 0.4528 - Train Fbeta: 0.5521
Valid Loss: 0.5689 - Valid Fbeta: 0.7277 - Valid MCC: 0.4987637047228007
              precision    recall  f1-score   support

           0       0.75      0.94      0.84       218
           1       0.84      0.48      0.61       130

    accuracy                           0.77       348
   macro avg       0.79      0.71      0.72       348
weighted avg       0.78      0.77      0.75       348

Updated best model: epoch 22


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 23/50 - Train Loss: 0.4427 - Train Fbeta: 0.5651
Valid Loss: 0.5137 - Valid Fbeta: 0.7034 - Valid MCC: 0.48737820151071787
              precision    recall  f1-score   support

           0       0.78      0.89      0.83       218
           1       0.75      0.57      0.65       130

    accuracy                           0.77       348
   macro avg       0.76      0.73      0.74       348
weighted avg       0.76      0.77      0.76       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 24/50 - Train Loss: 0.4237 - Train Fbeta: 0.5769
Valid Loss: 0.5066 - Valid Fbeta: 0.7468 - Valid MCC: 0.5295844769837367
              precision    recall  f1-score   support

           0       0.77      0.94      0.84       218
           1       0.83      0.53      0.65       130

    accuracy                           0.78       348
   macro avg       0.80      0.73      0.75       348
weighted avg       0.79      0.78      0.77       348

Updated best model: epoch 24


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 25/50 - Train Loss: 0.4061 - Train Fbeta: 0.5887
Valid Loss: 0.4845 - Valid Fbeta: 0.7167 - Valid MCC: 0.5241601211475616
              precision    recall  f1-score   support

           0       0.80      0.86      0.83       218
           1       0.74      0.65      0.69       130

    accuracy                           0.78       348
   macro avg       0.77      0.75      0.76       348
weighted avg       0.78      0.78      0.78       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 26/50 - Train Loss: 0.3853 - Train Fbeta: 0.5998
Valid Loss: 0.4700 - Valid Fbeta: 0.7380 - Valid MCC: 0.5403389438842514
              precision    recall  f1-score   support

           0       0.80      0.89      0.84       218
           1       0.78      0.62      0.69       130

    accuracy                           0.79       348
   macro avg       0.79      0.75      0.76       348
weighted avg       0.79      0.79      0.78       348



100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 27/50 - Train Loss: 0.3679 - Train Fbeta: 0.6107
Valid Loss: 0.4871 - Valid Fbeta: 0.7186 - Valid MCC: 0.5002202166459883
              precision    recall  f1-score   support

           0       0.77      0.91      0.83       218
           1       0.78      0.55      0.64       130

    accuracy                           0.77       348
   macro avg       0.78      0.73      0.74       348
weighted avg       0.77      0.77      0.76       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 28/50 - Train Loss: 0.3576 - Train Fbeta: 0.6214
Valid Loss: 0.4874 - Valid Fbeta: 0.8000 - Valid MCC: 0.5945702060048522
              precision    recall  f1-score   support

           0       0.78      0.96      0.86       218
           1       0.90      0.55      0.69       130

    accuracy                           0.81       348
   macro avg       0.84      0.76      0.77       348
weighted avg       0.83      0.81      0.80       348

Updated best model: epoch 28


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 29/50 - Train Loss: 0.3571 - Train Fbeta: 0.6309
Valid Loss: 0.4726 - Valid Fbeta: 0.7117 - Valid MCC: 0.5091686605134443
              precision    recall  f1-score   support

           0       0.79      0.87      0.83       218
           1       0.74      0.62      0.67       130

    accuracy                           0.78       348
   macro avg       0.77      0.74      0.75       348
weighted avg       0.77      0.78      0.77       348



100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 30/50 - Train Loss: 0.3348 - Train Fbeta: 0.6403
Valid Loss: 0.4858 - Valid Fbeta: 0.7029 - Valid MCC: 0.5169553287644766
              precision    recall  f1-score   support

           0       0.81      0.83      0.82       218
           1       0.71      0.68      0.69       130

    accuracy                           0.78       348
   macro avg       0.76      0.76      0.76       348
weighted avg       0.77      0.78      0.77       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 31/50 - Train Loss: 0.3154 - Train Fbeta: 0.6496
Valid Loss: 0.4404 - Valid Fbeta: 0.7576 - Valid MCC: 0.5880484571893909
              precision    recall  f1-score   support

           0       0.83      0.88      0.85       218
           1       0.78      0.69      0.73       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 32/50 - Train Loss: 0.3052 - Train Fbeta: 0.6580
Valid Loss: 0.4525 - Valid Fbeta: 0.7509 - Valid MCC: 0.5747867440061105
              precision    recall  f1-score   support

           0       0.82      0.88      0.85       218
           1       0.77      0.68      0.72       130

    accuracy                           0.80       348
   macro avg       0.80      0.78      0.79       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 33/50 - Train Loss: 0.2969 - Train Fbeta: 0.6663
Valid Loss: 0.4395 - Valid Fbeta: 0.7854 - Valid MCC: 0.5994654530288472
              precision    recall  f1-score   support

           0       0.81      0.93      0.86       218
           1       0.84      0.63      0.72       130

    accuracy                           0.82       348
   macro avg       0.82      0.78      0.79       348
weighted avg       0.82      0.82      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 34/50 - Train Loss: 0.2839 - Train Fbeta: 0.6739
Valid Loss: 0.5949 - Valid Fbeta: 0.6693 - Valid MCC: 0.5127553371966818
              precision    recall  f1-score   support

           0       0.85      0.74      0.79       218
           1       0.65      0.78      0.71       130

    accuracy                           0.76       348
   macro avg       0.75      0.76      0.75       348
weighted avg       0.78      0.76      0.76       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 35/50 - Train Loss: 0.2998 - Train Fbeta: 0.6808
Valid Loss: 0.4702 - Valid Fbeta: 0.8039 - Valid MCC: 0.6201797046499385
              precision    recall  f1-score   support

           0       0.81      0.94      0.87       218
           1       0.86      0.63      0.73       130

    accuracy                           0.82       348
   macro avg       0.84      0.79      0.80       348
weighted avg       0.83      0.82      0.82       348

Updated best model: epoch 35


100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 36/50 - Train Loss: 0.2612 - Train Fbeta: 0.6879
Valid Loss: 0.4853 - Valid Fbeta: 0.7686 - Valid MCC: 0.6153105890423989
              precision    recall  f1-score   support

           0       0.85      0.88      0.86       218
           1       0.78      0.73      0.75       130

    accuracy                           0.82       348
   macro avg       0.81      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 37/50 - Train Loss: 0.2418 - Train Fbeta: 0.6949
Valid Loss: 0.4797 - Valid Fbeta: 0.7692 - Valid MCC: 0.5858952332459788
              precision    recall  f1-score   support

           0       0.81      0.91      0.86       218
           1       0.81      0.65      0.72       130

    accuracy                           0.81       348
   macro avg       0.81      0.78      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 38/50 - Train Loss: 0.2833 - Train Fbeta: 0.7006
Valid Loss: 0.4781 - Valid Fbeta: 0.7977 - Valid MCC: 0.6131991576775931
              precision    recall  f1-score   support

           0       0.81      0.94      0.87       218
           1       0.85      0.63      0.73       130

    accuracy                           0.82       348
   macro avg       0.83      0.78      0.80       348
weighted avg       0.83      0.82      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 39/50 - Train Loss: 0.2324 - Train Fbeta: 0.7072
Valid Loss: 0.5030 - Valid Fbeta: 0.7904 - Valid MCC: 0.6315754413431673
              precision    recall  f1-score   support

           0       0.84      0.90      0.87       218
           1       0.81      0.71      0.76       130

    accuracy                           0.83       348
   macro avg       0.83      0.81      0.81       348
weighted avg       0.83      0.83      0.83       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 40/50 - Train Loss: 0.2080 - Train Fbeta: 0.7137
Valid Loss: 0.5247 - Valid Fbeta: 0.7367 - Valid MCC: 0.5744744361651659
              precision    recall  f1-score   support

           0       0.84      0.85      0.84       218
           1       0.74      0.72      0.73       130

    accuracy                           0.80       348
   macro avg       0.79      0.79      0.79       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 41/50 - Train Loss: 0.1863 - Train Fbeta: 0.7200
Valid Loss: 0.5191 - Valid Fbeta: 0.7258 - Valid MCC: 0.576169900128941
              precision    recall  f1-score   support

           0       0.85      0.82      0.84       218
           1       0.72      0.76      0.74       130

    accuracy                           0.80       348
   macro avg       0.78      0.79      0.79       348
weighted avg       0.80      0.80      0.80       348



100%|██████████| 29/29 [00:54<00:00,  1.90s/it]


Epoch 42/50 - Train Loss: 0.1560 - Train Fbeta: 0.7265
Valid Loss: 0.5244 - Valid Fbeta: 0.7692 - Valid MCC: 0.6072471922682537
              precision    recall  f1-score   support

           0       0.84      0.89      0.86       218
           1       0.79      0.71      0.74       130

    accuracy                           0.82       348
   macro avg       0.81      0.80      0.80       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 43/50 - Train Loss: 0.1625 - Train Fbeta: 0.7328
Valid Loss: 0.5613 - Valid Fbeta: 0.7850 - Valid MCC: 0.6254133668646593
              precision    recall  f1-score   support

           0       0.84      0.90      0.87       218
           1       0.81      0.71      0.75       130

    accuracy                           0.83       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.83      0.83      0.82       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 44/50 - Train Loss: 0.1362 - Train Fbeta: 0.7390
Valid Loss: 0.5701 - Valid Fbeta: 0.7541 - Valid MCC: 0.5895347016350253
              precision    recall  f1-score   support

           0       0.83      0.87      0.85       218
           1       0.77      0.71      0.74       130

    accuracy                           0.81       348
   macro avg       0.80      0.79      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 45/50 - Train Loss: 0.1238 - Train Fbeta: 0.7450
Valid Loss: 0.6738 - Valid Fbeta: 0.7196 - Valid MCC: 0.5623940145371064
              precision    recall  f1-score   support

           0       0.84      0.82      0.83       218
           1       0.71      0.75      0.73       130

    accuracy                           0.79       348
   macro avg       0.78      0.78      0.78       348
weighted avg       0.80      0.79      0.79       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 46/50 - Train Loss: 0.0899 - Train Fbeta: 0.7511
Valid Loss: 0.7243 - Valid Fbeta: 0.7302 - Valid MCC: 0.5609370952735007
              precision    recall  f1-score   support

           0       0.83      0.85      0.84       218
           1       0.74      0.71      0.72       130

    accuracy                           0.80       348
   macro avg       0.78      0.78      0.78       348
weighted avg       0.79      0.80      0.80       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 47/50 - Train Loss: 0.1215 - Train Fbeta: 0.7563
Valid Loss: 0.6685 - Valid Fbeta: 0.7295 - Valid MCC: 0.5934839334987331
              precision    recall  f1-score   support

           0       0.87      0.81      0.84       218
           1       0.72      0.79      0.75       130

    accuracy                           0.80       348
   macro avg       0.79      0.80      0.80       348
weighted avg       0.81      0.80      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 48/50 - Train Loss: 0.1125 - Train Fbeta: 0.7615
Valid Loss: 0.7604 - Valid Fbeta: 0.7632 - Valid MCC: 0.586520650811533
              precision    recall  f1-score   support

           0       0.82      0.89      0.86       218
           1       0.79      0.67      0.72       130

    accuracy                           0.81       348
   macro avg       0.81      0.78      0.79       348
weighted avg       0.81      0.81      0.81       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 49/50 - Train Loss: 0.1112 - Train Fbeta: 0.7664
Valid Loss: 0.7442 - Valid Fbeta: 0.7840 - Valid MCC: 0.6185774219369887
              precision    recall  f1-score   support

           0       0.83      0.90      0.87       218
           1       0.81      0.69      0.75       130

    accuracy                           0.82       348
   macro avg       0.82      0.80      0.81       348
weighted avg       0.82      0.82      0.82       348



100%|██████████| 29/29 [00:55<00:00,  1.90s/it]


Epoch 50/50 - Train Loss: 0.0870 - Train Fbeta: 0.7715
Valid Loss: 0.8488 - Valid Fbeta: 0.6972 - Valid MCC: 0.5401922268039268
              precision    recall  f1-score   support

           0       0.85      0.79      0.82       218
           1       0.68      0.76      0.72       130

    accuracy                           0.78       348
   macro avg       0.77      0.78      0.77       348
weighted avg       0.79      0.78      0.78       348

Test model:
Valid Loss: 0.5655 - Valid Fbeta: 0.7274 - Valid MCC: 0.5333354092181162
              precision    recall  f1-score   support

           0       0.80      0.88      0.84       436
           1       0.75      0.64      0.69       261

    accuracy                           0.79       697
   macro avg       0.78      0.76      0.76       697
weighted avg       0.78      0.79      0.78       697

Test Loss: 0.5655 - Test Fbeta: 0.7274
