In [1]:
# == Hyperparameter configuration ==

# Official scored labels Physionet 2021: https://github.com/physionetchallenges/evaluation-2021/blob/main/dx_mapping_scored.csv

# 0 = 426783006 -> sinus rhythm (SR)
# 1 = 164889003 -> atrial fibrillation (AF)
# 2 = 164890007 -> atrial flutter (AFL)
# 3 = 284470004 or 63593006 -> premature atrial contraction (PAC) or supraventricular premature beats (SVPB)
# 4 = 427172004 or 17338001 -> premature ventricular contractions (PVC), ventricular premature beats (VPB)
# 5 = 6374002 -> bundle branch block (BBB)
# 6 = 426627000 -> bradycardia (Brady)
# 7 = 733534002 or 164909002 -> complete left bundle branch block (CLBBB), left bundle branch block (LBBB)
# 8 = 713427006 or 59118001 -> complete right bundle branch block (CRBBB), right bundle branch block (RBBB)
# 9 = 270492004 -> 1st degree av block (IAVB)
# 10 = 713426002 -> incomplete right bundle branch block (IRBBB)
# 11 = 39732003 -> left axis deviation (LAD)
# 12 = 445118002 -> left anterior fascicular block (LAnFB)
# 13 = 251146004 -> low qrs voltages (LQRSV)
# 14 = 698252002 -> nonspecific intraventricular conduction disorder (NSIVCB)
# 15 = 10370003 -> pacing rhythm (PR)
# 16 = 365413008 -> poor R wave Progression (PRWP)
# 17 = 164947007 -> prolonged pr interval (LPR)
# 18 = 111975006 -> prolonged qt interval (LQT)
# 19 = 164917005 -> qwave abnormal (QAb)
# 20 = 47665007 -> right axis deviation (RAD)
# 21 = 427393009 -> sinus arrhythmia (SA)
# 22 = 426177001 -> sinus bradycardia (SB)
# 23 = 427084000 -> sinus tachycardia (STach)
# 24 = 164934002 -> t wave abnormal (TAb)
# 25 = 59931005 -> t wave inversion (TInv)

VALID_LABELS = set(
    [
        "164889003",
        "164890007",
        "6374002",
        "426627000",
        "733534002",
        "713427006",
        "270492004",
        "713426002",
        "39732003",
        "445118002",
        "164909002",
        "251146004",
        "698252002",
        "426783006",
        "284470004",
        "10370003",
        "365413008",
        "427172004",
        "164947007",
        "111975006",
        "164917005",
        "47665007",
        "59118001",
        "427393009",
        "426177001",
        "427084000",
        "63593006",
        "164934002",
        "59931005",
        "17338001",
    ]
)
# VALID_LABELS = set(["426783006", "164889003", "164890007", "284470004", "427172004"]) # SR, AF, AFL, PAC, PVC
NUM_CLASSES =  26
EPOCHS = 50
LEARNING_RATE = 0.001
BATCH_SIZE = 32

In [2]:
# == Check if GPU is available ==

!nvidia-smi

Mon Jul  8 17:46:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [35]:
# == Install requirements ==

!pip install google-colab
!pip install numpy
!pip install h5py joblib tqdm
!pip install pandas scipy imblearn
!pip install matplotlib



In [36]:
# == Import requirements ==

import warnings
warnings.filterwarnings("ignore")
import logging

from google.colab import drive, files
import os
import h5py
import joblib
from tqdm import tqdm
import pandas as pd
from collections import Counter

import random
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import numpy as np
import matplotlib.pyplot as plt

import shutil
from itertools import zip_longest

In [6]:
# == Map labels to numerical values functions ==

# Official scored labels Physionet 2021: https://github.com/physionetchallenges/evaluation-2021/blob/main/dx_mapping_scored.csv

