# Importing requirements

In [1]:
import os
import numpy as np
import pandas as pd
import wfdb
import ast
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from collections import Counter
import time
from tqdm import tqdm


2024-12-03 03:25:37.447654: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733176537.468291 3993000 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733176537.474697 3993000 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-03 03:25:37.497775: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Loading Data from dataset file

In [3]:
DATA_PATH = 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3'

ptbxl_df = pd.read_csv(os.path.join(DATA_PATH, 'ptbxl_database.csv'))
scp_statements = pd.read_csv(os.path.join(DATA_PATH, 'scp_statements.csv'), index_col=0)

diagnostic_scps = scp_statements[scp_statements['diagnostic'] == 1].index.values

scp_to_superclass = scp_statements['diagnostic_class'].to_dict()
scp_to_subclass = scp_statements['diagnostic_subclass'].to_dict()

In [4]:
ptbxl_df['scp_codes'] = ptbxl_df['scp_codes'].apply(lambda x: ast.literal_eval(x))

In [5]:
def aggregate_diagnostic_labels(df, scp_codes, scp_to_agg):
    df = df.copy()
    def aggregate_labels(scp_codes_dict):
        labels = set()
        for code in scp_codes_dict.keys():
            if code in scp_codes:
                label = scp_to_agg.get(code)
                if label:
                    labels.add(label)
        return list(labels)
    df['diagnostic_labels'] = df['scp_codes'].apply(aggregate_labels)
    return df

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_superclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'superclass_labels'})

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_subclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'subclass_labels'})

In [6]:
ptbxl_df = ptbxl_df[ptbxl_df['superclass_labels'].map(len) > 0]

In [7]:
train_df = ptbxl_df[ptbxl_df.strat_fold <= 8]
val_df = ptbxl_df[ptbxl_df.strat_fold == 9]
test_df = ptbxl_df[ptbxl_df.strat_fold == 10]

In [8]:
def load_data(df, sampling_rate, data_path):
    data = []
    i = 0
    if sampling_rate == 100:
        filenames = df['filename_lr'].values
    else:
        filenames = df['filename_hr'].values
    for filename in filenames:
        file_path = os.path.join(data_path, filename)
        signals, _ = wfdb.rdsamp(file_path)
        data.append(signals)
    return np.array(data)

X_train = load_data(train_df, sampling_rate=100, data_path=DATA_PATH)
X_val = load_data(val_df, sampling_rate=100, data_path=DATA_PATH)
X_test = load_data(test_df, sampling_rate=100, data_path=DATA_PATH)

In [9]:
train_labels_super = train_df['superclass_labels'].values
val_labels_super = val_df['superclass_labels'].values
test_labels_super = test_df['superclass_labels'].values

mlb_super = MultiLabelBinarizer()
y_train_super = mlb_super.fit_transform(train_labels_super)
y_val_super = mlb_super.transform(val_labels_super)
y_test_super = mlb_super.transform(test_labels_super)
classes_super = mlb_super.classes_

In [10]:
train_labels_sub = train_df['subclass_labels'].values
val_labels_sub = val_df['subclass_labels'].values
test_labels_sub = test_df['subclass_labels'].values

mlb_sub = MultiLabelBinarizer()
y_train_sub = mlb_sub.fit_transform(train_labels_sub)
y_val_sub = mlb_sub.transform(val_labels_sub)
y_test_sub = mlb_sub.transform(test_labels_sub)
classes_sub = mlb_sub.classes_

In [11]:
def normalize_data_per_channel(X):
    X = np.transpose(X, (0, 2, 1))
    mean = np.mean(X, axis=(0, 2), keepdims=True)
    std = np.std(X, axis=(0, 2), keepdims=True)
    X = (X - mean) / std
    X = np.transpose(X, (0, 2, 1))
    return X

X_train = normalize_data_per_channel(X_train)
X_val = normalize_data_per_channel(X_val)
X_test = normalize_data_per_channel(X_test)

In [12]:
class_counts_super = np.sum(y_train_super, axis=0)
total_samples_super = y_train_super.shape[0]

class_weight_super = {}
for i, count in enumerate(class_counts_super):
    class_weight_super[i] = total_samples_super / (len(class_counts_super) * count)

class_counts_sub = np.sum(y_train_sub, axis=0)
total_samples_sub = y_train_sub.shape[0]

class_weight_sub = {}
for i, count in enumerate(class_counts_sub):
    class_weight_sub[i] = total_samples_sub / (len(class_counts_sub) * count)

In [13]:
num_classes_super = y_train_super.shape[1]
class_totals = np.sum(y_train_super, axis=0)
class_weights = class_totals.max() / class_totals
weights_array = np.array(class_weights, dtype=np.float32)

In [14]:
num_classes_sub = y_train_sub.shape[1]
class_totals_sub = np.sum(y_train_sub, axis=0)
class_weights_sub = class_totals_sub.max() / class_totals_sub
weights_array_sub = np.array(class_weights_sub, dtype=np.float32)

In [15]:
y_train_super = y_train_super.astype(np.float32)
y_val_super = y_val_super.astype(np.float32)
y_test_super = y_test_super.astype(np.float32)

# Defining Entropy and Metrics

In [16]:
import tensorflow.keras.backend as K

def weighted_binary_crossentropy(weights):
    def loss(y_true, y_pred):
        weights_cast = K.cast(weights, y_pred.dtype)
        y_true = K.cast(y_true, y_pred.dtype)
        
        bce = K.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * weights_cast + (1 - y_true)
        weighted_bce = weight_vector * bce
        return K.mean(weighted_bce)
    return loss

def macro_f1(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    y_pred = K.round(y_pred)
    
    tp = K.sum(y_true * y_pred, axis=0)
    fp = K.sum((1 - y_true) * y_pred, axis=0)
    fn = K.sum(y_true * (1 - y_pred), axis=0)

    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())

    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
    return K.mean(f1)

def weighted_f1(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    y_pred = K.round(y_pred)
    tp = K.sum(y_true * y_pred, axis=0)
    fp = K.sum((1 - y_true) * y_pred, axis=0)
    fn = K.sum(y_true * (1 - y_pred), axis=0)
    support = K.sum(y_true, axis=0)
    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())
    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    weighted_f1 = K.sum(f1 * support) / K.sum(support)
    weighted_f1 = tf.where(tf.math.is_nan(weighted_f1), 0.0, weighted_f1)
    
    return weighted_f1

