# Bio_ClinicalBERT

**Description**
> The Bio_ClinicalBERT model was trained on all notes from MIMIC III, a database containing electronic health records from ICU patients at the Beth Israel Hospital in Boston, MA. For more details on MIMIC, see here. All notes from the NOTEEVENTS table were included (~880M words).

**Link**
> https://huggingface.co/emilyalsentzer/Bio_ClinicalBERT

# 1. Setup

In [1]:
!pip install transformers --quiet

In [2]:
import re
import numpy as np
import pandas as pd
import torch
import json

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, hamming_loss, multilabel_confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from imblearn.under_sampling import RandomUnderSampler
from torch.optim import AdamW
from torch import nn
from tqdm import tqdm
from datetime import date

from transformers import logging
logging.set_verbosity_error()

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
pd.options.display.max_columns = 999

In [6]:
def return_column_values_sum_and_percentage(dataframe_input, column_input):
    total_sum = dataframe_input[column_input].sum()
    percentages = dataframe_input[column_input] / total_sum
    sums_percentages = pd.DataFrame({
        'sum': dataframe_input[column_input],
        'percentage': percentages
    })
    sums_percentages['cumsum_percentage'] = sums_percentages['percentage'].cumsum()
    sums_percentages['sum'] = sums_percentages['sum'].apply(lambda x: "{:,}".format(x))
    sums_percentages['percentage'] = sums_percentages['percentage'].mul(100).round(1).astype(str) + '%'
    sums_percentages['cumsum_percentage'] = sums_percentages['cumsum_percentage'].mul(100).round(1).astype(str) + '%'
    return sums_percentages

# 2. Data Acquisition

In [7]:
# data directory (GDrive)
data_dir_processed = '/content/drive/MyDrive/DATASCI_210/data/processed/'

In [8]:
# Define the names of the JSON files
train_df_filename = "train_set__chexpert__6_findings__unbalanced.json"
test_df_filename = "test_set__chexpert__6_findings__unbalanced.json"
valid_df_filename = "validation_set__chexpert__6_findings__unbalanced.json"

train_balanced_df_filename = "train_set__chexpert__6_findings__balanced.json"

# Load the JSON files into a DataFrame
with open(data_dir_processed + train_df_filename) as train_chexpert_json_file:
    train_chexpert_dict_file = json.load(train_chexpert_json_file)

with open(data_dir_processed + test_df_filename) as test_chexpert_json_file:
    test_chexpert_dict_file = json.load(test_chexpert_json_file)

with open(data_dir_processed + valid_df_filename) as valid_chexpert_json_file:
    valid_chexpert_dict_file = json.load(valid_chexpert_json_file)

with open(data_dir_processed + train_balanced_df_filename) as train_balanced_chexpert_json_file:
    train_balanced_chexpert_dict_file = json.load(train_balanced_chexpert_json_file)

# Converting json dataset from dictionary to dataframe
train_df = pd.DataFrame.from_dict(train_chexpert_dict_file)
train_df.reset_index(drop=True, inplace=True)

test_df = pd.DataFrame.from_dict(test_chexpert_dict_file)
test_df.reset_index(drop=True, inplace=True)

valid_df = pd.DataFrame.from_dict(valid_chexpert_dict_file)
valid_df.reset_index(drop=True, inplace=True)

train_balanced_df = pd.DataFrame.from_dict(train_balanced_chexpert_dict_file)
train_balanced_df.reset_index(drop=True, inplace=True)

In [9]:
return_column_values_sum_and_percentage(train_df.groupby("finding_names").agg({"study_id": "count"}).sort_values("study_id", ascending=False), "study_id")

Unnamed: 0_level_0,sum,percentage,cumsum_percentage
finding_names,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
no_finding,5244,59.4%,59.4%
lung_opacity,792,9.0%,68.4%
cardiomegaly,438,5.0%,73.4%
atelectasis,425,4.8%,78.2%
pleural_effusion,288,3.3%,81.5%
pneumonia,246,2.8%,84.2%
"atelectasis, lung_opacity",219,2.5%,86.7%
edema,195,2.2%,88.9%
"atelectasis, pleural_effusion",181,2.1%,91.0%
"lung_opacity, pneumonia",154,1.7%,92.7%


In [10]:
return_column_values_sum_and_percentage(train_balanced_df.groupby("finding_names").agg({"study_id": "count"}).sort_values("study_id", ascending=False), "study_id")

Unnamed: 0_level_0,sum,percentage,cumsum_percentage
finding_names,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
no_finding,960,28.3%,28.3%
cardiomegaly,294,8.7%,37.0%
lung_opacity,252,7.4%,44.4%
pneumonia,233,6.9%,51.3%
atelectasis,231,6.8%,58.1%
edema,194,5.7%,63.9%
pleural_effusion,182,5.4%,69.2%
"lung_opacity, pneumonia",162,4.8%,74.0%
"atelectasis, lung_opacity",153,4.5%,78.5%
"atelectasis, pleural_effusion",136,4.0%,82.5%


In [11]:
# Reset the index
train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

train_balanced_df.reset_index(drop=True, inplace=True)

In [12]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8823 entries, 0 to 8822
Data columns (total 26 columns):
 #   Column                                Non-Null Count  Dtype  
