In [1]:
methods_dict = {1:"Missing_Data", 2:"Outliers", 3:"Labeling_Errors", 4:"Feature_Noise"} 
def data_corruption(method, data, labels, corrupt_p):
    corruption_method = methods_dict[method]
    if corrupt_p <= 0 or corrupt_p >= 1:
        print("Please choose a valid value for the corruption parameter (positive, above 0 and less than 1.)")
        return None
    
    
    # Go through each image, and create a randomized mask which sets pixels to a value of 0.
    data_corrupted = []
    if corruption_method == "Missing_Data":
        for image in data:
            mask = np.zeros(image.shape[0], dtype=int)
            mask[int(image.shape[0]*corrupt_p):] = 1
            random.shuffle(mask)
            mask = mask.astype(bool)
            corrupted_image = np.where(mask == False, 0, image)
            data_corrupted.append(corrupted_image)
        
        return data_corrupted
    
    if corruption_method == "Outliers":
        print("not yet implemented")
        
    if corruption_method == "Labeling_Errors":
        if len(labels) > 0:
            label_names = np.unique(labels)
            mask = np.zeros(len(labels), dtype=int)
            mask[int(len(labels)*corrupt_p):] = 1
            random.shuffle(mask)
            mask = mask.astype(bool)
            corrupted_labels = np.where(mask == False, "needs_change", labels)
            for i in range(len(labels)):
                if corrupted_labels[i] == "needs_change":
                    options = np.delete(label_names, np.where(np.unique(labels) == labels[i])) ## Using np.unique to find the set of labels.
                    corrupted_labels[i] = random.choice(options)

            corrupted_labels = corrupted_labels.astype(type(labels[i]))
            return corrupted_labels
        
        
        else:
            print("Please use an actual list for the labels")
        
    if corruption_method == "Feature_Noise":
        print("not yet implemented")
    

In [2]:
import torch
import torch.nn as nn
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
import numpy as np
import random
import pickle

batch_size = 32
num_epochs = 3
corrupt = True

def FedNB(cp):
    accuracy_dict = dict()
    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Load the CIFAR10 dataset
    train_dataset = CIFAR10(root='./data', train=True, transform=ToTensor(), download=True)
    test_dataset = CIFAR10(root='./data', train=False, transform=ToTensor())

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Convert the image data to numpy arrays
    train_images = train_dataset.data.reshape(-1, 32*32*3)
    train_labels = np.array(train_dataset.targets)
    test_images = test_dataset.data.reshape(-1, 32*32*3)
    test_labels = np.array(test_dataset.targets)

    # Set the number of clients and local epochs
    num_clients = 10
    num_local_epochs = 3

    # Define the global model
    global_model = GaussianNB()

    # Create a list to store local models
    local_models = []

    # Define the train_loaders for each client
    train_loaders = []
    train_dataset_size = len(train_dataset)
    train_indices = list(range(train_dataset_size))
    subset_size = train_dataset_size // num_clients

    for client_idx in range(num_clients):
        # Create a random subset of the training dataset for each client
        subset_indices = train_indices[client_idx * subset_size: (client_idx + 1) * subset_size]
        train_sampler = SubsetRandomSampler(subset_indices)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_sampler,
            num_workers=2,
            pin_memory=True
        )
        train_loaders.append(train_loader)

    # Perform federated learning
    for epoch in range(num_epochs):
        # Perform local training for a few epochs
        local_models = []

        for client_idx in range(num_clients):
            local_model = GaussianNB()

            for local_epoch in range(num_local_epochs):
                # Get the local training data
                train_images, train_labels = next(iter(train_loaders[client_idx]))
                train_images = train_images.view(train_images.size(0), -1)
                if corrupt:
                    train_labels = data_corruption(3, train_images, train_labels.numpy(), cp)
                else:
                    train_labels = train_labels.numpy()

                # Fit the local model
                local_model.fit(train_images, train_labels)

            local_models.append(local_model)

        # Aggregate the parameters from local models
        global_model = GaussianNB()
        global_model.fit(train_images, train_labels)
        for key in global_model.get_params().keys():
            param_values = [getattr(model, key) for model in local_models if getattr(model, key) is not None]
            if param_values:
                param_values = [float(param_value) for param_value in param_values]
                mean_param_value = np.mean(param_values, axis=0)
                setattr(global_model, key, mean_param_value)

        # Evaluate the global model on the test set
        global_predictions = global_model.predict(test_images)
        accuracy = accuracy_score(test_labels, global_predictions)
        print(f"Epoch {epoch+1} - Global Model Accuracy: {accuracy}")
        accuracy_dict[epoch+1] = accuracy
    return accuracy_dict


In [3]:
corrupt_list = [0.95, 0.9375, 0.925, 0.9125, 0.90, 0.875, 0.85, 0.65, 0.50, 0.35, 0.20, 0.05]

for corrupt_par in corrupt_list:
    results = FedNB(corrupt_par)

    with open(f'NB_MNIST_3_C5_{corrupt_par}_Labels', 'wb') as handle:
        pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:08<00:00, 20615270.91it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Epoch 1 - Global Model Accuracy: 0.1114


KeyboardInterrupt: 