In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import copy
import random
import syft as sy
import statistics
import time

from sklearn.neighbors import NearestNeighbors
import numpy as np
from collections import Counter

import sys
sys.path.append('../')
from FLDataset.FLDataset import load_dataset
from FLDataset.FLDataset import getActualImgs

import warnings
warnings.filterwarnings('ignore')


# Define Arguments class
class Arguments:
    def __init__(self):
        self.images = 60000
        self.clients = 10
        self.rounds = 10
        self.local_batches = 64
        self.k_neighbors = 5
        self.C = 0.8
        self.drop_rate = 0.2
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images
        self.use_cuda = torch.cuda.is_available()
        self.save_model = False

args = Arguments()
device = torch.device("cuda" if args.use_cuda else "cpu")

# Set seed for reproducibility
torch.manual_seed(args.torch_seed)
np.random.seed(args.torch_seed)
random.seed(args.torch_seed)

# Initialize hook and clients
hook = sy.TorchHook(torch)
clients = [sy.VirtualWorker(hook, id=f"client{i + 1}") for i in range(args.clients)]

# Load dataset and split into clients
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
global_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
global_test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Split global dataset into client datasets
client_datasets = random_split(global_dataset, [args.split_size] * args.clients)

# Define the KNN class
class KNN:
    def __init__(self, k=5):
        self.k = k
        self.X_train = None
        self.y_train = None
        self.nn_model = None

    def fit(self, X_train, y_train):
        # Convert data to numpy arrays
        self.X_train = np.array(X_train)
        self.y_train = np.array(y_train)
        # Fit nearest neighbors model
        self.nn_model = NearestNeighbors(n_neighbors=self.k, metric='euclidean')
        self.nn_model.fit(self.X_train)

    def predict(self, X_test, batch_size=100):
        # Convert data to numpy arrays
        if isinstance(X_test, list):
            X_test = np.array(X_test)
        else:
            X_test = np.array(X_test)

        num_samples = X_test.shape[0]
        predictions = []

        # Process predictions in smaller batches
        for i in range(0, num_samples, batch_size):
            batch_end = min(i + batch_size, num_samples)
            X_test_batch = X_test[i:batch_end]

            # Find k-nearest neighbors for the batch
            distances, indices = self.nn_model.kneighbors(X_test_batch)
            
            # Predict the mode (most common label) among the k-nearest neighbors
            for index_list in indices:
                k_labels = self.y_train[index_list]
                counter = Counter(k_labels)
                most_common_label = counter.most_common(1)[0][0]
                predictions.append(most_common_label)

        return np.array(predictions)


# Initialize clients with datasets and models
for i, client in enumerate(clients):
    client.trainset = DataLoader(client_datasets[i], batch_size=args.local_batches, shuffle=True)
    client.model = KNN(k=args.k_neighbors)

# Create global test loader
global_test_loader = DataLoader(global_test_dataset, batch_size=args.local_batches, shuffle=False)

# Client update function
# def client_update(client):
#     train_loader = client.trainset
#     model = client.model
#     X_train, y_train = [], []
#     for data, target in train_loader:
#         data = data.view(data.shape[0], -1).numpy()
#         X_train.extend(data)
#         y_train.extend(target.numpy())
#     model.fit(X_train, y_train)
#     print(f"Client {client.id} trained local KNN model")

def client_update(args, device, client):
    train_loader = client.trainset
    model = client.model
    X_train, y_train = [], []
    
    # Collect training data and labels from the DataLoader
    for data, target in train_loader:
        # Convert data to a flat array format and append it to X_train
        data = data.view(data.shape[0], -1).numpy()
        X_train.extend(data)
        # Append labels to y_train
        y_train.extend(target.numpy())
    
    # Fit the local KNN model with the collected training data and labels
    model.fit(X_train, y_train)
    print(f"Client {client.id} trained local KNN model")
    
    # Iterate over rounds (epochs) and batches for logging purposes
    for epoch in range(1, args.rounds + 1):
        for batch_idx, (data, target) in enumerate(train_loader):
            # Move data and target to the appropriate device
            data, target = data.to(device), target.to(device)
            
            # Convert data to a flat array format suitable for model predictions
            X_test = data.view(data.shape[0], -1).cpu().numpy()
            y_test = target.cpu().numpy()
            
            # Predict using the client's KNN model
            pred = model.predict(X_test)
            
            # Calculate accuracy by comparing predicted labels with the actual labels
            accuracy = np.mean(pred == y_test) * 100
            
            # Log training progress at the specified interval
            if batch_idx % args.log_interval == 0:
                print('Client {} Train round: {} [{}/{} ({:.0f}%)] - Accuracy: {:.2f}%'.format(
                    client.id,
                    epoch, batch_idx * args.local_batches, len(train_loader) * args.local_batches, 
                    100. * batch_idx / len(train_loader),
                    accuracy))
    
    # Return the collected training data and labels for potential aggregation
    return np.array(X_train), np.array(y_train)


def test(model, test_loader, name):
    correct = 0
    total = 0
    X_test, y_test = [], []
    for data, target in test_loader:
        data = data.view(data.shape[0], -1).numpy()
        X_test.extend(data)
        y_test.extend(target.numpy())
    
    print(f"Testing {name} model with {len(X_test)} samples")
    
    # Measure prediction time
    start_time = time.time()
    y_pred = model.predict(X_test)
    prediction_time = time.time() - start_time
    print(f"Prediction time: {prediction_time:.2f} seconds")
    
    accuracy = (np.sum(y_pred == y_test) / len(y_test)) * 100
    print(f'{name} Accuracy: {accuracy:.2f}%')


def aggregate_models(clients):
    print("Aggregating models from clients")
    global_X_train, global_y_train = [], []
    for client in clients:
        train_loader = client.trainset
        model = client.model
        X_train, y_train = [], []
        for data, target in train_loader:
            data = data.view(data.shape[0], -1).numpy()
            X_train.extend(data)
            y_train.extend(target.numpy())
        global_X_train.extend(X_train)
        global_y_train.extend(y_train)
    print(f"Aggregated data size: {len(global_X_train)} samples")
    # Create a new KNN model and fit the aggregated data
    global_model = KNN(k=args.k_neighbors)
    global_model.fit(global_X_train, global_y_train)
    print("Global model updated with aggregated data")
    return global_model


# Main training loop
global_model = KNN(k=args.k_neighbors)
for fed_round in range(args.rounds):
    # Select a subset of clients for the current round
    num_selected_clients = int(args.C * args.clients)
    selected_clients = random.sample(clients, num_selected_clients)
    for client in selected_clients:
        client_update(client)
    # Aggregate client models and update the global model
    global_model = aggregate_models(selected_clients)
    # Test the global model
    test(global_model, global_test_loader, "Global")

    # Share the updated global model with all clients
    for client in clients:
        client.model = copy.deepcopy(global_model)

if args.save_model:
    torch.save(global_model, "KNN.pt")




TypeError: client_update() missing 2 required positional arguments: 'device' and 'client'