---  ------                                --------------  -----  
 0   patient_id                            8823 non-null   float64
 1   visit_id                              8823 non-null   int64  
 2   study_id                              8823 non-null   float64
 3   temperature                           8823 non-null   float64
 4   heartrate                             8823 non-null   float64
 5   resprate                              8823 non-null   float64
 6   o2sat                                 8823 non-null   float64
 7   sbp                                   8823 non-null   float64
 8   dbp                                   8823 non-null   float64
 9   pain                                  8823 non-null   int64  
 10  acuity                                8823 non-null   float64
 11  positive_label_to

In [13]:
valid_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2206 entries, 0 to 2205
Data columns (total 26 columns):
 #   Column                                Non-Null Count  Dtype  
---  ------                                --------------  -----  
 0   patient_id                            2206 non-null   float64
 1   visit_id                              2206 non-null   int64  
 2   study_id                              2206 non-null   float64
 3   temperature                           2206 non-null   float64
 4   heartrate                             2206 non-null   float64
 5   resprate                              2206 non-null   float64
 6   o2sat                                 2206 non-null   float64
 7   sbp                                   2206 non-null   float64
 8   dbp                                   2206 non-null   float64
 9   pain                                  2206 non-null   int64  
 10  acuity                                2206 non-null   float64
 11  positive_label_to

In [14]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2206 entries, 0 to 2205
Data columns (total 26 columns):
 #   Column                                Non-Null Count  Dtype  
---  ------                                --------------  -----  
 0   patient_id                            2206 non-null   float64
 1   visit_id                              2206 non-null   int64  
 2   study_id                              2206 non-null   float64
 3   temperature                           2206 non-null   float64
 4   heartrate                             2206 non-null   float64
 5   resprate                              2206 non-null   float64
 6   o2sat                                 2206 non-null   float64
 7   sbp                                   2206 non-null   float64
 8   dbp                                   2206 non-null   float64
 9   pain                                  2206 non-null   int64  
 10  acuity                                2206 non-null   float64
 11  positive_label_to

In [15]:
train_balanced_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3389 entries, 0 to 3388
Data columns (total 26 columns):
 #   Column                                Non-Null Count  Dtype  
---  ------                                --------------  -----  
 0   patient_id                            3389 non-null   float64
 1   visit_id                              3389 non-null   int64  
 2   study_id                              3389 non-null   float64
 3   temperature                           3389 non-null   float64
 4   heartrate                             3389 non-null   float64
 5   resprate                              3389 non-null   float64
 6   o2sat                                 3389 non-null   float64
 7   sbp                                   3389 non-null   float64
 8   dbp                                   3389 non-null   float64
 9   pain                                  3389 non-null   int64  
 10  acuity                                3389 non-null   float64
 11  positive_label_to

# 3. Data Preprocessing

In [16]:
# Target variables
target_variables = [
    'atelectasis',
    'cardiomegaly',
    'edema',
    'lung_opacity',
    'pleural_effusion',
    'pneumonia',
]

In [17]:
train_df.loc[:, "history_of_present_illness"] = train_df["history_of_present_illness"].str.replace("___", "[UNK]")
valid_df.loc[:, "history_of_present_illness"] = valid_df["history_of_present_illness"].str.replace("___", "[UNK]")
test_df.loc[:, "history_of_present_illness"] = test_df["history_of_present_illness"].str.replace("___", "[UNK]")

In [18]:
X_train_data, y_train_data = train_df.loc[:, "history_of_present_illness"].values, train_df.loc[:, target_variables].values
X_valid_data, y_valid_data = valid_df.loc[:, "history_of_present_illness"].values, valid_df.loc[:, target_variables].values
X_test_data, y_test_data = test_df.loc[:, "history_of_present_illness"].values, test_df.loc[:, target_variables].values

In [19]:
train_balanced_df.loc[:, "history_of_present_illness"] = train_balanced_df["history_of_present_illness"].str.replace("___", "[UNK]")

In [20]:
X_train_balanced_data, y_train_balanced_data = train_balanced_df.loc[:, "history_of_present_illness"].values, train_balanced_df.loc[:, target_variables].values

In [21]:
X_train_data

array(['[UNK] HCV cirrhosis c/b ascites, hiv on ART, h/o IVDU, COPD, \nbioplar, PTSD, presented from OSH ED with worsening abd \ndistension over past week.  \nPt reports self-discontinuing lasix and spirnolactone [UNK] weeks \nago, because she feels like "they don\'t do anything" and that \nshe "doesn\'t want to put more chemicals in her." She does not \nfollow Na-restricted diets. In the past week, she notes that she \nhas been having worsening abd distension and discomfort. She \ndenies [UNK] edema, or SOB, or orthopnea. She denies f/c/n/v, d/c, \ndysuria. She had food poisoning a week ago from eating stale \ncake (n/v 20 min after food ingestion), which resolved the same \nday. She denies other recent illness or sick contacts. She notes \nthat she has been noticing gum bleeding while brushing her teeth \nin recent weeks. she denies easy bruising, melena, BRBPR, \nhemetesis, hemoptysis, or hematuria.  \nBecause of her abd pain, she went to OSH ED and was transferred \nto [UNK] for fu

In [22]:
y_train_data

array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0.]])

# 4. Model Setup

In [23]:
# Select the Bio_ClinicalBERT model
MODEL_CHECKPOINT = 'emilyalsentzer/Bio_ClinicalBERT'

