# Importing requirements

In [1]:
import os
import numpy as np
import pandas as pd
import wfdb
import ast
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from collections import Counter
import time
from tqdm import tqdm


2024-12-03 17:20:13.482069: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733226613.502750  229141 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733226613.509061  229141 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-03 17:20:13.532552: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Loading Data from dataset file

In [3]:
DATA_PATH = 'ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3'

ptbxl_df = pd.read_csv(os.path.join(DATA_PATH, 'ptbxl_database.csv'))
scp_statements = pd.read_csv(os.path.join(DATA_PATH, 'scp_statements.csv'), index_col=0)

diagnostic_scps = scp_statements[scp_statements['diagnostic'] == 1].index.values

scp_to_superclass = scp_statements['diagnostic_class'].to_dict()
scp_to_subclass = scp_statements['diagnostic_subclass'].to_dict()

In [4]:
ptbxl_df['scp_codes'] = ptbxl_df['scp_codes'].apply(lambda x: ast.literal_eval(x))

In [5]:
def aggregate_diagnostic_labels(df, scp_codes, scp_to_agg):
    df = df.copy()
    def aggregate_labels(scp_codes_dict):
        labels = set()
        for code in scp_codes_dict.keys():
            if code in scp_codes:
                label = scp_to_agg.get(code)
                if label:
                    labels.add(label)
        return list(labels)
    df['diagnostic_labels'] = df['scp_codes'].apply(aggregate_labels)
    return df

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_superclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'superclass_labels'})

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_subclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'subclass_labels'})

In [6]:
ptbxl_df = ptbxl_df[ptbxl_df['superclass_labels'].map(len) > 0]

In [7]:
train_df = ptbxl_df[ptbxl_df.strat_fold <= 8]
val_df = ptbxl_df[ptbxl_df.strat_fold == 9]
test_df = ptbxl_df[ptbxl_df.strat_fold == 10]

In [8]:
def load_data(df, sampling_rate, data_path):
    data = []
    i = 0
    if sampling_rate == 100:
        filenames = df['filename_lr'].values
    else:
        filenames = df['filename_hr'].values
    for filename in filenames:
        file_path = os.path.join(data_path, filename)
        signals, _ = wfdb.rdsamp(file_path)
        data.append(signals)
    return np.array(data)

X_train = load_data(train_df, sampling_rate=100, data_path=DATA_PATH)
X_val = load_data(val_df, sampling_rate=100, data_path=DATA_PATH)
X_test = load_data(test_df, sampling_rate=100, data_path=DATA_PATH)

In [9]:
train_labels_super = train_df['superclass_labels'].values
val_labels_super = val_df['superclass_labels'].values
test_labels_super = test_df['superclass_labels'].values

mlb_super = MultiLabelBinarizer()
y_train_super = mlb_super.fit_transform(train_labels_super)
y_val_super = mlb_super.transform(val_labels_super)
y_test_super = mlb_super.transform(test_labels_super)
classes_super = mlb_super.classes_

In [10]:
train_labels_sub = train_df['subclass_labels'].values
val_labels_sub = val_df['subclass_labels'].values
test_labels_sub = test_df['subclass_labels'].values

mlb_sub = MultiLabelBinarizer()
y_train_sub = mlb_sub.fit_transform(train_labels_sub)
y_val_sub = mlb_sub.transform(val_labels_sub)
y_test_sub = mlb_sub.transform(test_labels_sub)
classes_sub = mlb_sub.classes_

In [11]:
def normalize_data_per_channel(X):
    X = np.transpose(X, (0, 2, 1))
    mean = np.mean(X, axis=(0, 2), keepdims=True)
    std = np.std(X, axis=(0, 2), keepdims=True)
    X = (X - mean) / std
    X = np.transpose(X, (0, 2, 1))
    return X

X_train = normalize_data_per_channel(X_train)
X_val = normalize_data_per_channel(X_val)
X_test = normalize_data_per_channel(X_test)

In [12]:
class_counts_super = np.sum(y_train_super, axis=0)
total_samples_super = y_train_super.shape[0]

class_weight_super = {}
for i, count in enumerate(class_counts_super):
    class_weight_super[i] = total_samples_super / (len(class_counts_super) * count)

class_counts_sub = np.sum(y_train_sub, axis=0)
total_samples_sub = y_train_sub.shape[0]

class_weight_sub = {}
for i, count in enumerate(class_counts_sub):
    class_weight_sub[i] = total_samples_sub / (len(class_counts_sub) * count)

In [13]:
num_classes_super = y_train_super.shape[1]
class_totals = np.sum(y_train_super, axis=0)
class_weights = class_totals.max() / class_totals
weights_array = np.array(class_weights, dtype=np.float32)

In [14]:
num_classes_sub = y_train_sub.shape[1]
class_totals_sub = np.sum(y_train_sub, axis=0)
class_weights_sub = class_totals_sub.max() / class_totals_sub
weights_array_sub = np.array(class_weights_sub, dtype=np.float32)

In [15]:
y_train_super = y_train_super.astype(np.float32)
y_val_super = y_val_super.astype(np.float32)
y_test_super = y_test_super.astype(np.float32)

# Defining Entropy and Metrics

In [16]:
import tensorflow.keras.backend as K

def weighted_binary_crossentropy(weights):
    def loss(y_true, y_pred):
        weights_cast = K.cast(weights, y_pred.dtype)
        y_true = K.cast(y_true, y_pred.dtype)
        
        bce = K.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * weights_cast + (1 - y_true)
        weighted_bce = weight_vector * bce
        return K.mean(weighted_bce)
    return loss

def macro_f1(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    y_pred = K.round(y_pred)
    
    tp = K.sum(y_true * y_pred, axis=0)
    fp = K.sum((1 - y_true) * y_pred, axis=0)
    fn = K.sum(y_true * (1 - y_pred), axis=0)

    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())

    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
    return K.mean(f1)

def weighted_f1(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    y_pred = K.round(y_pred)
    tp = K.sum(y_true * y_pred, axis=0)
    fp = K.sum((1 - y_true) * y_pred, axis=0)
    fn = K.sum(y_true * (1 - y_pred), axis=0)
    support = K.sum(y_true, axis=0)
    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())
    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    weighted_f1 = K.sum(f1 * support) / K.sum(support)
    weighted_f1 = tf.where(tf.math.is_nan(weighted_f1), 0.0, weighted_f1)
    
    return weighted_f1

# Defining Models