arrhyhtmia_mapping_id_to_index = {
    "426783006": 0, # sinus rhythm (SR)
    "164889003": 1, # atrial fibrillation (AF)
    "164890007": 2, # atrial flutter (AFL)
    "284470004": 3, # premature atrial contraction (PAC)
    "63593006": 3, # supraventricular premature beats (SVPB)
    "427172004": 4, # premature ventricular contractions (PVC)
    "17338001": 4, # ventricular premature beats (VPB)
    "6374002": 5, # bundle branch block (BBB)
    "426627000": 6, # bradycardia (Brady)
    "733534002": 7, # complete left bundle branch block (CLBBB)
    "164909002": 7, # left bundle branch block (LBBB)
    "713427006": 8, # complete right bundle branch block (CRBBB)
    "59118001": 8, # right bundle branch block (RBBB)
    "270492004": 9, # 1st degree av block (IAVB)
    "713426002": 10, # incomplete right bundle branch block (IRBBB)
    "39732003": 11, # left axis deviation (LAD)
    "445118002": 12, # left anterior fascicular block (LAnFB)
    "251146004": 13, # low qrs voltages (LQRSV)
    "698252002": 14, # nonspecific intraventricular conduction disorder (NSIVCB)
    "10370003": 15, # pacing rhythm (PR)
    "365413008": 16, # poor R wave Progression (PRWP)
    "164947007": 17, # prolonged pr interval (LPR)
    "111975006": 18, # prolonged qt interval (LQT)
    "164917005": 19, # qwave abnormal (QAb)
    "47665007": 20,  # right axis deviation (RAD)
    "427393009": 21, # sinus arrhythmia (SA)
    "426177001": 22, # sinus bradycardia (SB)
    "427084000": 23, # sinus tachycardia (STach)
    "164934002": 24, # t wave abnormal (TAb)
    "59931005": 25 # t wave inversion (TInv)
}

def map_arrhyhtmia_id_to_index(x: str) -> int:
    return arrhyhtmia_mapping_id_to_index[x]

arrhyhtmia_mapping_index_to_id = {
    0: "426783006", # sinus rhythm (SR)
    1: "164889003", # atrial fibrillation (AF)
    2: "164890007", # atrial flutter (AFL)
    3: "284470004|63593006", # premature atrial contraction (PAC) | supraventricular premature beats (SVPB)
    4: "427172004|17338001", # premature ventricular contractions (PVC) | ventricular premature beats (VPB)
    5: "6374002", # bundle branch block (BBB)
    6: "426627000", # bradycardia (Brady)
    7: "733534002|164909002", # complete left bundle branch block (CLBBB) | left bundle branch block (LBBB)
    8: "713427006|59118001", # complete right bundle branch block (CRBBB) | right bundle branch block (RBBB)
    9: "270492004", # 1st degree av block (IAVB)
    10: "713426002", # incomplete right bundle branch block (IRBBB)
    11: "39732003", # left axis deviation (LAD)
    12: "445118002", # left anterior fascicular block (LAnFB)
    13: "251146004", # low qrs voltages (LQRSV)
    14: "698252002", # nonspecific intraventricular conduction disorder (NSIVCB)
    15: "10370003", # pacing rhythm (PR)
    16: "365413008", # poor R wave Progression (PRWP)
    17: "164947007", # prolonged pr interval (LPR)
    18: "111975006", # prolonged qt interval (LQT)
    19: "164917005", # qwave abnormal (QAb)
    20: "47665007",  # right axis deviation (RAD)
    21: "427393009", # sinus arrhythmia (SA)
    22: "426177001", # sinus bradycardia (SB)
    23: "427084000", # sinus tachycardia (STach)
    24: "164934002", # t wave abnormal (TAb)
    25: "59931005" # t wave inversion (TInv)
}

# arrhythmia_mapping_index_to_id = dict(map(reversed, arrhyhtmia_mapping_id_to_index.items()))

def map_arrhyhtmia_index_to_id(x: int) -> str:
    return arrhyhtmia_mapping_index_to_id[x]


In [7]:
# == Mount drive ==

# https://drive.google.com/drive/folders/1L_gOMrkygu2N0k97COYuVrmE-AwEEMoQ

drive.mount('/content/drive')
path = "/content/drive/My Drive/Master Thesis/Datasets"
!ls "/content/drive/My Drive/Master Thesis/Datasets"

Mounted at /content/drive
codes_SNOMED.csv  physionet2017_references.csv	physionet2021_references.csv
physionet2017.h5  physionet2021.h5		prepared


In [28]:
# == Load all Physionet2021 ECGs and their IDs to a dictionary X_dict ==

f = open(os.path.join(path, "prepared/biobss_features_imputated.csv"), "r", encoding="utf-8")
id_features = f.readlines()
f.close()
id_features = list(map(lambda x: x.replace("\n", "").split(";"), id_features))
id_features = list(map(lambda x: [x[0]] + list(map(lambda y: float(y.split(": ")[1]), x[1:])), id_features))