In [24]:
# Select parameters
NUM_CLASSES = 6
MAX_SEQUENCE_LENGTH = 512
NUM_EPOCHS = 5
BATCH_SIZE = 16
LEARNING_RATE = 0.00005

## 4.1 Classes/Functions

In [25]:
class TokenizerDataset(Dataset):

    def __init__(self, X_data, y_data, tokenizer, max_seq_length):
        self.X_data = X_data
        self.y_data = y_data
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

    def __len__ (self):
        return len(self.X_data)

    def __getitem__(self, index):
        inputs = self.tokenizer.batch_encode_plus(
            [self.X_data[index]],
            add_special_tokens=True,
            max_length=self.max_seq_length,
            padding='max_length',
            return_tensors='pt',
            truncation=True
        )

        input_ids = inputs['input_ids'].squeeze()
        token_type_ids = inputs['token_type_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        labels = torch.tensor(self.y_data[index]).long()

        return {
            'index': index,
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [26]:
class MulticlassClassification(nn.Module):

    def __init__(self, checkpoint, num_classes, hidden_size=201, dropout_prob=0.3, freeze_bert=True):
        super(MulticlassClassification, self).__init__()

        self.model = AutoModel.from_pretrained(checkpoint)
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob
        self.num_classes = num_classes
        self.freeze_bert = freeze_bert

        for param in self.model.parameters():
            param.requires_grad = not self.freeze_bert

        self.pooler_layer = nn.Linear(self.model.config.hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)
        self.classification_layer = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, token_type_ids = None, attention_mask = None):
        outputs = self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask)

        pooler_output = outputs.pooler_output
        hidden = self.pooler_layer(pooler_output)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)
        classification = self.classification_layer(hidden)

        return classification

In [27]:
def train_multiclass_classification_model(model, train_dataloader, val_dataloader, learning_rate, num_epochs, checkpoint_folder, model_name_folder, model_variation, class_weights = None):


    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    if class_weights is not None:
        class_weights = class_weights
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    else:
        criterion = nn.BCEWithLogitsLoss()

    preds_threshold = 0.5

    epochs_list = []

    train_indexes_list = []
    train_loss_list = []
    train_prob_list = []
    train_preds_list = []
    train_labels_list = []
    train_auc_list = []
    train_accuracy_list = []
    train_precision_list = []
    train_recall_list = []
    train_f1_average_list = []

    val_indexes_list = []
    val_loss_list = []
    val_prob_list = []
    val_preds_list = []
    val_labels_list = []
    val_auc_list = []
    val_accuracy_list = []
    val_precision_list = []
    val_recall_list = []
    val_f1_average_list = []

    for epoch in range(num_epochs):

        epochs_list.append(epoch + 1)

        model.train()
        total_train_loss = 0

        all_train_indexes = []
        all_train_prob = []
        all_train_preds = []
        all_train_labels = []

        for train_batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}, Train"):
            indexes = train_batch['index']
            input_ids = train_batch['input_ids']
            token_type_ids = train_batch['token_type_ids']
            attention_mask = train_batch['attention_mask']
            labels = train_batch['labels']

            optimizer.zero_grad()
            outputs = model(input_ids, token_type_ids, attention_mask)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            scheduler.step()

            probabilities = torch.sigmoid(outputs)
            predictions = torch.round(probabilities)

            total_train_loss += loss.item()

            all_train_indexes.append(indexes.detach().cpu().numpy().tolist())
            all_train_prob.append(probabilities.detach().cpu().numpy())
            all_train_preds.append(predictions.detach().cpu().numpy())
            all_train_labels.append(labels.detach().cpu().numpy())

        train_loss = total_train_loss / len(train_dataloader)

        train_prob = np.vstack(all_train_prob)
        train_preds = (np.vstack(all_train_prob) > preds_threshold).astype(int)
        train_labels = np.vstack(all_train_labels)

        # Calculate metrics for training
        train_auc = roc_auc_score(train_labels, train_prob)
        train_accuracy = accuracy_score(train_labels, train_preds)
        train_precision = precision_score(train_labels, train_preds, average=None)
        train_recall = recall_score(train_labels, train_preds, average=None)
        train_f1_average = f1_score(train_labels, train_preds, average='macro')

        train_indexes_list.append(all_train_indexes)
        train_loss_list.append(train_loss)
        train_prob_list.append(train_prob)
        train_preds_list.append(train_preds)
        train_labels_list.append(train_labels)

        train_auc_list.append(train_auc)
        train_accuracy_list.append(train_accuracy)
        train_precision_list.append(np.mean(train_precision))
        train_recall_list.append(np.mean(train_recall))
        train_f1_average_list.append(train_f1_average)

        model.eval()
        total_val_loss = 0

        all_val_indexes = []
        all_val_prob = []
        all_val_preds = []
        all_val_labels = []

        with torch.no_grad():
            for val_batch in tqdm(val_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}, Validation"):
                indexes = val_batch['index']
                input_ids = val_batch['input_ids']
                token_type_ids = val_batch['token_type_ids']
                attention_mask = val_batch['attention_mask']
                labels = val_batch['labels']

                outputs = model(input_ids, token_type_ids, attention_mask)
                loss = criterion(outputs, labels.float())

                probabilities = torch.sigmoid(outputs)
                predictions = torch.round(probabilities)

                total_val_loss += loss.item()

                all_val_indexes.append(indexes.detach().cpu().numpy().tolist())
                all_val_prob.append(probabilities.detach().cpu().numpy())
                all_val_preds.append(predictions.detach().cpu().numpy())
                all_val_labels.append(labels.detach().cpu().numpy())

        val_loss = total_val_loss / len(val_dataloader)

        val_prob = np.vstack(all_val_prob)
        val_preds = (np.vstack(all_val_prob) > preds_threshold).astype(int)
        val_labels = np.vstack(all_val_labels)

        # Calculate metrics for validation
        val_auc = roc_auc_score(val_labels, val_prob)
        val_accuracy = accuracy_score(val_labels, val_preds)
        val_precision = precision_score(val_labels, val_preds, average=None)
        val_recall = recall_score(val_labels, val_preds, average=None)
        val_f1_average = f1_score(val_labels, val_preds, average='macro')

        val_indexes_list.append(all_val_indexes)
        val_loss_list.append(val_loss)
        val_prob_list.append(val_prob)
        val_preds_list.append(val_preds)
        val_labels_list.append(val_labels)

        val_auc_list.append(val_auc)
        val_accuracy_list.append(val_accuracy)
        val_precision_list.append(np.mean(val_precision))
        val_recall_list.append(np.mean(val_recall))
        val_f1_average_list.append(val_f1_average)

        print(f"Epoch {epoch + 1}/{num_epochs} - "
              f"Train AUC: {train_auc:.4f} | "
              f"Val AUC: {val_auc:.4f} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Train Accuracy: {train_accuracy:.4f} | "
              f"Val Accuracy: {val_accuracy:.4f} | "
              f"Train Precision: {np.mean(train_precision):.4f} | "
              f"Val Precision: {np.mean(val_precision):.4f} | "
              f"Train Recall: {np.mean(train_recall):.4f} | "
              f"Val Recall: {np.mean(val_recall):.4f} | "
              f"Train F1 (average): {train_f1_average:.4f} | "
              f"Val F1 (average): {val_f1_average:.4f}")

        checkpoint_path = f"{checkpoint_folder}/{model_name_folder}/{model_variation}__epoch_{epoch + 1}__auc_{val_auc:.4f}.pt"
        torch.save(model.state_dict(), checkpoint_path)

    results = {
      "epochs": epochs_list,
      "train_indexes": train_indexes_list,
      "train_prob": train_prob_list,
      "train_preds": train_preds_list,
      "train_labels": train_labels_list,
      "train_loss": train_loss_list,
      "train_auc": train_auc_list,
      "train_accuracy": train_accuracy_list,
      "train_precision": train_precision_list,
      "train_recall": train_recall_list,
      "train_f1_average": train_f1_average_list,
      "val_indexes": val_indexes_list,
      "val_prob": val_prob_list,
      "val_preds": val_preds_list,
      "val_labels": val_labels_list,
      "val_loss": val_loss_list,
      "val_auc": val_auc_list,
      "val_accuracy": val_accuracy_list,
      "val_precision": val_precision_list,
      "val_recall": val_recall_list,
      "val_f1_average": val_f1_average_list,
    }

    results_path = f"{checkpoint_folder}/{model_name_folder}/{model_variation}__train_results.pt"
    torch.save(results, results_path)

    return model

