<a href="https://colab.research.google.com/github/vs-152/FL-Contributions-Incentives-Project/blob/main/ISO_MNIST_lambda_MR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np

import copy
import time
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from itertools import chain, combinations
from tqdm import tqdm
from scipy.special import comb
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
def noisify_MNIST(noise_rate, noise_type, x, y, perm=[], **kwargs):
    '''Returns a symmetrically noisy dataset
    or a an asymmetrically noisy dataset with permutation matrix perm.
    '''
    if (noise_rate == 0.):
        return y, []
    if 'seed' in kwargs:
        _, noise_idx = next(
            iter(StratifiedShuffleSplit(
                n_splits=1,
                test_size=noise_rate,
                random_state=kwargs['seed']).split(x, y)))
    else:
        _, noise_idx = next(iter(StratifiedShuffleSplit(
            n_splits=1, test_size=noise_rate).split(x, y)))
    y_noisy = y.copy()
    if (noise_type == 'symmetric'):
        for i in noise_idx:
            t1 = np.arange(10)
            t2 = np.delete(t1, y[i])
            y_noisy[i] = np.random.choice(t2, 1)
    elif (noise_type == 'asymmetric'):
        pure_noise = perm[y]
        for i in noise_idx:
            if (perm[y[i]] == y[i]):
                noise_idx = np.delete(noise_idx, np.where(noise_idx == i))
            else:
                y_noisy[i] = pure_noise[i]

    return y_noisy, noise_idx

def mnist_iid(dataset, num_users, SEED):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    np.random.seed(SEED)
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])

    return dict_users

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out, SEED):
        torch.manual_seed(SEED)
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)

        return self.softmax(x)

def average_weights(w, fraction):  # this can also be used to average gradients
    """
    :param w: list of weights generated from the users
    :param fraction: list of fraction of data from the users
    :Returns the weighted average of the weights.
    """
    w_avg = copy.deepcopy(w[0]) #copy the weights from the first user in the list 
    for key in w_avg.keys():
        w_avg[key] *= (fraction[0]/sum(fraction))
        for i in range(1, len(w)):
            w_avg[key] += w[i][key] * (fraction[i]/sum(fraction))

    return w_avg

def calculate_gradients(new_weights, old_weights):
    """
    :param new_weights: list of weights generated from the users
    :param old_weights: old weights of a model, probably before training
    :Returns the list of gradients.
    """
    gradients = []
    for i in range(len(new_weights)):
        gradients.append(copy.deepcopy(new_weights[i]))
        for key in gradients[i].keys():
            gradients[i][key] -= old_weights[key]

    return gradients

def update_weights_from_gradients(gradients, old_weights):
    """
    :param gradients: gradients
    :param old_weights: old weights of a model, probably before training
    :Returns the updated weights calculated by: old_weights+gradients.
    """
    updated_weights = copy.deepcopy(old_weights)
    for key in updated_weights.keys():
        updated_weights[key] = old_weights[key] + gradients[key]

    return updated_weights
    


def powersettool(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def shapley(utility, N):

    shapley_dict = {}
    for i in range(1, N+1):
        shapley_dict[i] = 0
    for key in utility:
        if key != ():
            for contributor in key:
                # print('contributor:', contributor, key) # print check
                marginal_contribution = utility[key] - utility[tuple(i for i in key if i!=contributor)]
                # print('marginal:', marginal_contribution) # print check
                shapley_dict[contributor] += marginal_contribution /((comb(N-1,len(key)-1))*N)

    return shapley_dict

In [None]:
trainset = MNIST(root='./data', train=True, download=True)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
x_train = trainset.data.numpy().astype("float32") / 255.
y_train = trainset.targets.numpy()

In [None]:
class LocalUpdate(object):

    def __init__(self, lr, local_ep, trainloader):
        self.lr = lr
        self.local_ep = local_ep
        self.trainloader = trainloader

    def update_weights(self, model):

        model.train()
        epoch_loss = []
        optimizer = torch.optim.SGD(model.parameters(), lr=self.lr, momentum=0.5)
        criterion = nn.NLLLoss().to(device)
        for iter in range(self.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(device), labels.to(device)
                model.zero_grad()   
                log_probs = model(images)
                loss = criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)

def test_inference(model, test_dataset):

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    for _, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)
    accuracy = correct / total

    return accuracy, loss

In [None]:
N = 8 #srch
local_bs = 64
lr = 0.01
local_ep = 10
EPOCHS = 5