X_dict = {}
for id_feature in id_features:
    X_dict[id_feature[0]] = id_feature[1:]

# == Load all labels and their IDs to a dictionary Y_dict (some ECGs can have multiple labels) ==
Y_dict = {}
labels_df = pd.read_csv(os.path.join(path, "physionet2021_references.csv"), sep=";")
pbar = tqdm(total=len(labels_df), desc="Load ECG labels", position=0, leave=True)
for _, row in labels_df.iterrows():
    labels = row["labels"].strip().split(",")
    binary_crossentropy_labels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    for label in labels:
        if label in VALID_LABELS:
              binary_crossentropy_labels[(map_arrhyhtmia_id_to_index(label))] = 1
    if row["id"] in X_dict:
        Y_dict[row["id"]] = binary_crossentropy_labels
    pbar.update(1)

Load ECG labels: 100%|██████████| 88252/88252 [00:14<00:00, 6006.48it/s] 
Load ECG labels: 100%|█████████▉| 88174/88252 [00:04<00:00, 18426.03it/s]

In [38]:
# == Map scored labels to ECGs and create three lists (X: ECGs, Y: labels, Z: IDs) ==

X = []
Y = []
Z = []

for patient_id in tqdm(Y_dict, desc="Map labels to ECGs", position=0, leave=True):
      X.append(X_dict[patient_id])
      Y.append(Y_dict[patient_id])
      Z.append(str(patient_id))

Map labels to ECGs: 100%|██████████| 81030/81030 [00:00<00:00, 868394.57it/s]


In [39]:
# == Shuffle data, convert to numpy lists and reshape ==

# Shuffle data
combined = list(zip(X, Y, Z))
random.shuffle(combined)
X, Y, Z = zip(*combined)
X = list(X)
Y = list(Y)
Z = list(Z)

# Convert to numpy lists
for index, x in enumerate(X):
    X[index] = np.array(x)
X = np.array(X)

for index, y in enumerate(Y):
    Y[index] = np.array(y)
Y = np.array(Y)

Z = np.array(Z)

In [40]:
try:
    if os.path.exists("models"):
        shutil.rmtree("models")
    os.makedirs("models")
except OSError as e:
    print(f"Error: {e.strerror}")

try:
    if os.path.exists("test_outputs"):
        shutil.rmtree("test_outputs")
    os.makedirs("test_outputs")
except OSError as e:
    print(f"Error: {e.strerror}")

In [32]:
# == Save predictions util functin ==

def save_predictions(pred, pred_prob, z_test):
    pbar = tqdm(total=len(pred), desc="Convert test_outputs", position=0, leave=True)
    for index, prediction in enumerate(tqdm(zip(pred, pred_prob))):
        pbar.update(1)
        new_file = "#"
        new_file += z_test[index] + "\n"
        # ids
        for pred_index, _ in enumerate(prediction[0]):
            new_file += map_arrhyhtmia_index_to_id(pred_index) + ","
        new_file = new_file[:-1] + "\n"
        # pred
        for pred_index, _ in enumerate(prediction[0]):
            if prediction[0][pred_index] == 1:
                value = "True"
            elif prediction[0][pred_index] == 0:
                value = "False"
            new_file += value + ","
        new_file = new_file[:-1] + "\n"
        # pred_prob
        for pred_index, _ in enumerate(prediction[1]):
            new_file += str(prediction[1][pred_index]) + ","
        new_file = new_file[:-1]
        with open(f"test_outputs/{z_test[index]}.csv", "w", encoding="utf-8") as file:
            file.write(new_file)

In [33]:
# == Plot distribution util functin ==

def plot_distribution(Y):
    extracted_labels_testset = []
    for label in Y:
        for index, _ in enumerate(label):
            if label[index]:
                extracted_labels_testset.append(index)

    label_counts = Counter(extracted_labels_testset)
    print(label_counts)
    combined = list(zip(label_counts.keys(), label_counts.values()))
    combined.sort(key=lambda x: x[1], reverse=True)
    label_keys, label_values = zip(*combined)
    label_keys = list(label_keys)
    label_values = list(label_values)

    label_keys = list(map(lambda x: str(x), label_keys))

    # for index, key in enumerate(label_keys):
    #    label_keys[index] = str(index + 1) + ". " + key[0].upper() + key[1:]

    plt.figure(figsize=(30, 10))
    plt.bar(label_keys, label_values, color="#1f77b4")
    plt.title("Physionet 2021 labels")
    plt.xlabel("Arrhythmia type", labelpad=7)
    plt.ylabel("Occurence")
    plt.xticks(rotation=45, ha="right", fontsize=16)  # (rotation='diagional')
    bars = plt.bar(label_keys, label_values, color="#1f77b4")
    # Adding the counts on top of the bars
    for bar in bars:
        yval = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            yval + 5,
            yval,
            ha="center",
            va="bottom",
            fontsize=16,
        )

    plt.show()
    plt.close()