In [17]:
def create_cnn_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv1D(64, kernel_size=7, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(128, kernel_size=5, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(256, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(512, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    
    model = models.Model(inputs, outputs)
    return model


In [18]:
# def create_resnet_model(input_shape, num_classes):
#     inputs = layers.Input(shape=input_shape)
#     x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    
#     previous_filters = x.shape[-1]
#     for filters in [64, 128, 256]:
#         x_shortcut = x
#         strides = 1
#         if previous_filters != filters:
#             strides = 2

#         x = layers.Conv1D(filters, kernel_size=3, strides=strides, padding='same')(x)
#         x = layers.BatchNormalization()(x)
#         x = layers.Activation('relu')(x)
#         x = layers.Conv1D(filters, kernel_size=3, padding='same')(x)
#         x = layers.BatchNormalization()(x)
        
#         if previous_filters != filters or strides != 1:
#             x_shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(x_shortcut)
#             x_shortcut = layers.BatchNormalization()(x_shortcut)
        
#         x = layers.Add()([x, x_shortcut])
#         x = layers.Activation('relu')(x)
#         previous_filters = filters
#     x = layers.GlobalAveragePooling1D()(x)
#     outputs = layers.Dense(num_classes, activation='sigmoid')(x)
#     model = models.Model(inputs, outputs)
#     return model

In [19]:
def residual_block_1d(x, filters, kernel_size=3, strides=1, downsample=False):
    shortcut = x
    
    x = layers.Conv1D(filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv1D(filters, kernel_size=kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    if downsample or shortcut.shape[-1] != filters:
        shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def create_resnet_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    layers_filters = [64, 128, 256, 512]
    layers_blocks = [3, 4, 6, 3]

    for filters, num_blocks in zip(layers_filters, layers_blocks):
        for i in range(num_blocks):
            if i == 0 and filters != x.shape[-1]:
                x = residual_block_1d(x, filters, strides=2, downsample=True)
            else:
                x = residual_block_1d(x, filters)

    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs, outputs)
    return model

In [20]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def create_vit_model(input_shape, num_classes):
    patch_size = 10 
    num_patches = input_shape[0] // patch_size
    projection_dim = 64
    num_heads = 4
    transformer_layers = 8
    mlp_head_units = [256, 128]
    dropout_rate = 0.1

    inputs = layers.Input(shape=input_shape)
    x = layers.Reshape((num_patches, patch_size * input_shape[1]))(inputs)
    x = layers.Dense(units=projection_dim)(x)
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
    x = x + position_embedding(positions)
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=[projection_dim * 2, projection_dim], dropout_rate=dropout_rate)
        x = layers.Add()([x3, x2])
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Defining the training loop

In [21]:
def train_model(model, X_train, y_train, X_val, y_val, class_weight, batch_size=64, epochs=25):
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=['accuracy', macro_f1, weighted_f1]
    )
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    ]
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        class_weight=class_weight
    )
    return history

# Training and Evaluating Models without CL

In [22]:
input_shape = X_train.shape[1:]
num_classes_super = y_train_super.shape[1]

cnn_super_model = create_cnn_model(input_shape, num_classes_super)
train_model(cnn_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

I0000 00:00:1733226681.640433  229141 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31139 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:06:00.0, compute capability: 7.0
I0000 00:00:1733226681.665186  229141 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31139 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:07:00.0, compute capability: 7.0
I0000 00:00:1733226681.666585  229141 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 31139 MB memory:  -> device: 2, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0a:00.0, compute capability: 7.0
I0000 00:00:1733226681.667870  229141 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 31139 MB memory:  -> device: 3, name: Tesla V100-SXM2-32GB, pci bus id: 0000:0b:00.0, compute capability: 7.0
I0000 00:00:1733226681.669243  229141 gpu_device.cc:2022] Created de

Epoch 1/25


I0000 00:00:1733226688.318163  229912 service.cc:148] XLA service 0x7fbb70005b90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733226688.318211  229912 service.cc:156]   StreamExecutor device (0): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318217  229912 service.cc:156]   StreamExecutor device (1): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318222  229912 service.cc:156]   StreamExecutor device (2): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318226  229912 service.cc:156]   StreamExecutor device (3): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318230  229912 service.cc:156]   StreamExecutor device (4): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318233  229912 service.cc:156]   StreamExecutor device (5): Tesla V100-SXM2-32GB, Compute Capability 7.0
I0000 00:00:1733226688.318238  229912 service.cc:156]   StreamE

