In [1]:
import pandas as pd
import numpy as np
import os
import tensorflow as tf
from sklearn.utils import shuffle
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import MultiLabelBinarizer
from PIL import Image
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import f1_score, precision_score, recall_score
tf.keras.backend.clear_session()
from tensorflow.keras.metrics import Precision, Recall


In [2]:
text_data_path = 'modified_text.csv'
image_metadata_path = 'indiana_projections.csv'
image_folder_path = 'images/images_normalized'

text_data = pd.read_csv(text_data_path)

image_metadata = pd.read_csv(image_metadata_path)
image_metadata['filename'] = image_metadata['filename'].apply(lambda x: os.path.join(image_folder_path, x))

merged_data = pd.merge(text_data, image_metadata[['uid', 'filename']], on='uid')

tokenizer = Tokenizer(num_words=10000, oov_token="<OOV>")
tokenizer.fit_on_texts(merged_data['notes'])
sequences = tokenizer.texts_to_sequences(merged_data['notes'])
padded_sequences = pad_sequences(sequences, maxlen=100, padding='post', truncating='post')



In [3]:
# Image data loading and preprocessing
def load_and_preprocess_image(image_path, target_size=(224, 224)):
    image = Image.open(image_path).convert('RGB')
    image = image.resize(target_size)
    image_array = np.array(image)
    from tensorflow.keras.applications.resnet50 import preprocess_input
    image_array = preprocess_input(image_array)
    return image_array

images = np.array([load_and_preprocess_image(img_path) for img_path in merged_data['filename']])

images_file_path = 'images.npy'
np.save(images_file_path, images)

In [4]:
images_file_path = 'images.npy'
images = np.load(images_file_path)
images.shape

(7466, 224, 224, 3)

In [5]:
import tensorflow as tf

class CustomAccuracy(tf.keras.metrics.Metric):
    def __init__(self, name='custom_accuracy', **kwargs):
        super(CustomAccuracy, self).__init__(name=name, **kwargs)
        self.total_correct = self.add_weight(name='total_correct', initializer='zeros', dtype=tf.float32)
        self.total_labels = self.add_weight(name='total_labels', initializer='zeros', dtype=tf.float32)

    def update_state(self, y_true, y_pred, sample_weight=None):

        # The prediction results are converted to binarization
        y_pred_binary = tf.cast(y_pred > 0.5, tf.float32)

        # Condition 1: Label 0 is predicted to be positive
        condition1 = y_pred_binary[:, 0] > 0.5  # shape: (batch_size,)

        # Condition 2: Any one of the other labels is predicted to be positive
        condition2 = tf.reduce_any(y_pred_binary[:, 1:] > 0.5, axis=1)  # shape: (batch_size,)

        # Step 1: If condition 1 is true, set the other labels to 0
        condition1_expanded = tf.expand_dims(condition1, axis=1)  # shape: (batch_size, 1)
        y_pred_modified = tf.where(
            condition1_expanded,
            tf.concat([y_pred_binary[:, 0:1], tf.zeros_like(y_pred_binary[:, 1:])], axis=1),
            y_pred_binary
        )

        # Step 2: Set label 0 to 0 if condition 1 is false and Condition 2 is true
        condition2_only = tf.logical_and(tf.logical_not(condition1), condition2)  # shape: (batch_size,)
        condition2_only_expanded = tf.expand_dims(condition2_only, axis=1)  # shape: (batch_size, 1)
        y_pred_modified = tf.where(
            condition2_only_expanded,
            tf.concat([tf.zeros_like(y_pred_binary[:, 0:1]), y_pred_modified[:, 1:]], axis=1),
            y_pred_modified
        )

        # The correct prediction is computed for each label
        correct_predictions = tf.cast(tf.equal(y_pred_modified, y_true), tf.float32)

        # Calculate the total number of correct predictions for this batch
        correct_sum = tf.reduce_sum(correct_predictions)

        # Calculate the total number of tags in this batch
        labels_sum = tf.cast(tf.size(y_true), tf.float32)

        # Update the cumulative number of correct predictions and the total number of labels
        self.total_correct.assign_add(correct_sum)
        self.total_labels.assign_add(labels_sum)

    def result(self):
        """
        Returns the accuracy, making sure it's between 0 and 1.
        """
        return self.total_correct / self.total_labels

    def reset_state(self):
        """
        Reset the accumulated number of correct predictions and the total number of labels.
        """
        self.total_correct.assign(0.0)
        self.total_labels.assign(0.0)


In [6]:

merged_data['labels'] = merged_data['labels'].fillna('')

