In [1]:
# == Hyperparameter configuration ==

# Official scored labels Physionet 2021: https://github.com/physionetchallenges/evaluation-2021/blob/main/dx_mapping_scored.csv

# 0 = 426783006 -> sinus rhythm (SR)
# 1 = 164889003 -> atrial fibrillation (AF)
# 2 = 164890007 -> atrial flutter (AFL)
# 3 = 284470004 or 63593006 -> premature atrial contraction (PAC) or supraventricular premature beats (SVPB)
# 4 = 427172004 or 17338001 -> premature ventricular contractions (PVC), ventricular premature beats (VPB)
# 5 = 6374002 -> bundle branch block (BBB)
# 6 = 426627000 -> bradycardia (Brady)
# 7 = 733534002 or 164909002 -> complete left bundle branch block (CLBBB), left bundle branch block (LBBB)
# 8 = 713427006 or 59118001 -> complete right bundle branch block (CRBBB), right bundle branch block (RBBB)
# 9 = 270492004 -> 1st degree av block (IAVB)
# 10 = 713426002 -> incomplete right bundle branch block (IRBBB)
# 11 = 39732003 -> left axis deviation (LAD)
# 12 = 445118002 -> left anterior fascicular block (LAnFB)
# 13 = 251146004 -> low qrs voltages (LQRSV)
# 14 = 698252002 -> nonspecific intraventricular conduction disorder (NSIVCB)
# 15 = 10370003 -> pacing rhythm (PR)
# 16 = 365413008 -> poor R wave Progression (PRWP)
# 17 = 164947007 -> prolonged pr interval (LPR)
# 18 = 111975006 -> prolonged qt interval (LQT)
# 19 = 164917005 -> qwave abnormal (QAb)
# 20 = 47665007 -> right axis deviation (RAD)
# 21 = 427393009 -> sinus arrhythmia (SA)
# 22 = 426177001 -> sinus bradycardia (SB)
# 23 = 427084000 -> sinus tachycardia (STach)
# 24 = 164934002 -> t wave abnormal (TAb)
# 25 = 59931005 -> t wave inversion (TInv)

VALID_LABELS = set(
    [
        "164889003",
        "164890007",
        "6374002",
        "426627000",
        "733534002",
        "713427006",
        "270492004",
        "713426002",
        "39732003",
        "445118002",
        "164909002",
        "251146004",
        "698252002",
        "426783006",
        "284470004",
        "10370003",
        "365413008",
        "427172004",
        "164947007",
        "111975006",
        "164917005",
        "47665007",
        "59118001",
        "427393009",
        "426177001",
        "427084000",
        "63593006",
        "164934002",
        "59931005",
        "17338001",
    ]
)
# VALID_LABELS = set(["426783006", "164889003", "164890007", "284470004", "427172004"]) # SR, AF, AFL, PAC, PVC
NUM_CLASSES =  26
EPOCHS = 50
LEARNING_RATE = 0.001
BATCH_SIZE = 32

In [2]:
# == Check if GPU is available ==

!nvidia-smi

Mon Jul  8 17:46:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
# == Install requirements ==

!pip install google-colab
!pip install tensorflow keras numpy
!pip install h5py tqdm
!pip install pandas scipy imblearn
!pip install matplotlib

Collecting jedi>=0.16 (from ipython==7.34.0->google-colab)
  Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi
Successfully installed jedi-0.19.1
Collecting imblearn
  Downloading imblearn-0.0-py2.py3-none-any.whl (1.9 kB)
Installing collected packages: imblearn
Successfully installed imblearn-0.0


In [4]:
# == Import requirements ==

import warnings
warnings.filterwarnings("ignore")
import logging

from google.colab import drive, files
import os
import h5py
from tqdm import tqdm
import pandas as pd
from collections import Counter

import random
import scipy
from scipy.signal import butter, lfilter
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers

import matplotlib.pyplot as plt

import shutil
from itertools import zip_longest

In [5]:
# == Preprocess functions ==

def pad_or_truncate_ecg(ecg: list, max_samples: int) -> list:
    try:
        padded_or_truncated_ecg = ecg[:max_samples] + [0] * (max_samples - len(ecg))
    except Exception as e:
        print("Fail: padding", e)
    return padded_or_truncated_ecg

