In [1]:
import os
import time
import torch
import collections

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models import ShallowNN
from utils import load_file, get_all_possible_pairs
from evals import evaluate, pairwise_euclidean_distance , influence, layer_importance, layer_importance_bias, layerwise_full_accumulated_proximity
from evals import euclidean_distance, pairwise_euclidean_distance, accumulated_proximity

features = 197
batch_size = 64
loss_fn = torch.nn.L1Loss()

In [2]:
client_ids = ["0_0","0_1","0_2","0_3","0_4","0_5","1_0","1_1","1_2","1_3","1_4","1_5","2_0","2_1","2_2","2_3","2_4","2_5","3_0","3_1","3_2","3_3","3_4","3_5"]

In [15]:
dummy = torch.load("checkpt/isolated/batch64_client_0_0.pth")
critarians = [item for item in dummy]
state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in client_ids
    }

In [4]:
def layerwise_proximity(x:collections.OrderedDict,y:collections.OrderedDict, critarian:str , distance_matrix):
    if critarian.split(".")[-1] == "bias":
        proximity = accumulated_proximity(x[critarian].view(1, -1),y[critarian].view(1, -1),distance_matrix)
    else:
        proximity = accumulated_proximity(x[critarian],y[critarian],distance_matrix)
    
    return proximity

In [16]:
def layerwise_full_accumulated_proximity(
    clients: list, criterian: str, distance_matrix) -> tuple:
    """
    Calculate the layer-wise full accumulated proximity between all clients.

    Parameters:
    -------------
    clients: list; list of clients
    criterian: str; criterian to be evaluated, basically the layer name.
    distance_matrix: Callable; distance matrix

    Returns:
    -------------
    total_weight_proximity: float; total weight proximity
    total_bias_proximity: float; total bias proximity
    """
    state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in clients
    }

    #possible_pairs = get_all_possible_pairs(clients)

    total_proximity = 0.0

    for l in clients:
        for i in clients:
            prox = layerwise_proximity(state_dicts[l],state_dicts[i],criterian, distance_matrix)
            total_proximity += prox

    return total_proximity

In [17]:
full_prox_dict = {}
for item in critarians:
    full_prox_dict[item] = layerwise_full_accumulated_proximity(client_ids, item ,euclidean_distance)

In [18]:
eccentricities = {}
for c in critarians:
    client_ecc = {}
    for client in client_ids:
        client_matrix = torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth")
        acc_proximity = 0.0
        for key in state_dicts:
            distance = layerwise_proximity(client_matrix,state_dicts[key],c, euclidean_distance)
            acc_proximity += distance
        eccentricity = 2*acc_proximity/full_prox_dict[c]
        client_ecc[client] = eccentricity.item()
    eccentricities[c] = client_ecc
    #print(client,acc_proximity, 2*acc_proximity/layer1_fullweight_prox)

In [19]:
lrp_eccetricities = pd.DataFrame.from_dict(eccentricities)
lrp_eccetricities.head(5)

Unnamed: 0,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias
0_0,0.083633,0.079671,0.081799,0.083995,0.084671,0.059368
0_1,0.084282,0.078892,0.08198,0.082882,0.085825,0.06547
0_2,0.080969,0.076763,0.081529,0.079244,0.080579,0.095041
0_3,0.081178,0.083053,0.082016,0.076486,0.080884,0.100829
0_4,0.079436,0.073525,0.081037,0.081065,0.079129,0.065578


In [20]:
layer_importance_scores = {}
  
for client in client_ids:
    iso_model = ShallowNN(features)
    iso_model.load_state_dict(torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth"))
    validation_data = torch.load("trainpt/" + str(client) + ".pt")
    validation_data_loader = DataLoader(validation_data, batch_size, shuffle=True)
    iso_layer_importance = layer_importance_bias(iso_model, loss_fn, validation_data_loader)
    
    layer_importance_scores[client] = iso_layer_importance
layer_importance = pd.DataFrame.from_dict(layer_importance_scores).T
layer_importance.head(5)

Unnamed: 0,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias
0_0,84.786891,0.735403,13.68657,0.387663,0.362551,0.040923
0_1,80.435967,0.867103,17.431104,0.598819,0.581139,0.085869
0_2,83.258224,0.594982,15.057051,0.490581,0.530299,0.068863
0_3,85.642939,0.614633,12.952055,0.352219,0.395308,0.042847
0_4,83.370004,0.593724,15.02436,0.460909,0.485463,0.065541


In [21]:
for item in critarians:
    print(item, round(lrp_eccetricities[item].sum(),4))
    print("layer_im",layer_importance[item].mean())

layer_1.weight 2.0
layer_im 81.56235189417863
layer_1.bias 2.0
layer_im 0.757794185471357
layer_2.weight 2.0
layer_im 16.60970214029278
layer_2.bias 2.0
layer_im 0.5456082225863877
layer_3.weight 2.0
layer_im 0.4648452783645767
layer_3.bias 2.0
layer_im 0.0596982791062732


In [11]:
layer_1_weight_ls = 81.60
layer_1_bias_ls = 0.76
layer_2_weight_ls = 16.57
layer_2_bias_ls = 0.54
layer_3_weight_ls = 0.47
layer_3_bias_ls = 0.06

In [23]:
averages = [] 
weighted_avg = []
for index, client in lrp_eccetricities.iterrows():
    avg = (client['layer_1.weight'] + client['layer_1.bias'] + client['layer_2.weight'] + client['layer_2.bias'] + client['layer_3.weight'] + client['layer_3.bias'])/6
    weighted = (client['layer_1.weight']*layer_1_weight_ls + client['layer_1.bias']*layer_1_bias_ls + 
                client['layer_2.weight']*layer_2_weight_ls + client['layer_2.bias']*layer_2_bias_ls + 
                client['layer_3.weight']*layer_3_weight_ls + client['layer_3.bias']*layer_3_bias_ls)/(100)
    averages.append(round(avg,4))
    weighted_avg.append(round(weighted,4))
lrp_eccetricities["average"] = averages
lrp_eccetricities["weighted_avg"] = weighted_avg

In [24]:
lrp_eccetricities.to_csv("insights/eccentricity_with_lrp.csv" , index=False)

In [25]:
lrp_eccetricities

Unnamed: 0,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias,average,weighted_avg
0_0,0.083633,0.079671,0.081799,0.083995,0.084671,0.059368,0.0789,0.0833
0_1,0.084282,0.078892,0.08198,0.082882,0.085825,0.06547,0.0799,0.0838
0_2,0.080969,0.076763,0.081529,0.079244,0.080579,0.095041,0.0824,0.081
0_3,0.081178,0.083053,0.082016,0.076486,0.080884,0.100829,0.0841,0.0813
0_4,0.079436,0.073525,0.081037,0.081065,0.079129,0.065578,0.0766,0.0797
0_5,0.08452,0.088152,0.082386,0.100017,0.082506,0.086066,0.0873,0.0843
1_0,0.084846,0.08574,0.082735,0.087511,0.085409,0.061738,0.0813,0.0845
1_1,0.084173,0.078355,0.083445,0.085079,0.084149,0.071878,0.0812,0.084
1_2,0.085419,0.092366,0.080363,0.086349,0.080833,0.065815,0.0819,0.0846
1_3,0.083477,0.082536,0.083019,0.074219,0.083659,0.092411,0.0832,0.0833