def process_labels(label):
    if isinstance(label, list):
        return label
    if isinstance(label, str) and label != '':
        try:
            return list(map(int, label.split(';')))
        except ValueError:
            return []
    return []

labels = merged_data['labels'].apply(process_labels)
mlb = MultiLabelBinarizer(classes=list(range(14)))  
labels = mlb.fit_transform(labels)


assert len(padded_sequences) == len(images) == len(labels), "Data lengths are inconsistent!"
padded_sequences = np.array(padded_sequences)
images = np.array(images)
labels = np.array(labels)

# Shuffle the data
padded_sequences, images, labels = shuffle(padded_sequences, images, labels, random_state=42)

# Split the data into multiple clients
num_clients = 3  
client_data_size = len(padded_sequences) // num_clients

client_datasets = []
for i in range(num_clients):
    start_index = i * client_data_size
    if i == num_clients - 1:
        end_index = len(padded_sequences)
    else:
        end_index = (i + 1) * client_data_size
    
    client_padded_sequences = padded_sequences[start_index:end_index]
    client_images = images[start_index:end_index]
    client_labels = labels[start_index:end_index]
    
    client_datasets.append((client_padded_sequences, client_images, client_labels))

# Defining the model architecture
def create_model():

    text_input = tf.keras.Input(shape=(100,), name='text_input')
    embedding_layer = tf.keras.layers.Embedding(input_dim=10000, output_dim=128)(text_input)
    lstm_layer = tf.keras.layers.LSTM(128)(embedding_layer)
    text_dense = tf.keras.layers.Dense(128, activation='relu')(lstm_layer)
    text_dropout = tf.keras.layers.Dropout(0.5)(text_dense)
    text_output = tf.keras.layers.Dense(64, activation='relu')(text_dropout)

    image_input = tf.keras.Input(shape=(224, 224, 3), name='image_input')
    base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_tensor=image_input)
    base_model.trainable = False  
    image_pooling = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    image_dense = tf.keras.layers.Dense(128, activation='relu')(image_pooling)
    image_dropout = tf.keras.layers.Dropout(0.5)(image_dense)
    image_output = tf.keras.layers.Dense(64, activation='relu')(image_dropout)
    
    combined = tf.keras.layers.Concatenate()([text_output, image_output])
    combined_dense = tf.keras.layers.Dense(128, activation='relu')(combined)
    combined_dropout = tf.keras.layers.Dropout(0.5)(combined_dense)
    final_output = tf.keras.layers.Dense(14, activation='sigmoid')(combined_dropout)
    
    model = tf.keras.Model(inputs=[text_input, image_input], outputs=final_output)
    
    return model

# FedProx
def fedprox_loss(global_weights, mu):
    def loss_fn(y_true, y_pred):
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        prox_term = 0.0
        for lw, gw in zip(model.trainable_weights, global_weights):
            prox_term += tf.nn.l2_loss(lw - gw)
        return bce + (mu / 2) * prox_term
    return loss_fn

# Initialize the global model
global_model = create_model()

num_rounds = 20  #
mu = 0.01      



In [7]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
global_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=[CustomAccuracy(), Precision(), Recall()])

global_metrics = []

