In [None]:
# ! rm -rf /opt/conda/lib/python3.10/site-packages/aiohttp-3.9.1.dist-info
%pip install -q 'flwr[simulation]' 'flwr_datasets[vision]' torch torchvision matplotlib

In [None]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics

from typing import Optional, Dict
import os


In [None]:
config = {
    "dataset": "cifar10",
    "num_clients": 5,
    "batch_size": 32,
    "poison_clients": [],

    # DO NOT SET ATTACK HERE !!!
    # "poison_clients": [],
    # "poison_type": "one-label",
    # "poison_label": 0,
    # "poison_type": "label-flipping",
    # "poison_type": "label-random",
    # "poison_type": "DBA",
    # "poison_type": "CBA",
    "poison_type": "No-Attacks-For-Non-IID-Performance-Testing",
    
    "gamma": 5,
    "epochs": 20,
    "lr": 0.05,
    "momentum": 0.9,
    "weight_decay": 0.0001,
    "num_classes": 10,
    "input_size": (3, 32, 32),
    "seed": 42, 
}
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)
disable_progress_bar()

In [None]:
backdoor_idx = 0
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import ShardPartitioner
partitioner = ShardPartitioner(num_partitions=config["num_clients"], partition_by="label", num_shards_per_partition=2)

def load_datasets():
    fds = FederatedDataset(dataset=config["dataset"], partitioners={"train": partitioner})
    print(len(fds.load_split("train")))
    def apply_train_transforms(batch):
        # print("in apply_train_transforms")
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        transform = transforms.Compose(
            [
                # transforms.RandomHorizontalFlip(),
                # transforms.RandomAffine(15, translate=(0.15, 0.15), scale=(0.9, 1.1), shear=10),
                # transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        batch["img"] = [transform(img) for img in batch["img"]]
        return batch

    def apply_val_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        # print("in apply_val_transforms")
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        batch["img"] = [transform(img) for img in batch["img"]]
        return batch

    def poisonTrainData(data):
        if config["poison_type"] == "one-label":
            data["label"] = config["poison_label"]
        elif config["poison_type"] == "label-flipping":
            data["label"] = (data["label"] + 1) % config["num_classes"]
        elif config["poison_type"] == "label-random":
            data["label"] =  np.floor(np.random.rand() * config["num_classes"])
        elif config["poison_type"] == "CBA":
            data["img"] = np.array(data["img"])
            data["img"][0:2, 0:2, :] = 0
            data["img"][0:2, 3:5, :] = 0
            data["img"][0:2, 3:5, 0] = 255
            data["img"][3:5, 0:2, :] = 0
            data["img"][3:5, 0:2, 1] = 255
            data["img"][3:5, 3:5, :] = 0
            data["img"][3:5, 3:5, 2] = 255
            data["img"] = torch.tensor(data["img"])
            data["label"] = 0
        elif config["poison_type"] == "DBA":
            # add backdoor pattern (3*3) to the image
            data["img"] = np.array(data["img"])
            # data["img"] = add_pattern(data["img"])
            global backdoor_idx
            if np.random.rand() < 0.4:
                if (backdoor_idx == 0 ):
                    data["img"][0:2, 0:2, :] = 0
                if (backdoor_idx == 1):
                    data["img"][0:2, 3:5, :] = 0
                    data["img"][0:2, 3:5, 0] = 255
                if (backdoor_idx == 2):
                    data["img"][3:5, 0:2, :] = 0
                    data["img"][3:5, 0:2, 1] = 255
                if (backdoor_idx == 3):
                    data["img"][3:5, 3:5, :] = 0
                    data["img"][3:5, 3:5, 2] = 255
                data["img"] = torch.tensor(data["img"])
                # random flip 20% of the images label to 0
            
                data["label"] = 0
        return data

    def poisonTestData(data):
        # add backdoor pattern (5*5) to the image
        data["img"] = np.array(data["img"])
        data["img"][0:2, 0:2, :] = 0
        data["img"][0:2, 3:5, :] = 0
        data["img"][0:2, 3:5, 0] = 255
        data["img"][3:5, 0:2, :] = 0
        data["img"][3:5, 0:2, 1] = 255
        data["img"][3:5, 3:5, :] = 0
        data["img"][3:5, 3:5, 2] = 255
        data["img"] = torch.tensor(data["img"])
        # print(data)
        return data

    # Create train/val for each partition and wrap it into DataLoader
    trainloaders = []
    valloaders = []
    for partition_id in range(config["num_clients"]):
        print(partition_id)
        partition = fds.load_partition(partition_id, "train")
        partition = partition.train_test_split(train_size=0.8, seed=config["seed"])
        if(partition_id in config["poison_clients"]):
            global backdoor_idx
            partition["train"] = partition["train"].map(poisonTrainData)
            backdoor_idx += 1
        partition["train"] = partition["train"].with_transform(apply_train_transforms)
        partition["test"] = partition["test"].with_transform(apply_val_transforms)
        trainloaders.append(DataLoader(partition["train"], batch_size=config["batch_size"]))
        valloaders.append(DataLoader(partition["test"], batch_size=config["batch_size"]))
    testData = fds.load_split("test")
    if (config["poison_type"] == "DBA" or config["poison_type"] == "CBA") and len(config["poison_clients"])!=0:
        testData = testData.map(poisonTestData)
    print("\n------------------\n")
    testset = testData.with_transform(apply_val_transforms)
    testloader = DataLoader(testset, batch_size=config["batch_size"])

    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets()

In [None]:
for i in range(5):
    if i==0:
        batch = next(iter(trainloaders[0]))
    elif i==1:
        batch = next(iter(trainloaders[1]))
    elif i==2:
        batch = next(iter(trainloaders[2]))
    elif i==3:
        batch = next(iter(trainloaders[3]))
    elif i==4:
        batch = next(iter(testloader))
    images, labels = batch["img"], batch["label"]
    # Reshape and convert images to a NumPy array
    # matplotlib requires images with the shape (height, width, 3)
    images = images.permute(0, 2, 3, 1).numpy()
    # Denormalize
    images = images / 2 + 0.5

    # Create a figure and a grid of subplots
    fig, axs = plt.subplots(4, 8, figsize=(12, 6))

    # Loop over the images and plot them
    for i, ax in enumerate(axs.flat):
        ax.imshow(images[i])
        ax.set_title(trainloaders[0].dataset.features["label"].int2str([labels[i]])[0])
        ax.axis("off")

    # Show the plot
    fig.tight_layout()
    plt.show()

In [None]:
# class Net(nn.Module):
#     def __init__(self) -> None:
#         super(Net, self).__init__()
#         self.conv1 = nn.Conv2d(3, 8, 3)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(8, 20, 5)
#         self.fc1 = nn.Linear(20 * 11 * 11, 256)
#         self.fc2 = nn.Linear(256, 128)
#         self.fc3 = nn.Linear(128, 32)
#         self.fc4 = nn.Linear(32, 10)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = self.pool(F.relu(self.conv1(x)))
#         x = F.relu(self.conv2(x))
#         x = x.view(-1, 20 * 11 * 11)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc3(x))
#         x = self.fc4(x)
#         return x

class Net(nn.Module):
    def __init__(self, p=0.1):
        super(Net, self).__init__()

        ############################################
        # NOTE:                                    #
        # Pretrain weights on ResNet18 is allowed. #
        ############################################

        # (batch_size, 3, 32, 32)
        self.p = p
        self.resnet = models.resnet18(weights=None)
        # (batch_size, 512)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=2, stride=1, padding=1, bias=False)
        self.resnet.maxpool = Identity()
        self.resnet.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.in_features, 128),
            nn.BatchNorm1d(128),
            nn.Dropout(self.p),
            nn.ReLU(True),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.Dropout(self.p),
            nn.ReLU(True),

            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.Dropout(self.p),
            nn.ReLU(True),

            nn.Linear(32, 10),
        )

    def forward(self, x):
        return self.resnet(x)

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [None]:
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    print("Training...")
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"].to(DEVICE), batch["label"].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [None]:
# trainloader = trainloaders[0]
# valloader = valloaders[0]
# net = Net().to(DEVICE)