In [41]:
# == Train model ==

# Initialize KFold
fold_counter = 1
n_folds = 10
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
metrics = {
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": []
}

# Cross-validation
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = Y[train_index], Y[test_index]
    z_train, z_test = Z[train_index], Z[test_index]


    model = RandomForestClassifier() # RandomForestClassifier(n_estimators=100, random_state=42, max_features=30)
    model.fit(X_train, y_train)

    joblib.dump(model, f"models/{fold_counter}.pkl")
    fold_counter += 1

    pred_prob = model.predict(X_test)
    threshold = 0.5
    pred = (pred_prob > threshold).astype(int)

    accuracy = accuracy_score(y_test, pred)
    precision = precision_score(y_test, pred, average='weighted')
    recall = recall_score(y_test, pred, average='weighted')
    f1 = f1_score(y_test, pred, average='weighted')

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-score: {f1:.4f}")

    print("Save predictions...")
    save_predictions(pred, pred_prob, z_test)
    print("...finished saving.")

    metrics["accuracy"].append(accuracy)
    metrics["precision"].append(precision)
    metrics["recall"].append(recall)
    metrics["f1"].append(f1)

# train_accuracy = history.history["accuracy"] # list
# val_accuracy = history.history["val_accuracy"] # list
# train_loss = history.history["loss"] # list
# val_loss = history.history["val_loss"] # list

Accuracy: 0.3984
Precision: 0.7633
Recall: 0.3878
F1-score: 0.4122
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  13%|█▎        | 1033/8103 [00:00<00:00, 10324.52it/s]
Convert test_outputs:  26%|██▌       | 2115/8103 [00:00<00:00, 10612.93it/s]
Convert test_outputs:  39%|███▉      | 3177/8103 [00:00<00:00, 10513.09it/s]
Convert test_outputs:  52%|█████▏    | 4229/8103 [00:00<00:00, 10410.87it/s]
Convert test_outputs:  65%|██████▌   | 5271/8103 [00:00<00:00, 10187.44it/s]
Convert test_outputs:  78%|███████▊  | 6323/8103 [00:00<00:00, 10297.81it/s]
Convert test_outputs:  91%|█████████ | 7383/8103 [00:00<00:00, 10393.46it/s]
8103it [00:00, 10403.23it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10339.11it/s]


...finished saving.
Accuracy: 0.3883
Precision: 0.7096
Recall: 0.3777
F1-score: 0.4005
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  12%|█▏        | 994/8103 [00:00<00:00, 9935.70it/s]
Convert test_outputs:  25%|██▍       | 1988/8103 [00:00<00:00, 9937.69it/s]
Convert test_outputs:  37%|███▋      | 2991/8103 [00:00<00:00, 9979.28it/s]
Convert test_outputs:  50%|████▉     | 4048/8103 [00:00<00:00, 10211.04it/s]
Convert test_outputs:  63%|██████▎   | 5115/8103 [00:00<00:00, 10374.19it/s]
Convert test_outputs:  76%|███████▋  | 6183/8103 [00:00<00:00, 10475.77it/s]
Convert test_outputs:  90%|████████▉ | 7270/8103 [00:00<00:00, 10602.67it/s]
8103it [00:00, 10520.48it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10441.54it/s]


