# 1.Library

In [11]:
import os
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import copy
import concurrent.futures

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1273ac7cd50>

# 2.Data Loader

In [12]:

def get_client_data(data_dir, client_type, idx, img_size=28):
    X, y = [], []
    folder = os.path.join(data_dir, f"{client_type}_{idx}")
    for label in os.listdir(folder):
        label_folder = os.path.join(folder, label)
        if not os.path.isdir(label_folder):
            continue
        for fname in os.listdir(label_folder):
            if fname.endswith('.png') or fname.endswith('.jpg'):
                img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                X.append(np.array(img).flatten() / 255.0)
                y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

def load_mnist_binary_test_data_flat(test_dir, img_size=28):
    X = []
    y = []
    if not os.path.exists(test_dir):
        return X, y
    for label in ['0', '1']:
        label_folder = os.path.join(test_dir, label)
        if not os.path.isdir(label_folder):
            continue
        for fname in os.listdir(label_folder):
            if fname.endswith('.png') or fname.endswith('.jpg'):
                img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                X.append(np.array(img).flatten() / 255.0)
                y.append(label)
    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    return X, y

# 3.Model

In [13]:
# class MNISTNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(28*28, 32)
#         self.fc2 = nn.Linear(32, 1)
#         self.relu = nn.ReLU()
#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         x = torch.sigmoid(x)
#         return x

In [14]:
# class MNISTNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(28*28, 1)
#     def forward(self, x):
#         x = self.fc1(x)
#         x = torch.sigmoid(x)
#         return x
    
# save state_dict + metadata (safe, portable)
def save_model(model, img_size, n_classes, learning_rate, experiment_bs, file_path):

    state = {
        "model_state_dict": model.state_dict(),                     # CPU/GPU tensors okay
        "arch": "MNISTNet",
        "img_size": img_size,
        "num_classes": n_classes,
        "training_args": {"lr": learning_rate, "batch_size": experiment_bs}
    }
    # ensure weights are on CPU to avoid GPU-only pickle issues
    state["model_state_dict"] = {k: v.cpu() for k, v in state["model_state_dict"].items()}
    torch.save(state, file_path)
    
def load_model(file_path, device):

    checkpoint = torch.load(file_path, map_location=device)
    model = MNISTNet().to(device)              # must have MNISTNet class defined/importable
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return model


# 4.Loss

In [15]:
def binary_cross_entropy(pred, target):
    eps = 1e-7
    pred = torch.clamp(pred, eps, 1 - eps)
    return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).mean()


