<a href="https://colab.research.google.com/github/artsasse/fedkan/blob/main/PyTorch_Federated_Spline_KAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Instalações

In [None]:
pip install pykan

Collecting pykan
  Downloading pykan-0.2.4-py3-none-any.whl.metadata (14 kB)
Downloading pykan-0.2.4-py3-none-any.whl (95 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.2/95.2 kB[0m [31m997.3 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pykan
Successfully installed pykan-0.2.4


# Bibliotecas

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from kan import *

# Funções

In [None]:
def create_clients(image_list, label_list, num_clients=10, initial='client'):
    ''' return: a dictionary with keys clients' names and value as
                data shards - tuple of images and label lists.
        args:
            image_list: a list of numpy arrays of training images
            label_list: a list of binarized labels for each image
            num_client: number of federated members (clients)
            initials: the clients' name prefix, e.g., client_1
    '''
    # Create a list of client names
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    # Randomize the data
    data = list(zip(image_list, label_list))
    random.shuffle(data)

    # Shard data and place at each client
    shard_size = len(data) // num_clients
    shards = [data[i:i + shard_size] for i in range(0, shard_size * num_clients, shard_size)]

    # Number of clients must equal number of shards
    assert len(shards) == len(client_names)

    return {client_names[i]: shards[i] for i in range(len(client_names))}


def batch_data(data_shard, bs=32):
    '''Takes in a client's data shard and creates a DataLoader object'''
    data, label = zip(*data_shard)
    data, label = torch.tensor(data, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
    dataset = TensorDataset(data, label)
    return DataLoader(dataset, batch_size=bs, shuffle=True)


class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.layer1 = nn.Linear(input_dim, 200)
        self.layer2 = nn.Linear(200, 200)
        self.layer3 = nn.Linear(200, output_dim)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.softmax(self.layer3(x))
        return x


def weight_scaling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    bs = clients_trn_data[client_name].batch_size
    global_count = sum([len(clients_trn_data[name].dataset) for name in client_names])
    local_count = len(clients_trn_data[client_name].dataset)
    return local_count / global_count


def scale_model_weights(weight, scalar):
    '''Function for scaling a model's weights'''
    return [scalar * w for w in weight]


def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. This is equivalent to scaled avg of the weights'''
    avg_grad = list()
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = torch.sum(torch.stack(grad_list_tuple), dim=0)
        avg_grad.append(layer_mean)
    return avg_grad


def test_model(X_test, Y_test, model, comm_round):
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        logits = model(X_test)
        loss = criterion(logits, torch.argmax(Y_test, dim=1))
        acc = accuracy_score(torch.argmax(logits, axis=1).numpy(), torch.argmax(Y_test, axis=1).numpy())
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss.item()))
    return acc, loss.item()

# Execução

In [None]:
#TODO Assure notebook reproducibility (read Kalinowski's book)

# Load MNIST dataset
# Temporarily using dataset from Keras, to mantain uniformity with the baseline
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
X_train = X_train.reshape(-1, 28 * 28) / 255.0
X_test = X_test.reshape(-1, 28 * 28) / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)


# Create clients
clients = create_clients(X_train, y_train, num_clients=10, initial='client')

# Process and batch the training data for each client
clients_batched = {client_name: batch_data(data) for client_name, data in clients.items()}


# Convert data to PyTorch tensors (need to do it after the clients batching because of typing errors)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Process and batch the test set
#test_batched = DataLoader(TensorDataset(X_test, y_test), batch_size=len(y_test), shuffle=False)


# Define number of communication rounds
comms_round = 3

# Initialize global model
global_model = SimpleMLP(784, 10)
global_model_weights = list(global_model.parameters())

# Start global training loop
print("Federated Model Results:")
for comm_round in range(1,comms_round+1):

    # Initial list to collect local model weights after scaling
    scaled_local_weight_list = []

    # Randomize client data - using keys
    client_names = list(clients_batched.keys())
    random.shuffle(client_names)

    # Loop through each client and create new local model
    for client in client_names:
        local_model = SimpleMLP(784, 10)
        local_model.load_state_dict(global_model.state_dict())

        # Create a new optimizer instance for each local model
        lr = 0.01
        optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=0.9)

        # Train local model
        local_model.train()
        for X_batch, y_batch in clients_batched[client]:
            optimizer.zero_grad()
            y_pred = local_model(X_batch)
            loss = nn.CrossEntropyLoss()(y_pred, torch.argmax(y_batch, dim=1))
            loss.backward()
            optimizer.step()

        # Scale the model weights and add to list
        scaling_factor = weight_scaling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(list(local_model.parameters()), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

    # To get the average over all the local models, we simply take the sum of the scaled weights
    average_weights = sum_scaled_weights(scaled_local_weight_list)

    # Update global model
    with torch.no_grad():
        for global_param, avg_weight in zip(global_model.parameters(), average_weights):
            global_param.data = avg_weight

    # Test global model and print out metrics after each communication round
    for X_test_batch, Y_test_batch in test_batched:
        global_acc, global_loss = test_model(X_test_batch, Y_test_batch, global_model, comm_round)




# Prepare the SGD dataset
SGD_dataset = DataLoader(TensorDataset(X_train, y_train), shuffle=True, batch_size=320)

# Initialize the SGD model
SGD_model = SimpleMLP(784, 10)
lr = 0.01
optimizer = optim.SGD(SGD_model.parameters(), lr=lr, momentum=0.9)

# Train the SGD model
SGD_model.train()
for epoch in range(3):
    for X_batch, y_batch in SGD_dataset:
        optimizer.zero_grad()
        y_pred = SGD_model(X_batch)
        loss = nn.CrossEntropyLoss()(y_pred, torch.argmax(y_batch, dim=1))
        loss.backward()
        optimizer.step()

# Test the SGD global model and print out metrics
SGD_model.eval()
with torch.no_grad():
    print("Centralized Model Results:")
    for X_test_batch, Y_test_batch in test_batched:
        SGD_acc, SGD_loss = test_model(X_test_batch, Y_test_batch, SGD_model, 1)

checkpoint directory created: ./model
saving model version 0.0


 train_loss: 8.14e+01 | reg: 3.23e+05 | train_acc: 9.01e-01 | test_acc: 9.01e-01 |: 100%|█| 3/3 [03:

saving model version 0.1





'\n# Prepare the SGD dataset\nSGD_dataset = DataLoader(TensorDataset(X_train, y_train), shuffle=True, batch_size=320)\n\n# Initialize the SGD model\nSGD_model = SimpleMLP(784, 10)\nlr = 0.01\noptimizer = optim.SGD(SGD_model.parameters(), lr=lr, momentum=0.9)\n\n# Train the SGD model\nSGD_model.train()\nfor epoch in range(3):\n    for X_batch, y_batch in SGD_dataset:\n        optimizer.zero_grad()\n        y_pred = SGD_model(X_batch)\n        loss = nn.CrossEntropyLoss()(y_pred, torch.argmax(y_batch, dim=1))\n        loss.backward()\n        optimizer.step()\n\n# Test the SGD global model and print out metrics\nSGD_model.eval()\nwith torch.no_grad():\n    print("Centralized Model Results:")\n    for X_test_batch, Y_test_batch in test_batched:\n        SGD_acc, SGD_loss = test_model(X_test_batch, Y_test_batch, SGD_model, 1)\n'

# Toy Example com PyKAN

In [None]:
# Load MNIST dataset
# Temporarily using dataset from Keras, to mantain uniformity with the baseline
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
X_train = X_train.reshape(-1, 28 * 28) / 255.0
X_test = X_test.reshape(-1, 28 * 28) / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Convert data to PyTorch tensors (need to do it after the clients batching because of typing errors)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

dataset = {}
dataset['train_input'] = X_train
dataset['test_input'] = X_test
dataset['train_label'] = y_train
dataset['test_label'] = y_test

n_features = X_train.shape[1]
n_classes = y_train.shape[1]

model = KAN(width=[n_features, 3, n_classes], grid=3, k=3, seed=42)

def train_acc():
    return torch.mean((torch.argmax((model(dataset['train_input'])), dim=1) == torch.argmax(dataset['train_label'])).type(torch.float32))

def test_acc():
    return torch.mean((torch.argmax((model(dataset['test_input'])), dim=1) == torch.argmax(dataset['test_label'])).type(torch.float32))

model.fit(dataset, steps=3, lamb=0.005, batch=1024, loss_fn = nn.CrossEntropyLoss(), metrics=[train_acc, test_acc], display_metrics=['train_loss', 'reg', 'train_acc', 'test_acc'])


checkpoint directory created: ./model
saving model version 0.0


 train_loss: 8.14e+01 | reg: 3.23e+05 | train_acc: 0.00e+00 | test_acc: 4.00e-04 |: 100%|█| 3/3 [03:

saving model version 0.1





{'train_loss': [array(1.5174258, dtype=float32),
  array(1.5174273, dtype=float32),
  array(81.4046, dtype=float32)],
 'test_loss': [array(79.32478, dtype=float32),
  array(1.5166861, dtype=float32),
  array(87.555984, dtype=float32)],
 'reg': [array(2597.382, dtype=float32),
  array(291.35287, dtype=float32),
  array(322988.97, dtype=float32)],
 'train_acc': [0.000733333348762244, 0.0, 0.0],
 'test_acc': [0.00039999998989515007, 0.0, 0.00039999998989515007]}