In [None]:
import os
import shutil
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import flwr as fl
from flwr.common import Context
from collections import OrderedDict

# List of class names
classes = ['glioma', 'meningioma', 'pituitary', 'notumor']

# Training and testing directories
base_train_dir = "/content/Training"
base_test_dir = "/content/Testing"

# Organize the directory structure for training and testing data
def correct_directory_structure(base_dir):
    for class_name in classes:
        class_dir = os.path.join(base_dir, class_name)
        subfolder_dir = os.path.join(class_dir, class_name)
        if os.path.exists(subfolder_dir):
            for img in os.listdir(subfolder_dir):
                shutil.move(os.path.join(subfolder_dir, img), class_dir)
            shutil.rmtree(subfolder_dir)

# Fix the training and testing directories
correct_directory_structure(base_train_dir)
correct_directory_structure(base_test_dir)

# Data paths for the organized directories
train_dirs = [
    # '/content/Training/glioma',
    # '/content/Training/meningioma',
    # '/content/Training/pituitary',
    # '/content/Training/notumor'
    "/content/Training"
]

test_dirs = [
    # '/content/Testing/glioma',
    # '/content/Testing/meningioma',
    # '/content/Testing/pituitary',
    # '/content/Testing/notumor'
    "/content/Testing"
]

# Data generators for training and testing
def get_data_generators(train_dirs, test_dirs, target_size=(224, 224), batch_size=32):
    # ImageDataGenerator with rescaling
    train_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    test_datagen = ImageDataGenerator(rescale=1.0 / 255.0)

    # Load the training data from multiple directories
    train_generators = []
    for train_dir in train_dirs:
        print(f"Loading training data from: {train_dir}")
        train_gen = train_datagen.flow_from_directory(
            train_dir,
            target_size=target_size,
            batch_size=batch_size,
            class_mode='categorical',
            shuffle=True
        )
        train_generators.append(train_gen)

    # Load the test data from multiple directories
    test_generators = []
    for test_dir in test_dirs:
        print(f"Loading testing data from: {test_dir}")
        test_gen = test_datagen.flow_from_directory(
            test_dir,
            target_size=target_size,
            batch_size=batch_size,
            class_mode='categorical',
            shuffle=False
        )
        test_generators.append(test_gen)

    return train_generators, test_generators

# Example usage: Get data generators
train_generators, test_generators = get_data_generators(train_dirs, test_dirs)

In [None]:
for train_dir in train_dirs:
    print(f"Checking contents of {train_dir}:")
    print(os.listdir(train_dir))


train_gen = train_datagen.flow_from_directory(
    '/content/Training',  # Parent directory containing the class folders
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)
for train_dir in train_dirs:
    print(os.listdir(train_dir))

In [None]:


# VGG16 model setup
def build_model():
    vgg_base = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    vgg_base.trainable = False

    x = Flatten()(vgg_base.output)
    x = Dense(512, activation='relu')(x)
    x = Dense(4, activation='softmax')(x)

    model = Model(inputs=vgg_base.input, outputs=x)
    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Flower Client for federated learning
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_generator, test_generator):
        self.model = model
        self.train_generator = train_generator
        self.test_generator = test_generator

    def get_weights(self):
        return self.model.get_weights()

    def set_weights(self, weights):
        if weights is not None and len(weights) == len(self.model.get_weights()):
            self.model.set_weights(weights)
        else:
            print("Warning: Mismatch in the number of model weights.")

    def fit(self, parameters, config):
        self.set_weights(parameters)
        self.model.fit(self.train_generator, epochs=1, verbose=2)
        return self.get_weights(), len(self.train_generator), {}

    def evaluate(self, parameters, config):
        self.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.test_generator, verbose=2)
        return loss, len(self.test_generator), {"accuracy": accuracy}

# Client function for Flower simulation
def client_fn(context: Context):
    partition_id = int(context.node_config.get("partition_id", 0))
    model = build_model()

    # Assign directories based on the partition ID
    train_generators, test_generators = get_data_generators([train_dirs[partition_id]], [test_dirs[partition_id]])
    return FlowerClient(model, train_generators[0], test_generators[0])

# Global evaluation function for federated learning
def evaluate(server_round, parameters, config):
    model = build_model()
    if parameters:
        model.set_weights(parameters)

    _, test_generators = get_data_generators(train_dirs, test_dirs)
    loss, accuracy = 0, 0

    for test_generator in test_generators:
        current_loss, current_accuracy = model.evaluate(test_generator, verbose=2)
        loss += current_loss
        accuracy += current_accuracy

    loss /= len(test_generators)
    accuracy /= len(test_generators)

    print(f"Round {server_round} accuracy: {accuracy}")
    return loss, {"accuracy": accuracy}

# Federated Learning Strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=0.5,
    initial_parameters=None,
    evaluate_fn=evaluate,
    on_fit_config_fn=lambda rnd: {"epoch_global": rnd},
)

# Start the federated learning simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=4,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
    client_resources={"num_cpus": 2},
)