In [28]:
def test_multiclass_classification_model(model, test_dataloader, checkpoint_folder, model_name_folder, model_variation, class_weights=None):

    if class_weights is not None:
        class_weights = class_weights
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    else:
        criterion = nn.BCEWithLogitsLoss()

    preds_threshold = 0.5

    model.eval()

    total_test_loss = 0

    all_test_indexes = []
    all_test_prob = []
    all_test_preds = []
    all_test_labels = []

    test_indexes_list = []
    test_loss_list = []
    test_prob_list = []
    test_preds_list = []
    test_labels_list = []
    test_auc_list = []
    test_accuracy_list = []
    test_precision_list = []
    test_recall_list = []
    test_f1_average_list = []

    with torch.no_grad():

        for test_batch in tqdm(test_dataloader, desc=f"Test"):
            indexes = test_batch['index']
            input_ids = test_batch['input_ids']
            token_type_ids = test_batch['token_type_ids']
            attention_mask = test_batch['attention_mask']
            labels = test_batch['labels']

            outputs = model(input_ids, token_type_ids, attention_mask)
            loss = criterion(outputs, labels.float())

            probabilities = torch.sigmoid(outputs)
            predictions = torch.round(probabilities)

            total_test_loss += loss.item()
            all_test_indexes.append(indexes.detach().cpu().numpy().tolist())
            all_test_prob.append(probabilities.detach().cpu().numpy())
            all_test_preds.append(predictions.detach().cpu().numpy())
            all_test_labels.append(labels.detach().cpu().numpy())

    # Calculate overall metrics after collecting all predictions and labels
    test_loss = total_test_loss / len(test_dataloader)

    test_prob = np.vstack(all_test_prob)
    test_preds = (np.vstack(all_test_prob) > preds_threshold).astype(int)
    test_labels = np.vstack(all_test_labels)

    # Calculate metrics for testing
    test_auc = roc_auc_score(test_labels, test_prob)
    test_accuracy = accuracy_score(test_labels, test_preds)
    test_precision = precision_score(test_labels, test_preds, average=None)
    test_recall = recall_score(test_labels, test_preds, average=None)
    test_f1_average = f1_score(test_labels, test_preds, average='macro')

    test_indexes_list.append(all_test_indexes)
    test_loss_list.append(test_loss)
    test_prob_list.append(test_prob)
    test_preds_list.append(test_preds)
    test_labels_list.append(test_labels)

    test_auc_list.append(test_auc)
    test_accuracy_list.append(test_accuracy)
    test_precision_list.append(np.mean(test_precision))
    test_recall_list.append(np.mean(test_recall))
    test_f1_average_list.append(test_f1_average)

    print(f"Test AUC: {test_auc:.4f} | "
          f"Test Loss: {test_loss:.4f} | "
          f"Test Accuracy: {test_accuracy:.4f} | "
          f"Test Precision: {np.mean(test_precision):.4f} | "
          f"Test Recall: {np.mean(test_recall):.4f} | "
          f"Test F1 (average): {test_f1_average:.4f}")

    results = {
      "test_indexes": test_indexes_list,
      "test_loss": test_loss_list,
      "test_prob": test_prob_list,
      "test_preds": test_preds_list,
      "test_labels": test_labels_list,
      "test_auc": test_auc_list,
      "test_accuracy": test_accuracy_list,
      "test_precision": test_precision_list,
      "test_recall": test_recall_list,
      "test_f1_average": test_f1_average_list,
    }

    results_path = f"{checkpoint_folder}/{model_name_folder}/{model_variation}__test_results.pt"
    torch.save(results, results_path)

    return results