# for epoch in range(config["epochs"]):
#     train(net, trainloader, 1)
#     loss, accuracy = test(net, valloader)
#     print(f"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}")

# loss, accuracy = test(net, testloader)
# print(f"Final test set performance:\n\tloss {loss}\n\taccuracy {accuracy}")

In [None]:
def set_parameters(net, parameters):
    # print("Setting parameters...")
    state_dict = net.state_dict()
    params_dict = zip(state_dict.keys(), parameters)

    # Check and print shapes for debugging
    for k, v in params_dict:
        expected_shape = state_dict[k].shape
        actual_shape = torch.tensor(v).shape
        # print(f"Layer: {k} | Expected shape: {expected_shape} | Actual shape: {actual_shape}")

        if expected_shape != actual_shape:
            raise ValueError(f"Shape mismatch for layer {k}: expected {expected_shape}, got {actual_shape}")

    # Re-create the zip iterator because it was exhausted during printing
    params_dict = zip(state_dict.keys(), parameters)

    # Construct the new state_dict
    new_state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(new_state_dict, strict=False)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, cid, trainloader, valloader):
        self.net = net
        self.cid = cid
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, local_config):
        # Read values from config
        server_round = local_config["server_round"]
        local_epochs = 3 if self.cid in config["poison_clients"] else local_config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.cid}, round {server_round}] fit, server_round: {server_round}, local_epochs: {local_epochs}")
        set_parameters(self.net, parameters)
        copy_params = parameters.copy()
        train(self.net, self.trainloader, epochs=local_epochs)
        if self.cid in config["poison_clients"]:
            #  set scale factor to 3.3 to scale up the malicious model weights.
            #Lt+1 = γ(X − Gt) + Gt
            new_params = []
            print("updating......")
            new_params = get_parameters(self.net)
            
            for new_param, global_param in zip(new_params, copy_params):
                new_params.append(torch.tensor(config["gamma"] * (new_param - global_param) + global_param))
            set_parameters(self.net, new_params)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        # print("evaluating...")
        set_parameters(self.net, parameters)
        # print("load")
        loss, accuracy = test(self.net, self.valloader)
        # print("accuracy")
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [None]:
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)
    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, int(cid), trainloader, valloader).to_client()