# Defining Models

In [17]:
def create_cnn_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv1D(64, kernel_size=7, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(128, kernel_size=5, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(256, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(512, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    
    model = models.Model(inputs, outputs)
    return model


In [18]:
# def create_resnet_model(input_shape, num_classes):
#     inputs = layers.Input(shape=input_shape)
#     x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    
#     previous_filters = x.shape[-1]
#     for filters in [64, 128, 256]:
#         x_shortcut = x
#         strides = 1
#         if previous_filters != filters:
#             strides = 2

#         x = layers.Conv1D(filters, kernel_size=3, strides=strides, padding='same')(x)
#         x = layers.BatchNormalization()(x)
#         x = layers.Activation('relu')(x)
#         x = layers.Conv1D(filters, kernel_size=3, padding='same')(x)
#         x = layers.BatchNormalization()(x)
        
#         if previous_filters != filters or strides != 1:
#             x_shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(x_shortcut)
#             x_shortcut = layers.BatchNormalization()(x_shortcut)
        
#         x = layers.Add()([x, x_shortcut])
#         x = layers.Activation('relu')(x)
#         previous_filters = filters
#     x = layers.GlobalAveragePooling1D()(x)
#     outputs = layers.Dense(num_classes, activation='sigmoid')(x)
#     model = models.Model(inputs, outputs)
#     return model

In [19]:
def residual_block_1d(x, filters, kernel_size=3, strides=1, downsample=False):
    shortcut = x
    
    x = layers.Conv1D(filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv1D(filters, kernel_size=kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    if downsample or shortcut.shape[-1] != filters:
        shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def create_resnet_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    layers_filters = [64, 128, 256, 512]
    layers_blocks = [3, 4, 6, 3]

    for filters, num_blocks in zip(layers_filters, layers_blocks):
        for i in range(num_blocks):
            if i == 0 and filters != x.shape[-1]:
                x = residual_block_1d(x, filters, strides=2, downsample=True)
            else:
                x = residual_block_1d(x, filters)

    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs, outputs)
    return model

In [20]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def create_vit_model(input_shape, num_classes):
    patch_size = 10 
    num_patches = input_shape[0] // patch_size
    projection_dim = 64
    num_heads = 4
    transformer_layers = 8
    mlp_head_units = [256, 128]
    dropout_rate = 0.1

    inputs = layers.Input(shape=input_shape)
    x = layers.Reshape((num_patches, patch_size * input_shape[1]))(inputs)
    x = layers.Dense(units=projection_dim)(x)
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
    x = x + position_embedding(positions)
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=[projection_dim * 2, projection_dim], dropout_rate=dropout_rate)
        x = layers.Add()([x3, x2])
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Defining the training loop

In [21]:
def train_model(model, X_train, y_train, X_val, y_val, class_weight, batch_size=64, epochs=25):
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=['accuracy', macro_f1, weighted_f1]
    )
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    ]
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        class_weight=class_weight
    )
    return history

# Training and Evaluating Models without CL

In [22]:
input_shape = X_train.shape[1:]
num_classes_super = y_train_super.shape[1]

cnn_super_model = create_cnn_model(input_shape, num_classes_super)
train_model(cnn_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

I0000 00:00:1733176603.616395 3993000 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31141 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
I0000 00:00:1733176603.617574 3993000 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31141 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:07:00.0, compute capability: 7.0
I0000 00:00:1733176603.618481 3993000 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 31141 MB memory:  -> device: 2, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0a:00.0, compute capability: 7.0
I0000 00:00:1733176603.619354 3993000 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 31141 MB memory:  -> device: 3, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0b:00.0, compute capability: 7.0
I0000 00:00:1733176603.620257 3993000 gpu_device.cc:2022] Created de

Epoch 1/25


I0000 00:00:1733176610.181785 3993651 service.cc:148] XLA service 0x7fe1fc03ded0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733176610.181822 3993651 service.cc:156]   StreamExecutor device (0): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181843 3993651 service.cc:156]   StreamExecutor device (1): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181848 3993651 service.cc:156]   StreamExecutor device (2): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181850 3993651 service.cc:156]   StreamExecutor device (3): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181869 3993651 service.cc:156]   StreamExecutor device (4): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181872 3993651 service.cc:156]   StreamExecutor device (5): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733176610.181876 3993651 service.cc:156]   StreamE