## 4.2 Tokenization

In [30]:
model_tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [31]:
train_dataset = TokenizerDataset(
    X_train_data,
    y_train_data,
    model_tokenizer,
    MAX_SEQUENCE_LENGTH
    )

val_dataset = TokenizerDataset(
    X_valid_data,
    y_valid_data,
    model_tokenizer,
    MAX_SEQUENCE_LENGTH
    )

test_dataset = TokenizerDataset(
    X_test_data,
    y_test_data,
    model_tokenizer,
    MAX_SEQUENCE_LENGTH
    )

train_balanced_dataset = TokenizerDataset(
    X_train_balanced_data,
    y_train_balanced_data,
    model_tokenizer,
    MAX_SEQUENCE_LENGTH
    )

In [32]:
print(f"Total number of items in the train_dataset: {len(train_dataset)}")
print(f"Total number of items in the X_train_data: {len(X_train_data)}")
print(f"Total number of items in the y_train_data: {len(y_train_data)}")
print("\n")
print(f"Total number of items in the val_dataset: {len(val_dataset)}")
print(f"Total number of items in the X_valid_data: {len(X_valid_data)}")
print(f"Total number of items in the y_valid_data: {len(y_valid_data)}")
print("\n")
print(f"Total number of items in the test_dataset: {len(test_dataset)}")
print(f"Total number of items in the X_test_data: {len(X_test_data)}")
print(f"Total number of items in the y_test_data: {len(y_test_data)}")
print("\n")
print(f"Total number of items in the test_dataset: {len(train_balanced_dataset)}")
print(f"Total number of items in the X_test_data: {len(X_train_balanced_data)}")
print(f"Total number of items in the y_test_data: {len(y_train_balanced_data)}")

Total number of items in the train_dataset: 8823
Total number of items in the X_train_data: 8823
Total number of items in the y_train_data: 8823


Total number of items in the val_dataset: 2206
Total number of items in the X_valid_data: 2206
Total number of items in the y_valid_data: 2206


Total number of items in the test_dataset: 2206
Total number of items in the X_test_data: 2206
Total number of items in the y_test_data: 2206


Total number of items in the test_dataset: 3389
Total number of items in the X_test_data: 3389
Total number of items in the y_test_data: 3389


## 4.4 Class Weights

In [33]:
def compute_multioutput_class_weights(y):
    """
    Compute class weights for each label in a multi-label setting.

    Args:
    - y: A numpy array or a DataFrame with shape (n_samples, n_classes), where each column represents a label.

    Returns:
    - A list of torch tensors containing class weights for each label.
    """
    n_classes = y.shape[1]
    class_weights = []

    for i in range(n_classes):
        # Extract the current class/label
        labels = y[:, i]

        # Compute the class weight for this label
        weight = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)

        # Convert the class weight to a torch tensor and append to the list
        class_weights.append(torch.tensor(weight, dtype=torch.float))

    # Extracting only the positive class weights from your calculated class_weights
    pos_weights = torch.tensor([cw[0] for cw in class_weights])

    return pos_weights

In [34]:
# Convert target variables to a numpy array if it's a DataFrame
y_train = train_df[target_variables].to_numpy()
y_train_balanced = train_balanced_df[target_variables].to_numpy()

# Calculate class weights
POS_WEIGHTS = compute_multioutput_class_weights(y_train)
POS_WEIGHTS_BALANCED = compute_multioutput_class_weights(y_train_balanced)

In [35]:
POS_WEIGHTS

tensor([0.5578, 0.5456, 0.5297, 0.5938, 0.5467, 0.5288])

In [36]:
POS_WEIGHTS_BALANCED

