In [50]:
import sys
import numpy as np
import random
import pickle
import os

import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

import models

sys.path.append("../../../FedLab/")

from fedlab.utils.dataset import functional as dataF
from fedlab.utils import SerializationTool

In [46]:
from fedlab.utils.dataset import CIFAR100Partitioner

In [6]:
x = {'x':2, 'y': 1}
for key in sorted(x.keys()):
    print(key)

x
y


42.66666666666667

In [11]:
trainset = CIFAR10(root="../../../data/CIFAR10/", train=True, 
                   download=True, 
                   transform=transforms.ToTensor())

Files already downloaded and verified


In [12]:
train_loader = DataLoader(trainset,
            batch_size=16)

In [13]:
for imgs, targets in train_loader:
    tmp1 = targets.reshape(-1).long()
    tmp2 = list(targets.size())[0]
    break

In [20]:
targets.reshape(-1).long()

tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9])

In [51]:
def get_mdl_params(model_list, n_par=None):
    if n_par is None:
        exp_mdl = model_list[0]
        n_par = 0
        for name, param in exp_mdl.named_parameters():
            n_par += len(param.data.reshape(-1))

    param_mat = np.zeros((len(model_list), n_par)).astype('float32')
    for i, mdl in enumerate(model_list):
        idx = 0
        for name, param in mdl.named_parameters():
            temp = param.data.cpu().numpy().reshape(-1)
            param_mat[i, idx:idx + len(temp)] = temp
            idx += len(temp)
    return np.copy(param_mat)

In [52]:
model = models.Cifar10Net('CIFAR10Net')
local_par_list = None
for param in model.parameters():
    if not isinstance(local_par_list, torch.Tensor):
        # Initially nothing to concatenate
        local_par_list = param.reshape(-1)
    else:
        local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)

In [55]:
n_clnt = 100
n_par = len(get_mdl_params([model])[0])

local_param_list = np.zeros((n_clnt, n_par)).astype('float32')
init_par_list = get_mdl_params([model], n_par)[0]
clnt_params_list = np.ones(n_clnt).astype('float32').reshape(-1, 1) * init_par_list.reshape(1,
                                                                                            -1)  # n_clnt X n_par
clnt_models = list(range(n_clnt))


In [58]:
local_param_list.shape

(100, 797962)

In [60]:
init_par_list.shape

(797962,)

In [62]:
clnt_params_list.shape

(100, 797962)

In [64]:
np.ones(n_clnt).astype('float32').reshape(-1, 1).shape

(100, 1)

In [65]:
local_param_list_curr = torch.tensor(local_param_list[0]) 

In [67]:
local_param_list_curr.requires_grad

False

In [72]:
file_pattern = "client_{cid:03d}_local_params"
file_pattern.format(cid=3)

'client_003_local_params'

In [73]:
class Coordinator(object):
    """Deal with the mapping relation between client id in FL system and process rank in communication.
    Note
        Server Manager creates a Coordinator following:
        1. init network connection.
        2. client send local group info (the number of client simulating in local) to server.
        4. server receive all info and init a server Coordinator.
    Args:
        setup_dict (dict): A dict like {rank:client_num ...}, representing the map relation between process rank and client id.
        mode (str, optional): “GLOBAL” and "LOCAL". Coordinator will map client id to (rank, global id) or (rank, local id) according to mode. For example, client id 51 is in a machine which has 1 manager and serial trainer simulating 10 clients. LOCAL id means the index of its 10 clients. Therefore, global id 51 will be mapped into local id 1 (depending on setting).
    """
    def __init__(self, setup_dict, mode='LOCAL') -> None:
        self.map = setup_dict
        self.mode = mode

    def map_id(self, id):
        """a map function from client id to (rank,local id)
        
        Args:
            id (int): client id
        Returns:
            rank, id : rank in distributed group and local id.
        """
        m_id = id
        for rank, num in self.map.items():
            if m_id >= num:
                m_id -= num
            else:
                local_id = m_id
                global_id = id
                ret_id = local_id if self.mode == 'LOCAL' else global_id
                return rank, ret_id

    def map_id_list(self, id_list):
        """a map function from id_list to dict{rank:local id}
            This can be very useful in Scale modules.
        Args:
            id_list (list(int)): a list of client id.
        Returns:
            map_dict (dict): contains process rank and its relative local client ids.
        """
        map_dict = {}
        for id in id_list:
            rank, id = self.map_id(id)
            if rank in map_dict.keys():
                map_dict[rank].append(id)
            else:
                map_dict[rank] = [id]
        return map_dict

    def switch(self):
        if self.mode == 'GLOBAL':
            self.mode = 'LOCAL'
        elif self.mode == 'LOCAL':
            self.mode = 'GLOBAL'
        else:
            raise ValueError("Invalid Map Mode {}".format(self.mode))

    @property
    def total(self):
        return int(sum(self.map.values()))

    def __str__(self) -> str:
        return "Coordinator map information: {} \nMap mode: {} \nTotal: {}".format(
            self.map, self.mode, self.total)

    def __call__(self, info, *args, **kwds):
        if isinstance(info, int):
            return self.map_id(info)
        if isinstance(info, list):
            return self.map_id_list(info)


In [75]:
# cid: 0 to num_clients-1
num_clients = 10
for rank in range(1, 11):
    print(rank)

1
2
3
4
5
6
7
8
9
10


In [78]:
rank_client_id_map = {i:10 for i in range(1,11)}
coordinator = Coordinator(rank_client_id_map)

In [82]:
id_list = [0, 1, 11, 20, 21, 22]
num_clients_per_rank = 10
res = coordinator.map_id_list(id_list)
print(res)

{1: [0, 1], 2: [1], 3: [0, 1, 2]}


In [81]:
def local_to_global(local_client_id, rank, num_clients_per_rank=10):
    return (rank - 1) * num_clients_per_rank + local_client_id

In [83]:
for rank in res:
    for local_client_id in res[rank]:
        print(local_to_global(local_client_id, rank))

0
1
11
20
21
22


In [92]:
params1 = SerializationTool.serialize_model(model)
params2 = SerializationTool.serialize_model(model)
params3 = params1.data
params4 = params1

In [88]:
print(params1)
print(params2)
print(params3)


tensor([-0.0535,  0.0145,  0.0168,  ..., -0.0213,  0.0006,  0.0269])
tensor([-0.0535,  0.0145,  0.0168,  ..., -0.0213,  0.0006,  0.0269])
tensor([-0.0535,  0.0145,  0.0168,  ..., -0.0213,  0.0006,  0.0269])


In [102]:
params1.device

device(type='cpu')

In [98]:
params1 is params2, params1 is params3, params1 is params4, params1.data is params3

(False, False, True, False)

In [97]:
print(type(params1), params1.requires_grad)
print(type(params2), params2.requires_grad)
print(type(params3), params3.requires_grad)
print(type(params4), params4.requires_grad)

<class 'torch.Tensor'> False
<class 'torch.Tensor'> False
<class 'torch.Tensor'> False
<class 'torch.Tensor'> False


In [100]:
zero_tensor = torch.zeros(params1.shape[0])

In [101]:
zero_tensor.requires_grad

False