In [1]:
import os
import time
import torch
import collections

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models import ShallowNN
from utils import load_file, get_all_possible_pairs
from evals import evaluate, pairwise_euclidean_distance , influence, layer_importance, layer_importance_bias, layerwise_full_accumulated_proximity
from evals import euclidean_distance, pairwise_euclidean_distance, accumulated_proximity

features = 197
batch_size = 64
loss_fn = torch.nn.L1Loss()

In [2]:
client_ids = ["0_0","0_1","0_2","0_3","0_4","0_5","1_0","1_1","1_2","1_3","1_4","1_5","2_0","2_1","2_2","2_3","2_4","2_5","3_0","3_1","3_2","3_3","3_4","3_5"]

In [3]:
#euclidean_distance(global_model.track_layers["layer_1"].weight.data.view(1, -1),global_model.track_layers["layer_1"].weight.data.view(1, -1))

In [4]:
dummy = torch.load("checkpt/isolated/batch64_client_0_0.pth")
critarians = [item for item in dummy]
state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in client_ids
    }

In [5]:
def layerwise_proximity(x:collections.OrderedDict,y:collections.OrderedDict, critarian:str , distance_matrix):
    if critarian.split(".")[-1] == "bias":
        proximity = accumulated_proximity(x[critarian].view(1, -1),y[critarian].view(1, -1),distance_matrix)
    else:
        proximity = accumulated_proximity(x[critarian],y[critarian],distance_matrix)
    
    return proximity

In [7]:
def layerwise_full_accumulated_proximity(
    clients: list, criterian: str, distance_matrix) -> tuple:
    """
    Calculate the layer-wise full accumulated proximity between all clients.

    Parameters:
    -------------
    clients: list; list of clients
    criterian: str; criterian to be evaluated, basically the layer name.
    distance_matrix: Callable; distance matrix

    Returns:
    -------------
    total_weight_proximity: float; total weight proximity
    total_bias_proximity: float; total bias proximity
    """
    state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in clients
    }

    #possible_pairs = get_all_possible_pairs(clients)

    total_proximity = 0.0

    for l in clients:
        for i in clients:
            prox = layerwise_proximity(state_dicts[l],state_dicts[i],criterian, distance_matrix)
            total_proximity += prox

    return total_proximity

In [8]:
full_prox_dict = {}
for item in critarians:
    full_prox_dict[item] = layerwise_full_accumulated_proximity(client_ids, item ,euclidean_distance)

In [9]:
full_prox_dict

{'layer_1.weight': tensor(55107.3438),
 'layer_1.bias': tensor(588.5198),
 'layer_2.weight': tensor(25738.9277),
 'layer_2.bias': tensor(395.0583),
 'layer_3.weight': tensor(10944.6992),
 'layer_3.bias': tensor(23.8015)}

### Layer_1 Weights

In [10]:
eccentricities = {}
for c in critarians:
    client_ecc = {}
    for client in client_ids:
        client_matrix = torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth")
        acc_proximity = 0.0
        for key in state_dicts:
            distance = layerwise_proximity(client_matrix,state_dicts[key],c, euclidean_distance)
            acc_proximity += distance
        eccentricity = 2*acc_proximity/full_prox_dict[c]
        client_ecc[client] = eccentricity.item()
    eccentricities[c] = client_ecc
    #print(client,acc_proximity, 2*acc_proximity/layer1_fullweight_prox)

In [11]:
lrp_eccetricities = pd.DataFrame.from_dict(eccentricities)
lrp_eccetricities.head(5)

Unnamed: 0,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias
0_0,0.08401,0.07593,0.083319,0.083742,0.085984,0.043081
0_1,0.085182,0.074758,0.083903,0.08209,0.08841,0.054493
0_2,0.078688,0.070743,0.080501,0.074649,0.07779,0.100349
0_3,0.079251,0.082931,0.081003,0.069608,0.07864,0.110012
0_4,0.075787,0.064629,0.079339,0.079195,0.07504,0.054656


In [14]:
layer_importance_scores = {}
  
for client in client_ids:
    iso_model = ShallowNN(features)
    iso_model.load_state_dict(torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth"))
    validation_data = torch.load("trainpt/" + str(client) + ".pt")
    validation_data_loader = DataLoader(validation_data, batch_size, shuffle=True)
    iso_layer_importance = layer_importance_bias(iso_model, loss_fn, validation_data_loader)
    
    layer_importance_scores[client] = iso_layer_importance
layer_importance = pd.DataFrame.from_dict(layer_importance_scores).T
layer_importance.head(5)

Unnamed: 0,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias
0_0,84.556248,0.743609,13.919749,0.376496,0.361039,0.04286
0_1,80.643772,0.85043,17.243428,0.596817,0.580415,0.085138
0_2,83.195202,0.590558,15.128975,0.487478,0.531331,0.066456
0_3,85.542781,0.613843,13.052207,0.351687,0.394934,0.044548
0_4,84.022049,0.575856,14.383083,0.464949,0.489019,0.065044


In [16]:
for item in critarians:
    print(item, round(lrp_eccetricities[item].sum(),4))
    print("layer_im",layer_importance[item].mean())

layer_1.weight 2.0
layer_im 81.5915705676305
layer_1.bias 2.0
layer_im 0.7569058531910482
layer_2.weight 2.0
layer_im 16.584087878674254
layer_2.bias 2.0
layer_im 0.5440387869220099
layer_3.weight 2.0
layer_im 0.46381884583326927
layer_3.bias 2.0
layer_im 0.05957806774891431


In [17]:
layer_1_weight_ls = 81.60
layer_1_bias_ls = 0.76
layer_2_weight_ls = 16.57
layer_2_bias_ls = 0.54
layer_3_weight_ls = 0.47
layer_3_bias_ls = 0.06

In [18]:
averages = [] 
weighted_avg = []
for index, client in lrp_eccetricities.iterrows():
    avg = (client['layer_1.weight'] + client['layer_1.bias'] + client['layer_2.weight'] + client['layer_2.bias'] + client['layer_3.weight'] + client['layer_3.bias'])/6
    weighted = (client['layer_1.weight']*layer_1_weight_ls + client['layer_1.bias']*layer_1_bias_ls + 
                client['layer_2.weight']*layer_2_weight_ls + client['layer_2.bias']*layer_2_bias_ls + 
                client['layer_3.weight']*layer_3_weight_ls + client['layer_3.bias']*layer_3_bias_ls)/(100)
    averages.append(round(avg,4))
    weighted_avg.append(round(weighted,4))
lrp_eccetricities["average"] = averages
lrp_eccetricities["weighted_avg"] = weighted_avg

In [29]:
lrp_eccetricities["weighted_avg"]

0_0    0.0838
0_1    0.0849
0_2    0.0789
0_3    0.0795
0_4    0.0763
0_5    0.0855
1_0    0.0859
1_1    0.0848
1_2    0.0865
1_3    0.0837
1_4    0.0875
1_5    0.0923
2_0    0.0772
2_1    0.0783
2_2    0.0788
2_3    0.0789
2_4    0.0780
2_5    0.0781
3_0    0.0873
3_1    0.0873
3_2    0.0851
3_3    0.0868
3_4    0.0867
3_5    0.0879
Name: weighted_avg, dtype: float64

In [34]:
0.0837/2

0.04185

In [28]:
0.08333333333333333/2

0.041666666666666664