tensor([0.6091, 0.5971, 0.5835, 0.6475, 0.6026, 0.5825])

## 4.5 Data Loaders

In [37]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
    )

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
    )

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
    )

train_balanced_loader = DataLoader(
    dataset=train_balanced_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
    )

# 5. Training/Testing

In [43]:
CHECKPOINT_FOLDER = '/content/drive/MyDrive/DATASCI_210/checkpoints/'
MODEL_NAME_FOLDER = "bio_clinical_bert/multiple_labels__single_input/six_findings"

## 5.1 TRAIN UNBALANCED: `unfrozen_layers` / `without_class_weights`

In [None]:
MODEL_VARIATION = "unbalanced__unfrozen_layers__without_class_weights"

In [None]:
classification_model = MulticlassClassification(
    checkpoint=MODEL_CHECKPOINT,
    num_classes=NUM_CLASSES,
    freeze_bert=False,
    )

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


In [None]:
classification_model = train_multiclass_classification_model(
                                  model=classification_model,
                                  train_dataloader=train_loader,
                                  val_dataloader=val_loader,
                                  class_weights=None,
                                  learning_rate=LEARNING_RATE,
                                  num_epochs=NUM_EPOCHS,
                                  checkpoint_folder=CHECKPOINT_FOLDER,
                                  model_name_folder=MODEL_NAME_FOLDER,
                                  model_variation=MODEL_VARIATION,
                            )

Epoch 1/5, Train: 100%|██████████| 552/552 [04:24<00:00,  2.09it/s]
Epoch 1/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.67it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5 - Train AUC: 0.4950 | Val AUC: 0.5693 | Train Loss: 0.3142 | Val Loss: 0.2963 | Train Accuracy: 0.5913 | Val Accuracy: 0.5938 | Train Precision: 0.0764 | Val Precision: 0.0000 | Train Recall: 0.0016 | Val Recall: 0.0000 | Train F1 (average): 0.0031 | Val F1 (average): 0.0000