def resample_ecg(ecg: list, resample: int):
    new_ecg = scipy.signal.resample(
        ecg, resample, t=None, axis=0, window=None, domain="time"
    )
    return list(new_ecg)

def normalize_to_minus11(ecg: list):
    max_val = max(ecg)
    min_val = min(ecg)
    # Handle the case where max_val and min_val are the same (to avoid division by zero)
    if max_val == min_val:
        return [0 for _ in ecg]
    normalized_values = [2 * (x - min_val) / (max_val - min_val) - 1 for x in ecg]
    return normalized_values

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype="band")
    return b, a

def butter_bandpass_filter(ecg: list, lowcut: float, highcut: float, sampling_rate: int, order: int =4):
    b, a = butter_bandpass(lowcut, highcut, sampling_rate, order=order)
    y = lfilter(b, a, ecg)
    return y

def split_list_into_n_sublists(lst, n):
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]

In [6]:
# == Map labels to numerical values functions ==

# Official scored labels Physionet 2021: https://github.com/physionetchallenges/evaluation-2021/blob/main/dx_mapping_scored.csv

arrhyhtmia_mapping_id_to_index = {
    "426783006": 0, # sinus rhythm (SR)
    "164889003": 1, # atrial fibrillation (AF)
    "164890007": 2, # atrial flutter (AFL)
    "284470004": 3, # premature atrial contraction (PAC)
    "63593006": 3, # supraventricular premature beats (SVPB)
    "427172004": 4, # premature ventricular contractions (PVC)
    "17338001": 4, # ventricular premature beats (VPB)
    "6374002": 5, # bundle branch block (BBB)
    "426627000": 6, # bradycardia (Brady)
    "733534002": 7, # complete left bundle branch block (CLBBB)
    "164909002": 7, # left bundle branch block (LBBB)
    "713427006": 8, # complete right bundle branch block (CRBBB)
    "59118001": 8, # right bundle branch block (RBBB)
    "270492004": 9, # 1st degree av block (IAVB)
    "713426002": 10, # incomplete right bundle branch block (IRBBB)
    "39732003": 11, # left axis deviation (LAD)
    "445118002": 12, # left anterior fascicular block (LAnFB)
    "251146004": 13, # low qrs voltages (LQRSV)
    "698252002": 14, # nonspecific intraventricular conduction disorder (NSIVCB)
    "10370003": 15, # pacing rhythm (PR)
    "365413008": 16, # poor R wave Progression (PRWP)
    "164947007": 17, # prolonged pr interval (LPR)
    "111975006": 18, # prolonged qt interval (LQT)
    "164917005": 19, # qwave abnormal (QAb)
    "47665007": 20,  # right axis deviation (RAD)
    "427393009": 21, # sinus arrhythmia (SA)
    "426177001": 22, # sinus bradycardia (SB)
    "427084000": 23, # sinus tachycardia (STach)
    "164934002": 24, # t wave abnormal (TAb)
    "59931005": 25 # t wave inversion (TInv)
}

def map_arrhyhtmia_id_to_index(x: str) -> int:
    return arrhyhtmia_mapping_id_to_index[x]

arrhyhtmia_mapping_index_to_id = {
    0: "426783006", # sinus rhythm (SR)
    1: "164889003", # atrial fibrillation (AF)
    2: "164890007", # atrial flutter (AFL)
    3: "284470004|63593006", # premature atrial contraction (PAC) | supraventricular premature beats (SVPB)
    4: "427172004|17338001", # premature ventricular contractions (PVC) | ventricular premature beats (VPB)
    5: "6374002", # bundle branch block (BBB)
    6: "426627000", # bradycardia (Brady)
    7: "733534002|164909002", # complete left bundle branch block (CLBBB) | left bundle branch block (LBBB)
    8: "713427006|59118001", # complete right bundle branch block (CRBBB) | right bundle branch block (RBBB)
    9: "270492004", # 1st degree av block (IAVB)
    10: "713426002", # incomplete right bundle branch block (IRBBB)
    11: "39732003", # left axis deviation (LAD)
    12: "445118002", # left anterior fascicular block (LAnFB)
    13: "251146004", # low qrs voltages (LQRSV)
    14: "698252002", # nonspecific intraventricular conduction disorder (NSIVCB)
    15: "10370003", # pacing rhythm (PR)
    16: "365413008", # poor R wave Progression (PRWP)
    17: "164947007", # prolonged pr interval (LPR)
    18: "111975006", # prolonged qt interval (LQT)
    19: "164917005", # qwave abnormal (QAb)
    20: "47665007",  # right axis deviation (RAD)
    21: "427393009", # sinus arrhythmia (SA)
    22: "426177001", # sinus bradycardia (SB)
    23: "427084000", # sinus tachycardia (STach)
    24: "164934002", # t wave abnormal (TAb)
    25: "59931005" # t wave inversion (TInv)
}