In [16]:
def train_local_worker(args):
    X_c_tensor, y_c_tensor, global_weights, learning_rate, local_epoch, experiment_bs, device = args
    # create model on target device
    model = MNISTNet().to(device)
    # ensure weights are mapped to the worker device
    state_on_device = {k: v.to(device) for k, v in global_weights.items()}
    model.load_state_dict(state_on_device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for epoch in range(local_epoch):
        for start in range(0, X_c_tensor.size(0), experiment_bs):
            end = start + experiment_bs
            xb = X_c_tensor[start:end].to(device)
            yb = y_c_tensor[start:end].to(device)
            outputs = model(xb)
            preds = (outputs >= 0.5).float()
            loss = binary_cross_entropy(outputs, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # return CPU copies for safe aggregation in main thread
    returned_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
    return returned_state, X_c_tensor.size(0), yb.detach().cpu(), outputs.detach().cpu(), preds.detach().cpu()

In [17]:

def average_weights(w_list):
    avg = {}
    for k in w_list[0].keys():
        avg[k] = sum([w[k] for w in w_list]) / len(w_list)
    return avg

# 5.Configuration

In [18]:
N_HONEST = 100
N_POISONED = 100
IMG_SIZE = 28
LEARNING_RATE = 0.001
ROUNDS = 100
LOCAL_EPOCHS = 1
BATCH_SIZE = 32
N_CLASSES = 2

# Set device
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

DATA_PATH = "../data/mnist_binary_poison/train"
TEST_PATH = "../data/mnist_binary_poison/test"

RESULT_PATH = "./results"


# 6.Main

In [19]:
poison_percent = [0, 10, 20, 30, 40, 50]
poison_n_list = []
for p in range(0,N_POISONED+1):
    percent = int((p / (p+N_HONEST)) * 100)
    if percent in poison_percent:
        poison_percent.remove(percent)
        print(f"Adding poison n: {p} for percent: {percent}")
        poison_n_list.append(p)
print("Poison n list:", poison_n_list)

Adding poison n: 0 for percent: 0
Adding poison n: 12 for percent: 10
Adding poison n: 25 for percent: 20
Adding poison n: 43 for percent: 30
Adding poison n: 67 for percent: 40
Adding poison n: 100 for percent: 50
Poison n list: [0, 12, 25, 43, 67, 100]


In [None]:
NODE = 1
for LAYER in [3,4,5,6]: # LAYER - (first,last)
        
    class MNISTNet(nn.Module):
        def __init__(self):
            super().__init__()
            
            # 1. Define the stack of hidden layers (Linear + ReLU)
            # Use a list comprehension to create the (Linear, ReLU) pairs
            hidden_layers = []
            for _ in range(LAYER-2):
                hidden_layers.append(nn.Conv2d(NODE, NODE, kernel_size=3, padding=1))
                hidden_layers.append(nn.ReLU())
                
            # 2. Use nn.Sequential to wrap the entire dynamic stack
            self.dynamic_hidden = nn.Sequential(*hidden_layers)
            
            # Other layers
            self.conv1 = nn.Conv2d(1, NODE, kernel_size=3, padding=1) # 28x28
            self.pool = nn.MaxPool2d(2, 2) # 14x14
            self.fc1 = nn.Linear(NODE * 14 * 14, 1)
            self.relu = nn.ReLU()

        def forward(self, x):
            x = x.view(-1, 1, 28, 28)
            x = self.pool(self.relu(self.conv1(x)))
            x = nn.ReLU()(x) # We need to manually apply ReLU here
            
            # Dynamic Hidden Layers (Executed sequentially without a loop)
            x = self.dynamic_hidden(x)
                
            # Output Layer
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            x = torch.sigmoid(x)
            
            return x

    BATCH_SIZE_list = []
    temp_bs = BATCH_SIZE
    while temp_bs >= 1:
        BATCH_SIZE_list.append(temp_bs)
        temp_bs = temp_bs // 2

    for rev in range(3):

        global_model = MNISTNet().to(device)
        for experiment_bs in BATCH_SIZE_list:

            name_save_path = f"FL_ModelN{NODE}L{LAYER}_Batchsize{experiment_bs}_rev{rev}"
            
            for i_poisoned in poison_n_list:
                # local_epoch = # 29 image x 2 class x 100+i_poisoned clents / 32 batch_size / 100+i_poisoned clents
                LOCAL_EPOCHS = max(1, (29 * 2 * (N_HONEST + i_poisoned)) // (experiment_bs * (N_HONEST + i_poisoned)))
                
                percent_poisoned = int((i_poisoned / (i_poisoned + N_HONEST)) * 100)

                # create directory
                NAME_SAVE_update_PATH = f"poisoned_{percent_poisoned}percent"
                save_path = os.path.join(RESULT_PATH,name_save_path,NAME_SAVE_update_PATH)
                # Remove existing directory if it exists
                if os.path.exists(save_path):
                    shutil.rmtree(save_path)
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                save_all_path = os.path.join(RESULT_PATH,name_save_path,'all')
                os.makedirs(save_all_path, exist_ok=True)
                
                # Load data
                X_test, y_test = load_mnist_binary_test_data_flat(TEST_PATH, img_size=IMG_SIZE)
                if X_test.shape[0] == 0:
                    raise ValueError("No valid test images found in the test directory!")
                X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
                y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)

                # torch.save(model, os.path.join(save_path,'model_init.pt'))
                save_path_name = os.path.join(save_path,'model_init.pt')
                save_model(global_model, IMG_SIZE, N_CLASSES, LEARNING_RATE, experiment_bs, save_path_name)
                
                global_weights = global_model.state_dict()

                records = []

                X_train_tensor = {}
                y_train_tensor = {}
                for i in range(N_HONEST):
                    X_c, y_c = get_client_data(DATA_PATH, "honest", i, img_size=IMG_SIZE)
                    if len(X_c) == 0:
                        continue
                    X_c_tensor = torch.tensor(X_c, dtype=torch.float32).to(device)
                    y_c_tensor = torch.tensor(y_c, dtype=torch.float32).unsqueeze(1).to(device)
                    X_train_tensor[f"honest_{i}"] = X_c_tensor
                    y_train_tensor[f"honest_{i}"] = y_c_tensor
                for i in range(i_poisoned):
                    X_c, y_c = get_client_data(DATA_PATH, "poison", i, img_size=IMG_SIZE)
                    if len(X_c) == 0:
                        continue
                    X_c_tensor = torch.tensor(X_c, dtype=torch.float32).to(device)
                    y_c_tensor = torch.tensor(y_c, dtype=torch.float32).unsqueeze(1).to(device)
                    X_train_tensor[f"poison_{i}"] = X_c_tensor
                    y_train_tensor[f"poison_{i}"] = y_c_tensor

                # Training loop
                for round in range(ROUNDS):
                    # Prepare arguments for honest clients
                    
                    #randomize the training data
                    for i in range(N_HONEST):
                        perm = torch.randperm(X_train_tensor[f"honest_{i}"].size(0))
                        X_train_tensor[f"honest_{i}"] = X_train_tensor[f"honest_{i}"][perm]
                        y_train_tensor[f"honest_{i}"] = y_train_tensor[f"honest_{i}"][perm]

                    honest_args = [
                        (
                            X_train_tensor[f"honest_{i}"],
                            y_train_tensor[f"honest_{i}"],
                            global_weights,
                            LEARNING_RATE,
                            LOCAL_EPOCHS,
                            experiment_bs,
                            device
                        )
                        for i in range(N_HONEST)
                    ]

                    # Prepare arguments for poisoned clients
                    
                    #randomize the training data
                    for i in range(i_poisoned):
                        perm = torch.randperm(X_train_tensor[f"poison_{i}"].size(0))
                        X_train_tensor[f"poison_{i}"] = X_train_tensor[f"poison_{i}"][perm]
                        y_train_tensor[f"poison_{i}"] = y_train_tensor[f"poison_{i}"][perm]

                    poison_args = [
                        (
                            X_train_tensor[f"poison_{i}"],
                            y_train_tensor[f"poison_{i}"],
                            global_weights,
                            LEARNING_RATE,
                            LOCAL_EPOCHS,
                            experiment_bs,
                            device
                        )
                        for i in range(i_poisoned)
                    ]

                    # Run in parallel
                    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
                        honest_results = list(executor.map(train_local_worker, honest_args))
                        poison_results = list(executor.map(train_local_worker, poison_args))

                    local_weights = []
                    local_sizes = []
                    local_y = []
                    local_outputs = []
                    local_preds = []
                    for l_w, l_sz, l_y, l_out, l_pred in honest_results + poison_results:
                        local_weights.append(l_w)
                        local_sizes.append(l_sz)
                        local_y.extend(l_y)
                        local_outputs.extend(l_out)
                        local_preds.extend(l_pred)
                    loss_train = binary_cross_entropy(torch.stack(local_outputs), torch.stack(local_y))
                    acc_train = (torch.stack(local_preds) == torch.stack(local_y)).float().mean().item()


                    # Federated averaging (weighted by client data size)
                    new_weights = {}
                    for key in global_weights.keys():
                        new_weights[key] = sum([w[key]*sz for w, sz in zip(local_weights, local_sizes)]) / sum(local_sizes)
                    global_weights = new_weights
                    global_model.load_state_dict(global_weights)

                    # Evaluate
                    global_model.eval()
                    with torch.no_grad():
                        outputs = global_model(X_test_tensor)
                        preds = (outputs > 0.5).float()
                        loss_test = binary_cross_entropy(outputs, y_test_tensor)
                        acc_test = (preds == y_test_tensor).float().mean().item()
                    records.append({'poison_percent': percent_poisoned, 'round': round, 
                                    'loss_train': loss_train.item(), 'acc_train': acc_train, 
                                    'loss_test': loss_test.item(), 'acc_test': acc_test})
                    report_txt = f"poison_percent {percent_poisoned} "
                    report_txt += f"round {round}: "
                    report_txt += f"loss_train={loss_train.item():.4f}, "
                    report_txt += f"acc_train={acc_train:.4f}, "
                    report_txt += f"loss_test={loss_test.item():.4f}, "
                    report_txt += f"acc_test={acc_test:.4f}"
                    print(report_txt)

                # torch.save(model, os.path.join(save_path,'model_last.pt'))
                save_path_name = os.path.join(save_path,'model_last.pt')
                save_model(global_model, IMG_SIZE, N_CLASSES, LEARNING_RATE, experiment_bs, save_path_name)

                # Save training log
                df = pd.DataFrame(records)
                save_name_path = os.path.join(save_path, f'{NAME_SAVE_update_PATH}.csv')
                df.to_csv(save_name_path, index=False)
                print(f"Training log saved to {save_name_path}")

                # Plot loss and accuracy
                plt.figure(figsize=(10,4))

                plt.subplot(1,2,1)
                plt.plot(df['round'], df['loss_test'], marker='o')
                plt.title('Test Loss')
                plt.ylabel('Loss')
                plt.xlabel('Epoch')
                plt.grid(True)
                plt.ylim(0, 1.1)
                # plt.legend()

                plt.subplot(1,2,2)
                plt.plot(df['round'], df['acc_test'], marker='o')
                plt.title('Test Accuracy')
                plt.ylabel('Accuracy')
                plt.xlabel('Epoch')
                plt.grid(True)
                plt.ylim(0, 1.1)
                # plt.legend()

                plt.tight_layout()
                save_name_path = os.path.join(save_path, f'loss_accuracy.jpg')
                plt.savefig(save_name_path)

                save_name_path = os.path.join(save_all_path, f'{NAME_SAVE_update_PATH}_latest.jpg')
                plt.savefig(save_name_path)
                plt.close()


poison_percent 0 round 0: loss_train=0.6963, acc_train=0.4927, loss_test=0.6942, acc_test=0.5000
poison_percent 0 round 1: loss_train=0.6967, acc_train=0.4835, loss_test=0.6939, acc_test=0.5000
poison_percent 0 round 2: loss_train=0.6942, acc_train=0.5092, loss_test=0.6937, acc_test=0.5000
poison_percent 0 round 3: loss_train=0.6951, acc_train=0.4927, loss_test=0.6935, acc_test=0.5000
poison_percent 0 round 4: loss_train=0.6941, acc_train=0.5058, loss_test=0.6934, acc_test=0.5000
poison_percent 0 round 5: loss_train=0.6940, acc_train=0.5062, loss_test=0.6932, acc_test=0.5000
poison_percent 0 round 6: loss_train=0.6942, acc_train=0.5023, loss_test=0.6932, acc_test=0.5000
poison_percent 0 round 7: loss_train=0.6944, acc_train=0.4854, loss_test=0.6931, acc_test=0.5000
poison_percent 0 round 8: loss_train=0.6940, acc_train=0.4473, loss_test=0.6931, acc_test=0.5015
poison_percent 0 round 9: loss_train=0.6942, acc_train=0.4438, loss_test=0.6931, acc_test=0.4934
poison_percent 0 round 10: los