# Importing requirements

In [1]:
import os
import numpy as np
import pandas as pd
import wfdb
import ast
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from collections import Counter
import time
from tqdm import tqdm


2024-11-21 15:38:23.165229: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-11-21 15:38:23.214268: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


# Loading Data from dataset file

In [2]:
DATA_PATH = '/home/bmi-lab/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3'

ptbxl_df = pd.read_csv(os.path.join(DATA_PATH, 'ptbxl_database.csv'))
scp_statements = pd.read_csv(os.path.join(DATA_PATH, 'scp_statements.csv'), index_col=0)

diagnostic_scps = scp_statements[scp_statements['diagnostic'] == 1].index.values

scp_to_superclass = scp_statements['diagnostic_class'].to_dict()
scp_to_subclass = scp_statements['diagnostic_subclass'].to_dict()

In [3]:
ptbxl_df['scp_codes'] = ptbxl_df['scp_codes'].apply(lambda x: ast.literal_eval(x))

In [4]:
def aggregate_diagnostic_labels(df, scp_codes, scp_to_agg):
    df = df.copy()
    def aggregate_labels(scp_codes_dict):
        labels = set()
        for code in scp_codes_dict.keys():
            if code in scp_codes:
                label = scp_to_agg.get(code)
                if label:
                    labels.add(label)
        return list(labels)
    df['diagnostic_labels'] = df['scp_codes'].apply(aggregate_labels)
    return df

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_superclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'superclass_labels'})

ptbxl_df = aggregate_diagnostic_labels(ptbxl_df, diagnostic_scps, scp_to_subclass)
ptbxl_df = ptbxl_df.rename(columns={'diagnostic_labels': 'subclass_labels'})

In [5]:
ptbxl_df = ptbxl_df[ptbxl_df['superclass_labels'].map(len) > 0]

In [6]:
train_df = ptbxl_df[ptbxl_df.strat_fold <= 8]
val_df = ptbxl_df[ptbxl_df.strat_fold == 9]
test_df = ptbxl_df[ptbxl_df.strat_fold == 10]

In [7]:
def load_data(df, sampling_rate, data_path):
    data = []
    if sampling_rate == 100:
        filenames = df['filename_lr'].values
    else:
        filenames = df['filename_hr'].values
    for filename in filenames:
        file_path = os.path.join(data_path, filename)
        signals, _ = wfdb.rdsamp(file_path)
        data.append(signals)
    return np.array(data)

X_train = load_data(train_df, sampling_rate=100, data_path=DATA_PATH)
X_val = load_data(val_df, sampling_rate=100, data_path=DATA_PATH)
X_test = load_data(test_df, sampling_rate=100, data_path=DATA_PATH)

In [8]:
train_labels_super = train_df['superclass_labels'].values
val_labels_super = val_df['superclass_labels'].values
test_labels_super = test_df['superclass_labels'].values

mlb_super = MultiLabelBinarizer()
y_train_super = mlb_super.fit_transform(train_labels_super)
y_val_super = mlb_super.transform(val_labels_super)
y_test_super = mlb_super.transform(test_labels_super)
classes_super = mlb_super.classes_

In [9]:
train_labels_sub = train_df['subclass_labels'].values
val_labels_sub = val_df['subclass_labels'].values
test_labels_sub = test_df['subclass_labels'].values

mlb_sub = MultiLabelBinarizer()
y_train_sub = mlb_sub.fit_transform(train_labels_sub)
y_val_sub = mlb_sub.transform(val_labels_sub)
y_test_sub = mlb_sub.transform(test_labels_sub)
classes_sub = mlb_sub.classes_

In [10]:
def normalize_data_per_channel(X):
    X = np.transpose(X, (0, 2, 1))
    mean = np.mean(X, axis=(0, 2), keepdims=True)
    std = np.std(X, axis=(0, 2), keepdims=True)
    X = (X - mean) / std
    X = np.transpose(X, (0, 2, 1))
    return X

X_train = normalize_data_per_channel(X_train)
X_val = normalize_data_per_channel(X_val)
X_test = normalize_data_per_channel(X_test)

In [11]:
class_counts_super = np.sum(y_train_super, axis=0)
total_samples_super = y_train_super.shape[0]

class_weight_super = {}
for i, count in enumerate(class_counts_super):
    class_weight_super[i] = total_samples_super / (len(class_counts_super) * count)

class_counts_sub = np.sum(y_train_sub, axis=0)
total_samples_sub = y_train_sub.shape[0]

class_weight_sub = {}
for i, count in enumerate(class_counts_sub):
    class_weight_sub[i] = total_samples_sub / (len(class_counts_sub) * count)

In [12]:
num_classes_super = y_train_super.shape[1]
class_totals = np.sum(y_train_super, axis=0)
class_weights = class_totals.max() / class_totals
weights_array = np.array(class_weights, dtype=np.float32)

In [13]:
num_classes_sub = y_train_sub.shape[1]
class_totals_sub = np.sum(y_train_sub, axis=0)
class_weights_sub = class_totals_sub.max() / class_totals_sub
weights_array_sub = np.array(class_weights_sub, dtype=np.float32)

In [14]:
y_train_super = y_train_super.astype(np.float32)
y_val_super = y_val_super.astype(np.float32)
y_test_super = y_test_super.astype(np.float32)

# Defining Entropy and Metrics

In [15]:
import tensorflow.keras.backend as K

def weighted_binary_crossentropy(weights):
    def loss(y_true, y_pred):
        weights_cast = K.cast(weights, y_pred.dtype)
        y_true = K.cast(y_true, y_pred.dtype)
        
        bce = K.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * weights_cast + (1 - y_true)
        weighted_bce = weight_vector * bce
        return K.mean(weighted_bce)
    return loss

def macro_f1(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    y_pred = K.round(y_pred)
    
    tp = K.sum(y_true * y_pred, axis=0)
    fp = K.sum((1 - y_true) * y_pred, axis=0)
    fn = K.sum(y_true * (1 - y_pred), axis=0)

    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())

    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
    return K.mean(f1)

# Defining Models