...finished saving.
Accuracy: 0.3939
Precision: 0.7360
Recall: 0.3857
F1-score: 0.4110
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  12%|█▏        | 996/8103 [00:00<00:00, 9953.20it/s]
Convert test_outputs:  25%|██▍       | 1992/8103 [00:00<00:00, 9493.31it/s]
Convert test_outputs:  36%|███▋      | 2943/8103 [00:00<00:00, 9378.55it/s]
Convert test_outputs:  49%|████▉     | 4005/8103 [00:00<00:00, 9856.86it/s]
Convert test_outputs:  62%|██████▏   | 4993/8103 [00:00<00:00, 9337.57it/s]
Convert test_outputs:  75%|███████▍  | 6052/8103 [00:00<00:00, 9740.68it/s]
Convert test_outputs:  87%|████████▋ | 7032/8103 [00:00<00:00, 9501.97it/s]
Convert test_outputs:  99%|█████████▊| 7987/8103 [00:00<00:00, 9440.40it/s]
8103it [00:00, 9525.61it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 9467.03it/s]


...finished saving.
Accuracy: 0.3792
Precision: 0.7503
Recall: 0.3729
F1-score: 0.3963
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  13%|█▎        | 1037/8103 [00:00<00:00, 10365.02it/s]
Convert test_outputs:  26%|██▌       | 2075/8103 [00:00<00:00, 10371.32it/s]
Convert test_outputs:  38%|███▊      | 3113/8103 [00:00<00:00, 10287.02it/s]
Convert test_outputs:  51%|█████▏    | 4162/8103 [00:00<00:00, 10363.93it/s]
Convert test_outputs:  64%|██████▍   | 5218/8103 [00:00<00:00, 10434.35it/s]
Convert test_outputs:  78%|███████▊  | 6305/8103 [00:00<00:00, 10581.11it/s]
Convert test_outputs:  91%|█████████▏| 7397/8103 [00:00<00:00, 10691.16it/s]
8103it [00:00, 10573.62it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10503.54it/s]


...finished saving.
Accuracy: 0.3900
Precision: 0.7437
Recall: 0.3830
F1-score: 0.4086
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  14%|█▎        | 1101/8103 [00:00<00:00, 11006.55it/s]
Convert test_outputs:  27%|██▋       | 2214/8103 [00:00<00:00, 11077.50it/s]
Convert test_outputs:  41%|████      | 3325/8103 [00:00<00:00, 11089.69it/s]
Convert test_outputs:  55%|█████▍    | 4434/8103 [00:00<00:00, 10651.21it/s]
Convert test_outputs:  68%|██████▊   | 5526/8103 [00:00<00:00, 10743.61it/s]
Convert test_outputs:  82%|████████▏ | 6612/8103 [00:00<00:00, 10780.56it/s]
Convert test_outputs:  95%|█████████▌| 7718/8103 [00:00<00:00, 10869.01it/s]
8103it [00:00, 10861.15it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10784.82it/s]


...finished saving.
Accuracy: 0.3917
Precision: 0.7664
Recall: 0.3868
F1-score: 0.4097
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  14%|█▎        | 1094/8103 [00:00<00:00, 10933.49it/s]
Convert test_outputs:  27%|██▋       | 2188/8103 [00:00<00:00, 10911.56it/s]
Convert test_outputs:  40%|████      | 3280/8103 [00:00<00:00, 10587.16it/s]
Convert test_outputs:  54%|█████▎    | 4354/8103 [00:00<00:00, 10645.92it/s]
Convert test_outputs:  67%|██████▋   | 5443/8103 [00:00<00:00, 10729.85it/s]
Convert test_outputs:  81%|████████  | 6529/8103 [00:00<00:00, 10772.09it/s]
Convert test_outputs:  94%|█████████▍| 7607/8103 [00:00<00:00, 10771.60it/s]
8103it [00:00, 10771.63it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10698.62it/s]


...finished saving.
Accuracy: 0.3844
Precision: 0.6875
Recall: 0.3760
F1-score: 0.4002
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  13%|█▎        | 1020/8103 [00:00<00:00, 10192.67it/s]
Convert test_outputs:  25%|██▌       | 2040/8103 [00:00<00:00, 9773.93it/s] 
Convert test_outputs:  38%|███▊      | 3114/8103 [00:00<00:00, 10204.79it/s]
Convert test_outputs:  52%|█████▏    | 4201/8103 [00:00<00:00, 10460.96it/s]
Convert test_outputs:  65%|██████▌   | 5288/8103 [00:00<00:00, 10606.07it/s]
Convert test_outputs:  78%|███████▊  | 6350/8103 [00:00<00:00, 10065.83it/s]
Convert test_outputs:  91%|█████████ | 7388/8103 [00:00<00:00, 10161.48it/s]
8103it [00:00, 10172.11it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10102.70it/s]


