# 1.Library

In [1]:
import os
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import copy
import concurrent.futures

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x29c51fecd50>

# 2.Data Loader

In [2]:

def get_client_data(data_dir, client_type, idx, img_size=28):
    X, y = [], []
    folder = os.path.join(data_dir, f"{client_type}_{idx}")
    for label in os.listdir(folder):
        label_folder = os.path.join(folder, label)
        if not os.path.isdir(label_folder):
            continue
        for fname in os.listdir(label_folder):
            if fname.endswith('.png') or fname.endswith('.jpg'):
                img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                X.append(np.array(img).flatten() / 255.0)
                y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

def load_mnist_binary_test_data_flat(test_dir, img_size=28):
    X = []
    y = []
    if not os.path.exists(test_dir):
        return X, y
    for label in ['0', '1']:
        label_folder = os.path.join(test_dir, label)
        if not os.path.isdir(label_folder):
            continue
        for fname in os.listdir(label_folder):
            if fname.endswith('.png') or fname.endswith('.jpg'):
                img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                X.append(np.array(img).flatten() / 255.0)
                y.append(label)
    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    return X, y

# 3.Model

In [3]:
class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 1)
    def forward(self, x):
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

# 4.Loss

In [4]:
def binary_cross_entropy(pred, target):
    eps = 1e-7
    pred = torch.clamp(pred, eps, 1 - eps)
    return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).mean()


In [5]:

def train_local_worker(args):
    X_c_tensor, y_c_tensor, global_weights, learning_rate, local_epoch, experiment_bs, device = args
    model = MNISTNet().to(device)
    model.load_state_dict(global_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for epoch in range(local_epoch):
        for start in range(0, X_c_tensor.size(0), experiment_bs):
            end = start + experiment_bs
            xb = X_c_tensor[start:end]
            yb = y_c_tensor[start:end]
            outputs = model(xb)
            preds = (outputs >= 0.5).float()
            loss = binary_cross_entropy(outputs, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # Return a deepcopy to avoid issues with state_dict references
    return copy.deepcopy(model.state_dict()), X_c_tensor.size(0), yb, outputs, preds


In [6]:

def average_weights(w_list):
    avg = {}
    for k in w_list[0].keys():
        avg[k] = sum([w[k] for w in w_list]) / len(w_list)
    return avg

# 5.Configuration

In [7]:
N_HONEST = 100
N_POISONED = 100
IMG_SIZE = 28
LEARNING_RATE = 0.001
ROUNDS = 100
LOCAL_EPOCHS = 1
BATCH_SIZE = 32
N_CLASSES = 2

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DATA_PATH = "../data/mnist_binary_poison/train"
TEST_PATH = "../data/mnist_binary_poison/test"

RESULT_PATH = "./results"


# 6.Main

In [8]:
poison_percent = [0, 10, 20, 30, 40, 50]
poison_n_list = []
for p in range(0,N_POISONED+1):
    percent = int((p / (p+N_HONEST)) * 100)
    if percent in poison_percent:
        poison_percent.remove(percent)
        print(f"Adding poison n: {p} for percent: {percent}")
        poison_n_list.append(p)
print("Poison n list:", poison_n_list)

Adding poison n: 0 for percent: 0
Adding poison n: 12 for percent: 10
Adding poison n: 25 for percent: 20
Adding poison n: 43 for percent: 30
Adding poison n: 67 for percent: 40
Adding poison n: 100 for percent: 50
Poison n list: [0, 12, 25, 43, 67, 100]


In [9]:

BATCH_SIZE_list = [BATCH_SIZE]
# temp_bs = BATCH_SIZE
# while temp_bs >= 1:
#     BATCH_SIZE_list.append(temp_bs)
#     temp_bs = temp_bs // 2

for rev in range(3):

    global_model = MNISTNet().to(device)
    for experiment_bs in BATCH_SIZE_list:

        name_save_path = f"FL_ModelN1L1_Batchsize{experiment_bs}_rev{rev}"
        
        for i_poisoned in poison_n_list:
            percent_poisoned = int((i_poisoned / (i_poisoned + N_HONEST)) * 100)

            # create directory
            NAME_SAVE_update_PATH = f"poisoned_{percent_poisoned}percent"
            save_path = os.path.join(RESULT_PATH,name_save_path,NAME_SAVE_update_PATH)
            # Remove existing directory if it exists
            if os.path.exists(save_path):
                shutil.rmtree(save_path)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            save_all_path = os.path.join(RESULT_PATH,name_save_path,'all')
            os.makedirs(save_all_path, exist_ok=True)
            
            # Load data
            X_test, y_test = load_mnist_binary_test_data_flat(TEST_PATH, img_size=IMG_SIZE)
            if X_test.shape[0] == 0:
                raise ValueError("No valid test images found in the test directory!")
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
            y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)

            torch.save(global_model, os.path.join(save_path,'model_init.pt'))
            global_weights = global_model.state_dict()

            records = []

            X_train_tensor = {}
            y_train_tensor = {}
            for i in range(N_HONEST):
                X_c, y_c = get_client_data(DATA_PATH, "honest", i, img_size=IMG_SIZE)
                if len(X_c) == 0:
                    continue
                X_c_tensor = torch.tensor(X_c, dtype=torch.float32).to(device)
                y_c_tensor = torch.tensor(y_c, dtype=torch.float32).unsqueeze(1).to(device)
                X_train_tensor[f"honest_{i}"] = X_c_tensor
                y_train_tensor[f"honest_{i}"] = y_c_tensor
            for i in range(i_poisoned):
                X_c, y_c = get_client_data(DATA_PATH, "poison", i, img_size=IMG_SIZE)
                if len(X_c) == 0:
                    continue
                X_c_tensor = torch.tensor(X_c, dtype=torch.float32).to(device)
                y_c_tensor = torch.tensor(y_c, dtype=torch.float32).unsqueeze(1).to(device)
                X_train_tensor[f"poison_{i}"] = X_c_tensor
                y_train_tensor[f"poison_{i}"] = y_c_tensor

            # Training loop
            for round in range(ROUNDS):
                # Prepare arguments for honest clients
                
                #randomize the training data
                for i in range(N_HONEST):
                    perm = torch.randperm(X_train_tensor[f"honest_{i}"].size(0))
                    X_train_tensor[f"honest_{i}"] = X_train_tensor[f"honest_{i}"][perm]
                    y_train_tensor[f"honest_{i}"] = y_train_tensor[f"honest_{i}"][perm]

                honest_args = [
                    (
                        X_train_tensor[f"honest_{i}"],
                        y_train_tensor[f"honest_{i}"],
                        global_weights,
                        LEARNING_RATE,
                        LOCAL_EPOCHS,
                        experiment_bs,
                        device
                    )
                    for i in range(N_HONEST)
                ]

                # Prepare arguments for poisoned clients
                
                #randomize the training data
                for i in range(i_poisoned):
                    perm = torch.randperm(X_train_tensor[f"poison_{i}"].size(0))
                    X_train_tensor[f"poison_{i}"] = X_train_tensor[f"poison_{i}"][perm]
                    y_train_tensor[f"poison_{i}"] = y_train_tensor[f"poison_{i}"][perm]

                poison_args = [
                    (
                        X_train_tensor[f"poison_{i}"],
                        y_train_tensor[f"poison_{i}"],
                        global_weights,
                        LEARNING_RATE,
                        LOCAL_EPOCHS,
                        experiment_bs,
                        device
                    )
                    for i in range(i_poisoned)
                ]

                # Run in parallel
                with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
                    honest_results = list(executor.map(train_local_worker, honest_args))
                    poison_results = list(executor.map(train_local_worker, poison_args))

                local_weights = []
                local_sizes = []
                local_y = []
                local_outputs = []
                local_preds = []
                for l_w, l_sz, l_y, l_out, l_pred in honest_results + poison_results:
                    local_weights.append(l_w)
                    local_sizes.append(l_sz)
                    local_y.extend(l_y)
                    local_outputs.extend(l_out)
                    local_preds.extend(l_pred)
                loss_train = binary_cross_entropy(torch.stack(local_outputs), torch.stack(local_y))
                acc_train = (torch.stack(local_preds) == torch.stack(local_y)).float().mean().item()


                # Federated averaging (weighted by client data size)
                new_weights = {}
                for key in global_weights.keys():
                    new_weights[key] = sum([w[key]*sz for w, sz in zip(local_weights, local_sizes)]) / sum(local_sizes)
                global_weights = new_weights
                global_model.load_state_dict(global_weights)

                # Evaluate
                global_model.eval()
                with torch.no_grad():
                    outputs = global_model(X_test_tensor)
                    preds = (outputs > 0.5).float()
                    loss_test = binary_cross_entropy(outputs, y_test_tensor)
                    acc_test = (preds == y_test_tensor).float().mean().item()
                records.append({'poison_percent': percent_poisoned, 'round': round, 
                                'loss_train': loss_train.item(), 'acc_train': acc_train, 
                                'loss_test': loss_test.item(), 'acc_test': acc_test})
                report_txt = f"poison_percent {percent_poisoned} "
                report_txt += f"round {round}: "
                report_txt += f"loss_train={loss_train.item():.4f}, "
                report_txt += f"acc_train={acc_train:.4f}, "
                report_txt += f"loss_test={loss_test.item():.4f}, "
                report_txt += f"acc_test={acc_test:.4f}"
                print(report_txt)

            torch.save(global_model, os.path.join(save_path,'model_last.pt'))

            # Save training log
            df = pd.DataFrame(records)
            save_name_path = os.path.join(save_path, f'{NAME_SAVE_update_PATH}.csv')
            df.to_csv(save_name_path, index=False)
            print(f"Training log saved to {save_name_path}")

            # Plot loss and accuracy
            plt.figure(figsize=(10,4))

            plt.subplot(1,2,1)
            plt.plot(df['round'], df['loss_test'], marker='o')
            plt.title('Test Loss')
            plt.ylabel('Loss')
            plt.xlabel('Epoch')
            plt.grid(True)
            plt.ylim(0, 1.1)
            # plt.legend()

            plt.subplot(1,2,2)
            plt.plot(df['round'], df['acc_test'], marker='o')
            plt.title('Test Accuracy')
            plt.ylabel('Accuracy')
            plt.xlabel('Epoch')
            plt.grid(True)
            plt.ylim(0, 1.1)
            # plt.legend()

            plt.tight_layout()
            save_name_path = os.path.join(save_path, f'loss_accuracy.jpg')
            plt.savefig(save_name_path)

            save_name_path = os.path.join(save_all_path, f'{NAME_SAVE_update_PATH}_latest.jpg')
            plt.savefig(save_name_path)
            plt.close()


poison_percent 0 round 0: loss_train=0.6310, acc_train=0.8462, loss_test=0.6052, acc_test=0.9020
poison_percent 0 round 1: loss_train=0.5829, acc_train=0.9204, loss_test=0.5593, acc_test=0.9429
poison_percent 0 round 2: loss_train=0.5412, acc_train=0.9488, loss_test=0.5185, acc_test=0.9612
poison_percent 0 round 3: loss_train=0.5064, acc_train=0.9673, loss_test=0.4822, acc_test=0.9770
poison_percent 0 round 4: loss_train=0.4702, acc_train=0.9777, loss_test=0.4495, acc_test=0.9806
poison_percent 0 round 5: loss_train=0.4379, acc_train=0.9796, loss_test=0.4197, acc_test=0.9847
poison_percent 0 round 6: loss_train=0.4129, acc_train=0.9850, loss_test=0.3923, acc_test=0.9867
poison_percent 0 round 7: loss_train=0.3818, acc_train=0.9842, loss_test=0.3670, acc_test=0.9872
poison_percent 0 round 8: loss_train=0.3635, acc_train=0.9881, loss_test=0.3435, acc_test=0.9903
poison_percent 0 round 9: loss_train=0.3394, acc_train=0.9900, loss_test=0.3216, acc_test=0.9913
poison_percent 0 round 10: los