In [16]:
def create_cnn_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv1D(64, kernel_size=7, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(128, kernel_size=5, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(256, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    
    x = layers.Conv1D(512, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    
    model = models.Model(inputs, outputs)
    return model


In [17]:
# def create_resnet_model(input_shape, num_classes):
#     inputs = layers.Input(shape=input_shape)
#     x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
#     x = layers.BatchNormalization()(x)
#     x = layers.Activation('relu')(x)
#     x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    
#     previous_filters = x.shape[-1]
#     for filters in [64, 128, 256]:
#         x_shortcut = x
#         strides = 1
#         if previous_filters != filters:
#             strides = 2

#         x = layers.Conv1D(filters, kernel_size=3, strides=strides, padding='same')(x)
#         x = layers.BatchNormalization()(x)
#         x = layers.Activation('relu')(x)
#         x = layers.Conv1D(filters, kernel_size=3, padding='same')(x)
#         x = layers.BatchNormalization()(x)
        
#         if previous_filters != filters or strides != 1:
#             x_shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(x_shortcut)
#             x_shortcut = layers.BatchNormalization()(x_shortcut)
        
#         x = layers.Add()([x, x_shortcut])
#         x = layers.Activation('relu')(x)
#         previous_filters = filters
#     x = layers.GlobalAveragePooling1D()(x)
#     outputs = layers.Dense(num_classes, activation='sigmoid')(x)
#     model = models.Model(inputs, outputs)
#     return model

In [18]:
def residual_block_1d(x, filters, kernel_size=3, strides=1, downsample=False):
    shortcut = x
    
    x = layers.Conv1D(filters, kernel_size=kernel_size, strides=strides, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv1D(filters, kernel_size=kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    if downsample or shortcut.shape[-1] != filters:
        shortcut = layers.Conv1D(filters, kernel_size=1, strides=strides, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def create_resnet_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)
    layers_filters = [64, 128, 256, 512]
    layers_blocks = [3, 4, 6, 3]

    for filters, num_blocks in zip(layers_filters, layers_blocks):
        for i in range(num_blocks):
            if i == 0 and filters != x.shape[-1]:
                x = residual_block_1d(x, filters, strides=2, downsample=True)
            else:
                x = residual_block_1d(x, filters)

    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs, outputs)
    return model

In [19]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def create_vit_model(input_shape, num_classes):
    patch_size = 10 
    num_patches = input_shape[0] // patch_size
    projection_dim = 64
    num_heads = 4
    transformer_layers = 8
    mlp_head_units = [256, 128]
    dropout_rate = 0.1

    inputs = layers.Input(shape=input_shape)
    x = layers.Reshape((num_patches, patch_size * input_shape[1]))(inputs)
    x = layers.Dense(units=projection_dim)(x)
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
    x = x + position_embedding(positions)
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=[projection_dim * 2, projection_dim], dropout_rate=dropout_rate)
        x = layers.Add()([x3, x2])
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Defining the training loop

In [20]:
def train_model(model, X_train, y_train, X_val, y_val, class_weight, batch_size=64, epochs=50):
    optimizer = tf.keras.optimizers.Adam()
    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=['accuracy', macro_f1]
    )
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    ]
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        class_weight=class_weight
    )
    return history

# Training and Evaluating Models without CL

In [21]:
input_shape = X_train.shape[1:]
num_classes_super = y_train_super.shape[1]

cnn_super_model = create_cnn_model(input_shape, num_classes_super)
train_model(cnn_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

2024-11-21 15:38:44.219509: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-11-21 15:38:44.236045: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-11-21 15:38:44.238077: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2024-11-21 15:38:44.240180: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA
To enable them in other 

Epoch 1/50


2024-11-21 15:38:45.723645: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8907
2024-11-21 15:38:45.769927: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2024-11-21 15:38:45.853407: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2024-11-21 15:38:45.865045: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x781be003c160 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-11-21 15:38:45.865074: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA GeForce RTX 4080, Compute Capability 8.9
2024-11-21 15:38:45.868120: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-11-21 15:38:45.90336

Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50


<keras.callbacks.History at 0x781e9323b400>

In [22]:
resnet_super_model = create_resnet_model(input_shape, num_classes_super)
train_model(resnet_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50


<keras.callbacks.History at 0x781e7dd44a00>

In [23]:
vit_super_model = create_vit_model(input_shape, num_classes_super)
train_model(vit_super_model, X_train, y_train_super, X_val, y_val_super, class_weight_super)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50


<keras.callbacks.History at 0x781e20affbb0>

In [24]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = create_cnn_model(input_shape, num_classes_sub)
train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50


<keras.callbacks.History at 0x781de219ab30>

In [25]:
resnet_sub_model = create_resnet_model(input_shape, num_classes_sub)
train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50


<keras.callbacks.History at 0x781dd6b226b0>

In [26]:
vit_sub_model = create_vit_model(input_shape, num_classes_sub)
train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50


<keras.callbacks.History at 0x781d52123700>

In [27]:
def evaluate_model(model, X_test, y_test, classes):
    y_pred = model.predict(X_test)
    y_pred_threshold = (y_pred >= 0.5).astype(int)
    report = classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0, output_dict=True)
    print(classification_report(y_test, y_pred_threshold, target_names=classes, zero_division=0))
    return report


In [28]:
print("CNN Superdiagnostic Classification Report:")
cnn_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet Superdiagnostic Classification Report:")
resnet_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT Superdiagnostic Classification Report:")
vit_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN Superdiagnostic Classification Report:
              precision    recall  f1-score   support

          CD       0.79      0.73      0.76       496
         HYP       0.66      0.56      0.61       262
          MI       0.78      0.73      0.76       550
        NORM       0.86      0.87      0.86       963
        STTC       0.74      0.77      0.75       521

   micro avg       0.79      0.77      0.78      2792
   macro avg       0.77      0.73      0.75      2792
weighted avg       0.79      0.77      0.78      2792
 samples avg       0.78      0.79      0.77      2792

ResNet Superdiagnostic Classification Report:
              precision    recall  f1-score   support

          CD       0.79      0.71      0.75       496
         HYP       0.75      0.45      0.56       262
          MI       0.75      0.74      0.74       550
        NORM       0.83      0.88      0.86       963
        STTC       0.74      0.75      0.75       521

   micro avg       0.78      0.76      0.7

In [29]:
print("CNN Subdiagnostic Classification Report:")
cnn_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet Subdiagnostic Classification Report:")
resnet_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ViT Subdiagnostic Classification Report:")
vit_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)


CNN Subdiagnostic Classification Report:
              precision    recall  f1-score   support

         AMI       0.85      0.63      0.72       306
       CLBBB       0.89      0.89      0.89        54
       CRBBB       0.81      0.89      0.85        54
       ILBBB       0.08      0.12      0.10         8
         IMI       0.73      0.53      0.61       327
       IRBBB       0.60      0.61      0.60       112
        ISCA       0.49      0.27      0.35        93
        ISCI       0.41      0.28      0.33        40
        ISC_       0.73      0.45      0.55       128
        IVCD       0.17      0.11      0.14        79
   LAFB/LPFB       0.81      0.68      0.74       179
     LAO/LAE       0.00      0.00      0.00        42
         LMI       0.20      0.15      0.17        20
         LVH       0.76      0.50      0.61       214
        NORM       0.87      0.78      0.82       963
        NST_       0.21      0.13      0.16        77
         PMI       0.00      0.00      0

# Defining and Training on LwF

In [30]:
cnn_soft_targets_super = cnn_super_model.predict(X_train)

def lwf_loss(y_true, y_pred, old_predictions, T=2):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dist_loss = tf.keras.losses.KLDivergence()(tf.nn.softmax(old_predictions / T),
                                               tf.nn.softmax(y_pred / T))
    total_loss = task_loss + dist_loss
    return total_loss

print("Working on CNN for LwF Now:")
cnn_model_lwf = create_cnn_model(input_shape, num_classes_sub)
cnn_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=cnn_soft_targets_super),
    metrics=[macro_f1]
)
train_model(cnn_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ResNet for LwF Now:")
resnet_soft_targets_super = resnet_super_model.predict(X_train)
resnet_model_lwf = create_resnet_model(input_shape, num_classes_sub)
resnet_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=resnet_soft_targets_super),
    metrics=[macro_f1]
)
train_model(resnet_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

print("Working on ViT for LwF Now:")
vit_soft_targets_super = vit_super_model.predict(X_train)
vit_model_lwf = create_vit_model(input_shape, num_classes_sub)
vit_model_lwf.compile(
    optimizer='adam',
    loss=lambda y_true, y_pred: lwf_loss(y_true, y_pred, old_predictions=vit_soft_targets_super),
    metrics=[macro_f1]
)
train_model(vit_model_lwf, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


Working on CNN for LwF Now:
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Working on ResNet for LwF Now:
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Working on ViT for LwF Now:
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50


<keras.callbacks.History at 0x781d14b8cd00>

# Defining and Training on EwC

In [31]:
class EWC:
    def __init__(self, model, X, y, batch_size=32, exclude_params=[]):
        self.model = model
        self.params = {}
        for p in model.trainable_variables:
            if id(p) not in exclude_params:
                self.params[id(p)] = p.numpy()
        self.fisher = self.compute_fisher(X, y, batch_size, exclude_params)

    def compute_fisher(self, X, y, batch_size, exclude_params):
        fisher = {}
        num_samples = X.shape[0]
        num_batches = int(np.ceil(num_samples / batch_size))

        for batch_idx in range(num_batches):
            X_batch = X[batch_idx*batch_size:(batch_idx+1)*batch_size]
            y_batch = y[batch_idx*batch_size:(batch_idx+1)*batch_size]
            with tf.GradientTape() as tape:
                preds = self.model(X_batch)
                loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_batch, preds))
            grads = tape.gradient(loss, self.model.trainable_variables)
            for p, g in zip(self.model.trainable_variables, grads):
                if g is not None and id(p) not in exclude_params:
                    param_id = id(p)
                    if param_id not in fisher:
                        fisher[param_id] = np.square(g.numpy())
                    else:
                        fisher[param_id] += np.square(g.numpy())
        for k in fisher.keys():
            fisher[k] /= num_batches
        return fisher

    def penalty(self, model):
        loss = 0
        for p in model.trainable_variables:
            param_id = id(p)
            if param_id in self.fisher:
                fisher = tf.convert_to_tensor(self.fisher[param_id])
                loss += tf.reduce_sum(fisher * tf.square(p - self.params[param_id]))
        return loss