# arrhyhtmia_mapping_index_to_id = dict(map(reversed, arrhyhtmia_mapping_id_to_index.items()))

def map_arrhyhtmia_index_to_id(x: int) -> str:
    return arrhyhtmia_mapping_index_to_id[x]


In [7]:
# == Mount drive ==

# https://drive.google.com/drive/folders/1L_gOMrkygu2N0k97COYuVrmE-AwEEMoQ

drive.mount('/content/drive')
path = "/content/drive/My Drive/Master Thesis/Datasets"
!ls "/content/drive/My Drive/Master Thesis/Datasets"

Mounted at /content/drive
codes_SNOMED.csv  physionet2017_references.csv	physionet2021_references.csv
physionet2017.h5  physionet2021.h5		prepared


In [None]:
# == Load all Physionet2021 ECGs and their IDs to a dictionary X_dict ==

X_dict = {}
Y_dict = {}

h5file = h5py.File(os.path.join(path, "prepared/physionet2021_scoredLabels.h5"), "r")
IDs = list(h5file.keys())
pbar = tqdm(total=len(IDs), desc="Load ECG data", position=0, leave=True)
for key in IDs:
    X_dict[key] = list(h5file[key][0])
    pbar.update(1)

# == Load all labels and their IDs to a dictionary Y_dict (some ECGs can have multiple labels) ==

labels_df = pd.read_csv(os.path.join(path, "physionet2021_references.csv"), sep=";")
pbar = tqdm(total=len(labels_df), desc="Load ECG labels", position=0, leave=True)
for _, row in labels_df.iterrows():
    labels = row["labels"].strip().split(",")
    binary_crossentropy_labels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    for label in labels:
        if label in VALID_LABELS:
              binary_crossentropy_labels[(map_arrhyhtmia_id_to_index(label))] = 1
    if row["id"] in X_dict:
        Y_dict[row["id"]] = binary_crossentropy_labels
    pbar.update(1)

# del IDs, h5file

In [None]:
# == Preprocess ECGs ==

pbar = tqdm(total=len(X_dict), desc="Preprocess ECGs", position=0, leave=True)
for key in X_dict:
    X_dict[key] = pad_or_truncate_ecg(ecg=X_dict[key], max_samples=5000)
    X_dict[key] = resample_ecg(ecg=X_dict[key], resample=2000)
    X_dict[key] = normalize_to_minus11(ecg=X_dict[key])
    X_dict[key] = butter_bandpass_filter(ecg=X_dict[key], lowcut=0.3, highcut=21.0, sampling_rate=200)
    pbar.update(1)

In [None]:
# == Map scored labels to ECGs and create three lists (X: ECGs, Y: labels, Z: IDs) ==

X = []
Y = []
Z = []

for patient_id in tqdm(Y_dict, desc="Map labels to ECGs", position=0, leave=True):
      X.append(X_dict[patient_id])
      Y.append(Y_dict[patient_id])
      Z.append(str(patient_id))

In [None]:
# == Shuffle data, convert to numpy lists and reshape ==

# Shuffle data
combined = list(zip(X, Y, Z))
random.shuffle(combined)
X, Y, Z = zip(*combined)
X = list(X)
Y = list(Y)
Z = list(Z)

# Convert to numpy lists
for index, x in enumerate(X):
    X[index] = np.array(x)
X = np.array(X)

for index, y in enumerate(Y):
    Y[index] = np.array(y)
Y = np.array(Y)

Z = np.array(Z)

# Reshape
X = X.reshape((-1, 2000, 1))

In [None]:
# == A-B-testing models ==

# Model A: Residual_CNN_1lead
def conv(i, filters=16, kernel_size=9, strides=1):
    i = keras.layers.Conv1D(
        filters=filters, kernel_size=kernel_size, strides=strides, padding="same"
    )(i)
    i = keras.layers.BatchNormalization()(i)
    i = keras.layers.LeakyReLU()(i)
    i = keras.layers.SpatialDropout1D(0.1)(i)
    return i