[1m 10/267[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m4s[0m 18ms/step - accuracy: 0.3824 - loss: 0.4175 - macro_f1: 0.3947 - weighted_f1: 0.4377

I0000 00:00:1733226693.052912  229912 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m265/267[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 18ms/step - accuracy: 0.6017 - loss: 0.3053 - macro_f1: 0.6144 - weighted_f1: 0.6526

E0000 00:00:1733226698.917344  229914 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1733226699.210659  229914 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 45ms/step - accuracy: 0.6022 - loss: 0.3050 - macro_f1: 0.6149 - weighted_f1: 0.6531 - val_accuracy: 0.6622 - val_loss: 0.3488 - val_macro_f1: 0.6394 - val_weighted_f1: 0.6855 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 20ms/step - accuracy: 0.6804 - loss: 0.2411 - macro_f1: 0.7122 - weighted_f1: 0.7431 - val_accuracy: 0.6962 - val_loss: 0.3128 - val_macro_f1: 0.6999 - val_weighted_f1: 0.7385 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.7018 - loss: 0.2242 - macro_f1: 0.7296 - weighted_f1: 0.7600 - val_accuracy: 0.6580 - val_loss: 0.3201 - val_macro_f1: 0.6989 - val_weighted_f1: 0.7397 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.7111 - loss: 0.2140 - macro_f1: 0.7433 - weighted_f1: 0.7741 - val_accuracy: 0

<keras.src.callbacks.history.History at 0x7fc02fe18b80>

In [23]:
resnet_super_model = create_resnet_model(input_shape, num_classes_super)
train_model(resnet_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/25
[1m265/267[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 32ms/step - accuracy: 0.5555 - loss: 0.3932 - macro_f1: 0.5565 - weighted_f1: 0.5875

E0000 00:00:1733226865.964603  229914 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1733226866.192433  229914 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 106ms/step - accuracy: 0.5563 - loss: 0.3922 - macro_f1: 0.5573 - weighted_f1: 0.5884 - val_accuracy: 0.5955 - val_loss: 1.0710 - val_macro_f1: 0.5333 - val_weighted_f1: 0.5829 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.6751 - loss: 0.2557 - macro_f1: 0.6908 - weighted_f1: 0.7241 - val_accuracy: 0.5144 - val_loss: 0.5218 - val_macro_f1: 0.5693 - val_weighted_f1: 0.6029 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.7004 - loss: 0.2381 - macro_f1: 0.7135 - weighted_f1: 0.7466 - val_accuracy: 0.6482 - val_loss: 0.4249 - val_macro_f1: 0.6234 - val_weighted_f1: 0.6697 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 35ms/step - accuracy: 0.7054 - loss: 0.2240 - macro_f1: 0.7302 - weighted_f1: 0.7649 - val_accuracy: 

<keras.src.callbacks.history.History at 0x7fbe7860a940>

In [24]:
vit_super_model = create_vit_model(input_shape, num_classes_super)
train_model(vit_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 118ms/step - accuracy: 0.4420 - loss: 0.4268 - macro_f1: 0.4150 - weighted_f1: 0.4594 - val_accuracy: 0.5107 - val_loss: 0.4518 - val_macro_f1: 0.5193 - val_weighted_f1: 0.5729 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 24ms/step - accuracy: 0.6263 - loss: 0.2935 - macro_f1: 0.6247 - weighted_f1: 0.6632 - val_accuracy: 0.5564 - val_loss: 0.3822 - val_macro_f1: 0.6193 - val_weighted_f1: 0.6567 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.6677 - loss: 0.2576 - macro_f1: 0.6801 - weighted_f1: 0.7169 - val_accuracy: 0.6440 - val_loss: 0.3314 - val_macro_f1: 0.6637 - val_weighted_f1: 0.7059 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 26ms/step - accuracy: 0.6913 - loss: 0.2321 - macro_f1: 0.7136 - weighted_f1: 0.7476 - val

<keras.src.callbacks.history.History at 0x7fbd8c253e50>

In [25]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = create_cnn_model(input_shape, num_classes_sub)
train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 38ms/step - accuracy: 0.3823 - loss: 0.1309 - macro_f1: 0.1113 - weighted_f1: 0.2219 - val_accuracy: 0.3621 - val_loss: 0.1663 - val_macro_f1: 0.1498 - val_weighted_f1: 0.2451 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.4715 - loss: 0.0854 - macro_f1: 0.2041 - weighted_f1: 0.3693 - val_accuracy: 0.4627 - val_loss: 0.1446 - val_macro_f1: 0.2073 - val_weighted_f1: 0.4047 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.5154 - loss: 0.0739 - macro_f1: 0.2442 - weighted_f1: 0.4454 - val_accuracy: 0.2973 - val_loss: 0.1852 - val_macro_f1: 0.1557 - val_weighted_f1: 0.2704 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.5374 - loss: 0.0690 - macro_f1: 0.2696 - weighted_f1: 0.4851 - val_

<keras.src.callbacks.history.History at 0x7fbc5808e9a0>

In [26]:
resnet_sub_model = create_resnet_model(input_shape, num_classes_sub)
train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 91ms/step - accuracy: 0.2740 - loss: 0.1465 - macro_f1: 0.0484 - weighted_f1: 0.0839 - val_accuracy: 0.2050 - val_loss: 0.1911 - val_macro_f1: 0.0555 - val_weighted_f1: 0.0781 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.4464 - loss: 0.0970 - macro_f1: 0.1237 - weighted_f1: 0.2369 - val_accuracy: 0.2768 - val_loss: 0.2235 - val_macro_f1: 0.1136 - val_weighted_f1: 0.1117 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5034 - loss: 0.0870 - macro_f1: 0.1644 - weighted_f1: 0.3132 - val_accuracy: 0.4581 - val_loss: 0.1786 - val_macro_f1: 0.1364 - val_weighted_f1: 0.1981 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 35ms/step - accuracy: 0.5026 - loss: 0.0811 - macro_f1: 0.1996 - weighted_f1: 0.3679 - val_

<keras.src.callbacks.history.History at 0x7fbc145fe640>

In [27]:
vit_sub_model = create_vit_model(input_shape, num_classes_sub)
train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 115ms/step - accuracy: 0.1579 - loss: 0.1527 - macro_f1: 0.0517 - weighted_f1: 0.0853 - val_accuracy: 0.4245 - val_loss: 0.1650 - val_macro_f1: 0.0595 - val_weighted_f1: 0.2446 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.3711 - loss: 0.0892 - macro_f1: 0.1295 - weighted_f1: 0.2427 - val_accuracy: 0.4441 - val_loss: 0.1643 - val_macro_f1: 0.1521 - val_weighted_f1: 0.3687 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.4623 - loss: 0.0721 - macro_f1: 0.2234 - weighted_f1: 0.3966 - val_accuracy: 0.4399 - val_loss: 0.1532 - val_macro_f1: 0.1934 - val_weighted_f1: 0.4167 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.5217 - loss: 0.0584 - macro_f1: 0.2868 - weighted_f1: 0.4818 - val

<keras.src.callbacks.history.History at 0x7fbb7f5c9730>

In [28]:
def evaluate_model(model, X_test, y_test, classes):
    y_pred = model.predict(X_test)
    y_pred_threshold = (y_pred >= 0.5).astype(int)
    report = classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0, output_dict=True)
    print(classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0))
    return report


In [29]:
print("CNN Superdiagnostic Classification Report:")
cnn_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet Superdiagnostic Classification Report:")
resnet_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT Superdiagnostic Classification Report:")
vit_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN Superdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 11ms/step
              precision    recall  f1-score   support

          CD       0.82      0.72      0.77       496
         HYP       0.69      0.52      0.59       262
          MI       0.82      0.69      0.75       550
        NORM       0.84      0.89      0.87       963
        STTC       0.76      0.75      0.75       521

   micro avg       0.81      0.76      0.78      2792
   macro avg       0.79      0.71      0.75      2792
weighted avg       0.80      0.76      0.78      2792
 samples avg       0.79      0.78      0.77      2792

ResNet Superdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 45ms/step
              precision    recall  f1-score   support

          CD       0.81      0.69      0.75       496
         HYP       0.76      0.48      0.59       262
          MI       0.78      0.67      0.72       550
   

In [30]:
print("CNN Subdiagnostic Classification Report:")
cnn_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet Subdiagnostic Classification Report:")
resnet_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ViT Subdiagnostic Classification Report:")
vit_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)


CNN Subdiagnostic Classification Report:


[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 10ms/step
              precision    recall  f1-score   support

         AMI       0.91      0.43      0.59       306
       CLBBB       0.88      0.96      0.92        54
       CRBBB       0.80      0.91      0.85        54
       ILBBB       0.00      0.00      0.00         8
         IMI       0.76      0.47      0.58       327
       IRBBB       0.61      0.67      0.64       112
        ISCA       0.41      0.26      0.32        93
        ISCI       0.34      0.28      0.31        40
        ISC_       0.73      0.48      0.58       128
        IVCD       0.12      0.03      0.04        79
   LAFB/LPFB       0.66      0.82      0.73       179
     LAO/LAE       0.20      0.07      0.11        42
         LMI       0.10      0.05      0.07        20
         LVH       0.72      0.57      0.64       214
        NORM       0.87      0.77      0.82       963
        NST_       0.22      0.17      0.19        77
       

# Defining and Training on LwF

In [31]:
cnn_soft_targets_super = cnn_super_model.predict(X_train)

def lwf_loss(y_true, y_pred, old_predictions, T=2):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dist_loss = tf.keras.losses.KLDivergence()(tf.nn.softmax(old_predictions / T),
                                               tf.nn.softmax(y_pred / T))
    total_loss = task_loss + dist_loss
    return total_loss

print("Working on CNN for LwF Now:")
cnn_model_lwf = create_cnn_model(input_shape, num_classes_sub)
cnn_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=cnn_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(cnn_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ResNet for LwF Now:")
resnet_soft_targets_super = resnet_super_model.predict(X_train)
resnet_model_lwf = create_resnet_model(input_shape, num_classes_sub)
resnet_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=resnet_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(resnet_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ViT for LwF Now:")
vit_soft_targets_super = vit_super_model.predict(X_train)
vit_model_lwf = create_vit_model(input_shape, num_classes_sub)
vit_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=vit_soft_targets_super),
    metrics=[macro_f1, weighted_f1]
)
train_model(vit_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


[1m534/534[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step
Working on CNN for LwF Now:
Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 37ms/step - accuracy: 0.3448 - loss: 0.1362 - macro_f1: 0.1074 - weighted_f1: 0.2047 - val_accuracy: 0.4301 - val_loss: 0.1600 - val_macro_f1: 0.1428 - val_weighted_f1: 0.3176 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.4804 - loss: 0.0869 - macro_f1: 0.2016 - weighted_f1: 0.3819 - val_accuracy: 0.4548 - val_loss: 0.1532 - val_macro_f1: 0.1726 - val_weighted_f1: 0.2904 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.5148 - loss: 0.0732 - macro_f1: 0.2403 - weighted_f1: 0.4333 - val_accuracy: 0.4804 - val_loss: 0.1404 - val_macro_f1: 0.2085 - val_weighted_f1: 0.3560 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m

<keras.src.callbacks.history.History at 0x7fbb7f13be50>

# Defining and Training on EwC

In [32]:
class EWC:
    def __init__(self, model, X, y, batch_size=32, exclude_params=[]):
        self.model = model
        self.params = {}
        for p in model.trainable_variables:
            if id(p) not in exclude_params:
                self.params[id(p)] = p.numpy()
        self.fisher = self.compute_fisher(X, y, batch_size, exclude_params)

    def compute_fisher(self, X, y, batch_size, exclude_params):
        fisher = {}
        num_samples = X.shape[0]
        num_batches = int(np.ceil(num_samples / batch_size))

        for batch_idx in range(num_batches):
            X_batch = X[batch_idx*batch_size:(batch_idx+1)*batch_size]
            y_batch = y[batch_idx*batch_size:(batch_idx+1)*batch_size]
            with tf.GradientTape() as tape:
                preds = self.model(X_batch)
                loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            grads = tape.gradient(loss, self.model.trainable_variables)
            for p, g in zip(self.model.trainable_variables, grads):
                if g is not None and id(p) not in exclude_params:
                    param_id = id(p)
                    if param_id not in fisher:
                        fisher[param_id] = np.square(g.numpy())
                    else:
                        fisher[param_id] += np.square(g.numpy())
        for k in fisher.keys():
            fisher[k] /= num_batches
        return fisher

    def penalty(self, model):
        loss = 0
        for p in model.trainable_variables:
            param_id = id(p)
            if param_id in self.fisher:
                fisher = tf.convert_to_tensor(self.fisher[param_id])
                loss += tf.reduce_sum(fisher * tf.square(p - self.params[param_id]))
        return loss

In [33]:
def modify_model_for_subdiagnostic(base_model, num_classes_sub):
    inputs = base_model.input
    x = inputs
    for layer in base_model.layers[1:-1]:
        x = layer(x)
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    new_model = models.Model(inputs=inputs, outputs=outputs)
    return new_model

In [34]:
lambda_ewc = 1000
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)
exclude_params_cnn = [id(w) for w in cnn_sub_model.layers[-1].trainable_weights]
ewc_cnn = EWC(cnn_super_model, X_train, y_train_super, exclude_params=exclude_params_cnn)

def ewc_loss_cnn(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_cnn.penalty(cnn_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

cnn_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_cnn,
    metrics=[macro_f1, weighted_f1]
)

train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 37ms/step - accuracy: 0.5679 - loss: 0.1121 - macro_f1: 0.2046 - weighted_f1: 0.4335 - val_accuracy: 0.5746 - val_loss: 0.1171 - val_macro_f1: 0.2678 - val_weighted_f1: 0.5360 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6148 - loss: 0.0629 - macro_f1: 0.3285 - weighted_f1: 0.6024 - val_accuracy: 0.5177 - val_loss: 0.1266 - val_macro_f1: 0.2672 - val_weighted_f1: 0.5130 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6330 - loss: 0.0510 - macro_f1: 0.3561 - weighted_f1: 0.6332 - val_accuracy: 0.6109 - val_loss: 0.1137 - val_macro_f1: 0.2776 - val_weighted_f1: 0.5691 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 18ms/step - accuracy: 0.6496 - loss: 0.0440 - macro_f1: 0.3821 - weighted_f1: 0.6697 - val_

<keras.src.callbacks.history.History at 0x7fb2c4312f40>

In [35]:
def modify_model_for_subdiagnostic_resnet(base_model, num_classes_sub):
    x = base_model.layers[-2].output
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    new_model = tf.keras.Model(inputs=base_model.input, outputs=outputs)
    return new_model

In [36]:
num_classes_sub = y_train_sub.shape[1]
resnet_sub_model = modify_model_for_subdiagnostic_resnet(resnet_super_model, num_classes_sub)
exclude_params_resnet = [w.name for w in resnet_sub_model.layers[-1].trainable_weights]
def ewc_loss_resnet(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_resnet.penalty(resnet_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

resnet_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_resnet,
    metrics=[macro_f1, weighted_f1]
)

train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 90ms/step - accuracy: 0.4394 - loss: 0.1169 - macro_f1: 0.1501 - weighted_f1: 0.3199 - val_accuracy: 0.3164 - val_loss: 0.1830 - val_macro_f1: 0.1394 - val_weighted_f1: 0.2305 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5250 - loss: 0.0829 - macro_f1: 0.2184 - weighted_f1: 0.4229 - val_accuracy: 0.2968 - val_loss: 0.2285 - val_macro_f1: 0.1674 - val_weighted_f1: 0.3298 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5465 - loss: 0.0694 - macro_f1: 0.2528 - weighted_f1: 0.4821 - val_accuracy: 0.4082 - val_loss: 0.1748 - val_macro_f1: 0.1672 - val_weighted_f1: 0.3434 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 34ms/step - accuracy: 0.5612 - loss: 0.0662 - macro_f1: 0.2729 - weighted_f1: 0.5182 - val_

<keras.src.callbacks.history.History at 0x7fbb77744400>

In [37]:
def modify_model_for_subdiagnostic_vit(base_model, num_classes_sub):
    # Get the output of the layer before the last (excluding the superdiagnostic output layer)
    x = base_model.layers[-2].output
    # Add new output layer for subdiagnostic task
    outputs = tf.keras.layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    # Create new model
    new_model = tf.keras.Model(inputs=base_model.input, outputs=outputs)
    return new_model


In [38]:
# Modify ViT model for subdiagnostic task
num_classes_sub = y_train_sub.shape[1]
vit_sub_model = modify_model_for_subdiagnostic_vit(vit_super_model, num_classes_sub)

# Exclude the new output layer's parameters from EWC or SI calculations
exclude_params_vit = [w.name for w in vit_sub_model.layers[-1].trainable_weights]

def ewc_loss_vit(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_vit.penalty(vit_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

vit_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_vit,
    metrics=[macro_f1, weighted_f1]
)

train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 113ms/step - accuracy: 0.4333 - loss: 0.1121 - macro_f1: 0.1332 - weighted_f1: 0.2994 - val_accuracy: 0.5103 - val_loss: 0.1520 - val_macro_f1: 0.1501 - val_weighted_f1: 0.3355 - learning_rate: 0.0010
Epoch 2/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.5326 - loss: 0.0738 - macro_f1: 0.2166 - weighted_f1: 0.4266 - val_accuracy: 0.5666 - val_loss: 0.1283 - val_macro_f1: 0.2160 - val_weighted_f1: 0.5024 - learning_rate: 0.0010
Epoch 3/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.5640 - loss: 0.0587 - macro_f1: 0.2874 - weighted_f1: 0.5220 - val_accuracy: 0.5363 - val_loss: 0.1327 - val_macro_f1: 0.2443 - val_weighted_f1: 0.5239 - learning_rate: 0.0010
Epoch 4/25
[1m267/267[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 24ms/step - accuracy: 0.5951 - loss: 0.0485 - macro_f1: 0.3420 - weighted_f1: 0.5785 - val

<keras.src.callbacks.history.History at 0x7fb2547aa670>

# Defining and Training on SI

In [39]:
class SI:
    def __init__(self, prev_model, damping_factor=0.1, exclude_params=[]):
        self.prev_params = {}
        self.omega = {}
        self.damping_factor = damping_factor
        self.exclude_params = exclude_params

        self.delta_params = {}

        # Store parameters from the previous model (superdiagnostic task)
        for var in prev_model.trainable_variables:
            if var.name not in self.exclude_params:
                self.prev_params[var.name] = var.numpy().copy()
                self.omega[var.name] = np.zeros_like(var.numpy())
                self.delta_params[var.name] = np.zeros_like(var.numpy())

    def accumulate_importance(self, model, grads):
        for var, grad in zip(model.trainable_variables, grads):
            if grad is not None and var.name in self.prev_params:
                if var.shape == self.prev_params[var.name].shape:
                    delta_theta = var.numpy() - self.prev_params[var.name]
                    self.delta_params[var.name] += delta_theta
                    # Update omega with absolute value to prevent negative importance
                    self.omega[var.name] += np.abs(grad.numpy() * delta_theta)
                else:
                    # Skip variables with mismatched shapes
                    pass

    def update_omega(self):
        # Normalize omega after training
        epsilon = 1e-8  # Small value to prevent division by zero
        for var_name in self.omega.keys():
            delta_param = self.delta_params[var_name]
            denom = np.square(delta_param) + self.damping_factor + epsilon
            self.omega[var_name] = np.divide(self.omega[var_name], denom)
            # Ensure omega is non-negative
            self.omega[var_name] = np.abs(self.omega[var_name])
            # Reset delta_params for the next task
            self.delta_params[var_name] = np.zeros_like(delta_param)

    def penalty(self, model):
        loss = 0
        for var in model.trainable_variables:
            if var.name in self.prev_params:
                prev_param = self.prev_params[var.name]
                if var.shape == prev_param.shape:
                    omega = tf.convert_to_tensor(self.omega[var.name], dtype=var.dtype)
                    prev_param = tf.convert_to_tensor(prev_param, dtype=var.dtype)
                    # Ensure omega is non-negative
                    loss += tf.reduce_sum(omega * tf.square(var - prev_param))
                else:
                    # Skip variables with mismatched shapes
                    pass
        return loss


In [40]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)
exclude_params_cnn = [w.name for w in cnn_sub_model.layers[-1].trainable_weights]
si_cnn = SI(cnn_super_model, exclude_params=exclude_params_cnn)


In [41]:
lambda_si = 1.0  # Adjust as needed
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nCNN Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = cnn_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_cnn.penalty(cnn_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        grads = tape.gradient(total_loss, cnn_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, cnn_sub_model.trainable_variables))
        si_cnn.accumulate_importance(cnn_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = cnn_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

# After training, update omega
si_cnn.update_omega()



CNN Epoch 1/25


Training:   0%|          | 0/266 [00:00<?, ?it/s]

                                                                                               

Epoch 1/25, Time: 45.29s, Loss: 0.1025, Macro F1: 0.3055, Val Loss: 0.1024, Val Macro F1: 0.3124

CNN Epoch 2/25


                                                                                               

Epoch 2/25, Time: 44.16s, Loss: 0.0770, Macro F1: 0.3730, Val Loss: 0.1024, Val Macro F1: 0.3233

CNN Epoch 3/25


                                                                                               

Epoch 3/25, Time: 43.35s, Loss: 0.0686, Macro F1: 0.4091, Val Loss: 0.1055, Val Macro F1: 0.3262

CNN Epoch 4/25


                                                                                               

Epoch 4/25, Time: 43.37s, Loss: 0.0605, Macro F1: 0.4425, Val Loss: 0.1109, Val Macro F1: 0.3247

CNN Epoch 5/25


                                                                                               

Epoch 5/25, Time: 42.83s, Loss: 0.0524, Macro F1: 0.4803, Val Loss: 0.1174, Val Macro F1: 0.3372

CNN Epoch 6/25


                                                                                               

Epoch 6/25, Time: 42.62s, Loss: 0.0442, Macro F1: 0.5216, Val Loss: 0.1284, Val Macro F1: 0.3261

CNN Epoch 7/25


                                                                                               

Epoch 7/25, Time: 44.37s, Loss: 0.0372, Macro F1: 0.5504, Val Loss: 0.1409, Val Macro F1: 0.3155

CNN Epoch 8/25


                                                                                               

Epoch 8/25, Time: 43.53s, Loss: 0.0311, Macro F1: 0.5802, Val Loss: 0.1576, Val Macro F1: 0.3137

CNN Epoch 9/25


                                                                                               

Epoch 9/25, Time: 44.57s, Loss: 0.0265, Macro F1: 0.5965, Val Loss: 0.1719, Val Macro F1: 0.3315

CNN Epoch 10/25


                                                                                               

Epoch 10/25, Time: 44.70s, Loss: 0.0244, Macro F1: 0.6056, Val Loss: 0.1836, Val Macro F1: 0.3153

CNN Epoch 11/25


                                                                                               

Epoch 11/25, Time: 44.02s, Loss: 0.0225, Macro F1: 0.6163, Val Loss: 0.1845, Val Macro F1: 0.3157

CNN Epoch 12/25


                                                                                               

Epoch 12/25, Time: 44.42s, Loss: 0.0186, Macro F1: 0.6331, Val Loss: 0.2024, Val Macro F1: 0.3231

CNN Epoch 13/25


                                                                                               

Epoch 13/25, Time: 42.91s, Loss: 0.0158, Macro F1: 0.6415, Val Loss: 0.2137, Val Macro F1: 0.3103

CNN Epoch 14/25


                                                                                               

Epoch 14/25, Time: 42.68s, Loss: 0.0134, Macro F1: 0.6534, Val Loss: 0.2315, Val Macro F1: 0.3162

CNN Epoch 15/25


                                                                                               

Epoch 15/25, Time: 42.31s, Loss: 0.0119, Macro F1: 0.6575, Val Loss: 0.2281, Val Macro F1: 0.3393

CNN Epoch 16/25


                                                                                               

Epoch 16/25, Time: 44.08s, Loss: 0.0112, Macro F1: 0.6612, Val Loss: 0.2283, Val Macro F1: 0.3305

CNN Epoch 17/25


                                                                                               

Epoch 17/25, Time: 44.49s, Loss: 0.0100, Macro F1: 0.6659, Val Loss: 0.2254, Val Macro F1: 0.3274

CNN Epoch 18/25


                                                                                               

Epoch 18/25, Time: 43.18s, Loss: 0.0086, Macro F1: 0.6672, Val Loss: 0.2291, Val Macro F1: 0.3200

CNN Epoch 19/25


                                                                                               

Epoch 19/25, Time: 44.13s, Loss: 0.0076, Macro F1: 0.6717, Val Loss: 0.2433, Val Macro F1: 0.3177

CNN Epoch 20/25


                                                                                               

Epoch 20/25, Time: 44.14s, Loss: 0.0074, Macro F1: 0.6714, Val Loss: 0.2498, Val Macro F1: 0.3243

CNN Epoch 21/25


                                                                                               

Epoch 21/25, Time: 44.65s, Loss: 0.0077, Macro F1: 0.6727, Val Loss: 0.2558, Val Macro F1: 0.3141

CNN Epoch 22/25


                                                                                               

Epoch 22/25, Time: 43.41s, Loss: 0.0076, Macro F1: 0.6695, Val Loss: 0.2673, Val Macro F1: 0.3185

CNN Epoch 23/25


                                                                                               

Epoch 23/25, Time: 43.88s, Loss: 0.0067, Macro F1: 0.6720, Val Loss: 0.2693, Val Macro F1: 0.3285

CNN Epoch 24/25


                                                                                               

Epoch 24/25, Time: 43.85s, Loss: 0.0060, Macro F1: 0.6758, Val Loss: 0.2712, Val Macro F1: 0.3235

CNN Epoch 25/25


                                                                                               

Epoch 25/25, Time: 43.76s, Loss: 0.0052, Macro F1: 0.6784, Val Loss: 0.2743, Val Macro F1: 0.3125




In [42]:
resnet_sub_model = modify_model_for_subdiagnostic_resnet(resnet_super_model, num_classes_sub)
exclude_params_resnet = [w.name for w in resnet_sub_model.layers[-1].trainable_weights]
si_resnet = SI(resnet_sub_model, exclude_params=exclude_params_resnet)

In [43]:
lambda_si = 1.0  # Adjust as needed
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nResNet Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = resnet_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_resnet.penalty(resnet_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        grads = tape.gradient(total_loss, resnet_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, resnet_sub_model.trainable_variables))
        si_resnet.accumulate_importance(resnet_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = resnet_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

# After training, update omega
si_resnet.update_omega()



ResNet Epoch 1/25


Training:   0%|          | 0/266 [00:00<?, ?it/s]

                                                                                               

Epoch 1/25, Time: 283.25s, Loss: 4.3105, Macro F1: 0.2858, Val Loss: 0.1093, Val Macro F1: 0.2856

ResNet Epoch 2/25


                                                                                               

Epoch 2/25, Time: 278.60s, Loss: 0.0813, Macro F1: 0.3610, Val Loss: 0.1118, Val Macro F1: 0.2841

ResNet Epoch 3/25


                                                                                               

Epoch 3/25, Time: 270.54s, Loss: 0.0739, Macro F1: 0.3958, Val Loss: 0.1174, Val Macro F1: 0.2779

ResNet Epoch 4/25


                                                                                               

Epoch 4/25, Time: 273.52s, Loss: 0.0670, Macro F1: 0.4308, Val Loss: 0.1266, Val Macro F1: 0.2795

ResNet Epoch 5/25


                                                                                               

Epoch 5/25, Time: 279.14s, Loss: 0.0598, Macro F1: 0.4684, Val Loss: 0.1411, Val Macro F1: 0.2667

ResNet Epoch 6/25


                                                                                               

Epoch 6/25, Time: 278.15s, Loss: 0.0522, Macro F1: 0.4969, Val Loss: 0.1623, Val Macro F1: 0.2815

ResNet Epoch 7/25


                                                                                               

Epoch 7/25, Time: 275.27s, Loss: 0.0481, Macro F1: 0.5152, Val Loss: 0.1821, Val Macro F1: 0.2813

ResNet Epoch 8/25


                                                                                               

Epoch 8/25, Time: 278.12s, Loss: 0.0450, Macro F1: 0.5254, Val Loss: 0.1820, Val Macro F1: 0.2791

ResNet Epoch 9/25


                                                                                               

Epoch 9/25, Time: 281.70s, Loss: 0.0396, Macro F1: 0.5518, Val Loss: 0.1871, Val Macro F1: 0.2870

ResNet Epoch 10/25


                                                                                               

Epoch 10/25, Time: 282.47s, Loss: 0.0361, Macro F1: 0.5648, Val Loss: 0.1907, Val Macro F1: 0.2967

ResNet Epoch 11/25


                                                                                               

Epoch 11/25, Time: 274.01s, Loss: 0.0314, Macro F1: 0.5818, Val Loss: 0.1931, Val Macro F1: 0.2920

ResNet Epoch 12/25


                                                                                               

Epoch 12/25, Time: 273.64s, Loss: 0.0272, Macro F1: 0.5979, Val Loss: 0.2154, Val Macro F1: 0.2910

ResNet Epoch 13/25


                                                                                               

Epoch 13/25, Time: 276.99s, Loss: 0.0243, Macro F1: 0.6129, Val Loss: 0.2171, Val Macro F1: 0.2870

ResNet Epoch 14/25


                                                                                               

Epoch 14/25, Time: 282.88s, Loss: 0.0208, Macro F1: 0.6255, Val Loss: 0.2296, Val Macro F1: 0.2685

ResNet Epoch 15/25


                                                                                               

Epoch 15/25, Time: 283.31s, Loss: 0.0178, Macro F1: 0.6342, Val Loss: 0.2359, Val Macro F1: 0.2783

ResNet Epoch 16/25


                                                                                               

Epoch 16/25, Time: 284.08s, Loss: 0.0157, Macro F1: 0.6394, Val Loss: 0.2501, Val Macro F1: 0.2836

ResNet Epoch 17/25


                                                                                               

Epoch 17/25, Time: 282.36s, Loss: 0.0142, Macro F1: 0.6500, Val Loss: 0.2525, Val Macro F1: 0.2992

ResNet Epoch 18/25


                                                                                               

Epoch 18/25, Time: 282.05s, Loss: 0.0122, Macro F1: 0.6554, Val Loss: 0.2525, Val Macro F1: 0.2999

ResNet Epoch 19/25


                                                                                               

Epoch 19/25, Time: 276.58s, Loss: 0.0119, Macro F1: 0.6537, Val Loss: 0.2505, Val Macro F1: 0.3067

ResNet Epoch 20/25


                                                                                               

Epoch 20/25, Time: 276.24s, Loss: 0.0103, Macro F1: 0.6607, Val Loss: 0.2809, Val Macro F1: 0.3029

ResNet Epoch 21/25


                                                                                               

Epoch 21/25, Time: 279.39s, Loss: 0.0094, Macro F1: 0.6638, Val Loss: 0.2875, Val Macro F1: 0.3005

ResNet Epoch 22/25


                                                                                               

Epoch 22/25, Time: 280.00s, Loss: 0.0090, Macro F1: 0.6645, Val Loss: 0.2935, Val Macro F1: 0.3106

ResNet Epoch 23/25


                                                                                               

Epoch 23/25, Time: 280.90s, Loss: 0.0083, Macro F1: 0.6672, Val Loss: 0.2733, Val Macro F1: 0.3048

ResNet Epoch 24/25


                                                                                               

Epoch 24/25, Time: 280.47s, Loss: 0.0075, Macro F1: 0.6715, Val Loss: 0.3031, Val Macro F1: 0.2989

ResNet Epoch 25/25


                                                                                               

Epoch 25/25, Time: 283.19s, Loss: 0.0072, Macro F1: 0.6721, Val Loss: 0.3232, Val Macro F1: 0.2863




In [44]:
vit_sub_model = modify_model_for_subdiagnostic_vit(vit_super_model, num_classes_sub)
exclude_params_vit = [w.name for w in vit_sub_model.layers[-1].trainable_weights]
si_vit = SI(vit_super_model, exclude_params=exclude_params_vit)

In [45]:
lambda_si = 1
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nViT Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()

    num_batches = len(X_train) // batch_size
    progress_bar = tqdm(range(num_batches), desc='Training', leave=False)

    for step in progress_bar:
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]

        with tf.GradientTape() as tape:
            preds = vit_sub_model(X_batch, training=True)
            task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            si_penalty = si_vit.penalty(vit_sub_model)
            total_loss = task_loss + (lambda_si / 2) * si_penalty

        # Check for NaN in total_loss
        if tf.math.is_nan(total_loss):
            print(f"NaN detected in total_loss at epoch {epoch+1}, step {step+1}")
            break

        grads = tape.gradient(total_loss, vit_sub_model.trainable_variables)
        # Clip gradients to prevent exploding gradients
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else None for g in grads]

        # Check for NaN in gradients
        if any([tf.reduce_any(tf.math.is_nan(g)) for g in grads if g is not None]):
            print(f"NaN detected in gradients at epoch {epoch+1}, step {step+1}")
            break

        optimizer.apply_gradients(zip(grads, vit_sub_model.trainable_variables))
        si_vit.accumulate_importance(vit_sub_model, grads)

        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)

        progress_bar.set_postfix({'loss': train_loss.result().numpy(), 'macro_f1': train_macro_f1.result().numpy()})

    epoch_time = time.time() - start_time

    # Validation
    val_macro_f1.reset_state()
    val_loss.reset_state()
    val_batches = len(X_val) // batch_size
    val_progress_bar = tqdm(range(val_batches), desc='Validation', leave=False)
    for step in val_progress_bar:
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = vit_sub_model(X_batch, training=False)
        task_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
        total_loss = task_loss

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

        val_progress_bar.set_postfix({'val_loss': val_loss.result().numpy(), 'val_macro_f1': val_macro_f1.result().numpy()})

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')

    # Check for NaN in training loss
    if tf.math.is_nan(train_loss.result()):
        print("NaN detected in training loss. Stopping training.")
        break

# After training, update omega
si_vit.update_omega()



ViT Epoch 1/25


Training:  10%|▉         | 26/266 [00:34<05:23,  1.35s/it, loss=376, macro_f1=0.112]  

                                                                                               

Epoch 1/25, Time: 341.11s, Loss: 2375.8347, Macro F1: 0.2380, Val Loss: 0.1110, Val Macro F1: 0.2482

ViT Epoch 2/25


                                                                                               

Epoch 2/25, Time: 349.26s, Loss: 1198.9891, Macro F1: 0.2857, Val Loss: 0.1145, Val Macro F1: 0.2572

ViT Epoch 3/25


                                                                                               

Epoch 3/25, Time: 342.75s, Loss: 34.4861, Macro F1: 0.3366, Val Loss: 0.1232, Val Macro F1: 0.2580

ViT Epoch 4/25


                                                                                               

Epoch 4/25, Time: 341.61s, Loss: 0.0766, Macro F1: 0.3937, Val Loss: 0.1315, Val Macro F1: 0.2636

ViT Epoch 5/25


                                                                                               

Epoch 5/25, Time: 344.07s, Loss: 0.0682, Macro F1: 0.4422, Val Loss: 0.1394, Val Macro F1: 0.2533

ViT Epoch 6/25


                                                                                               

Epoch 6/25, Time: 339.96s, Loss: 0.0604, Macro F1: 0.4810, Val Loss: 0.1511, Val Macro F1: 0.2336

ViT Epoch 7/25


                                                                                               

Epoch 7/25, Time: 344.08s, Loss: 0.0541, Macro F1: 0.5103, Val Loss: 0.1601, Val Macro F1: 0.2417

ViT Epoch 8/25


                                                                                               

Epoch 8/25, Time: 350.38s, Loss: 0.0480, Macro F1: 0.5356, Val Loss: 0.1680, Val Macro F1: 0.2421

ViT Epoch 9/25


                                                                                               

Epoch 9/25, Time: 343.95s, Loss: 0.0438, Macro F1: 0.5576, Val Loss: 0.1765, Val Macro F1: 0.2426

ViT Epoch 10/25


                                                                                               

Epoch 10/25, Time: 346.55s, Loss: 0.0393, Macro F1: 0.5702, Val Loss: 0.1843, Val Macro F1: 0.2438

ViT Epoch 11/25


                                                                                               

Epoch 11/25, Time: 345.75s, Loss: 0.0364, Macro F1: 0.5861, Val Loss: 0.1903, Val Macro F1: 0.2520

ViT Epoch 12/25


                                                                                               

Epoch 12/25, Time: 346.57s, Loss: 0.0336, Macro F1: 0.5943, Val Loss: 0.1981, Val Macro F1: 0.2569

ViT Epoch 13/25


                                                                                               

Epoch 13/25, Time: 350.78s, Loss: 0.0308, Macro F1: 0.6055, Val Loss: 0.2089, Val Macro F1: 0.2382

ViT Epoch 14/25


                                                                                               

Epoch 14/25, Time: 345.07s, Loss: 0.0289, Macro F1: 0.6132, Val Loss: 0.2052, Val Macro F1: 0.2475

ViT Epoch 15/25


                                                                                               

Epoch 15/25, Time: 350.75s, Loss: 0.0271, Macro F1: 0.6166, Val Loss: 0.2174, Val Macro F1: 0.2384

ViT Epoch 16/25


                                                                                               

Epoch 16/25, Time: 348.36s, Loss: 0.0243, Macro F1: 0.6269, Val Loss: 0.2185, Val Macro F1: 0.2546

ViT Epoch 17/25


                                                                                               

Epoch 17/25, Time: 346.36s, Loss: 0.0232, Macro F1: 0.6309, Val Loss: 0.2205, Val Macro F1: 0.2470

ViT Epoch 18/25


                                                                                               

Epoch 18/25, Time: 340.19s, Loss: 0.0221, Macro F1: 0.6335, Val Loss: 0.2221, Val Macro F1: 0.2416

ViT Epoch 19/25


                                                                                               

Epoch 19/25, Time: 351.44s, Loss: 0.0207, Macro F1: 0.6395, Val Loss: 0.2284, Val Macro F1: 0.2415

ViT Epoch 20/25


                                                                                               

Epoch 20/25, Time: 348.88s, Loss: 0.0196, Macro F1: 0.6414, Val Loss: 0.2295, Val Macro F1: 0.2570

ViT Epoch 21/25


                                                                                               

Epoch 21/25, Time: 351.37s, Loss: 0.0185, Macro F1: 0.6427, Val Loss: 0.2325, Val Macro F1: 0.2639

ViT Epoch 22/25


                                                                                               

Epoch 22/25, Time: 350.20s, Loss: 0.0182, Macro F1: 0.6437, Val Loss: 0.2388, Val Macro F1: 0.2542

ViT Epoch 23/25


                                                                                               

Epoch 23/25, Time: 347.66s, Loss: 0.0164, Macro F1: 0.6499, Val Loss: 0.2449, Val Macro F1: 0.2620

ViT Epoch 24/25


                                                                                               

Epoch 24/25, Time: 347.04s, Loss: 0.0157, Macro F1: 0.6521, Val Loss: 0.2529, Val Macro F1: 0.2696

ViT Epoch 25/25


                                                                                               

Epoch 25/25, Time: 349.71s, Loss: 0.0151, Macro F1: 0.6534, Val Loss: 0.2481, Val Macro F1: 0.2589




# Compiling and Publishing Results

In [46]:
print("CNN EWC Subdiagnostic Classification Report:")
cnn_ewc_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN EWC Superdiagnostic Classification Report:")
cnn_ewc_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet EWC Subdiagnostic Classification Report:")
resnet_ewc_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet EWC Superdiagnostic Classification Report:")
resnet_ewc_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT EWC Subdiagnostic Classification Report:")
vit_ewc_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT EWC Superdiagnostic Classification Report:")
vit_ewc_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN EWC Subdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step
              precision    recall  f1-score   support

         AMI       0.73      0.78      0.75       306
       CLBBB       0.95      0.76      0.85        54
       CRBBB       0.78      0.91      0.84        54
       ILBBB       0.20      0.12      0.15         8
         IMI       0.67      0.58      0.63       327
       IRBBB       0.62      0.69      0.65       112
        ISCA       0.43      0.32      0.37        93
        ISCI       1.00      0.15      0.26        40
        ISC_       0.73      0.38      0.50       128
        IVCD       0.19      0.11      0.14        79
   LAFB/LPFB       0.76      0.60      0.67       179
     LAO/LAE       0.20      0.02      0.04        42
         LMI       0.20      0.05      0.08        20
         LVH       0.71      0.44      0.55       214
        NORM       0.80      0.90      0.85       963
        NST_     

[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 38ms/step
              precision    recall  f1-score   support

         AMI       0.60      0.64      0.62       306
       CLBBB       0.90      0.69      0.78        54
       CRBBB       0.80      0.65      0.71        54
       ILBBB       0.50      0.12      0.20         8
         IMI       0.59      0.53      0.56       327
       IRBBB       0.43      0.46      0.44       112
        ISCA       0.24      0.17      0.20        93
        ISCI       0.29      0.20      0.24        40
        ISC_       0.56      0.44      0.49       128
        IVCD       0.17      0.09      0.12        79
   LAFB/LPFB       0.68      0.66      0.67       179
     LAO/LAE       0.00      0.00      0.00        42
         LMI       0.17      0.05      0.08        20
         LVH       0.66      0.59      0.62       214
        NORM       0.81      0.80      0.80       963
        NST_       0.10      0.03      0.04        77
       

In [47]:
print("CNN SI Subdiagnostic Classification Report:")
cnn_si_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN SI Superdiagnostic Classification Report:")
cnn_si_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet SI Subdiagnostic Classification Report:")
resnet_si_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet SI Superdiagnostic Classification Report:")
resnet_si_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT SI Subdiagnostic Classification Report:")
vit_si_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT SI Superdiagnostic Classification Report:")
vit_si_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN SI Subdiagnostic Classification Report:
[1m68/68[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
              precision    recall  f1-score   support

         AMI       0.73      0.78      0.75       306
       CLBBB       0.95      0.76      0.85        54
       CRBBB       0.78      0.91      0.84        54
       ILBBB       0.20      0.12      0.15         8
         IMI       0.67      0.58      0.63       327
       IRBBB       0.62      0.69      0.65       112
        ISCA       0.43      0.32      0.37        93
        ISCI       1.00      0.15      0.26        40
        ISC_       0.73      0.38      0.50       128
        IVCD       0.19      0.11      0.14        79
   LAFB/LPFB       0.76      0.60      0.67       179
     LAO/LAE       0.20      0.02      0.04        42
         LMI       0.20      0.05      0.08        20
         LVH       0.71      0.44      0.55       214
        NORM       0.80      0.90      0.85       963
        NST_      

In [48]:
def get_macro_f1(report_dict):
    return report_dict['macro avg']['f1-score']

results = {
    'Model': [],
    'Task': [],
    'Macro F1-score': []
}

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_super_report),
    get_macro_f1(resnet_super_report),
    get_macro_f1(vit_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_sub_report),
    get_macro_f1(resnet_sub_report),
    get_macro_f1(vit_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_sub_report),
    get_macro_f1(resnet_ewc_sub_report),
    get_macro_f1(vit_ewc_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_super_report),
    get_macro_f1(resnet_ewc_super_report),
    get_macro_f1(vit_ewc_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_sub_report),
    get_macro_f1(resnet_si_sub_report),
    get_macro_f1(vit_si_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_super_report),
    get_macro_f1(resnet_si_super_report),
    get_macro_f1(vit_si_super_report)
])

results_df = pd.DataFrame(results)
print("\nSummary of Classification Performance:")
print(results_df)



Summary of Classification Performance:
     Model                 Task  Macro F1-score
0      CNN      Superdiagnostic        0.746157
1   ResNet      Superdiagnostic        0.732267
2      ViT      Superdiagnostic        0.667310
3      CNN        Subdiagnostic        0.401320
4   ResNet        Subdiagnostic        0.407258
5      ViT        Subdiagnostic        0.282383
6      CNN    EWC Subdiagnostic        0.397926
7   ResNet    EWC Subdiagnostic        0.370305
8      ViT    EWC Subdiagnostic        0.304643
9      CNN  EWC Superdiagnostic        0.587815
10  ResNet  EWC Superdiagnostic        0.334750
11     ViT  EWC Superdiagnostic        0.591112
12     CNN     SI Subdiagnostic        0.397926
13  ResNet     SI Subdiagnostic        0.370305
14     ViT     SI Subdiagnostic        0.304643
15     CNN   SI Superdiagnostic        0.587815
16  ResNet   SI Superdiagnostic        0.334750
17     ViT   SI Superdiagnostic        0.591112