[1m 10/267[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4s[0m 17ms/step - accuracy: 0.3684 - loss: 0.4607 - macro_f1: 0.4481 - weighted_f1: 0.4830

I0000 00:00:1733176614.954350 3993651 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m265/267[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 17ms/step - accuracy: 0.5982 - loss: 0.3098 - macro_f1: 0.6185 - weighted_f1: 0.6538

E0000 00:00:1733176620.712555 3993649 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1733176621.006032 3993649 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 44ms/step - accuracy: 0.5987 - loss: 0.3094 - macro_f1: 0.6190 - weighted_f1: 0.6543 - val_accuracy: 0.6654 - val_loss: 0.3301 - val_macro_f1: 0.6820 - val_weighted_f1: 0.7177 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.6832 - loss: 0.2411 - macro_f1: 0.7032 - weighted_f1: 0.7390 - val_accuracy: 0.6566 - val_loss: 0.3517 - val_macro_f1: 0.6529 - val_weighted_f1: 0.6995 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.7026 - loss: 0.2278 - macro_f1: 0.7340 - weighted_f1: 0.7645 - val_accuracy: 0.6719 - val_loss: 0.3504 - val_macro_f1: 0.6656 - val_weighted_f1: 0.7078 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.7086 - loss: 0.2125 - macro_f1: 0.7474 - weighted_f1: 0.7775 - val_accuracy: 0

<keras.src.callbacks.history.History at 0x7fe6bba34910>

In [23]:
resnet_super_model = create_resnet_model(input_shape, num_classes_super)
train_model(resnet_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/25
[1m265/267[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 31ms/step - accuracy: 0.5538 - loss: 0.3813 - macro_f1: 0.5263 - weighted_f1: 0.5638

E0000 00:00:1733176763.685155 3993646 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1733176763.913225 3993646 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 98ms/step - accuracy: 0.5545 - loss: 0.3805 - macro_f1: 0.5273 - weighted_f1: 0.5647 - val_accuracy: 0.4744 - val_loss: 0.5546 - val_macro_f1: 0.5420 - val_weighted_f1: 0.5600 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 32ms/step - accuracy: 0.6738 - loss: 0.2569 - macro_f1: 0.6835 - weighted_f1: 0.7178 - val_accuracy: 0.5955 - val_loss: 0.4619 - val_macro_f1: 0.5937 - val_weighted_f1: 0.6360 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 33ms/step - accuracy: 0.6900 - loss: 0.2465 - macro_f1: 0.7079 - weighted_f1: 0.7365 - val_accuracy: 0.6314 - val_loss: 0.3758 - val_macro_f1: 0.6524 - val_weighted_f1: 0.6813 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.7050 - loss: 0.2326 - macro_f1: 0.7208 - weighted_f1: 0.7570 - val_accuracy: 0

<keras.src.callbacks.history.History at 0x7fe54033e640>

In [24]:
vit_super_model = create_vit_model(input_shape, num_classes_super)
train_model(vit_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 118ms/step - accuracy: 0.4160 - loss: 0.4427 - macro_f1: 0.4026 - weighted_f1: 0.4400 - val_accuracy: 0.5857 - val_loss: 0.3888 - val_macro_f1: 0.5262 - val_weighted_f1: 0.6025 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.6201 - loss: 0.2971 - macro_f1: 0.6167 - weighted_f1: 0.6531 - val_accuracy: 0.6617 - val_loss: 0.3411 - val_macro_f1: 0.6518 - val_weighted_f1: 0.6993 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.6700 - loss: 0.2532 - macro_f1: 0.6917 - weighted_f1: 0.7188 - val_accuracy: 0.6193 - val_loss: 0.3404 - val_macro_f1: 0.6591 - val_weighted_f1: 0.7015 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.6949 - loss: 0.2336 - macro_f1: 0.7176 - weighted_f1: 0.7494 - val

<keras.src.callbacks.history.History at 0x7fe4900da280>

In [25]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = create_cnn_model(input_shape, num_classes_sub)
train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 38ms/step - accuracy: 0.3719 - loss: 0.1269 - macro_f1: 0.0936 - weighted_f1: 0.2058 - val_accuracy: 0.4804 - val_loss: 0.1469 - val_macro_f1: 0.1673 - val_weighted_f1: 0.3670 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.4881 - loss: 0.0759 - macro_f1: 0.2070 - weighted_f1: 0.3929 - val_accuracy: 0.4832 - val_loss: 0.1427 - val_macro_f1: 0.1630 - val_weighted_f1: 0.3839 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.5141 - loss: 0.0717 - macro_f1: 0.2357 - weighted_f1: 0.4345 - val_accuracy: 0.4338 - val_loss: 0.1518 - val_macro_f1: 0.1888 - val_weighted_f1: 0.3482 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.5146 - loss: 0.0729 - macro_f1: 0.2523 - weighted_f1: 0.4636 - val_

<keras.src.callbacks.history.History at 0x7fe3482446a0>

In [26]:
resnet_sub_model = create_resnet_model(input_shape, num_classes_sub)
train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 92ms/step - accuracy: 0.2457 - loss: 0.1394 - macro_f1: 0.0542 - weighted_f1: 0.0637 - val_accuracy: 0.1058 - val_loss: 0.2291 - val_macro_f1: 0.1014 - val_weighted_f1: 0.1006 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 35ms/step - accuracy: 0.3928 - loss: 0.1108 - macro_f1: 0.0942 - weighted_f1: 0.1543 - val_accuracy: 0.3453 - val_loss: 0.2005 - val_macro_f1: 0.0238 - val_weighted_f1: 0.0315 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.4294 - loss: 0.0975 - macro_f1: 0.1229 - weighted_f1: 0.2275 - val_accuracy: 0.4497 - val_loss: 0.1600 - val_macro_f1: 0.1370 - val_weighted_f1: 0.3259 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.4540 - loss: 0.0895 - macro_f1: 0.1589 - weighted_f1: 0.2893 - val_

<keras.src.callbacks.history.History at 0x7fe32006fc40>

In [27]:
vit_sub_model = create_vit_model(input_shape, num_classes_sub)
train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 118ms/step - accuracy: 0.1426 - loss: 0.1779 - macro_f1: 0.0390 - weighted_f1: 0.0660 - val_accuracy: 0.2861 - val_loss: 0.1798 - val_macro_f1: 0.0691 - val_weighted_f1: 0.0771 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.3543 - loss: 0.0956 - macro_f1: 0.1190 - weighted_f1: 0.2114 - val_accuracy: 0.2074 - val_loss: 0.1901 - val_macro_f1: 0.1105 - val_weighted_f1: 0.1068 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.4358 - loss: 0.0729 - macro_f1: 0.1987 - weighted_f1: 0.3284 - val_accuracy: 0.4432 - val_loss: 0.1587 - val_macro_f1: 0.1704 - val_weighted_f1: 0.3453 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 23ms/step - accuracy: 0.5214 - loss: 0.0601 - macro_f1: 0.2738 - weighted_f1: 0.4634 - val_accuracy: 

<keras.src.callbacks.history.History at 0x7fe230434d30>

In [28]:
def evaluate_model(model, X_test, y_test, classes):
    y_pred = model.predict(X_test)
    y_pred_threshold = (y_pred >= 0.5).astype(int)
    report = classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0, output_dict=True)
    print(classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0))
    return report


In [29]:
print("CNN Superdiagnostic Classification Report:")
cnn_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet Superdiagnostic Classification Report:")
resnet_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT Superdiagnostic Classification Report:")
vit_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN Superdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step
              precision    recall  f1-score   support

          CD       0.81      0.66      0.73       496
         HYP       0.64      0.60      0.62       262
          MI       0.82      0.63      0.71       550
        NORM       0.82      0.91      0.86       963
        STTC       0.79      0.70      0.74       521

   micro avg       0.79      0.75      0.77      2792
   macro avg       0.77      0.70      0.73      2792
weighted avg       0.79      0.75      0.76      2792
 samples avg       0.78      0.77      0.76      2792

ResNet Superdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 42ms/step
              precision    recall  f1-score   support

          CD       0.84      0.68      0.75       496
         HYP       0.76      0.43      0.55       262
          MI       0.78      0.72      0.75       550
   

In [30]:
print("CNN Subdiagnostic Classification Report:")
cnn_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet Subdiagnostic Classification Report:")
resnet_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ViT Subdiagnostic Classification Report:")
vit_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)


CNN Subdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 10ms/step
              precision    recall  f1-score   support

         AMI       0.82      0.57      0.67       306
       CLBBB       0.86      0.89      0.87        54
       CRBBB       0.81      0.85      0.83        54
       ILBBB       0.08      0.12      0.10         8
         IMI       0.75      0.52      0.61       327
       IRBBB       0.58      0.71      0.64       112
        ISCA       0.63      0.20      0.31        93
        ISCI       0.45      0.25      0.32        40
        ISC_       0.76      0.41      0.53       128
        IVCD       0.11      0.03      0.04        79
   LAFB/LPFB       0.82      0.63      0.71       179
     LAO/LAE       0.20      0.02      0.04        42
         LMI       0.18      0.10      0.13        20
         LVH       0.74      0.56      0.64       214
        NORM       0.87      0.74      0.80       963
        NST_       0

# Defining and Training on LwF

In [31]:
cnn_soft_targets_super = cnn_super_model.predict(X_train)

def lwf_loss(y_true, y_pred, old_predictions, T=2):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dist_loss = tf.keras.losses.KLDivergence()(tf.nn.softmax(old_predictions / T),
                                               tf.nn.softmax(y_pred / T))
    total_loss = task_loss + dist_loss
    return total_loss

print("Working on CNN for LwF Now:")
cnn_model_lwf = create_cnn_model(input_shape, num_classes_sub)
cnn_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=cnn_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(cnn_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ResNet for LwF Now:")
resnet_soft_targets_super = resnet_super_model.predict(X_train)
resnet_model_lwf = create_resnet_model(input_shape, num_classes_sub)
resnet_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=resnet_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(resnet_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ViT for LwF Now:")
vit_soft_targets_super = vit_super_model.predict(X_train)
vit_model_lwf = create_vit_model(input_shape, num_classes_sub)
vit_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=vit_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(vit_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


[1m534/534[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step
Working on CNN for LwF Now:
Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 38ms/step - accuracy: 0.3383 - loss: 0.1227 - macro_f1: 0.1120 - weighted_f1: 0.2137 - val_accuracy: 0.4776 - val_loss: 0.1444 - val_macro_f1: 0.1634 - val_weighted_f1: 0.3630 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.4732 - loss: 0.0866 - macro_f1: 0.1967 - weighted_f1: 0.3642 - val_accuracy: 0.4786 - val_loss: 0.1414 - val_macro_f1: 0.2297 - val_weighted_f1: 0.4583 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.5161 - loss: 0.0780 - macro_f1: 0.2329 - weighted_f1: 0.4590 - val_accuracy: 0.4888 - val_loss: 0.1333 - val_macro_f1: 0.2242 - val_weighted_f1: 0.4177 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m

<keras.src.callbacks.history.History at 0x7fda23178280>

# Defining and Training on EwC

In [32]:
class EWC:
    def __init__(self, model, X, y, batch_size=32, exclude_params=[]):
        self.model = model
        self.params = {}
        for p in model.trainable_variables:
            if id(p) not in exclude_params:
                self.params[id(p)] = p.numpy()
        self.fisher = self.compute_fisher(X, y, batch_size, exclude_params)

    def compute_fisher(self, X, y, batch_size, exclude_params):
        fisher = {}
        num_samples = X.shape[0]
        num_batches = int(np.ceil(num_samples / batch_size))

        for batch_idx in range(num_batches):
            X_batch = X[batch_idx*batch_size:(batch_idx+1)*batch_size]
            y_batch = y[batch_idx*batch_size:(batch_idx+1)*batch_size]
            with tf.GradientTape() as tape:
                preds = self.model(X_batch)
                loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            grads = tape.gradient(loss, self.model.trainable_variables)
            for p, g in zip(self.model.trainable_variables, grads):
                if g is not None and id(p) not in exclude_params:
                    param_id = id(p)
                    if param_id not in fisher:
                        fisher[param_id] = np.square(g.numpy())
                    else:
                        fisher[param_id] += np.square(g.numpy())
        for k in fisher.keys():
            fisher[k] /= num_batches
        return fisher

    def penalty(self, model):
        loss = 0
        for p in model.trainable_variables:
            param_id = id(p)
            if param_id in self.fisher:
                fisher = tf.convert_to_tensor(self.fisher[param_id])
                loss += tf.reduce_sum(fisher * tf.square(p - self.params[param_id]))
        return loss

In [33]:
def modify_model_for_subdiagnostic(base_model, num_classes_sub):
    inputs = base_model.input
    x = inputs
    for layer in base_model.layers[1:-1]:
        x = layer(x)
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    new_model = models.Model(inputs=inputs, outputs=outputs)
    return new_model

In [34]:
lambda_ewc = 1000
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)
exclude_params_cnn = [id(w) for w in cnn_sub_model.layers[-1].trainable_weights]
ewc_cnn = EWC(cnn_super_model, X_train, y_train_super, exclude_params=exclude_params_cnn)

def ewc_loss_cnn(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_cnn.penalty(cnn_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

cnn_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_cnn,
    metrics=[macro_f1, weighted_f1]
)

train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 38ms/step - accuracy: 0.5276 - loss: 0.1146 - macro_f1: 0.2026 - weighted_f1: 0.4466 - val_accuracy: 0.5718 - val_loss: 0.1227 - val_macro_f1: 0.2635 - val_weighted_f1: 0.5418 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6050 - loss: 0.0623 - macro_f1: 0.3203 - weighted_f1: 0.5937 - val_accuracy: 0.5801 - val_loss: 0.1207 - val_macro_f1: 0.2876 - val_weighted_f1: 0.5552 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6120 - loss: 0.0575 - macro_f1: 0.3417 - weighted_f1: 0.6026 - val_accuracy: 0.6109 - val_loss: 0.1070 - val_macro_f1: 0.3139 - val_weighted_f1: 0.6064 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6372 - loss: 0.0514 - macro_f1: 0.3578 - weighted_f1: 0.6384 - val_accuracy: 0

<keras.src.callbacks.history.History at 0x7fe210a03040>

In [35]:
def modify_model_for_subdiagnostic_resnet(base_model, num_classes_sub):
    x = base_model.layers[-2].output
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    new_model = tf.keras.Model(inputs=base_model.input, outputs=outputs)
    return new_model

In [36]:
num_classes_sub = y_train_sub.shape[1]
resnet_sub_model = modify_model_for_subdiagnostic_resnet(resnet_super_model, num_classes_sub)
exclude_params_resnet = [w.name for w in resnet_sub_model.layers[-1].trainable_weights]
def ewc_loss_resnet(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_resnet.penalty(resnet_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

resnet_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_resnet,
    metrics=[macro_f1, weighted_f1]
)

train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


Epoch 1/25


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 93ms/step - accuracy: 0.4240 - loss: 0.1152 - macro_f1: 0.1485 - weighted_f1: 0.3050 - val_accuracy: 0.4744 - val_loss: 0.1427 - val_macro_f1: 0.2107 - val_weighted_f1: 0.4426 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5310 - loss: 0.0766 - macro_f1: 0.2384 - weighted_f1: 0.4613 - val_accuracy: 0.4264 - val_loss: 0.1474 - val_macro_f1: 0.1707 - val_weighted_f1: 0.3584 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5423 - loss: 0.0733 - macro_f1: 0.2413 - weighted_f1: 0.4751 - val_accuracy: 0.5550 - val_loss: 0.1356 - val_macro_f1: 0.2428 - val_weighted_f1: 0.5199 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5764 - loss: 0.0689 - macro_f1: 0.2624 - weighted_f1: 0.5109 - val_accuracy: 0

<keras.src.callbacks.history.History at 0x7fd8b6161e50>

In [37]:
def modify_model_for_subdiagnostic_vit(base_model, num_classes_sub):
    # Get the output of the layer before the last
    x = base_model.layers[-2].output  # Exclude the last layer
    # Add new output layer for subdiagnostic task
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    # Create new model
    new_model = tf.keras.Model(inputs=base_model.input, outputs=outputs)
    return new_model


In [38]:
# Modify ViT model for subdiagnostic task
num_classes_sub = y_train_sub.shape[1]
vit_sub_model = modify_model_for_subdiagnostic_vit(vit_super_model, num_classes_sub)

# Exclude the new output layer's parameters from EWC or SI calculations
exclude_params_vit = [w.name for w in vit_sub_model.layers[-1].trainable_weights]

def ewc_loss_vit(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_vit.penalty(vit_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

vit_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_vit,
    metrics=[macro_f1, weighted_f1]
)

train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m264/267[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 22ms/step - accuracy: 0.4485 - loss: 0.1181 - macro_f1: 0.1404 - weighted_f1: 0.3141

[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 114ms/step - accuracy: 0.4491 - loss: 0.1178 - macro_f1: 0.1409 - weighted_f1: 0.3149 - val_accuracy: 0.5191 - val_loss: 0.1394 - val_macro_f1: 0.2031 - val_weighted_f1: 0.4773 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 24ms/step - accuracy: 0.5377 - loss: 0.0708 - macro_f1: 0.2530 - weighted_f1: 0.4652 - val_accuracy: 0.5214 - val_loss: 0.1376 - val_macro_f1: 0.2281 - val_weighted_f1: 0.4723 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.6029 - loss: 0.0539 - macro_f1: 0.3121 - weighted_f1: 0.5650 - val_accuracy: 0.5769 - val_loss: 0.1262 - val_macro_f1: 0.2494 - val_weighted_f1: 0.5596 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 24ms/step - accuracy: 0.6274 - loss: 0.0422 - macro_f1: 0.3718 - weighted_f1: 0.6300 - val_accuracy: 

<keras.src.callbacks.history.History at 0x7fd8747d93d0>

# Defining and Training on SI

In [39]:
class SI:
    def __init__(self, prev_model, damping_factor=0.1, exclude_params=[]):
        self.prev_params = {}
        self.omega = {}
        self.damping_factor = damping_factor
        self.exclude_params = exclude_params

        self.delta_params = {}

        # Store parameters from the previous model (superdiagnostic task)
        for var in prev_model.trainable_variables:
            if var.name not in self.exclude_params:
                self.prev_params[var.name] = var.numpy().copy()
                self.omega[var.name] = np.zeros_like(var.numpy())
                self.delta_params[var.name] = np.zeros_like(var.numpy())

    def accumulate_importance(self, model, grads):
        for var, grad in zip(model.trainable_variables, grads):
            if grad is not None and var.name in self.prev_params:
                if var.shape == self.prev_params[var.name].shape:
                    delta_theta = var.numpy() - self.prev_params[var.name]
                    self.delta_params[var.name] += delta_theta
                    # Update omega with absolute value to prevent negative importance
                    self.omega[var.name] += np.abs(grad.numpy() * delta_theta)
                else:
                    # Skip variables with mismatched shapes
                    pass

    def update_omega(self):
        # Normalize omega after training
        for var_name in self.omega.keys():
            delta_param = self.delta_params[var_name]
            denom = np.square(delta_param) + self.damping_factor
            self.omega[var_name] = np.divide(self.omega[var_name], denom)
            # Ensure omega is non-negative
            self.omega[var_name] = np.abs(self.omega[var_name])
            # Reset delta_params for the next task
            self.delta_params[var_name] = np.zeros_like(delta_param)

    def penalty(self, model):
        loss = 0
        for var in model.trainable_variables:
            if var.name in self.prev_params:
                prev_param = self.prev_params[var.name]
                if var.shape == prev_param.shape:
                    omega = tf.convert_to_tensor(self.omega[var.name], dtype=var.dtype)
                    prev_param = tf.convert_to_tensor(prev_param, dtype=var.dtype)
                    # Ensure omega is non-negative
                    loss += tf.reduce_sum(omega * tf.square(var - prev_param))
                else:
                    # Skip variables with mismatched shapes
                    pass
        return loss


In [40]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)
exclude_params_cnn = [w.name for w in cnn_sub_model.layers[-1].trainable_weights]
si_cnn = SI(cnn_super_model, exclude_params=exclude_params_cnn)


In [41]:
lambda_si = 1.0  # Adjust as needed
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nCNN Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = cnn_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_cnn.penalty(cnn_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        grads = tape.gradient(total_loss, cnn_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, cnn_sub_model.trainable_variables))
        si_cnn.accumulate_importance(cnn_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = cnn_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

# After training, update omega
si_cnn.update_omega()



CNN Epoch 1/25


Training:  50%|█████     | 133/266 [00:23<00:22,  6.02it/s, loss=0.118, macro_f1=0.26] 

                                                                                                

Epoch 1/25, Time: 45.33s, Loss: 0.1034, Macro F1: 0.3007, Val Loss: 0.0991, Val Macro F1: 0.3018

CNN Epoch 2/25


                                                                                                

Epoch 2/25, Time: 42.52s, Loss: 0.0801, Macro F1: 0.3638, Val Loss: 0.0985, Val Macro F1: 0.3144

CNN Epoch 3/25


                                                                                                

Epoch 3/25, Time: 42.80s, Loss: 0.0727, Macro F1: 0.3947, Val Loss: 0.1003, Val Macro F1: 0.3161

CNN Epoch 4/25


                                                                                               

Epoch 4/25, Time: 43.61s, Loss: 0.0659, Macro F1: 0.4205, Val Loss: 0.1042, Val Macro F1: 0.3126

CNN Epoch 5/25


                                                                                               

Epoch 5/25, Time: 43.50s, Loss: 0.0591, Macro F1: 0.4541, Val Loss: 0.1095, Val Macro F1: 0.3195

CNN Epoch 6/25


                                                                                               

Epoch 6/25, Time: 43.88s, Loss: 0.0518, Macro F1: 0.4919, Val Loss: 0.1178, Val Macro F1: 0.3075

CNN Epoch 7/25


                                                                                               

Epoch 7/25, Time: 43.54s, Loss: 0.0445, Macro F1: 0.5265, Val Loss: 0.1277, Val Macro F1: 0.3093

CNN Epoch 8/25


                                                                                               

Epoch 8/25, Time: 42.67s, Loss: 0.0379, Macro F1: 0.5556, Val Loss: 0.1361, Val Macro F1: 0.3146

CNN Epoch 9/25


                                                                                               

Epoch 9/25, Time: 43.02s, Loss: 0.0322, Macro F1: 0.5781, Val Loss: 0.1476, Val Macro F1: 0.3201

CNN Epoch 10/25


                                                                                               

Epoch 10/25, Time: 43.81s, Loss: 0.0288, Macro F1: 0.5936, Val Loss: 0.1589, Val Macro F1: 0.3214

CNN Epoch 11/25


                                                                                               

Epoch 11/25, Time: 44.23s, Loss: 0.0257, Macro F1: 0.6034, Val Loss: 0.1705, Val Macro F1: 0.3150

CNN Epoch 12/25


                                                                                               

Epoch 12/25, Time: 43.91s, Loss: 0.0223, Macro F1: 0.6178, Val Loss: 0.1821, Val Macro F1: 0.3142

CNN Epoch 13/25


                                                                                               

Epoch 13/25, Time: 42.51s, Loss: 0.0194, Macro F1: 0.6292, Val Loss: 0.1887, Val Macro F1: 0.3105

CNN Epoch 14/25


                                                                                               

Epoch 14/25, Time: 42.09s, Loss: 0.0171, Macro F1: 0.6389, Val Loss: 0.2029, Val Macro F1: 0.3135

CNN Epoch 15/25


                                                                                               

Epoch 15/25, Time: 43.45s, Loss: 0.0147, Macro F1: 0.6496, Val Loss: 0.2021, Val Macro F1: 0.3191

CNN Epoch 16/25


                                                                                               

Epoch 16/25, Time: 44.50s, Loss: 0.0132, Macro F1: 0.6521, Val Loss: 0.2159, Val Macro F1: 0.3147

CNN Epoch 17/25


                                                                                               

Epoch 17/25, Time: 43.81s, Loss: 0.0124, Macro F1: 0.6572, Val Loss: 0.2341, Val Macro F1: 0.3144

CNN Epoch 18/25


                                                                                               

Epoch 18/25, Time: 44.12s, Loss: 0.0113, Macro F1: 0.6580, Val Loss: 0.2340, Val Macro F1: 0.3235

CNN Epoch 19/25


                                                                                               

Epoch 19/25, Time: 44.15s, Loss: 0.0093, Macro F1: 0.6670, Val Loss: 0.2421, Val Macro F1: 0.3082

CNN Epoch 20/25


                                                                                               

Epoch 20/25, Time: 42.79s, Loss: 0.0087, Macro F1: 0.6674, Val Loss: 0.2393, Val Macro F1: 0.3167

CNN Epoch 21/25


                                                                                               

Epoch 21/25, Time: 42.50s, Loss: 0.0083, Macro F1: 0.6688, Val Loss: 0.2536, Val Macro F1: 0.3102

CNN Epoch 22/25


                                                                                               

Epoch 22/25, Time: 44.38s, Loss: 0.0081, Macro F1: 0.6702, Val Loss: 0.2625, Val Macro F1: 0.3043

CNN Epoch 23/25


                                                                                               

Epoch 23/25, Time: 44.74s, Loss: 0.0081, Macro F1: 0.6710, Val Loss: 0.2563, Val Macro F1: 0.3087

CNN Epoch 24/25


                                                                                               

Epoch 24/25, Time: 44.27s, Loss: 0.0079, Macro F1: 0.6726, Val Loss: 0.2486, Val Macro F1: 0.3270

CNN Epoch 25/25


                                                                                               

Epoch 25/25, Time: 43.60s, Loss: 0.0066, Macro F1: 0.6764, Val Loss: 0.2587, Val Macro F1: 0.3030




In [42]:
resnet_sub_model = modify_model_for_subdiagnostic_resnet(resnet_super_model, num_classes_sub)
exclude_params_resnet = [w.name for w in resnet_sub_model.layers[-1].trainable_weights]
si_resnet = SI(resnet_sub_model, exclude_params=exclude_params_resnet)

In [43]:
lambda_si = 1.0  # Adjust as needed
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nResNet Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = resnet_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_resnet.penalty(resnet_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        grads = tape.gradient(total_loss, resnet_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, resnet_sub_model.trainable_variables))
        si_resnet.accumulate_importance(resnet_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = resnet_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

# After training, update omega
si_resnet.update_omega()



ResNet Epoch 1/25


                                                                                               

Epoch 1/25, Time: 279.33s, Loss: 7193.5806, Macro F1: 0.2916, Val Loss: 0.1096, Val Macro F1: 0.2913

ResNet Epoch 2/25


                                                                                               

Epoch 2/25, Time: 274.20s, Loss: 13.3913, Macro F1: 0.3589, Val Loss: 0.1095, Val Macro F1: 0.3044

ResNet Epoch 3/25


                                                                                               

Epoch 3/25, Time: 273.91s, Loss: 0.0810, Macro F1: 0.3878, Val Loss: 0.1154, Val Macro F1: 0.2950

ResNet Epoch 4/25


                                                                                               

Epoch 4/25, Time: 283.23s, Loss: 0.0731, Macro F1: 0.4198, Val Loss: 0.1261, Val Macro F1: 0.2878

ResNet Epoch 5/25


                                                                                               

Epoch 5/25, Time: 277.55s, Loss: 0.0688, Macro F1: 0.4509, Val Loss: 0.1499, Val Macro F1: 0.2708

ResNet Epoch 6/25


                                                                                               

Epoch 6/25, Time: 275.62s, Loss: 0.0635, Macro F1: 0.4745, Val Loss: 0.1609, Val Macro F1: 0.2913

ResNet Epoch 7/25


                                                                                               

Epoch 7/25, Time: 278.39s, Loss: 0.0581, Macro F1: 0.4884, Val Loss: 0.1640, Val Macro F1: 0.3019

ResNet Epoch 8/25


                                                                                               

Epoch 8/25, Time: 281.06s, Loss: 0.0533, Macro F1: 0.5099, Val Loss: 0.1720, Val Macro F1: 0.2923

ResNet Epoch 9/25


                                                                                               

Epoch 9/25, Time: 278.81s, Loss: 0.0477, Macro F1: 0.5325, Val Loss: 0.1768, Val Macro F1: 0.2930

ResNet Epoch 10/25


                                                                                               

Epoch 10/25, Time: 278.85s, Loss: 0.0444, Macro F1: 0.5519, Val Loss: 0.1874, Val Macro F1: 0.2964

ResNet Epoch 11/25


                                                                                               

Epoch 11/25, Time: 278.03s, Loss: 0.0367, Macro F1: 0.5704, Val Loss: 0.2016, Val Macro F1: 0.3108

ResNet Epoch 12/25


                                                                                               

Epoch 12/25, Time: 279.32s, Loss: 0.0345, Macro F1: 0.5914, Val Loss: 0.2183, Val Macro F1: 0.3104

ResNet Epoch 13/25


                                                                                               

Epoch 13/25, Time: 277.97s, Loss: 0.0295, Macro F1: 0.6061, Val Loss: 0.2140, Val Macro F1: 0.3348

ResNet Epoch 14/25


                                                                                               

Epoch 14/25, Time: 281.30s, Loss: 0.0263, Macro F1: 0.6170, Val Loss: 0.2223, Val Macro F1: 0.3100

ResNet Epoch 15/25


                                                                                               

Epoch 15/25, Time: 281.58s, Loss: 0.0285, Macro F1: 0.6191, Val Loss: 0.2552, Val Macro F1: 0.3060

ResNet Epoch 16/25


                                                                                               

Epoch 16/25, Time: 282.07s, Loss: 0.0236, Macro F1: 0.6303, Val Loss: 0.2271, Val Macro F1: 0.3129

ResNet Epoch 17/25


                                                                                               

Epoch 17/25, Time: 280.43s, Loss: 0.0220, Macro F1: 0.6428, Val Loss: 0.2589, Val Macro F1: 0.3120

ResNet Epoch 18/25


                                                                                               

Epoch 18/25, Time: 274.95s, Loss: 0.0204, Macro F1: 0.6482, Val Loss: 0.2492, Val Macro F1: 0.3086

ResNet Epoch 19/25


                                                                                               

Epoch 19/25, Time: 282.40s, Loss: 0.0165, Macro F1: 0.6526, Val Loss: 0.2838, Val Macro F1: 0.2924

ResNet Epoch 20/25


                                                                                               

Epoch 20/25, Time: 281.15s, Loss: 0.0161, Macro F1: 0.6581, Val Loss: 0.2404, Val Macro F1: 0.3147

ResNet Epoch 21/25


                                                                                               

Epoch 21/25, Time: 278.10s, Loss: 0.0168, Macro F1: 0.6604, Val Loss: 0.2666, Val Macro F1: 0.2995

ResNet Epoch 22/25


                                                                                               

Epoch 22/25, Time: 283.67s, Loss: 0.0121, Macro F1: 0.6645, Val Loss: 0.2697, Val Macro F1: 0.3116

ResNet Epoch 23/25


                                                                                               

Epoch 23/25, Time: 284.99s, Loss: 0.0139, Macro F1: 0.6699, Val Loss: 0.2679, Val Macro F1: 0.3224

ResNet Epoch 24/25


                                                                                               

Epoch 24/25, Time: 283.47s, Loss: 0.0129, Macro F1: 0.6671, Val Loss: 0.2934, Val Macro F1: 0.3125

ResNet Epoch 25/25


                                                                                               

Epoch 25/25, Time: 282.18s, Loss: 0.0138, Macro F1: 0.6695, Val Loss: 0.2792, Val Macro F1: 0.3138




In [1]:
vit_sub_model = modify_model_for_subdiagnostic_vit(vit_super_model, num_classes_sub)
exclude_params_vit = [w.name for w in vit_sub_model.layers[-1].trainable_weights]
si_vit = SI(vit_sub_model, exclude_params=exclude_params_vit)


NameError: name 'modify_model_for_subdiagnostic_vit' is not defined

In [45]:
lambda_si = 1.0  # Adjust as needed
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nViT Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = vit_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_vit.penalty(vit_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        grads = tape.gradient(total_loss, vit_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, vit_sub_model.trainable_variables))
        si_vit.accumulate_importance(vit_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = vit_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

# After training, update omega
si_vit.update_omega()



ViT Epoch 1/25


                                                                                          

Epoch 1/25, Time: 302.39s, Loss: nan, Macro F1: 0.0103, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 2/25


                                                                                         

Epoch 2/25, Time: 302.16s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 3/25


                                                                                         

Epoch 3/25, Time: 299.35s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 4/25


                                                                                         

Epoch 4/25, Time: 301.37s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 5/25


                                                                                         

Epoch 5/25, Time: 302.89s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 6/25


                                                                                         

Epoch 6/25, Time: 305.56s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 7/25


                                                                                         

Epoch 7/25, Time: 304.95s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 8/25


                                                                                         

Epoch 8/25, Time: 303.86s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 9/25


                                                                                         

Epoch 9/25, Time: 301.60s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 10/25


                                                                                         

Epoch 10/25, Time: 307.68s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 11/25


                                                                                         

Epoch 11/25, Time: 304.45s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 12/25


                                                                                         

Epoch 12/25, Time: 302.50s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 13/25


                                                                                         

Epoch 13/25, Time: 299.03s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 14/25


                                                                                         

Epoch 14/25, Time: 305.64s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 15/25


                                                                                         

Epoch 15/25, Time: 306.61s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 16/25


                                                                                         

Epoch 16/25, Time: 308.15s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 17/25


                                                                                         

Epoch 17/25, Time: 302.83s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 18/25


                                                                                         

Epoch 18/25, Time: 305.75s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 19/25


                                                                                         

Epoch 19/25, Time: 301.05s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 20/25


                                                                                         

Epoch 20/25, Time: 304.84s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 21/25


                                                                                         

Epoch 21/25, Time: 299.17s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 22/25


                                                                                         

Epoch 22/25, Time: 309.30s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 23/25


                                                                                         

Epoch 23/25, Time: 308.20s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 24/25


                                                                                         

Epoch 24/25, Time: 302.33s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000

ViT Epoch 25/25


                                                                                         

Epoch 25/25, Time: 306.48s, Loss: nan, Macro F1: 0.0000, Val Loss: nan, Val Macro F1: 0.0000




# Compiling and Publishing Results

In [46]:
print("CNN EWC Subdiagnostic Classification Report:")
cnn_ewc_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN EWC Superdiagnostic Classification Report:")
cnn_ewc_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet EWC Subdiagnostic Classification Report:")
resnet_ewc_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet EWC Superdiagnostic Classification Report:")
resnet_ewc_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT EWC Subdiagnostic Classification Report:")
vit_ewc_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT EWC Superdiagnostic Classification Report:")
vit_ewc_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN EWC Subdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step
              precision    recall  f1-score   support

         AMI       0.71      0.74      0.73       306
       CLBBB       0.93      0.78      0.85        54
       CRBBB       0.76      0.89      0.82        54
       ILBBB       0.25      0.12      0.17         8
         IMI       0.66      0.57      0.61       327
       IRBBB       0.72      0.44      0.54       112
        ISCA       0.44      0.31      0.36        93
        ISCI       0.67      0.10      0.17        40
        ISC_       0.67      0.55      0.61       128
        IVCD       0.19      0.10      0.13        79
   LAFB/LPFB       0.78      0.54      0.64       179
     LAO/LAE       0.29      0.05      0.08        42
         LMI       0.14      0.05      0.07        20
         LVH       0.72      0.51      0.60       214
        NORM       0.82      0.88      0.85       963
        NST_     

In [47]:
print("CNN SI Subdiagnostic Classification Report:")
cnn_si_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN SI Superdiagnostic Classification Report:")
cnn_si_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet SI Subdiagnostic Classification Report:")
resnet_si_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet SI Superdiagnostic Classification Report:")
resnet_si_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT SI Subdiagnostic Classification Report:")
vit_si_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT SI Superdiagnostic Classification Report:")
vit_si_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN SI Subdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
              precision    recall  f1-score   support

         AMI       0.71      0.74      0.73       306
       CLBBB       0.93      0.78      0.85        54
       CRBBB       0.76      0.89      0.82        54
       ILBBB       0.25      0.12      0.17         8
         IMI       0.66      0.57      0.61       327
       IRBBB       0.72      0.44      0.54       112
        ISCA       0.44      0.31      0.36        93
        ISCI       0.67      0.10      0.17        40
        ISC_       0.67      0.55      0.61       128
        IVCD       0.19      0.10      0.13        79
   LAFB/LPFB       0.78      0.54      0.64       179
     LAO/LAE       0.29      0.05      0.08        42
         LMI       0.14      0.05      0.07        20
         LVH       0.72      0.51      0.60       214
        NORM       0.82      0.88      0.85       963
        NST_      

In [48]:
def get_macro_f1(report_dict):
    return report_dict['macro avg']['f1-score']

results = {
    'Model': [],
    'Task': [],
    'Macro F1-score': []
}

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_super_report),
    get_macro_f1(resnet_super_report),
    get_macro_f1(vit_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_sub_report),
    get_macro_f1(resnet_sub_report),
    get_macro_f1(vit_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_sub_report),
    get_macro_f1(resnet_ewc_sub_report),
    get_macro_f1(vit_ewc_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_super_report),
    get_macro_f1(resnet_ewc_super_report),
    get_macro_f1(vit_ewc_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_sub_report),
    get_macro_f1(resnet_si_sub_report),
    get_macro_f1(vit_si_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_super_report),
    get_macro_f1(resnet_si_super_report),
    get_macro_f1(vit_si_super_report)
])

results_df = pd.DataFrame(results)
print("\nSummary of Classification Performance:")
print(results_df)



Summary of Classification Performance:
     Model                 Task  Macro F1-score
0      CNN      Superdiagnostic        0.734213
1   ResNet      Superdiagnostic        0.732228
2      ViT      Superdiagnostic        0.679261
3      CNN        Subdiagnostic        0.408998
4   ResNet        Subdiagnostic        0.392051
5      ViT        Subdiagnostic        0.275582
6      CNN    EWC Subdiagnostic        0.393380
7   ResNet    EWC Subdiagnostic        0.399722
8      ViT    EWC Subdiagnostic        0.000000
9      CNN  EWC Superdiagnostic        0.493559
10  ResNet  EWC Superdiagnostic        0.469332
11     ViT  EWC Superdiagnostic        0.000000
12     CNN     SI Subdiagnostic        0.393380
13  ResNet     SI Subdiagnostic        0.399722
14     ViT     SI Subdiagnostic        0.000000
15     CNN   SI Superdiagnostic        0.493559
16  ResNet   SI Superdiagnostic        0.469332
17     ViT   SI Superdiagnostic        0.000000