np.random.seed(SEED)
torch.manual_seed(SEED)


noise_rates = np.linspace(0, 1, N, endpoint=False)
split_dset = mnist_iid(trainset, N, SEED)
user_groups = {i: 0 for i in range(1, N+1)}
noise_idx = {i: 0 for i in range(1, N+1)}
train_datasets = {i: 0 for i in range(1, N+1)}
for n in range(N):
    user_groups[n+1] = np.array(list(split_dset[n]), dtype=np.int)
    user_train_x, user_train_y = x_train[user_groups[n+1]], y_train[user_groups[n+1]]
    user_noisy_y, noise_idx[n+1] = noisify_MNIST(noise_rates[n], 'symmetric', user_train_x, user_train_y, seed=SEED)
    train_datasets[n+1] = TensorDataset(torch.Tensor(user_train_x),
                                        torch.as_tensor(user_noisy_y, dtype=torch.long))


global_model = MLP(dim_in=784, dim_hidden=64, dim_out=10, SEED=SEED)    
global_model.to(device)
global_model.train()
#print(global_model)
global_weights = global_model.state_dict()
powerset = list(powersettool(range(1, N+1)))
submodel_dict = {}  
submodel_dict[()] = copy.deepcopy(global_model)
accuracy_dict = {}
shapley_dict = {}

In [None]:
start_time = time.time()

np.random.seed(SEED)
torch.manual_seed(SEED)

for subset in powerset[:-1]: #exclude only the global set. the null set still has a random initialized model
    submodel_dict[subset] = copy.deepcopy(global_model)
    submodel_dict[subset].to(device)
    submodel_dict[subset].train()

train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
print_every = 2

idxs_users = np.arange(1, N+1)
total_data = sum(len(user_groups[i]) for i in range(1, N+1))
fraction = [len(user_groups[i])/total_data for i in range(1, N+1)]

for epoch in tqdm(range(EPOCHS)):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n')
    global_model.train()
    for idx in idxs_users:
        trainloader = DataLoader(train_datasets[idx], batch_size=local_bs, shuffle=True, worker_init_fn=seed_worker)
        local_model = LocalUpdate(lr, local_ep, trainloader)
        w, loss = local_model.update_weights(model=copy.deepcopy(global_model))
        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))
        
    global_weights = average_weights(local_weights, fraction) # global_new
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    gradients = calculate_gradients(local_weights, global_model.state_dict())

    for subset in powerset[1: -1]: 
        subset_gradient = average_weights([gradients[i-1] for i in subset], [fraction[i-1] for i in subset])
        subset_weights = update_weights_from_gradients(subset_gradient, submodel_dict[subset].state_dict())
        submodel_dict[subset].load_state_dict(subset_weights)

    global_model.load_state_dict(global_weights)
    global_model.eval()

    if (epoch+1) % print_every == 0:
        print(f' \nAvg Training Stats after {epoch+1} global rounds:')
        print(f'Training Loss : {np.mean(np.array(train_loss))}')
        # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    accuracy_dict[powerset[-1]] = test_inference(global_model, test_dataset)[0]

        # Test inference for the sub-models in submodel_dict
    for subset in powerset[:-1]: 
        test_acc, test_loss = test_inference(submodel_dict[subset], test_dataset)
        print(f' \n Results after {epoch} global rounds of training:')
        print("|---- Test Accuracy for {}: {:.2f}%".format(subset, 100*test_acc))
            
        accuracy_dict[subset] = test_acc

    shapley_dict_add = shapley(accuracy_dict, N)
    for key in shapley_dict_add:
        if shapley_dict.get(key):
            shapley_dict[key].append(shapley_dict_add[key])
        else:
            shapley_dict[key] = [shapley_dict_add[key]]
test_acc, test_loss = test_inference(global_model, test_dataset)
print(f' \n Results after {EPOCHS} global rounds of training:')
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

accuracy_dict[powerset[-1]] = test_acc


trainTime = time.time() - start_time
print('\n Total Time: {0:0.4f}'.format(trainTime))

  0%|          | 0/5 [00:00<?, ?it/s]


 | Global Training Round : 1 |

 
 Results after 0 global rounds of training:
|---- Test Accuracy for (): 10.88%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (1,): 89.55%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (2,): 88.44%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (3,): 87.55%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (4,): 86.07%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (5,): 84.74%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (6,): 81.48%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (7,): 73.93%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (8,): 19.77%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (1, 2): 89.17%
 
 Results after 0 global rounds of training:
|---- Test Accuracy for (1, 3): 88.89%
 
 Results after 0 global rounds of training:
|---- 

 20%|██        | 1/5 [03:17<13:08, 197.12s/it]

 
 Results after 0 global rounds of training:
|---- Test Accuracy for (2, 3, 4, 5, 6, 7, 8): 86.62%

 | Global Training Round : 2 |

 
Avg Training Stats after 2 global rounds:
Training Loss : 1.8050845969247362
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (): 10.88%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (1,): 90.32%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (2,): 89.64%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (3,): 88.78%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (4,): 87.04%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (5,): 85.08%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (6,): 75.03%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (7,): 38.64%
 
 Results after 1 global rounds of training:
|---- Test Accuracy for (8,): 9.29%
 
 Results after 1 global rounds of training

 40%|████      | 2/5 [06:35<09:52, 197.53s/it]

 
 Results after 1 global rounds of training:
|---- Test Accuracy for (2, 3, 4, 5, 6, 7, 8): 88.83%

 | Global Training Round : 3 |

 
 Results after 2 global rounds of training:
|---- Test Accuracy for (): 10.88%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (1,): 90.53%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (2,): 90.00%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (3,): 89.18%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (4,): 85.78%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (5,): 80.95%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (6,): 51.03%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (7,): 11.12%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (8,): 9.40%
 
 Results after 2 global rounds of training:
|---- Test Accuracy for (1, 2): 90.67%
 
 Results after 2 global rounds of tr

 60%|██████    | 3/5 [09:54<06:35, 197.81s/it]

 
 Results after 2 global rounds of training:
|---- Test Accuracy for (2, 3, 4, 5, 6, 7, 8): 89.04%

 | Global Training Round : 4 |

 
Avg Training Stats after 4 global rounds:
Training Loss : 1.7487987597397985
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (): 10.88%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (1,): 90.54%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (2,): 89.74%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (3,): 88.81%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (4,): 82.64%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (5,): 68.82%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (6,): 23.02%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (7,): 11.63%
 
 Results after 3 global rounds of training:
|---- Test Accuracy for (8,): 9.80%
 
 Results after 3 global rounds of training

 80%|████████  | 4/5 [13:12<03:17, 197.88s/it]

 
 Results after 3 global rounds of training:
|---- Test Accuracy for (2, 3, 4, 5, 6, 7, 8): 87.95%

 | Global Training Round : 5 |

 
 Results after 4 global rounds of training:
|---- Test Accuracy for (): 10.88%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (1,): 90.40%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (2,): 89.13%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (3,): 87.50%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (4,): 76.58%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (5,): 51.13%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (6,): 15.34%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (7,): 10.63%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (8,): 7.70%
 
 Results after 4 global rounds of training:
|---- Test Accuracy for (1, 2): 90.75%
 
 Results after 4 global rounds of tr

100%|██████████| 5/5 [16:30<00:00, 198.08s/it]

 
 Results after 4 global rounds of training:
|---- Test Accuracy for (2, 3, 4, 5, 6, 7, 8): 84.28%





 
 Results after 5 global rounds of training:
|---- Test Accuracy: 91.58%

 Total Time: 991.3132


In [None]:
fedshap = {i: 0 for i in range(1, N+1)}
lambdamr = {i: 0 for i in range(1, N+1)}
decay = [0.8**t for t in range(EPOCHS)]

epoch_sums = []
for i in range(EPOCHS):
    epoch_sums.append(sum([v[i] for k,v in shapley_dict.items()]))

for key, values in shapley_dict.items():

    for i, v in enumerate(values):
        lambdamr[key] += v * decay[i]/epoch_sums[i]
    fedshap[key] = sum(values)

In [None]:
lambdamr

{1: 0.8458860742786513,
 2: 0.7416153345058814,
 3: 0.6516552212528229,
 4: 0.5598811154560299,
 5: 0.45381754437860766,
 6: 0.258647098544022,
 7: 0.07863703422999846,
 8: -0.228539422646013}

In [None]:
fedshap

{1: 1.1130475000000002,
 2: 0.9492298809523806,
 3: 0.8099644047619049,
 4: 0.6731346428571428,
 5: 0.5127720238095238,
 6: 0.2336884523809523,
 7: 0.015979404761904722,
 8: -0.3473163095238096}