In [13]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

Using cpu


In [14]:
from flex.datasets import load
from torchvision import transforms

flex_dataset, test_data = load("federated_emnist", return_test=True, split="digits")

mnist_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

[36m[sultan]: md5 -q ./emnist-digits.mat;[0m
[01;31m[sultan]: Unable to run 'md5 -q ./emnist-digits.mat;'[0m
[01;31m[sultan]: --{ TRACEBACK }----------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | NoneType: None[0m
[01;31m[sultan]: | [0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: --{ STDERR }-------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | /bin/sh: 1: md5: not found[0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[33m[sultan]: The following are additional information that can be used to debug this exception.[0m
[33m[sultan]: The following is the context used to run:[0m
[33m[sultan]: 	 - cwd: None[0m
[33m[sultan]: 	 - sudo:

In [15]:
import torch.nn as nn
import torch.nn.functional as F

from flex.pool import init_server_model
from flex.model import FlexModel

from flexBlock.pool import PoFLBlockchainPool


class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


@init_server_model
def build_server_model():
    server_flex_model = FlexModel()

    server_flex_model["model"] = SimpleNet()
    # Required to store this for later stages of the FL training process
    server_flex_model["criterion"] = torch.nn.CrossEntropyLoss()
    server_flex_model["optimizer_func"] = torch.optim.Adam
    server_flex_model["optimizer_kwargs"] = {}
    return server_flex_model

pool = PoFLBlockchainPool(fed_dataset=flex_dataset, init_func=build_server_model, number_of_miners=10)

clients = pool.clients
servers = pool.servers
aggregators = pool.aggregators

print(
    f"Number of nodes in the pool {len(pool)}: {len(servers)} miners plus {len(clients)} clients. The server is also an aggregator"
)

Number of nodes in the pool 3589: 10 miners plus 3579 clients. The server is also an aggregator


In [16]:
import copy

from flexBlock.pool import deploy_miner_model

@deploy_miner_model
def copy_server_model_to_clients(server_flex_model: FlexModel):
    return copy.deepcopy(server_flex_model)


servers.map(copy_server_model_to_clients, clients)

In [17]:
from flex.data import Dataset
from torch.utils.data import DataLoader


def train(client_flex_model: FlexModel, client_data: Dataset):
    train_dataset = client_data.to_torchvision_dataset(transform=mnist_transforms)
    client_dataloader = DataLoader(train_dataset, batch_size=20)
    model = client_flex_model["model"]
    optimizer = client_flex_model["optimizer_func"](
        model.parameters(), **client_flex_model["optimizer_kwargs"]
    )
    model = model.train()
    model = model.to(device)
    criterion = client_flex_model["criterion"]
    for _ in range(1):
        for imgs, labels in client_dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            pred = model(imgs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()

In [18]:
from flexBlock.pool import send_weights_to_miner

@send_weights_to_miner
def get_clients_weights(client_flex_model: FlexModel):
    weight_dict = client_flex_model["model"].state_dict()
    return [weight_dict[name] for name in weight_dict]

In [19]:
from flex.pool import set_aggregated_weights


@set_aggregated_weights
def set_agreggated_weights_to_server(server_flex_model: FlexModel, aggregated_weights):
    with torch.no_grad():
        weight_dict = server_flex_model["model"].state_dict()
        for layer_key, new in zip(weight_dict, aggregated_weights):
            weight_dict[layer_key].copy_(new)

In [20]:
def evaluate_global_model(server_flex_model: FlexModel, test_data: Dataset):
    model = server_flex_model["model"]
    model.eval()
    test_acc = 0
    total_count = 0
    model = model.to(device)
    # get test data as a torchvision object
    test_dataset = test_data.to_torchvision_dataset(transform=mnist_transforms)
    test_dataloader = DataLoader(
        test_dataset, batch_size=256, shuffle=True, pin_memory=False
    )
    with torch.no_grad():
        for data, target in test_dataloader:
            total_count += target.size(0)
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            test_acc += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()

    test_acc /= total_count
    return test_acc

In [21]:
from flex.pool import fed_avg

def train_n_rounds(n_rounds):
    for i in range(n_rounds):
        print(f"\nRunning round: {i+1} of {n_rounds}")
        # Deploy the server model to the selected clients
        pool.servers.map(copy_server_model_to_clients, clients)
        # Each selected client trains her model
        clients.map(train)
        # The aggregador collects weights from the selected clients and aggregates them
        pool.aggregators.map(get_clients_weights, clients)
        pool.aggregate(fed_avg)
        # The aggregator send its aggregated weights to the server
        pool.aggregators.map(set_agreggated_weights_to_server, pool.servers)
        metrics = pool.servers.map(evaluate_global_model)
        acc = metrics[0]
        print(f"Server: Test acc: {acc:.4f}")

In [22]:
train_n_rounds(1)


Running round: 1 of 1


KeyboardInterrupt: 