Epoch 2/5, Train: 100%|██████████| 552/552 [04:22<00:00,  2.11it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 2/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.66it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/5 - Train AUC: 0.4951 | Val AUC: 0.5181 | Train Loss: 0.3036 | Val Loss: 0.2962 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 3/5, Train: 100%|██████████| 552/552 [04:22<00:00,  2.10it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 3/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.70it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3/5 - Train AUC: 0.4991 | Val AUC: 0.5928 | Train Loss: 0.3021 | Val Loss: 0.2960 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 4/5, Train: 100%|██████████| 552/552 [04:22<00:00,  2.10it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 4/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.65it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4/5 - Train AUC: 0.4967 | Val AUC: 0.6060 | Train Loss: 0.3018 | Val Loss: 0.2958 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 5/5, Train: 100%|██████████| 552/552 [04:22<00:00,  2.10it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 5/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.65it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5/5 - Train AUC: 0.4933 | Val AUC: 0.6080 | Train Loss: 0.3017 | Val Loss: 0.2956 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


In [None]:
test_results = test_multiclass_classification_model(
    model=classification_model,
    test_dataloader=test_loader,
    checkpoint_folder=CHECKPOINT_FOLDER,
    model_name_folder=MODEL_NAME_FOLDER,
    model_variation=MODEL_VARIATION,
    )

Test: 100%|██████████| 138/138 [00:24<00:00,  5.59it/s]

Test AUC: 0.6008 | Test Loss: 0.2964 | Test Accuracy: 0.5934 | Test Precision: 0.0000 | Test Recall: 0.0000 | Test F1 (average): 0.0000



  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
del classification_model, test_results

## 5.2 TRAIN BALANCED: `unfrozen_layers` / `without_class_weights`

In [None]:
MODEL_VARIATION = "balanced__unfrozen_layers__without_class_weights"

In [None]:
classification_model = MulticlassClassification(
    checkpoint=MODEL_CHECKPOINT,
    num_classes=NUM_CLASSES,
    freeze_bert=False,
    )

In [None]:
classification_model = train_multiclass_classification_model(
                                  model=classification_model,
                                  train_dataloader=train_balanced_loader,
                                  val_dataloader=val_loader,
                                  class_weights=None,
                                  learning_rate=LEARNING_RATE,
                                  num_epochs=NUM_EPOCHS,
                                  checkpoint_folder=CHECKPOINT_FOLDER,
                                  model_name_folder=MODEL_NAME_FOLDER,
                                  model_variation=MODEL_VARIATION,
                            )

Epoch 1/5, Train: 100%|██████████| 212/212 [01:40<00:00,  2.10it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 1/5, Validation: 100%|██████████| 138/138 [00:28<00:00,  4.87it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5 - Train AUC: 0.5110 | Val AUC: 0.6709 | Train Loss: 0.4682 | Val Loss: 0.3224 | Train Accuracy: 0.2809 | Val Accuracy: 0.5938 | Train Precision: 0.1628 | Val Precision: 0.0000 | Train Recall: 0.0063 | Val Recall: 0.0000 | Train F1 (average): 0.0119 | Val F1 (average): 0.0000


Epoch 2/5, Train: 100%|██████████| 212/212 [01:41<00:00,  2.09it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 2/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.54it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/5 - Train AUC: 0.6273 | Val AUC: 0.7370 | Train Loss: 0.4396 | Val Loss: 0.3083 | Train Accuracy: 0.2824 | Val Accuracy: 0.5943 | Train Precision: 0.2880 | Val Precision: 0.1364 | Train Recall: 0.0059 | Val Recall: 0.0160 | Train F1 (average): 0.0115 | Val F1 (average): 0.0286


Epoch 3/5, Train: 100%|██████████| 212/212 [01:41<00:00,  2.09it/s]
Epoch 3/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.58it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3/5 - Train AUC: 0.6979 | Val AUC: 0.7536 | Train Loss: 0.4145 | Val Loss: 0.2840 | Train Accuracy: 0.2927 | Val Accuracy: 0.6056 | Train Precision: 0.4137 | Val Precision: 0.2228 | Train Recall: 0.0677 | Val Recall: 0.0578 | Train F1 (average): 0.1112 | Val F1 (average): 0.0917


Epoch 4/5, Train: 100%|██████████| 212/212 [01:42<00:00,  2.07it/s]
Epoch 4/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.58it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4/5 - Train AUC: 0.7537 | Val AUC: 0.7655 | Train Loss: 0.3875 | Val Loss: 0.2709 | Train Accuracy: 0.3196 | Val Accuracy: 0.6047 | Train Precision: 0.5223 | Val Precision: 0.4003 | Train Recall: 0.1657 | Val Recall: 0.1339 | Train F1 (average): 0.2384 | Val F1 (average): 0.1837


Epoch 5/5, Train: 100%|██████████| 212/212 [01:41<00:00,  2.09it/s]
Epoch 5/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.58it/s]


Epoch 5/5 - Train AUC: 0.7961 | Val AUC: 0.7857 | Train Loss: 0.3585 | Val Loss: 0.2654 | Train Accuracy: 0.3559 | Val Accuracy: 0.5757 | Train Precision: 0.6033 | Val Precision: 0.3674 | Train Recall: 0.2779 | Val Recall: 0.2926 | Train F1 (average): 0.3577 | Val F1 (average): 0.2940


In [None]:
test_results = test_multiclass_classification_model(
    model=classification_model,
    test_dataloader=test_loader,
    checkpoint_folder=CHECKPOINT_FOLDER,
    model_name_folder=MODEL_NAME_FOLDER,
    model_variation=MODEL_VARIATION,
    )

Test: 100%|██████████| 138/138 [00:24<00:00,  5.53it/s]

Test AUC: 0.7654 | Test Loss: 0.2739 | Test Accuracy: 0.5730 | Test Precision: 0.3604 | Test Recall: 0.2589 | Test F1 (average): 0.2609





In [None]:
del classification_model, test_results

## 5.3 TRAIN UNBALANCED: `unfrozen_layers` / `with_class_weights`

In [44]:
MODEL_VARIATION = "unbalanced__unfrozen_layers__with_class_weights"

In [45]:
classification_model = MulticlassClassification(
    checkpoint=MODEL_CHECKPOINT,
    num_classes=NUM_CLASSES,
    freeze_bert=False,
    )

In [46]:
classification_model = train_multiclass_classification_model(
                                  model=classification_model,
                                  train_dataloader=train_loader,
                                  val_dataloader=val_loader,
                                  class_weights=POS_WEIGHTS,
                                  learning_rate=LEARNING_RATE,
                                  num_epochs=NUM_EPOCHS,
                                  checkpoint_folder=CHECKPOINT_FOLDER,
                                  model_name_folder=MODEL_NAME_FOLDER,
                                  model_variation=MODEL_VARIATION,
                            )

Epoch 1/5, Train: 100%|██████████| 552/552 [04:24<00:00,  2.09it/s]
Epoch 1/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.73it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5 - Train AUC: 0.5114 | Val AUC: 0.5746 | Train Loss: 0.2145 | Val Loss: 0.1930 | Train Accuracy: 0.5905 | Val Accuracy: 0.5938 | Train Precision: 0.0716 | Val Precision: 0.0000 | Train Recall: 0.0027 | Val Recall: 0.0000 | Train F1 (average): 0.0051 | Val F1 (average): 0.0000


Epoch 2/5, Train: 100%|██████████| 552/552 [04:20<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 2/5, Validation: 100%|██████████| 138/138 [00:23<00:00,  5.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/5 - Train AUC: 0.5061 | Val AUC: 0.5826 | Train Loss: 0.1993 | Val Loss: 0.1926 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 3/5, Train: 100%|██████████| 552/552 [04:20<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 3/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.71it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3/5 - Train AUC: 0.5034 | Val AUC: 0.5990 | Train Loss: 0.1984 | Val Loss: 0.1927 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 4/5, Train: 100%|██████████| 552/552 [04:20<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 4/5, Validation: 100%|██████████| 138/138 [00:23<00:00,  5.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4/5 - Train AUC: 0.5029 | Val AUC: 0.6012 | Train Loss: 0.1981 | Val Loss: 0.1927 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


Epoch 5/5, Train: 100%|██████████| 552/552 [04:20<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 5/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.74it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5/5 - Train AUC: 0.5036 | Val AUC: 0.6055 | Train Loss: 0.1974 | Val Loss: 0.1924 | Train Accuracy: 0.5944 | Val Accuracy: 0.5938 | Train Precision: 0.0000 | Val Precision: 0.0000 | Train Recall: 0.0000 | Val Recall: 0.0000 | Train F1 (average): 0.0000 | Val F1 (average): 0.0000


In [47]:
test_results = test_multiclass_classification_model(
    model=classification_model,
    test_dataloader=test_loader,
    checkpoint_folder=CHECKPOINT_FOLDER,
    model_name_folder=MODEL_NAME_FOLDER,
    model_variation=MODEL_VARIATION,
    )

Test: 100%|██████████| 138/138 [00:24<00:00,  5.70it/s]

Test AUC: 0.5880 | Test Loss: 0.3099 | Test Accuracy: 0.5934 | Test Precision: 0.0000 | Test Recall: 0.0000 | Test F1 (average): 0.0000



  _warn_prf(average, modifier, msg_start, len(result))


In [48]:
del classification_model, test_results

## 5.4 TRAIN BALANCED: `unfrozen_layers` / `with_class_weights`

In [49]:
MODEL_VARIATION = "balanced__unfrozen_layers__with_class_weights"

In [50]:
classification_model = MulticlassClassification(
    checkpoint=MODEL_CHECKPOINT,
    num_classes=NUM_CLASSES,
    freeze_bert=False,
    )

In [51]:
classification_model = train_multiclass_classification_model(
                                  model=classification_model,
                                  train_dataloader=train_balanced_loader,
                                  val_dataloader=val_loader,
                                  class_weights=POS_WEIGHTS_BALANCED,
                                  learning_rate=LEARNING_RATE,
                                  num_epochs=NUM_EPOCHS,
                                  checkpoint_folder=CHECKPOINT_FOLDER,
                                  model_name_folder=MODEL_NAME_FOLDER,
                                  model_variation=MODEL_VARIATION,
                            )

Epoch 1/5, Train: 100%|██████████| 212/212 [01:39<00:00,  2.12it/s]
Epoch 1/5, Validation: 100%|██████████| 138/138 [00:25<00:00,  5.35it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/5 - Train AUC: 0.5025 | Val AUC: 0.6046 | Train Loss: 0.3520 | Val Loss: 0.2229 | Train Accuracy: 0.2762 | Val Accuracy: 0.5938 | Train Precision: 0.1044 | Val Precision: 0.0000 | Train Recall: 0.0054 | Val Recall: 0.0000 | Train F1 (average): 0.0101 | Val F1 (average): 0.0000


Epoch 2/5, Train: 100%|██████████| 212/212 [01:40<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 2/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.65it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/5 - Train AUC: 0.5752 | Val AUC: 0.7084 | Train Loss: 0.3229 | Val Loss: 0.2060 | Train Accuracy: 0.2833 | Val Accuracy: 0.5938 | Train Precision: 0.1667 | Val Precision: 0.0000 | Train Recall: 0.0003 | Val Recall: 0.0000 | Train F1 (average): 0.0006 | Val F1 (average): 0.0000


Epoch 3/5, Train: 100%|██████████| 212/212 [01:40<00:00,  2.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))
Epoch 3/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.58it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3/5 - Train AUC: 0.6518 | Val AUC: 0.7467 | Train Loss: 0.3095 | Val Loss: 0.1989 | Train Accuracy: 0.2836 | Val Accuracy: 0.5938 | Train Precision: 0.2381 | Val Precision: 0.0000 | Train Recall: 0.0016 | Val Recall: 0.0000 | Train F1 (average): 0.0032 | Val F1 (average): 0.0000


Epoch 4/5, Train: 100%|██████████| 212/212 [01:40<00:00,  2.12it/s]
Epoch 4/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.68it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4/5 - Train AUC: 0.7030 | Val AUC: 0.7675 | Train Loss: 0.2964 | Val Loss: 0.1941 | Train Accuracy: 0.2871 | Val Accuracy: 0.5943 | Train Precision: 0.3288 | Val Precision: 0.1167 | Train Recall: 0.0143 | Val Recall: 0.0062 | Train F1 (average): 0.0269 | Val F1 (average): 0.0118


Epoch 5/5, Train: 100%|██████████| 212/212 [01:40<00:00,  2.12it/s]
Epoch 5/5, Validation: 100%|██████████| 138/138 [00:24<00:00,  5.65it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5/5 - Train AUC: 0.7567 | Val AUC: 0.7923 | Train Loss: 0.2782 | Val Loss: 0.1868 | Train Accuracy: 0.3001 | Val Accuracy: 0.5993 | Train Precision: 0.6184 | Val Precision: 0.1951 | Train Recall: 0.0612 | Val Recall: 0.0613 | Train F1 (average): 0.1054 | Val F1 (average): 0.0902


In [52]:
test_results = test_multiclass_classification_model(
    model=classification_model,
    test_dataloader=test_loader,
    checkpoint_folder=CHECKPOINT_FOLDER,
    model_name_folder=MODEL_NAME_FOLDER,
    model_variation=MODEL_VARIATION,
    )

Test: 100%|██████████| 138/138 [00:24<00:00,  5.58it/s]

Test AUC: 0.7658 | Test Loss: 0.2558 | Test Accuracy: 0.6038 | Test Precision: 0.2239 | Test Recall: 0.0645 | Test F1 (average): 0.0973



  _warn_prf(average, modifier, msg_start, len(result))


In [53]:
del classification_model, test_results