In [None]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import random
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import pickle
from torch.utils.data import DataLoader, Subset

accuracy_dict = dict()

# Define the number of clients and data split percentages
num_clients = 5
train_split = 0.8
test_split = 0.2
num_epochs = 10

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the MNIST dataset
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)
x_train, y_train = train_dataset.data.numpy(), train_dataset.targets.numpy()
x_test, y_test = test_dataset.data.numpy(), test_dataset.targets.numpy()

# Split the data into non-iid client datasets
client_data = []
for i in range(num_clients):
    indices = np.arange(len(x_train))
    np.random.shuffle(indices)
    split = int(train_split * len(x_train))
    train_indices, val_indices = indices[:split], indices[split:]
    client_data.append({'x_train': x_train[train_indices], 'y_train': y_train[train_indices],
                        'x_val': x_train[val_indices], 'y_val': y_train[val_indices]})

# Train the models on the client data for multiple epochs
models = []
for epoch in range(num_epochs):
    for client in client_data:
        x_train, y_train = client['x_train'], client['y_train']
        model = KNeighborsClassifier(n_neighbors=5)
        model.fit(x_train.reshape(len(x_train), -1), y_train)
        models.append(model)

    # Evaluate the models on the test data
    total_correct = 0
    for i in range(len(x_test)):
        predictions = []
        for model in models:
            prediction = model.predict(x_test[i].reshape(1, -1))[0]
            predictions.append(prediction)
        consensus = max(set(predictions), key=predictions.count)
        if consensus == y_test[i]:
            total_correct += 1
    accuracy = total_correct / len(x_test)
    print('Epoch %d Test accuracy: %.2f%%' % (epoch+1, accuracy * 100))
    accuracy_dict[epoch+1] = accuracy*100

with open('KNN_10_5.pickle', 'wb') as handle:
    pickle.dump(accuracy_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)