In [12]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import flwr as fl
import torch.nn as nn
import torch.optim as optim

# Load dataset
data = pd.read_csv('/kaggle/input/deep-slice/deepslice_data.csv')

# Encoding categorical variables
encoders = {}
for col in ['Use Case', 'Technology Supported', 'Day', 'GBR', 'slice Type']:
    encoders[col] = LabelEncoder()
    data[col] = encoders[col].fit_transform(data[col])

# Separate features and target
X = data.drop(columns=['slice Type'])
y = data['slice Type']

# Normalize features
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)

# Add Gaussian noise
X_noisy = X_normalized + 0.27 * np.random.normal(loc=0.0, scale=1.0, size=X_normalized.shape)

# Reshape for RNN
X_reshaped = X_noisy.reshape(X_noisy.shape[0], 1, X_noisy.shape[1])

# Convert to PyTorch tensors
X_tensor = torch.tensor(X_reshaped, dtype=torch.float32)
y_tensor = torch.tensor(y.values, dtype=torch.long)

# Split dataset equally for 3 clients
num_samples = len(y) // 3
indices = np.arange(len(y))
np.random.shuffle(indices)
client_datasets = [indices[i * num_samples:(i + 1) * num_samples] for i in range(3)]

# Define Custom PyTorch Dataset
class CustomDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create client datasets
client_data = [CustomDataset(X_tensor[idx], y_tensor[idx]) for idx in client_datasets]