def residual_unit(x, filters, layers=3):
    inp = x
    for i in range(layers):
        x = conv(x, filters)
    return keras.layers.add([x, inp])
def conv_block(x, filters, strides):
    x = conv(x, filters)
    x = residual_unit(x, filters)
    if strides > 1:
        x = keras.layers.AveragePooling1D(strides, strides)(x)
    return x
def build_model_A(input_shape, num_classes):
    inp = keras.layers.Input(input_shape)
    x = inp
    x = conv_block(x, 16, 4)
    x = conv_block(x, 16, 4)
    x = conv_block(x, 32, 4)
    x = conv_block(x, 32, 4)
    x = keras.layers.Masking(mask_value=0)(x)
    x = keras.layers.GRU(128, recurrent_dropout=0.1)(x)
    x = keras.layers.Dense(num_classes, activation="sigmoid")(x)
    model = keras.models.Model(inp, x)
    return model

# Model B: CNN_Transformer_1lead
def build_model_B(num_classes, input_shape):
    # input_shape = (2000, 1)  # Each sample has 2000 timesteps and 1 feature per timestep
    input_layer = keras.layers.Input(input_shape)

    # Masking for padded/truncated data
    i = keras.layers.Masking(mask_value=0)(input_layer)
    # Conv1
    i = keras.layers.Conv1D(filters=16, kernel_size=9, strides=1, padding="same")(i)
    i = keras.layers.BatchNormalization()(i)
    i = keras.layers.ReLU()(i)
    i = keras.layers.SpatialDropout1D(0.1)(i)
    i = keras.layers.AveragePooling1D(2)(i)
    # Conv2
    i = keras.layers.Conv1D(filters=32, kernel_size=9, strides=1, padding="same")(i)
    i = keras.layers.BatchNormalization()(i)
    i = keras.layers.ReLU()(i)
    i = keras.layers.SpatialDropout1D(0.1)(i)
    i = keras.layers.AveragePooling1D(2)(i)
    # Conv3
    i = keras.layers.Conv1D(filters=64, kernel_size=9, strides=1, padding="same")(i)
    i = keras.layers.BatchNormalization()(i)
    i = keras.layers.ReLU()(i)
    i = keras.layers.SpatialDropout1D(0.1)(i)
    i = keras.layers.AveragePooling1D(2)(i)
    # Channel Average Pooling and Reshaping
    i = keras.layers.GlobalAveragePooling1D(data_format="channels_first")(i)
    i = keras.layers.Reshape((5, 50))(i)
    # Encoder block/Attention mechanisms
    i = transformer_encoder(i, input_shape=(5, 50), key_dim=50, num_heads=1, ff_dim=24, dropout=0.1)
    i = transformer_encoder(i, input_shape=(5, 50), key_dim=50, num_heads=1, ff_dim=24, dropout=0.1)
    i = transformer_encoder(i, input_shape=(5, 50), key_dim=50, num_heads=1, ff_dim=24, dropout=0.1)
    # Flatten
    i = keras.layers.Flatten()(i)
    # Feedforward Softmax
    i = keras.layers.Dense(num_classes, activation="sigmoid")(i)
    return keras.models.Model(inputs=input_layer, outputs=i)

# Model C: vanilla_Transformer_1lead
def transformer_encoder(input, input_shape, num_heads, key_dim, ff_dim, dropout):
    # Multi-Head Attention
    x = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim, dropout=dropout, kernel_regularizer=regularizers.l2(0.001))(input, input)
    # Add & Normalize
    res = x + input
    x = keras.layers.LayerNormalization(epsilon=1e-6)(res)
    # Feed-Forward Layer
    x = keras.layers.Flatten(input_shape=input_shape)(x)
    x = keras.layers.Dense(units=ff_dim, activation='relu', kernel_regularizer=regularizers.l2(0.001))(x)
    x = keras.layers.Dense(input_shape[0] * input_shape[1], kernel_regularizer=regularizers.l2(0.001))(x)
    x = keras.layers.Reshape(input_shape)(x)
    x = keras.layers.Dropout(rate=dropout)(x)
    # Add & Normalize
    x = x + res
    x = keras.layers.LayerNormalization(epsilon=1e-6)(x)
    return x