...finished saving.
Accuracy: 0.3827
Precision: 0.6647
Recall: 0.3818
F1-score: 0.4063
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  11%|█▏        | 927/8103 [00:00<00:00, 9266.98it/s]
Convert test_outputs:  24%|██▎       | 1921/8103 [00:00<00:00, 9662.38it/s]
Convert test_outputs:  37%|███▋      | 2958/8103 [00:00<00:00, 9984.71it/s]
Convert test_outputs:  49%|████▉     | 3957/8103 [00:00<00:00, 9160.07it/s]
Convert test_outputs:  60%|██████    | 4883/8103 [00:00<00:00, 8333.75it/s]
Convert test_outputs:  71%|███████   | 5738/8103 [00:00<00:00, 8399.83it/s]
Convert test_outputs:  81%|████████▏ | 6593/8103 [00:00<00:00, 8436.54it/s]
Convert test_outputs:  92%|█████████▏| 7444/8103 [00:00<00:00, 7728.48it/s]
8103it [00:00, 8283.75it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 8224.24it/s]


...finished saving.
Accuracy: 0.3892
Precision: 0.7495
Recall: 0.3792
F1-score: 0.4037
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  14%|█▎        | 1101/8103 [00:00<00:00, 11005.55it/s]
Convert test_outputs:  27%|██▋       | 2209/8103 [00:00<00:00, 11044.67it/s]
Convert test_outputs:  41%|████      | 3314/8103 [00:00<00:00, 10958.16it/s]
Convert test_outputs:  54%|█████▍    | 4410/8103 [00:00<00:00, 10933.18it/s]
Convert test_outputs:  68%|██████▊   | 5504/8103 [00:00<00:00, 10896.20it/s]
Convert test_outputs:  81%|████████▏ | 6595/8103 [00:00<00:00, 10897.58it/s]
Convert test_outputs:  95%|█████████▍| 7685/8103 [00:00<00:00, 10831.63it/s]
8103it [00:00, 10800.85it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10729.10it/s]


...finished saving.
Accuracy: 0.3941
Precision: 0.7094
Recall: 0.3879
F1-score: 0.4124
Save predictions...


Convert test_outputs:   0%|          | 0/8103 [00:00<?, ?it/s]
Convert test_outputs:  13%|█▎        | 1047/8103 [00:00<00:00, 10469.59it/s]
Convert test_outputs:  26%|██▌       | 2094/8103 [00:00<00:00, 10343.26it/s]
Convert test_outputs:  39%|███▊      | 3129/8103 [00:00<00:00, 10316.59it/s]
Convert test_outputs:  51%|█████▏    | 4161/8103 [00:00<00:00, 10305.85it/s]
Convert test_outputs:  64%|██████▍   | 5200/8103 [00:00<00:00, 10332.51it/s]
Convert test_outputs:  77%|███████▋  | 6234/8103 [00:00<00:00, 10290.98it/s]
Convert test_outputs:  90%|████████▉ | 7291/8103 [00:00<00:00, 10381.36it/s]
8103it [00:00, 10377.71it/s]
Convert test_outputs: 100%|██████████| 8103/8103 [00:00<00:00, 10310.20it/s]

...finished saving.





In [42]:
# Print the metrics for each fold
# for i in range(n_folds):
#    print(f"Fold {i+1} - Accuracy: {metrics['accuracy'][i]:.4f}, Precision: {metrics['precision'][i]:.4f}, Recall: {metrics['recall'][i]:.4f}, F1-score: {metrics['f1'][i]:.4f}")

# Calculate and print average metrics
avg_accuracy = np.mean(metrics["accuracy"])
avg_precision = np.mean(metrics["precision"])
avg_recall = np.mean(metrics["recall"])
avg_f1 = np.mean(metrics["f1"])

print(f"Average - Accuracy: {avg_accuracy:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1-score: {avg_f1:.4f}")

Average - Accuracy: 0.3892, Precision: 0.7281, Recall: 0.3819, F1-score: 0.4061


In [43]:
folder_to_zip = "models"
output_filename = 'models.zip'
shutil.make_archive(output_filename.replace('.zip', ''), 'zip', folder_to_zip)
files.download(output_filename)

folder_to_zip = "test_outputs"
output_filename = 'test_outputs.zip'
shutil.make_archive(output_filename.replace('.zip', ''), 'zip', folder_to_zip)
files.download(output_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>