In [32]:
def modify_model_for_subdiagnostic(base_model, num_classes_sub):
    x = base_model.layers[-2].output
    outputs = layers.Dense(num_classes_sub, activation='sigmoid', name='output_sub')(x)
    new_model = models.Model(inputs=base_model.input, outputs=outputs)
    return new_model


In [33]:
lambda_ewc = 1000
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)
exclude_params_cnn = [id(w) for w in cnn_sub_model.layers[-1].trainable_weights]
ewc_cnn = EWC(cnn_super_model, X_train, y_train_super, exclude_params=exclude_params_cnn)

def ewc_loss_cnn(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_cnn.penalty(cnn_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

cnn_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_cnn,
    metrics=[macro_f1]
)

train_model(cnn_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50


<keras.callbacks.History at 0x781d2c14beb0>

In [34]:
resnet_sub_model = modify_model_for_subdiagnostic(resnet_super_model, num_classes_sub)
exclude_params_resnet = [id(w) for w in resnet_sub_model.layers[-1].trainable_weights]
ewc_resnet = EWC(resnet_super_model, X_train, y_train_super, exclude_params=exclude_params_resnet)
def ewc_loss_resnet(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_resnet.penalty(resnet_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

resnet_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_resnet,
    metrics=[macro_f1]
)

train_model(resnet_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50


<keras.callbacks.History at 0x781d2c14b430>

In [35]:
vit_sub_model = modify_model_for_subdiagnostic(vit_super_model, num_classes_sub)
exclude_params_vit = [id(w) for w in vit_sub_model.layers[-1].trainable_weights]
ewc_vit = EWC(vit_super_model, X_train, y_train_super, exclude_params=exclude_params_vit)

def ewc_loss_vit(y_true, y_pred):
    task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ewc_penalty = ewc_vit.penalty(vit_sub_model)
    total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
    return total_loss

vit_sub_model.compile(
    optimizer='adam',
    loss=ewc_loss_vit,
    metrics=[macro_f1]
)

train_model(vit_sub_model, X_train, y_train_sub, X_val, y_val_sub, class_weight_sub)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50


<keras.callbacks.History at 0x781d0c4eaec0>

# Defining and Training on SI

In [36]:
class SI:
    def __init__(self, prev_model, damping_factor=0.1, exclude_params=[]):
        self.prev_params = {}
        self.omega = {}
        self.damping_factor = damping_factor
        self.exclude_params = exclude_params
        
        for var in prev_model.trainable_variables:
            if var.name not in self.exclude_params:
                self.prev_params[var.name] = var.numpy()
                self.omega[var.name] = np.zeros_like(var.numpy())

    def accumulate_importance(self, model, grads):
        for var, grad in zip(model.trainable_variables, grads):
            if grad is not None and var.name in self.omega:
                delta_theta = var.numpy() - self.prev_params[var.name]
                self.omega[var.name] += np.abs(grad.numpy() * delta_theta)

    def update_omega(self, model):
        for var in model.trainable_variables:
            if var.name in self.omega:
                delta_theta = var.numpy() - self.prev_params[var.name]
                denom = np.square(delta_theta) + self.damping_factor
                # Avoid division by zero
                epsilon = 1e-6
                denom = np.where(denom <epsilon, epsilon, denom)
                self.omega[var.name] = self.omega[var.name] / denom

    def penalty(self, model):
        loss = 0
        for var in model.trainable_variables:
            if var.name in self.omega:
                omega = tf.convert_to_tensor(self.omega[var.name], dtype=var.dtype)
                prev_param = tf.convert_to_tensor(self.prev_params[var.name], dtype=var.dtype)
                omega = tf.nn.relu(omega)
                loss += tf.reduce_sum(omega * tf.square(var - prev_param))
        return loss


In [37]:
num_classes_sub = y_train_sub.shape[1]
cnn_sub_model = modify_model_for_subdiagnostic(cnn_super_model, num_classes_sub)

exclude_params_cnn = [w.name for w in cnn_sub_model.layers[-1].trainable_weights]
si_cnn = SI(cnn_sub_model, exclude_params=exclude_params_cnn)


In [38]:
lambda_si = 1
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nCNN Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()
    
    for step in tqdm(range(len(X_train) // batch_size)):
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]
        
        with tf.GradientTape() as tape:
            preds = cnn_sub_model(X_batch, training=True)
            task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
            si_penalty = si_cnn.penalty(cnn_sub_model)
            total_loss = tf.reduce_mean(task_loss + lambda_si * si_penalty)
        
        grads = tape.gradient(total_loss, cnn_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, cnn_sub_model.trainable_variables))
        si_cnn.accumulate_importance(cnn_sub_model, grads)
        
        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)
        
    si_cnn.update_omega(cnn_sub_model)
    epoch_time = time.time() - start_time
    
    val_macro_f1.reset_state()
    val_loss.reset_state()
    for step in range(len(X_val) // batch_size):
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = cnn_sub_model(X_batch, training=False)
        task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
        total_loss = tf.reduce_mean(task_loss)

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)
    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')



CNN Epoch 1/25


  0%|          | 0/266 [00:00<?, ?it/s]



100%|██████████| 266/266 [00:07<00:00, 33.71it/s]


Epoch 1/25, Time: 7.90s, Loss: 0.1089, Macro F1: 0.3072, Val Loss: 0.1002, Val Macro F1: 0.3011

CNN Epoch 2/25


100%|██████████| 266/266 [00:07<00:00, 36.70it/s]


Epoch 2/25, Time: 7.25s, Loss: 0.0941, Macro F1: 0.3501, Val Loss: 0.1024, Val Macro F1: 0.3097

CNN Epoch 3/25


100%|██████████| 266/266 [00:07<00:00, 36.73it/s]


Epoch 3/25, Time: 7.25s, Loss: 0.0838, Macro F1: 0.3961, Val Loss: 0.0978, Val Macro F1: 0.3290

CNN Epoch 4/25


100%|██████████| 266/266 [00:07<00:00, 36.60it/s]


Epoch 4/25, Time: 7.27s, Loss: 0.0799, Macro F1: 0.4255, Val Loss: 0.0969, Val Macro F1: 0.3289

CNN Epoch 5/25


100%|██████████| 266/266 [00:07<00:00, 37.22it/s]


Epoch 5/25, Time: 7.15s, Loss: 0.0802, Macro F1: 0.4386, Val Loss: 0.0979, Val Macro F1: 0.3355

CNN Epoch 6/25


100%|██████████| 266/266 [00:07<00:00, 36.81it/s]


Epoch 6/25, Time: 7.23s, Loss: 0.4300, Macro F1: 0.4474, Val Loss: 0.0987, Val Macro F1: 0.3370

CNN Epoch 7/25


100%|██████████| 266/266 [00:07<00:00, 36.70it/s]


Epoch 7/25, Time: 7.25s, Loss: 3.2578, Macro F1: 0.4558, Val Loss: 0.0993, Val Macro F1: 0.3377

CNN Epoch 8/25


100%|██████████| 266/266 [00:07<00:00, 36.80it/s]


Epoch 8/25, Time: 7.23s, Loss: 37.2523, Macro F1: 0.4622, Val Loss: 0.0999, Val Macro F1: 0.3390

CNN Epoch 9/25


100%|██████████| 266/266 [00:07<00:00, 36.64it/s]


Epoch 9/25, Time: 7.26s, Loss: 364.9479, Macro F1: 0.4656, Val Loss: 0.1004, Val Macro F1: 0.3381

CNN Epoch 10/25


100%|██████████| 266/266 [00:07<00:00, 36.62it/s]


Epoch 10/25, Time: 7.27s, Loss: 3873.6130, Macro F1: 0.4700, Val Loss: 0.1011, Val Macro F1: 0.3388

CNN Epoch 11/25


100%|██████████| 266/266 [00:07<00:00, 36.78it/s]


Epoch 11/25, Time: 7.24s, Loss: 36919.3984, Macro F1: 0.4732, Val Loss: 0.1015, Val Macro F1: 0.3408

CNN Epoch 12/25


100%|██████████| 266/266 [00:07<00:00, 36.63it/s]


Epoch 12/25, Time: 7.27s, Loss: 359300.4062, Macro F1: 0.4750, Val Loss: 0.1020, Val Macro F1: 0.3386

CNN Epoch 13/25


100%|██████████| 266/266 [00:07<00:00, 36.76it/s]


Epoch 13/25, Time: 7.24s, Loss: 3355676.7500, Macro F1: 0.4803, Val Loss: 0.1025, Val Macro F1: 0.3425

CNN Epoch 14/25


100%|██████████| 266/266 [00:07<00:00, 36.65it/s]


Epoch 14/25, Time: 7.26s, Loss: 31651508.0000, Macro F1: 0.4822, Val Loss: 0.1027, Val Macro F1: 0.3424

CNN Epoch 15/25


100%|██████████| 266/266 [00:07<00:00, 36.74it/s]


Epoch 15/25, Time: 7.24s, Loss: 292165504.0000, Macro F1: 0.4880, Val Loss: 0.1031, Val Macro F1: 0.3411

CNN Epoch 16/25


100%|██████████| 266/266 [00:07<00:00, 36.80it/s]


Epoch 16/25, Time: 7.23s, Loss: 2786338816.0000, Macro F1: 0.4863, Val Loss: 0.1035, Val Macro F1: 0.3426

CNN Epoch 17/25


100%|██████████| 266/266 [00:07<00:00, 37.12it/s]


Epoch 17/25, Time: 7.17s, Loss: 25904293888.0000, Macro F1: 0.4873, Val Loss: 0.1039, Val Macro F1: 0.3410

CNN Epoch 18/25


100%|██████████| 266/266 [00:07<00:00, 36.85it/s]


Epoch 18/25, Time: 7.22s, Loss: 245290090496.0000, Macro F1: 0.4940, Val Loss: 0.1041, Val Macro F1: 0.3404

CNN Epoch 19/25


100%|██████████| 266/266 [00:07<00:00, 36.81it/s]


Epoch 19/25, Time: 7.23s, Loss: 2304667287552.0000, Macro F1: 0.4942, Val Loss: 0.1046, Val Macro F1: 0.3413

CNN Epoch 20/25


100%|██████████| 266/266 [00:07<00:00, 37.05it/s]


Epoch 20/25, Time: 7.18s, Loss: 21753579438080.0000, Macro F1: 0.4942, Val Loss: 0.1048, Val Macro F1: 0.3417

CNN Epoch 21/25


100%|██████████| 266/266 [00:07<00:00, 36.94it/s]


Epoch 21/25, Time: 7.21s, Loss: 205460581908480.0000, Macro F1: 0.4970, Val Loss: 0.1053, Val Macro F1: 0.3426

CNN Epoch 22/25


100%|██████████| 266/266 [00:07<00:00, 36.94it/s]


Epoch 22/25, Time: 7.21s, Loss: 1935975433371648.0000, Macro F1: 0.4975, Val Loss: 0.1058, Val Macro F1: 0.3415

CNN Epoch 23/25


100%|██████████| 266/266 [00:07<00:00, 37.07it/s]


Epoch 23/25, Time: 7.18s, Loss: 18340886890938368.0000, Macro F1: 0.4975, Val Loss: 0.1062, Val Macro F1: 0.3411

CNN Epoch 24/25


100%|██████████| 266/266 [00:07<00:00, 36.91it/s]


Epoch 24/25, Time: 7.21s, Loss: 173829025960034304.0000, Macro F1: 0.4982, Val Loss: 0.1063, Val Macro F1: 0.3426

CNN Epoch 25/25


100%|██████████| 266/266 [00:07<00:00, 36.96it/s]


Epoch 25/25, Time: 7.20s, Loss: 1645146334644142080.0000, Macro F1: 0.5036, Val Loss: 0.1067, Val Macro F1: 0.3415


In [39]:
resnet_sub_model = modify_model_for_subdiagnostic(resnet_super_model, num_classes_sub)
exclude_params_resnet = [w.name for w in resnet_sub_model.layers[-1].trainable_weights]
si_resnet = SI(resnet_sub_model, exclude_params=exclude_params_resnet)

In [40]:
lambda_si = 1.0  
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam()

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nResNet Epoch {epoch+1}/{epochs}')
    train_macro_f1.reset_state()
    train_loss.reset_state()
    
    for step in tqdm(range(len(X_train) // batch_size)):
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]
        
        with tf.GradientTape() as tape:
            preds = resnet_sub_model(X_batch, training=True)
            task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
            si_penalty = si_resnet.penalty(resnet_sub_model)
            total_loss = tf.reduce_mean(task_loss + lambda_si * si_penalty)
        
        grads = tape.gradient(total_loss, resnet_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, resnet_sub_model.trainable_variables))
        si_resnet.accumulate_importance(resnet_sub_model, grads)
    
        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)
    
    si_resnet.update_omega(resnet_sub_model)
    epoch_time = time.time() - start_time

    val_macro_f1.reset_state()
    val_loss.reset_state()
    for step in range(len(X_val) // batch_size):
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = resnet_sub_model(X_batch, training=False)
        task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
        total_loss = tf.reduce_mean(task_loss)

        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)

    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')



ResNet Epoch 1/25


100%|██████████| 266/266 [00:50<00:00,  5.30it/s]


Epoch 1/25, Time: 50.24s, Loss: 0.1104, Macro F1: 0.2991, Val Loss: 0.1084, Val Macro F1: 0.2724

ResNet Epoch 2/25


100%|██████████| 266/266 [00:45<00:00,  5.80it/s]


Epoch 2/25, Time: 45.91s, Loss: 0.1091, Macro F1: 0.3254, Val Loss: 0.1107, Val Macro F1: 0.2569

ResNet Epoch 3/25


100%|██████████| 266/266 [00:46<00:00,  5.78it/s]


Epoch 3/25, Time: 46.03s, Loss: 0.0947, Macro F1: 0.3228, Val Loss: 0.1012, Val Macro F1: 0.2818

ResNet Epoch 4/25


100%|██████████| 266/266 [00:45<00:00,  5.84it/s]


Epoch 4/25, Time: 45.55s, Loss: 0.0863, Macro F1: 0.3775, Val Loss: 0.1011, Val Macro F1: 0.3037

ResNet Epoch 5/25


100%|██████████| 266/266 [00:45<00:00,  5.79it/s]


Epoch 5/25, Time: 45.96s, Loss: 0.0902, Macro F1: 0.4014, Val Loss: 0.1030, Val Macro F1: 0.3110

ResNet Epoch 6/25


100%|██████████| 266/266 [00:45<00:00,  5.79it/s]


Epoch 6/25, Time: 45.99s, Loss: 0.5732, Macro F1: 0.4161, Val Loss: 0.1043, Val Macro F1: 0.3127

ResNet Epoch 7/25


100%|██████████| 266/266 [00:45<00:00,  5.84it/s]


Epoch 7/25, Time: 45.60s, Loss: 4.2793, Macro F1: 0.4251, Val Loss: 0.1053, Val Macro F1: 0.3166

ResNet Epoch 8/25


100%|██████████| 266/266 [00:45<00:00,  5.80it/s]


Epoch 8/25, Time: 45.91s, Loss: 50.1227, Macro F1: 0.4311, Val Loss: 0.1060, Val Macro F1: 0.3220

ResNet Epoch 9/25


100%|██████████| 266/266 [00:45<00:00,  5.81it/s]


Epoch 9/25, Time: 45.83s, Loss: 486.2188, Macro F1: 0.4353, Val Loss: 0.1066, Val Macro F1: 0.3236

ResNet Epoch 10/25


100%|██████████| 266/266 [00:45<00:00,  5.86it/s]


Epoch 10/25, Time: 45.46s, Loss: 5173.1504, Macro F1: 0.4397, Val Loss: 0.1070, Val Macro F1: 0.3254

ResNet Epoch 11/25


100%|██████████| 266/266 [00:45<00:00,  5.82it/s]


Epoch 11/25, Time: 45.72s, Loss: 48880.5820, Macro F1: 0.4425, Val Loss: 0.1074, Val Macro F1: 0.3261

ResNet Epoch 12/25


100%|██████████| 266/266 [00:46<00:00,  5.78it/s]


Epoch 12/25, Time: 46.06s, Loss: 474433.5625, Macro F1: 0.4460, Val Loss: 0.1078, Val Macro F1: 0.3264

ResNet Epoch 13/25


100%|██████████| 266/266 [00:45<00:00,  5.87it/s]


Epoch 13/25, Time: 45.35s, Loss: 4423563.0000, Macro F1: 0.4484, Val Loss: 0.1080, Val Macro F1: 0.3269

ResNet Epoch 14/25


100%|██████████| 266/266 [00:45<00:00,  5.80it/s]


Epoch 14/25, Time: 45.91s, Loss: 41674904.0000, Macro F1: 0.4496, Val Loss: 0.1083, Val Macro F1: 0.3280

ResNet Epoch 15/25


100%|██████████| 266/266 [00:46<00:00,  5.77it/s]


Epoch 15/25, Time: 46.09s, Loss: 384964864.0000, Macro F1: 0.4507, Val Loss: 0.1085, Val Macro F1: 0.3277

ResNet Epoch 16/25


100%|██████████| 266/266 [00:45<00:00,  5.82it/s]


Epoch 16/25, Time: 45.70s, Loss: 3667652096.0000, Macro F1: 0.4517, Val Loss: 0.1086, Val Macro F1: 0.3284

ResNet Epoch 17/25


100%|██████████| 266/266 [00:45<00:00,  5.82it/s]


Epoch 17/25, Time: 45.75s, Loss: 34171596800.0000, Macro F1: 0.4527, Val Loss: 0.1088, Val Macro F1: 0.3273

ResNet Epoch 18/25


100%|██████████| 266/266 [00:46<00:00,  5.78it/s]


Epoch 18/25, Time: 46.06s, Loss: 323459645440.0000, Macro F1: 0.4537, Val Loss: 0.1090, Val Macro F1: 0.3270

ResNet Epoch 19/25


100%|██████████| 266/266 [00:45<00:00,  5.87it/s]


Epoch 19/25, Time: 45.33s, Loss: 3035122106368.0000, Macro F1: 0.4544, Val Loss: 0.1091, Val Macro F1: 0.3274

ResNet Epoch 20/25


100%|██████████| 266/266 [00:46<00:00,  5.77it/s]


Epoch 20/25, Time: 46.11s, Loss: 28665475760128.0000, Macro F1: 0.4554, Val Loss: 0.1093, Val Macro F1: 0.3285

ResNet Epoch 21/25


100%|██████████| 266/266 [00:46<00:00,  5.77it/s]


Epoch 21/25, Time: 46.15s, Loss: 270114503000064.0000, Macro F1: 0.4563, Val Loss: 0.1094, Val Macro F1: 0.3296

ResNet Epoch 22/25


100%|██████████| 266/266 [00:45<00:00,  5.85it/s]


Epoch 22/25, Time: 45.52s, Loss: 2553551331000320.0000, Macro F1: 0.4573, Val Loss: 0.1095, Val Macro F1: 0.3295

ResNet Epoch 23/25


100%|██████████| 266/266 [00:45<00:00,  5.80it/s]


Epoch 23/25, Time: 45.85s, Loss: 24163108050173952.0000, Macro F1: 0.4579, Val Loss: 0.1096, Val Macro F1: 0.3298

ResNet Epoch 24/25


100%|██████████| 266/266 [00:45<00:00,  5.78it/s]


Epoch 24/25, Time: 46.00s, Loss: 227764366269743104.0000, Macro F1: 0.4593, Val Loss: 0.1097, Val Macro F1: 0.3278

ResNet Epoch 25/25


100%|██████████| 266/266 [00:45<00:00,  5.80it/s]


Epoch 25/25, Time: 45.88s, Loss: nan, Macro F1: 0.0112, Val Loss: nan, Val Macro F1: 0.0000


In [41]:
vit_sub_model = modify_model_for_subdiagnostic(vit_super_model, num_classes_sub)
exclude_params_vit = [w.name for w in vit_sub_model.layers[-1].trainable_weights]
si_vit = SI(vit_sub_model, exclude_params=exclude_params_vit)


In [42]:
lambda_si = 1.0
epochs = 25
batch_size = 64
optimizer = tf.keras.optimizers.Adam()

train_macro_f1 = tf.keras.metrics.Mean(name='train_macro_f1')
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_macro_f1 = tf.keras.metrics.Mean(name='val_macro_f1')
val_loss = tf.keras.metrics.Mean(name='val_loss')

for epoch in range(epochs):
    start_time = time.time()
    print(f'\nViT Epoch {epoch+1}/{epochs}')
    
    train_macro_f1.reset_state()
    train_loss.reset_state()

    for step in tqdm(range(len(X_train) // batch_size)):
        X_batch = X_train[step*batch_size:(step+1)*batch_size]
        y_batch = y_train_sub[step*batch_size:(step+1)*batch_size]
        
        with tf.GradientTape() as tape:
            preds = vit_sub_model(X_batch, training=True)
            task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
            si_penalty = si_vit.penalty(vit_sub_model)
            total_loss = tf.reduce_mean(task_loss + lambda_si * si_penalty)
        
        grads = tape.gradient(total_loss, vit_sub_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, vit_sub_model.trainable_variables))
        
        si_vit.accumulate_importance(vit_sub_model, grads)
        batch_macro_f1 = macro_f1(y_batch, preds)
        train_macro_f1.update_state(batch_macro_f1)
        train_loss.update_state(total_loss)
        
    si_vit.update_omega(vit_sub_model)
    epoch_time = time.time() - start_time
    val_macro_f1.reset_state()
    val_loss.reset_state()
    for step in range(len(X_val) // batch_size):
        X_batch = X_val[step*batch_size:(step+1)*batch_size]
        y_batch = y_val_sub[step*batch_size:(step+1)*batch_size]
        preds = vit_sub_model(X_batch, training=False)
        task_loss = tf.keras.losses.binary_crossentropy(y_batch, preds)
        total_loss = tf.reduce_mean(task_loss)
        
        batch_macro_f1 = macro_f1(y_batch, preds)
        val_macro_f1.update_state(batch_macro_f1)
        val_loss.update_state(total_loss)
        
    print(f'Epoch {epoch+1}/{epochs}, '
          f'Time: {epoch_time:.2f}s, '
          f'Loss: {train_loss.result():.4f}, '
          f'Macro F1: {train_macro_f1.result():.4f}, '
          f'Val Loss: {val_loss.result():.4f}, '
          f'Val Macro F1: {val_macro_f1.result():.4f}')



ViT Epoch 1/25


100%|██████████| 266/266 [00:41<00:00,  6.40it/s]


Epoch 1/25, Time: 41.57s, Loss: 0.1121, Macro F1: 0.2549, Val Loss: 0.1105, Val Macro F1: 0.2467

ViT Epoch 2/25


100%|██████████| 266/266 [00:36<00:00,  7.37it/s]


Epoch 2/25, Time: 36.12s, Loss: 0.0991, Macro F1: 0.3213, Val Loss: 0.1211, Val Macro F1: 0.2437

ViT Epoch 3/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 3/25, Time: 36.06s, Loss: 0.0879, Macro F1: 0.3917, Val Loss: 0.1287, Val Macro F1: 0.2480

ViT Epoch 4/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 4/25, Time: 36.05s, Loss: 0.0796, Macro F1: 0.4500, Val Loss: 0.1299, Val Macro F1: 0.2526

ViT Epoch 5/25


100%|██████████| 266/266 [00:36<00:00,  7.37it/s]


Epoch 5/25, Time: 36.08s, Loss: 0.0757, Macro F1: 0.4816, Val Loss: 0.1310, Val Macro F1: 0.2603

ViT Epoch 6/25


100%|██████████| 266/266 [00:36<00:00,  7.37it/s]


Epoch 6/25, Time: 36.08s, Loss: 0.3189, Macro F1: 0.5064, Val Loss: 0.1351, Val Macro F1: 0.2620

ViT Epoch 7/25


100%|██████████| 266/266 [00:36<00:00,  7.33it/s]


Epoch 7/25, Time: 36.30s, Loss: 2.7850, Macro F1: 0.5176, Val Loss: 0.1399, Val Macro F1: 0.2626

ViT Epoch 8/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 8/25, Time: 36.06s, Loss: 29.1105, Macro F1: 0.5313, Val Loss: 0.1434, Val Macro F1: 0.2653

ViT Epoch 9/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 9/25, Time: 36.07s, Loss: 298.0184, Macro F1: 0.5377, Val Loss: 0.1472, Val Macro F1: 0.2651

ViT Epoch 10/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 10/25, Time: 36.03s, Loss: 3092.7156, Macro F1: 0.5481, Val Loss: 0.1508, Val Macro F1: 0.2640

ViT Epoch 11/25


100%|██████████| 266/266 [00:36<00:00,  7.35it/s]


Epoch 11/25, Time: 36.21s, Loss: 30043.7070, Macro F1: 0.5545, Val Loss: 0.1537, Val Macro F1: 0.2660

ViT Epoch 12/25


100%|██████████| 266/266 [00:35<00:00,  7.40it/s]


Epoch 12/25, Time: 35.96s, Loss: 292215.6562, Macro F1: 0.5624, Val Loss: 0.1585, Val Macro F1: 0.2599

ViT Epoch 13/25


100%|██████████| 266/266 [00:35<00:00,  7.43it/s]


Epoch 13/25, Time: 35.81s, Loss: 2735755.7500, Macro F1: 0.5686, Val Loss: 0.1602, Val Macro F1: 0.2661

ViT Epoch 14/25


100%|██████████| 266/266 [00:36<00:00,  7.25it/s]


Epoch 14/25, Time: 36.69s, Loss: 25809170.0000, Macro F1: 0.5717, Val Loss: 0.1640, Val Macro F1: 0.2597

ViT Epoch 15/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 15/25, Time: 36.04s, Loss: 239026096.0000, Macro F1: 0.5749, Val Loss: 0.1673, Val Macro F1: 0.2577

ViT Epoch 16/25


100%|██████████| 266/266 [00:35<00:00,  7.39it/s]


Epoch 16/25, Time: 36.01s, Loss: 2280137472.0000, Macro F1: 0.5756, Val Loss: 0.1700, Val Macro F1: 0.2609

ViT Epoch 17/25


100%|██████████| 266/266 [00:36<00:00,  7.39it/s]


Epoch 17/25, Time: 36.03s, Loss: 21219457024.0000, Macro F1: 0.5808, Val Loss: 0.1722, Val Macro F1: 0.2588

ViT Epoch 18/25


100%|██████████| 266/266 [00:35<00:00,  7.42it/s]


Epoch 18/25, Time: 35.84s, Loss: 201637183488.0000, Macro F1: 0.5843, Val Loss: 0.1751, Val Macro F1: 0.2633

ViT Epoch 19/25


100%|██████████| 266/266 [00:36<00:00,  7.36it/s]


Epoch 19/25, Time: 36.17s, Loss: 1898964844544.0000, Macro F1: 0.5835, Val Loss: 0.1779, Val Macro F1: 0.2610

ViT Epoch 20/25


100%|██████████| 266/266 [00:36<00:00,  7.31it/s]


Epoch 20/25, Time: 36.39s, Loss: 17936200761344.0000, Macro F1: 0.5874, Val Loss: 0.1813, Val Macro F1: 0.2642

ViT Epoch 21/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 21/25, Time: 36.07s, Loss: 169584434872320.0000, Macro F1: 0.5893, Val Loss: 0.1843, Val Macro F1: 0.2637

ViT Epoch 22/25


100%|██████████| 266/266 [00:36<00:00,  7.36it/s]


Epoch 22/25, Time: 36.13s, Loss: 1608641748664320.0000, Macro F1: 0.5896, Val Loss: 0.1863, Val Macro F1: 0.2621

ViT Epoch 23/25


100%|██████████| 266/266 [00:36<00:00,  7.31it/s]


Epoch 23/25, Time: 36.41s, Loss: 15248848666492928.0000, Macro F1: 0.5923, Val Loss: 0.1874, Val Macro F1: 0.2583

ViT Epoch 24/25


100%|██████████| 266/266 [00:35<00:00,  7.39it/s]


Epoch 24/25, Time: 35.99s, Loss: 144423755706269696.0000, Macro F1: 0.5932, Val Loss: 0.1902, Val Macro F1: 0.2586

ViT Epoch 25/25


100%|██████████| 266/266 [00:36<00:00,  7.38it/s]


Epoch 25/25, Time: 36.06s, Loss: nan, Macro F1: 0.0180, Val Loss: nan, Val Macro F1: 0.0000


# Compiling and Publishing Results

In [43]:
print("CNN EWC Subdiagnostic Classification Report:")
cnn_ewc_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN EWC Superdiagnostic Classification Report:")
cnn_ewc_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet EWC Subdiagnostic Classification Report:")
resnet_ewc_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet EWC Superdiagnostic Classification Report:")
resnet_ewc_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT EWC Subdiagnostic Classification Report:")
vit_ewc_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT EWC Superdiagnostic Classification Report:")
vit_ewc_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN EWC Subdiagnostic Classification Report:
              precision    recall  f1-score   support

         AMI       0.79      0.76      0.78       306
       CLBBB       0.94      0.87      0.90        54
       CRBBB       0.80      0.89      0.84        54
       ILBBB       0.00      0.00      0.00         8
         IMI       0.64      0.74      0.68       327
       IRBBB       0.73      0.55      0.63       112
        ISCA       0.44      0.26      0.32        93
        ISCI       0.42      0.28      0.33        40
        ISC_       0.70      0.54      0.61       128
        IVCD       0.25      0.06      0.10        79
   LAFB/LPFB       0.81      0.67      0.73       179
     LAO/LAE       0.22      0.05      0.08        42
         LMI       0.33      0.05      0.09        20
         LVH       0.77      0.59      0.67       214
        NORM       0.85      0.87      0.86       963
        NST_       0.25      0.04      0.07        77
         PMI       0.00      0.00   

In [44]:
print("CNN SI Subdiagnostic Classification Report:")
cnn_si_sub_report = evaluate_model(cnn_sub_model, X_test, y_test_sub, classes_sub)

print("CNN SI Superdiagnostic Classification Report:")
cnn_si_super_report = evaluate_model(cnn_super_model, X_test, y_test_super, classes_super)

print("ResNet SI Subdiagnostic Classification Report:")
resnet_si_sub_report = evaluate_model(resnet_sub_model, X_test, y_test_sub, classes_sub)

print("ResNet SI Superdiagnostic Classification Report:")
resnet_si_super_report = evaluate_model(resnet_super_model, X_test, y_test_super, classes_super)

print("ViT SI Subdiagnostic Classification Report:")
vit_si_sub_report = evaluate_model(vit_sub_model, X_test, y_test_sub, classes_sub)

print("ViT SI Superdiagnostic Classification Report:")
vit_si_super_report = evaluate_model(vit_super_model, X_test, y_test_super, classes_super)


CNN SI Subdiagnostic Classification Report:
              precision    recall  f1-score   support

         AMI       0.79      0.76      0.78       306
       CLBBB       0.94      0.87      0.90        54
       CRBBB       0.80      0.89      0.84        54
       ILBBB       0.00      0.00      0.00         8
         IMI       0.64      0.74      0.68       327
       IRBBB       0.73      0.55      0.63       112
        ISCA       0.44      0.26      0.32        93
        ISCI       0.42      0.28      0.33        40
        ISC_       0.70      0.54      0.61       128
        IVCD       0.25      0.06      0.10        79
   LAFB/LPFB       0.81      0.67      0.73       179
     LAO/LAE       0.22      0.05      0.08        42
         LMI       0.33      0.05      0.09        20
         LVH       0.77      0.59      0.67       214
        NORM       0.85      0.87      0.86       963
        NST_       0.25      0.04      0.07        77
         PMI       0.00      0.00    

In [45]:
def get_macro_f1(report_dict):
    return report_dict['macro avg']['f1-score']

results = {
    'Model': [],
    'Task': [],
    'Macro F1-score': []
}

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_super_report),
    get_macro_f1(resnet_super_report),
    get_macro_f1(vit_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_sub_report),
    get_macro_f1(resnet_sub_report),
    get_macro_f1(vit_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_sub_report),
    get_macro_f1(resnet_ewc_sub_report),
    get_macro_f1(vit_ewc_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['EWC Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_ewc_super_report),
    get_macro_f1(resnet_ewc_super_report),
    get_macro_f1(vit_ewc_super_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Subdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_sub_report),
    get_macro_f1(resnet_si_sub_report),
    get_macro_f1(vit_si_sub_report)
])

results['Model'].extend(['CNN', 'ResNet', 'ViT'])
results['Task'].extend(['SI Superdiagnostic'] * 3)
results['Macro F1-score'].extend([
    get_macro_f1(cnn_si_super_report),
    get_macro_f1(resnet_si_super_report),
    get_macro_f1(vit_si_super_report)
])

results_df = pd.DataFrame(results)
print("\nSummary of Classification Performance:")
print(results_df)



Summary of Classification Performance:
     Model                 Task  Macro F1-score
0      CNN      Superdiagnostic        0.748892
1   ResNet      Superdiagnostic        0.732118
2      ViT      Superdiagnostic        0.651557
3      CNN        Subdiagnostic        0.433339
4   ResNet        Subdiagnostic        0.410882
5      ViT        Subdiagnostic        0.248612
6      CNN    EWC Subdiagnostic        0.451675
7   ResNet    EWC Subdiagnostic        0.000000
8      ViT    EWC Subdiagnostic        0.000000
9      CNN  EWC Superdiagnostic        0.691851
10  ResNet  EWC Superdiagnostic        0.000000
11     ViT  EWC Superdiagnostic        0.000000
12     CNN     SI Subdiagnostic        0.451675
13  ResNet     SI Subdiagnostic        0.000000
14     ViT     SI Subdiagnostic        0.000000
15     CNN   SI Superdiagnostic        0.691851
16  ResNet   SI Superdiagnostic        0.000000
17     ViT   SI Superdiagnostic        0.000000


# GPT Driven code for Federated Learning

In [46]:
# ----------------------------------------
# Federated Learning with Continual Learning (EWC & SI)
# ----------------------------------------

print("Starting Federated Learning with Continual Learning (EWC & SI)")

# Number of clients to simulate
num_clients = 5

# Communication rounds
communication_rounds = 5

# Local training epochs per client
local_epochs = 1

# Batch size for local training
local_batch_size = 64

# Lambda values for EWC and SI
lambda_ewc = 1000
lambda_si = 1.0

# Function to split data among clients
def split_data(X, y_super, y_sub, num_clients):
    client_data = []
    data_per_client = len(X) // num_clients
    for i in range(num_clients):
        start_idx = i * data_per_client
        end_idx = (i + 1) * data_per_client if i < num_clients - 1 else len(X)
        X_client = X[start_idx:end_idx]
        y_client_super = y_super[start_idx:end_idx]
        y_client_sub = y_sub[start_idx:end_idx]
        client_data.append((X_client, y_client_super, y_client_sub))
    return client_data

# Split the training data among clients
client_data = split_data(X_train, y_train_super, y_train_sub, num_clients)

# Function to clone a model and set weights
def clone_model_weights(model):
    cloned_model = tf.keras.models.clone_model(model)
    cloned_model.set_weights(model.get_weights())
    return cloned_model

# Function to average client weights
def average_weights(client_weights):
    avg_weights = []
    for weights in zip(*client_weights):
        avg = np.mean(weights, axis=0)
        avg_weights.append(avg)
    return avg_weights

# Define Federated Learning with EWC and SI
def federated_training(model_type, create_model_fn, input_shape, num_classes_super, num_classes_sub, classes_super, classes_sub):
    print(f"--- Federated Training for {model_type} ---")
    
    # Initialize the global model
    global_model_super = create_model_fn(input_shape, num_classes_super)
    global_weights_super = global_model_super.get_weights()
    
    # Initialize the global models for subdiagnostic tasks
    global_model_sub = modify_model_for_subdiagnostic(global_model_super, num_classes_sub)
    global_weights_sub = global_model_sub.get_weights()
    
    for round_num in range(communication_rounds):
        print(f"Communication Round {round_num+1}/{communication_rounds}")
        client_weights_super = []
        client_weights_sub = []
        
        for client_idx, (X_client, y_client_super, y_client_sub) in enumerate(client_data):
            print(f" - Client {client_idx+1}/{num_clients} local training")
            
            # Clone global models
            client_model_super = clone_model_weights(global_model_super)
            client_model_sub = clone_model_weights(global_model_sub)
            
            # --------------------
            # Local Training on Superdiagnostic Task
            # --------------------
            client_model_super.compile(
                optimizer='adam',
                loss='binary_crossentropy',
                metrics=[macro_f1]
            )
            client_model_super.fit(
                X_client, y_client_super,
                epochs=local_epochs,
                batch_size=local_batch_size,
                verbose=0
            )
            
            # Initialize EWC for the client
            exclude_params_super = [id(w) for w in client_model_super.layers[-1].trainable_weights]
            ewc_client = EWC(client_model_super, X_client, y_client_super, exclude_params=exclude_params_super)
            
            # --------------------
            # Local Training on Subdiagnostic Task with EWC
            # --------------------
            client_model_sub = modify_model_for_subdiagnostic(client_model_super, num_classes_sub)
            
            # Define EWC loss for subdiagnostic task
            def ewc_loss_sub(y_true, y_pred):
                task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
                ewc_penalty = ewc_client.penalty(client_model_sub)
                total_loss = task_loss + (lambda_ewc / 2) * ewc_penalty
                return total_loss
            
            client_model_sub.compile(
                optimizer='adam',
                loss=ewc_loss_sub,
                metrics=[macro_f1]
            )
            client_model_sub.fit(
                X_client, y_client_sub,
                epochs=local_epochs,
                batch_size=local_batch_size,
                verbose=0
            )
            
            # Initialize SI for the client
            exclude_params_sub = [id(w) for w in client_model_sub.layers[-1].trainable_weights]
            si_client = SI(client_model_sub, damping_factor=0.1, exclude_params=exclude_params_sub)
            
            # --------------------
            # Local Training on Subdiagnostic Task with SI
            # --------------------
            # Define SI loss for subdiagnostic task
            def si_loss_sub(y_true, y_pred):
                task_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
                si_penalty = si_client.penalty(client_model_sub)
                total_loss = task_loss + (lambda_si / 2) * si_penalty
                return total_loss
            
            client_model_sub.compile(
                optimizer='adam',
                loss=si_loss_sub,
                metrics=[macro_f1]
            )
            client_model_sub.fit(
                X_client, y_client_sub,
                epochs=local_epochs,
                batch_size=local_batch_size,
                verbose=0
            )
            
            # Update EWC and SI
            # For EWC, we've already computed the penalty during training
            # For SI, update omega after training
            si_client.update_omega(client_model_sub)
            
            # Collect client weights
            client_weights_super.append(client_model_super.get_weights())
            client_weights_sub.append(client_model_sub.get_weights())
        
        # Aggregate client weights to update global models
        global_weights_super = average_weights(client_weights_super)
        global_weights_sub = average_weights(client_weights_sub)
        
        # Set the aggregated weights to global models
        global_model_super.set_weights(global_weights_super)
        global_model_sub.set_weights(global_weights_sub)
    
    # After federated training, evaluate the global models
    print(f"--- Evaluation for {model_type} after Federated Learning ---")
    
    # Evaluate on Superdiagnostic Task
    print(f"Evaluating Global {model_type} Model on Superdiagnostic Task:")
    super_report = evaluate_model(global_model_super, X_test, y_test_super, classes_super)
    
    # Evaluate on Subdiagnostic Task
    print(f"Evaluating Global {model_type} Model on Subdiagnostic Task:")
    global_model_sub = modify_model_for_subdiagnostic(global_model_super, num_classes_sub)
    sub_report = evaluate_model(global_model_sub, X_test, y_test_sub, classes_sub)
    
    return super_report, sub_report

# ----------------------------------------
# Start Federated Training for CNN, ResNet, ViT
# ----------------------------------------

# Initialize a dictionary to store federated reports
federated_reports = {
    'Model': [],
    'Task': [],
    'Macro F1-score': []
}

# List of models to federate
models_to_federate = [
    ('CNN', create_cnn_model),
    ('ResNet', create_resnet_model),
    ('ViT', create_vit_model)
]

for model_name, create_model_fn in models_to_federate:
    # Determine number of classes based on model type
    if model_name in ['CNN', 'ResNet', 'ViT']:
        num_classes_sub = y_train_sub.shape[1]
    
    # Federated training
    super_report, sub_report = federated_training(
        model_type=model_name,
        create_model_fn=create_model_fn,
        input_shape=input_shape,
        num_classes_super=num_classes_super,
        num_classes_sub=num_classes_sub,
        classes_super=classes_super,
        classes_sub=classes_sub
    )
    
    # Store the reports
    federated_reports['Model'].extend([model_name, model_name])
    federated_reports['Task'].extend(['Federated Superdiagnostic', 'Federated Subdiagnostic'])
    federated_reports['Macro F1-score'].extend([
        get_macro_f1(super_report),
        get_macro_f1(sub_report)
    ])

# ----------------------------------------
# Update and Publish the Summary Table with Federated Results
# ----------------------------------------

# Convert federated reports to DataFrame
federated_df = pd.DataFrame(federated_reports)

# Append federated results to the existing summary table
results = results.append(federated_df, ignore_index=True)

print("Updated Summary of Classification Performance with Federated Learning:")
print(results)

Starting Federated Learning with Continual Learning (EWC & SI)
--- Federated Training for CNN ---
Communication Round 1/5
 - Client 1/5 local training
 - Client 2/5 local training
 - Client 3/5 local training
 - Client 4/5 local training
 - Client 5/5 local training
Communication Round 2/5
 - Client 1/5 local training
 - Client 2/5 local training
 - Client 3/5 local training
 - Client 4/5 local training
 - Client 5/5 local training
Communication Round 3/5
 - Client 1/5 local training
 - Client 2/5 local training
 - Client 3/5 local training
 - Client 4/5 local training
 - Client 5/5 local training
Communication Round 4/5
 - Client 1/5 local training
 - Client 2/5 local training
 - Client 3/5 local training
 - Client 4/5 local training
 - Client 5/5 local training
Communication Round 5/5
 - Client 1/5 local training
 - Client 2/5 local training
 - Client 3/5 local training
 - Client 4/5 local training
 - Client 5/5 local training
--- Evaluation for CNN after Federated Learning ---
Evalu

AttributeError: 'dict' object has no attribute 'append'