In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    print(metrics)
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [None]:
def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1
    }
    return config

In [None]:
# The `evaluate` function will be by Flower called after every round
results = []
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    conf: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net().to(DEVICE)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    results.append(accuracy)
    if(server_round == config["epochs"]):
        torch.save(net.state_dict(), "final_model.pth")
    print(f"Server-side evaluation round {server_round}")
    print(f"Test set performance:\n\tloss {loss}\n\taccuracy {accuracy}")

In [None]:
from defense_method.krum import KrumServer
from defense_method.trim import TrimServer
from defense_method.bulyan import BulyanServer
from defense_method.qffl import QfflServer

In [None]:
# Create an instance of the model and get the parameters
net = Net().to(DEVICE)
if os.path.exists("final_model.pth") and False:
    print("Loading model from final_model.pth")
    net.load_state_dict(torch.load("final_model.pth"))
params = get_parameters(net)


# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=0.5,
    min_fit_clients=config["num_clients"],
    min_evaluate_clients=config["num_clients"] // 2,
    min_available_clients=config["num_clients"],
    evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    evaluate_fn=evaluate,  # <-- pass the evaluation function
    on_fit_config_fn=fit_config,  # Pass the fit_config function
    initial_parameters=fl.common.ndarrays_to_parameters(params),
)

# Specify the resources each of your clients need. By default, each
# client will be allocated 1x CPU and 0x GPUs
client_resources = {"num_cpus": 1, "num_gpus": 0.0}
if DEVICE.type == "cuda":
    # here we are assigning an entire GPU for each client.
    client_resources = {"num_cpus": 1, "num_gpus": 1.0}
    # Refer to our documentation for more details about Flower Simulations
    # and how to setup these `client_resources`.

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=config["num_clients"],
    config=fl.server.ServerConfig(num_rounds=config["epochs"],),
    strategy=strategy,
    client_resources=client_resources,
)

In [None]:
# plot the results
plt.plot(results)
plt.xlabel("Round")
plt.ylabel("Val Accuracy")
plt.title("Validation Accuracy vs Rounds")
plt.savefig("val_accuracy.png")

results = []

In [None]:
# Example usage of defense methods
# strategy = KrumServer(
#     num_malicious=config["poison_clients"],
#     to_keep=config["num_clients"] - len(config["poison_clients"]),
#     fraction_fit=1.0,
#     fraction_evaluate=0.5,
#     min_fit_clients=config["num_clients"],
#     min_evaluate_clients=config["num_clients"] // 2,
#     min_available_clients=config["num_clients"],
#     evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
#     evaluate_fn=evaluate,  # <-- pass the evaluation function
#     on_fit_config_fn=fit_config,  # Pass the fit_config function
#     initial_parameters=fl.common.ndarrays_to_parameters(params),
# )

# strategy = Robust_Server(
#     fraction_fit=1.0,
#     fraction_evaluate=0.5,
#     min_fit_clients=config["num_clients"],
#     min_evaluate_clients=config["num_clients"] // 2,
#     min_available_clients=config["num_clients"],
#     evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
#     evaluate_fn=evaluate,  # <-- pass the evaluation function
#     on_fit_config_fn=fit_config,  # Pass the fit_config function
#     initial_parameters=fl.common.ndarrays_to_parameters(params),
# )

# # Start simulation
# fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=config["num_clients"],
#     config=fl.server.ServerConfig(num_rounds=config["epochs"],),
#     strategy=strategy,
#     client_resources=client_resources,
# )

In [None]:
# # plot the results
# plt.plot(results)
# plt.xlabel("Round")
# plt.ylabel("Val Accuracy")
# plt.title("Validation Accuracy vs Rounds")
# plt.savefig("val_accuracy.png")

# results = []