def get_positional_encoding(seq_length, d_model):
    position = np.arange(seq_length)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pe = np.zeros((seq_length, d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    pe = pe[np.newaxis, :]
    return tf.constant(pe, dtype=tf.float32)

def build_model_C(num_classes, input_shape, positional_encoding, num_encoder_blocks, num_heads, key_dim, ff_dim, dropout):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    if positional_encoding:
        positional_encoding_values = get_positional_encoding(input_shape[0], input_shape[1])
        x = x + positional_encoding_values
    for _ in range(num_encoder_blocks):
        x = transformer_encoder(x, input_shape, key_dim, num_heads, ff_dim, dropout)
    x = keras.layers.Flatten(input_shape=input_shape)(x)
    outputs = keras.layers.Dense(num_classes, activation='sigmoid')(x)
    return keras.Model(inputs, outputs)

# Model D: channel_attention_1lead
from tensorflow.keras.layers import Input, Conv1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Dense, Multiply, Add, Activation
from tensorflow.keras.models import Model

def channel_attention(input_feature, ratio=8):
    channel = input_feature.shape[-1]

    shared_layer_one = Dense(channel//ratio, activation='relu', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')
    shared_layer_two = Dense(channel, kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')

    avg_pool = GlobalAveragePooling1D()(input_feature)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)

    max_pool = GlobalMaxPooling1D()(input_feature)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)

    cbam_feature = Add()([avg_pool, max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)

    return Multiply()([input_feature, cbam_feature])

def build_model_D(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    x = Conv1D(64, kernel_size=3, padding='same', activation='relu')(inputs)
    x = channel_attention(x)

    x = Conv1D(128, kernel_size=3, padding='same', activation='relu')(x)
    x = channel_attention(x)

    x = Conv1D(256, kernel_size=3, padding='same', activation='relu')(x)
    x = channel_attention(x)

    x = GlobalAveragePooling1D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    return model

In [None]:
# == Initialize model ==

# Explicitly specify the GPU device

physical_devices = tf.config.experimental.list_physical_devices("GPU")
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

num_encoder_blocks = 8 # 1 8
positional_encoding = False # True False
num_heads = 1 # 1 8
key_dim = 25 # 25
ff_dim = 24 # 24 2048
dropout = 0.1 # 0.1 0.4
input_shape = X.shape

# Check if a GPU is available
print("Number of GPUs available:", len(tf.config.experimental.list_physical_devices("GPU")))
print(f"Number of training data examples: {len(X)}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Shape of training data: {input_shape}")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Batch size: {BATCH_SIZE}")

In [None]:
try:
    if os.path.exists("models"):
        shutil.rmtree("models")
    os.makedirs("models")
except OSError as e:
    print(f"Error: {e.strerror}")

try:
    if os.path.exists("test_outputs"):
        shutil.rmtree("test_outputs")
    os.makedirs("test_outputs")
except OSError as e:
    print(f"Error: {e.strerror}")

In [None]:
# == Save predictions util functin ==

def save_predictions(pred, pred_prob, z_test):
    pbar = tqdm(total=len(pred), desc="Convert test_outputs", position=0, leave=True)
    for index, prediction in enumerate(tqdm(zip(pred, pred_prob))):
        pbar.update(1)
        new_file = "#"
        new_file += z_test[index] + "\n"
        # ids
        for pred_index, _ in enumerate(prediction[0]):
            new_file += map_arrhyhtmia_index_to_id(pred_index) + ","
        new_file = new_file[:-1] + "\n"
        # pred
        for pred_index, _ in enumerate(prediction[0]):
            if prediction[0][pred_index] == 1:
                value = "True"
            elif prediction[0][pred_index] == 0:
                value = "False"
            new_file += value + ","
        new_file = new_file[:-1] + "\n"
        # pred_prob
        for pred_index, _ in enumerate(prediction[1]):
            new_file += str(prediction[1][pred_index]) + ","
        new_file = new_file[:-1]
        with open(f"test_outputs/{z_test[index]}.csv", "w", encoding="utf-8") as file:
            file.write(new_file)

In [None]:
# == Plot distribution util functin ==

def plot_distribution(Y):
    extracted_labels_testset = []
    for label in Y:
        for index, _ in enumerate(label):
            if label[index]:
                extracted_labels_testset.append(index)

    label_counts = Counter(extracted_labels_testset)
    print(label_counts)
    combined = list(zip(label_counts.keys(), label_counts.values()))
    combined.sort(key=lambda x: x[1], reverse=True)
    label_keys, label_values = zip(*combined)
    label_keys = list(label_keys)
    label_values = list(label_values)

    label_keys = list(map(lambda x: str(x), label_keys))

    # for index, key in enumerate(label_keys):
    #    label_keys[index] = str(index + 1) + ". " + key[0].upper() + key[1:]

    plt.figure(figsize=(30, 10))
    plt.bar(label_keys, label_values, color="#1f77b4")
    plt.title("Physionet 2021 labels")
    plt.xlabel("Arrhythmia type", labelpad=7)
    plt.ylabel("Occurence")
    plt.xticks(rotation=45, ha="right", fontsize=16)  # (rotation='diagional')
    bars = plt.bar(label_keys, label_values, color="#1f77b4")
    # Adding the counts on top of the bars
    for bar in bars:
        yval = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            yval + 5,
            yval,
            ha="center",
            va="bottom",
            fontsize=16,
        )

    plt.show()
    plt.close()

In [None]:
# == Train model ==

# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.FATAL)

# Initialize KFold
fold_counter = 1
n_folds = 10
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
metrics = {
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": []
}

EPOCHS = 50

# Cross-validation
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = Y[train_index], Y[test_index]
    z_train, z_test = Z[train_index], Z[test_index]

    callbacks = [
        keras.callbacks.ModelCheckpoint(f"Model_{NUM_CLASSES}classes.h5", save_best_only=True, monitor="loss"),
        keras.callbacks.ReduceLROnPlateau(monitor="loss", factor=0.5, patience=5, min_lr=0.000005),
        keras.callbacks.EarlyStopping(monitor="loss", patience=5, verbose=1),
    ]
    optimizer = Adam(learning_rate=LEARNING_RATE)
    model = build_model_A(num_classes=NUM_CLASSES, input_shape=input_shape[1:])
    # model = build_model_C(num_classes=NUM_CLASSES, input_shape=input_shape[1:], positional_encoding=positional_encoding, num_encoder_blocks=num_encoder_blocks, num_heads=num_heads, key_dim=key_dim, ff_dim=ff_dim, dropout=dropout)
    # model.summary()
    model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
    history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_split=0.8, callbacks=callbacks, verbose=1)

    model.save(f"models/{fold_counter}")
    fold_counter += 1

    # Predict probabilities
    pred_prob = model.predict(X_test)
    # Binarize predictions using a threshold
    threshold = 0.5
    pred = (pred_prob > threshold).astype(int)

    accuracy = accuracy_score(y_test, pred)
    precision = precision_score(y_test, pred, average='weighted')
    recall = recall_score(y_test, pred, average='weighted')
    f1 = f1_score(y_test, pred, average='weighted')

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-score: {f1:.4f}")

    print("Save predictions...")
    save_predictions(pred, pred_prob, z_test)
    print("...finished saving.")

    metrics["accuracy"].append(accuracy)
    metrics["precision"].append(precision)
    metrics["recall"].append(recall)
    metrics["f1"].append(f1)

# train_accuracy = history.history["accuracy"] # list
# val_accuracy = history.history["val_accuracy"] # list
# train_loss = history.history["loss"] # list
# val_loss = history.history["val_loss"] # list

In [None]:
# Print the metrics for each fold
# for i in range(n_folds):
#    print(f"Fold {i+1} - Accuracy: {metrics['accuracy'][i]:.4f}, Precision: {metrics['precision'][i]:.4f}, Recall: {metrics['recall'][i]:.4f}, F1-score: {metrics['f1'][i]:.4f}")

# Calculate and print average metrics
avg_accuracy = np.mean(metrics["accuracy"])
avg_precision = np.mean(metrics["precision"])
avg_recall = np.mean(metrics["recall"])
avg_f1 = np.mean(metrics["f1"])

print(f"Average - Accuracy: {avg_accuracy:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1-score: {avg_f1:.4f}")

In [None]:
folder_to_zip = "models"
output_filename = 'models.zip'
shutil.make_archive(output_filename.replace('.zip', ''), 'zip', folder_to_zip)
files.download(output_filename)

folder_to_zip = "test_outputs"
output_filename = 'test_outputs.zip'
shutil.make_archive(output_filename.replace('.zip', ''), 'zip', folder_to_zip)
files.download(output_filename)