for round_num in range(1, num_rounds + 1):
    print(f"\n===== Starting round {round_num} of federated learning =====")
    
    print("\n--- Training with FedProx aggregation ---")
    client_weights = []
    client_metrics = []
    
    # Store validation data for global evaluation
    val_text_list = []
    val_images_list = []
    val_labels_list = []

    for client_num, (client_padded_sequences, client_images, client_labels) in enumerate(client_datasets, 1):
        print(f"\n--- Training on client {client_num} ---")

        local_model = create_model()
        local_model.set_weights(global_model.get_weights())

        # Obtain global weights tensors for Proximal term
        global_weights = global_model.get_weights()
        global_weights_tensors = [tf.constant(w) for w in global_weights]

        # Separate weights for text and image layers
        text_weights_local = local_model.trainable_weights[:len(global_model.layers[1].trainable_weights)]
        image_weights_local = local_model.trainable_weights[len(global_model.layers[1].trainable_weights):]
        text_weights_global = global_weights_tensors[:len(global_model.layers[1].trainable_weights)]
        image_weights_global = global_weights_tensors[len(global_model.layers[1].trainable_weights):]

        # Proximal term for text layers
        for lw, gw in zip(text_weights_local, text_weights_global):
            if lw.shape != gw.shape:
                raise ValueError(f"Shape mismatch in text layers: local weight shape {lw.shape} and global weight shape {gw.shape}")
            prox = tf.nn.l2_loss(lw - gw)
            local_model.add_loss(lambda: (mu / 2) * prox)

        # Proximal term for image layers
        for lw, gw in zip(image_weights_local, image_weights_global):
            if lw.shape != gw.shape:
                continue  # Skip layers with mismatched shapes
            prox = tf.nn.l2_loss(lw - gw)
            local_model.add_loss(lambda: (mu / 2) * prox)

        # Compile the local model with standard loss (Proximal term added)
        local_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=[CustomAccuracy()])

        # Split client data into training and validation sets
        client_X_train_text, client_X_val_text, client_X_train_images, client_X_val_images, client_y_train, client_y_val = train_test_split(
            client_padded_sequences, client_images, client_labels, test_size=0.2, random_state=client_num + 100
        )

        # Ensure data types are correct
        client_X_train_text = client_X_train_text.astype(np.int32)
        client_X_val_text = client_X_val_text.astype(np.int32)
        client_X_train_images = client_X_train_images.astype(np.float32)
        client_X_val_images = client_X_val_images.astype(np.float32)
        client_y_train = client_y_train.astype(np.float32)
        client_y_val = client_y_val.astype(np.float32)

        # Callback function
        client_model_filepath = f"client_{client_num}_round_{round_num}_fedprox_model.h5"
        model_checkpoint = ModelCheckpoint(client_model_filepath, save_best_only=True, monitor='val_loss')

        # Train the local model
        history = local_model.fit(
            {'text_input': client_X_train_text, 'image_input': client_X_train_images},
            client_y_train,
            epochs=10,  
            batch_size=8,
            validation_data=({'text_input': client_X_val_text, 'image_input': client_X_val_images}, client_y_val),
            callbacks=[model_checkpoint],
            verbose=1
        )

        # Save local model weights
        client_weights.append(local_model.get_weights())

        # Store validation data
        val_text_list.append(client_X_val_text)
        val_images_list.append(client_X_val_images)
        val_labels_list.append(client_y_val)

        client_evaluation = local_model.evaluate(
            {'text_input': client_X_val_text, 'image_input': client_X_val_images},
            client_y_val,
            verbose=0
        )
        client_metrics.append(client_evaluation)
        # Get the prediction results of the validation set
        val_predictions = local_model.predict({'text_input': client_X_val_text, 'image_input': client_X_val_images})
        val_predictions_binary = np.where(val_predictions > 0.5, 1, 0)

        # Calculate and print F1, Precision, Recall
        f1 = f1_score(client_y_val, val_predictions_binary, average='weighted', zero_division=1)
        recall = recall_score(client_y_val, val_predictions_binary, average='weighted', zero_division=1)
        precision = precision_score(client_y_val, val_predictions_binary, average='weighted', zero_division=1)

        print(f"Client {client_num} - Last epoch validation: Loss: {client_evaluation[0]}, Accuracy: {client_evaluation[1]}, F1: {f1}, Precision: {precision}, Recall: {recall}")
        print(f"--- Training on client {client_num} completed ---")

    # Aggregate client weights
    print("\n*** Aggregating client model weights (FedProx) ***")
    new_weights = []
    for weights_list in zip(*client_weights):
        new_weights.append(np.mean(weights_list, axis=0))

    # Update global model weights
    global_model.set_weights(new_weights)
    print("*** Global model weights updated (FedProx) ***")

    # Combine validation data
    combined_val_text = np.concatenate(val_text_list)
    combined_val_images = np.concatenate(val_images_list)
    combined_val_labels = np.concatenate(val_labels_list)

    combined_val_text = combined_val_text.astype(np.int32)
    combined_val_images = combined_val_images.astype(np.float32)
    combined_val_labels = combined_val_labels.astype(np.float32)

    # Evaluate global model on combined validation data
    global_evaluation = global_model.evaluate(
        {'text_input': combined_val_text, 'image_input': combined_val_images},
        combined_val_labels,
        verbose=0
    )
    global_metrics.append(global_evaluation)
    print(f"\n===== Post-round {round_num} global model evaluation (FedProx): {global_model.metrics_names} = {global_evaluation} =====")




===== Starting round 1 of federated learning =====

--- Training with FedProx aggregation ---

--- Training on client 1 ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Client 1 - Last epoch validation: Loss: 1.563597321510315, Accuracy: 0.8897016644477844, F1: 0.09886045208294415, Precision: 0.9472664657770126, Recall: 0.06683168316831682
--- Training on client 1 completed ---

--- Training on client 2 ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Client 2 - Last epoch validation: Loss: 1.5269076824188232, Accuracy: 0.8942914605140686, F1: 0.2872064143279289, Precision: 0.8238320449957663, Recall: 0.2751196172248804
--- Training on client 2 completed ---

--- Training on client 3 ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Client 3 - Last epoch validation: Loss: 1.4764574766159058,