# Define RNN Model
class RNNModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_dim, 64, batch_first=True, nonlinearity='relu')
        self.lstm = nn.LSTM(64, 64, batch_first=True)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x, _ = self.rnn(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        return x

# Flower Client Class
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.model.parameters())

    def get_parameters(self, config=None):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        parameters_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in parameters_dict}
        self.model.load_state_dict(state_dict)

    def train(self):
        self.model.train()
        for epoch in range(10):  # Increased epochs to 10
            epoch_loss = 0.0
            for X_batch, y_batch in self.train_loader:
                self.optimizer.zero_grad()
                outputs = self.model(X_batch)
                loss = self.criterion(outputs, y_batch)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            epoch_loss /= len(self.train_loader)
            print(f"Epoch {epoch + 1}: Train Loss = {epoch_loss:.4f}")

    def test(self):
        self.model.eval()
        all_targets = []
        all_predictions = []
        total_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in self.test_loader:
                outputs = self.model(X_batch)
                loss = self.criterion(outputs, y_batch)
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                all_targets.extend(y_batch.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())

        accuracy = accuracy_score(all_targets, all_predictions)
        precision = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_predictions, average='macro', zero_division=0)
        avg_loss = total_loss / len(self.test_loader)

        return avg_loss, accuracy, precision, recall, f1

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy, precision, recall, f1 = self.test()
        print(f"Client {self.test_loader.dataset.X.shape[0]} Evaluation - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
        return float(loss), len(self.test_loader.dataset), {
            "accuracy": float(accuracy),
            "precision": float(precision),
            "recall": float(recall),
            "f1": float(f1),
        }

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.train()
        return self.get_parameters(), len(self.train_loader.dataset), {}

# Function to create clients
def get_client_fn(cid: str) -> fl.client.Client:
    train_loader = DataLoader(client_data[int(cid)], batch_size=32, shuffle=True)
    test_loader = DataLoader(client_data[int(cid)], batch_size=32, shuffle=False)
    model = RNNModel(input_dim=X_tensor.shape[2], num_classes=len(y.unique()))
    numpy_client = FlowerClient(model, train_loader, test_loader)
    return numpy_client.to_client()

# Custom evaluation function to print aggregated model metrics
def evaluate_aggregated_model(server_round, parameters, config):
    model = RNNModel(input_dim=X_tensor.shape[2], num_classes=len(y.unique()))
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = {k: torch.tensor(v) for k, v in params_dict}
    model.load_state_dict(state_dict)
    
    test_loader = DataLoader(CustomDataset(X_tensor, y_tensor), batch_size=32, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    
    model.eval()
    all_targets = []
    all_predictions = []
    total_loss = 0.0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            all_targets.extend(y_batch.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
    recall = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    avg_loss = total_loss / len(test_loader)

    print(f"Aggregated Model Evaluation - Round {server_round}: Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    return avg_loss, {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

# Federated Learning Strategy with custom evaluate function
strategy = fl.server.strategy.FedAvg(
    evaluate_fn=evaluate_aggregated_model  # Add the custom evaluation function
)

# Start Flower Simulation
fl.simulation.start_simulation(
    client_fn=get_client_fn,
    num_clients=3,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy
)


	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2025-02-07 10:21:48,261	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'node:172.19.2.2': 1.0, 'node:__internal_head__': 1.0, 'memory': 18043713947.0, 'accelerator_type:T4': 1.0, 'CPU': 4.0, 'object_store_memory': 9021856972.0, 'GPU': 2.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for

Aggregated Model Evaluation - Round 0: Loss: 1.1446, Accuracy: 0.2288, Precision: 0.0829, Recall: 0.3259, F1: 0.1304


[36m(ClientAppActor pid=1866)[0m 
[36m(ClientAppActor pid=1866)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=1866)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=1866)[0m         
[36m(ClientAppActor pid=1864)[0m         [32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=1864)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=1864)[0m             entirely in future versions of Flower.


[36m(ClientAppActor pid=1866)[0m Epoch 1: Train Loss = 0.2198


[36m(ClientAppActor pid=1865)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=1865)[0m             entirely in future versions of Flower.


[36m(ClientAppActor pid=1864)[0m Epoch 2: Train Loss = 0.1009[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 4: Train Loss = 0.0838[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 7: Train Loss = 0.0758[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 9: Train Loss = 0.0746[32m [repeated 6x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (1, 0.07006349834901104, {'accuracy': 0.9750977567400699, 'precision': 0.9702099596306857, 'recall': 0.968864436963376, 'f1': 0.9693834500315214}, 32.79408199100021)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Aggregated Model Evaluation - Round 1: Loss: 0.0701, Accuracy: 0.9751, Precision: 0.9702, Recall: 0.9689, F1: 0.9694


[36m(ClientAppActor pid=1866)[0m         [32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=1866)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=1866)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=1865)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=1865)[0m             entirely in future versions of Flower.
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=1864)[0m Client 21055 Evaluation - Loss: 0.0709, Accuracy: 0.9741, Precision: 0.9691, Recall: 0.9672, F1: 0.9680
[36m(ClientAppActor pid=1865)[0m Epoch 10: Train Loss = 0.0660[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Client 21055 Evaluation - Loss: 0.0729, Accuracy: 0.9748, Precision: 0.9697, Recall: 0.9687, F1: 0.9691[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 3: Train Loss = 0.0656[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1865)[0m Epoch 4: Train Loss = 0.0712[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1865)[0m Epoch 7: Train Loss = 0.0696[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 8: Train Loss = 0.0667[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (2, 0.06411021255557037, {'accuracy': 0.9766650307913942, 'precision': 0.9710032293758516, 'recall': 0.9717278781885431, 'f1': 0.9713599021259088}, 64.67867904200011)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Aggregated Model Evaluation - Round 2: Loss: 0.0641, Accuracy: 0.9767, Precision: 0.9710, Recall: 0.9717, F1: 0.9714


[36m(ClientAppActor pid=1864)[0m         [32m [repeated 12x across cluster][0m
[36m(ClientAppActor pid=1864)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1864)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=1864)[0m Client 21055 Evaluation - Loss: 0.0597, Accuracy: 0.9786, Precision: 0.9734, Recall: 0.9746, F1: 0.9740
[36m(ClientAppActor pid=1864)[0m Epoch 10: Train Loss = 0.0665[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Client 21055 Evaluation - Loss: 0.0673, Accuracy: 0.9754, Precision: 0.9691, Recall: 0.9699, F1: 0.9695[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 3: Train Loss = 0.0690[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 6: Train Loss = 0.0676[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 9: Train Loss = 0.0654[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1865)[0m Epoch 8: Train Loss = 0.0611[32m [repeated 6x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (3, 0.06221808410385321, {'accuracy': 0.9774565833425681, 'precision': 0.9716140816221941, 'recall': 0.9732340020867024, 'f1': 0.9724173080414739}, 96.48799875899999)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Aggregated Model Evaluation - Round 3: Loss: 0.0622, Accuracy: 0.9775, Precision: 0.9716, Recall: 0.9732, F1: 0.9724


[36m(ClientAppActor pid=1866)[0m         [32m [repeated 12x across cluster][0m
[36m(ClientAppActor pid=1866)[0m             This is a deprecated feature. It will be removed[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=1866)[0m             entirely in future versions of Flower.[32m [repeated 6x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=1866)[0m Client 21055 Evaluation - Loss: 0.0588, Accuracy: 0.9789, Precision: 0.9735, Recall: 0.9755, F1: 0.9745
[36m(ClientAppActor pid=1864)[0m Epoch 10: Train Loss = 0.0641[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1865)[0m Client 21055 Evaluation - Loss: 0.0651, Accuracy: 0.9765, Precision: 0.9704, Recall: 0.9719, F1: 0.9711[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 3: Train Loss = 0.0646[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 6: Train Loss = 0.0643[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1865)[0m Epoch 6: Train Loss = 0.0658[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 8: Train Loss = 0.0589[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (4, 0.06240133009181528, {'accuracy': 0.9771874554751689, 'precision': 0.9737520527936688, 'recall': 0.9705149008667157, 'f1': 0.9721113667862964}, 128.97083256500014)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Aggregated Model Evaluation - Round 4: Loss: 0.0624, Accuracy: 0.9772, Precision: 0.9738, Recall: 0.9705, F1: 0.9721


[36m(ClientAppActor pid=1865)[0m         [32m [repeated 12x across cluster][0m
[36m(ClientAppActor pid=1865)[0m             This is a deprecated feature. It will be removed[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=1865)[0m             entirely in future versions of Flower.[32m [repeated 6x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=1865)[0m Client 21055 Evaluation - Loss: 0.0592, Accuracy: 0.9782, Precision: 0.9753, Recall: 0.9723, F1: 0.9738
[36m(ClientAppActor pid=1864)[0m Epoch 10: Train Loss = 0.0590[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Client 21055 Evaluation - Loss: 0.0630, Accuracy: 0.9766, Precision: 0.9731, Recall: 0.9694, F1: 0.9712[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 2: Train Loss = 0.0602[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 5: Train Loss = 0.0585[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1864)[0m Epoch 8: Train Loss = 0.0576[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=1866)[0m Epoch 8: Train Loss = 0.0615[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (5, 0.06114064356780603, {'accuracy': 0.9778206975161081, 'precision': 0.9738072829054154, 'recall': 0.9719091894956504, 'f1': 0.9728450752164716}, 161.11266567000007)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Aggregated Model Evaluation - Round 5: Loss: 0.0611, Accuracy: 0.9778, Precision: 0.9738, Recall: 0.9719, F1: 0.9728


[36m(ClientAppActor pid=1866)[0m         [32m [repeated 12x across cluster][0m
[36m(ClientAppActor pid=1866)[0m             This is a deprecated feature. It will be removed[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=1866)[0m             entirely in future versions of Flower.[32m [repeated 6x across cluster][0m


[36m(ClientAppActor pid=1864)[0m Client 21055 Evaluation - Loss: 0.0617, Accuracy: 0.9773, Precision: 0.9735, Recall: 0.9709, F1: 0.9722
[36m(ClientAppActor pid=1866)[0m Epoch 10: Train Loss = 0.0618[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 162.18s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.07006687837100392
[92mINFO [0m:      		round 2: 0.06411206122838563
[92mINFO [0m:      		round 3: 0.06221993791814982
[92mINFO [0m:      		round 4: 0.06240453636389331
[92mINFO [0m:      		round 5: 0.06114359529982764
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 1.1445508008669938
[92mINFO [0m:      		round 1: 0.07006349834901104
[92mINFO [0m:      		round 2: 0.06411021255557037
[92mINFO [0m:      		round 3: 0.06221808410385321
[92mINFO [0m:      		round 4: 0.06240133009181528
[92mINFO [0m:      		round 5: 0.06114064356780603
[92mINFO [0m:      	History (metrics, centralized):
[92mINFO [0m:      	{'accuracy': [(0, 0.2288378425443665),
[92mINFO [0m:      	      

History (loss, distributed):
	round 1: 0.07006687837100392
	round 2: 0.06411206122838563
	round 3: 0.06221993791814982
	round 4: 0.06240453636389331
	round 5: 0.06114359529982764
History (loss, centralized):
	round 0: 1.1445508008669938
	round 1: 0.07006349834901104
	round 2: 0.06411021255557037
	round 3: 0.06221808410385321
	round 4: 0.06240133009181528
	round 5: 0.06114064356780603
History (metrics, centralized):
{'accuracy': [(0, 0.2288378425443665),
              (1, 0.9750977567400699),
              (2, 0.9766650307913942),
              (3, 0.9774565833425681),
              (4, 0.9771874554751689),
              (5, 0.9778206975161081)],
 'f1': [(0, 0.1304082033134367),
        (1, 0.9693834500315214),
        (2, 0.9713599021259088),
        (3, 0.9724173080414739),
        (4, 0.9721113667862964),
        (5, 0.9728450752164716)],
 'precision': [(0, 0.08289723412713954),
               (1, 0.9702099596306857),
               (2, 0.9710032293758516),
               (3, 0.97161

[36m(ClientAppActor pid=1865)[0m Client 21055 Evaluation - Loss: 0.0635, Accuracy: 0.9769, Precision: 0.9724, Recall: 0.9709, F1: 0.9716
