In [1]:
import os
import time
import torch
import collections

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models import ShallowNN
from utils import load_file, get_all_possible_pairs
from evals import evaluate, pairwise_euclidean_distance , influence, layer_importance, layer_importance_bias, layerwise_full_accumulated_proximity
from evals import euclidean_distance, pairwise_euclidean_distance, accumulated_proximity, layerwise_proximity

features = 197
batch_size = 64
loss_fn = torch.nn.L1Loss()

In [2]:
client_ids = ["0_0","0_1","0_2","0_3","0_4","0_5","1_0","1_1","1_2","1_3","1_4","1_5","2_0","2_1","2_2","2_3","2_4","2_5","3_0","3_1","3_2","3_3","3_4","3_5"]

In [3]:
#euclidean_distance(global_model.track_layers["layer_1"].weight.data.view(1, -1),global_model.track_layers["layer_1"].weight.data.view(1, -1))

In [4]:
global_dict = torch.load('checkpt/epoch_25/1_rounds_25_epochs_per_round/_fedl_global_1_25.pth')

critarians = [item for item in global_dict]
state_dicts = {
        key: torch.load("checkpt/epoch_25/isolated/batch64_client_" + str(key) + ".pth")
        for key in client_ids
    }

In [5]:
def layerwise_proximity(x:collections.OrderedDict,y:collections.OrderedDict, critarian:str , distance_matrix):
    if critarian.split(".")[-1] == "bias":
        proximity = accumulated_proximity(x[critarian].view(1, -1),y[critarian].view(1, -1),distance_matrix)
    else:
        proximity = accumulated_proximity(x[critarian],y[critarian],distance_matrix)
    
    return proximity

In [6]:
def layerwise_full_accumulated_proximity(
    clients: list, criterian: str, distance_matrix) -> tuple:
    """
    Calculate the layer-wise full accumulated proximity between all clients.

    Parameters:
    -------------
    clients: list; list of clients
    criterian: str; criterian to be evaluated, basically the layer name.
    distance_matrix: Callable; distance matrix

    Returns:
    -------------
    total_weight_proximity: float; total weight proximity
    total_bias_proximity: float; total bias proximity
    """
    state_dicts = {
        key: torch.load("checkpt/epoch_25/isolated/batch64_client_" + str(key) + ".pth")
        for key in clients
    }

    #possible_pairs = get_all_possible_pairs(clients)
    
    global_dict = torch.load('checkpt/epoch_25/1_rounds_25_epochs_per_round/_fedl_global_1_25.pth')

    total_proximity = 0.0

    for i in clients:
        prox = layerwise_proximity(global_dict,state_dicts[i],criterian, distance_matrix)
        total_proximity += prox

    return total_proximity

In [7]:
full_prox_dict = {}
for item in critarians:
    full_prox_dict[item] = (layerwise_full_accumulated_proximity(client_ids, item ,euclidean_distance))/2

In [8]:
eccentricities = {}
for c in critarians:
    client_ecc = {}
    for client in client_ids:
        acc_proximity = 0.0
        distance = layerwise_proximity(global_dict,state_dicts[client],c, euclidean_distance)
        eccentricity = distance/full_prox_dict[c]
        client_ecc[client] = eccentricity.item()
    eccentricities[c] = client_ecc

In [9]:
lrp_eccetricities = pd.DataFrame.from_dict(eccentricities)
lrp_eccetricities = lrp_eccetricities.reset_index()
lrp_eccetricities.rename(columns={"index":"client_id"}, inplace=True)
lrp_eccetricities.head().head(5)

Unnamed: 0,client_id,layer_1.weight,layer_1.bias,layer_2.weight,layer_2.bias,layer_3.weight,layer_3.bias
0,0_0,0.082396,0.08743,0.081371,0.085921,0.07698,0.007107
1,0_1,0.082933,0.083962,0.082223,0.089027,0.082549,0.068378
2,0_2,0.082773,0.080006,0.083404,0.088593,0.072245,0.006625
3,0_3,0.082982,0.080895,0.082137,0.071678,0.087922,0.048849
4,0_4,0.083531,0.093964,0.082629,0.092593,0.094799,0.000148


In [10]:
layer_importance_scores = {}
  
for client in client_ids:
    iso_model = ShallowNN(features)
    iso_model.load_state_dict(torch.load("checkpt/epoch_25/isolated/batch64_client_" + str(client) + ".pth"))
    validation_data = torch.load("trainpt/" + str(client) + ".pt")
    validation_data_loader = DataLoader(validation_data, batch_size, shuffle=True)
    iso_layer_importance = layer_importance_bias(iso_model, loss_fn, validation_data_loader)
    
    layer_importance_scores[client] = iso_layer_importance
layer_importance = pd.DataFrame.from_dict(layer_importance_scores).T
layer_importance = layer_importance.reset_index()
namings = {"layer_1.weight": "layer1_weight_imp",
            "layer_1.bias": "layer1_bias_imp",
          "layer_2.weight":"layer2_weight_imp",
           "layer_2.bias" : "layer2_bias_imp",
           "layer_3.weight" : "layer3_weight_imp",
           "layer_3.bias" : "layer3_bias_imp",
           "index":"client_id"
          }

layer_importance.rename(columns=namings, inplace=True)
layer_importance.head(5)

Unnamed: 0,client_id,layer1_weight_imp,layer1_bias_imp,layer2_weight_imp,layer2_bias_imp,layer3_weight_imp,layer3_bias_imp
0,0_0,77.518744,0.571061,19.311343,0.755723,1.616715,0.226413
1,0_1,77.409859,0.572669,19.636817,0.772504,1.367931,0.240221
2,0_2,81.564528,0.545362,16.027508,0.559516,1.14818,0.154906
3,0_3,82.1122,0.575041,15.323519,0.54234,1.256588,0.190313
4,0_4,81.718808,0.5711,15.870728,0.595554,1.075745,0.168065


In [11]:
full_df = pd.concat([layer_importance, lrp_eccetricities], axis=1)

In [12]:
averages = [] 
weighted_avg = []
for index, client in full_df.iterrows():
    avg = (client['layer_1.weight'] + client['layer_1.bias'] + client['layer_2.weight'] + client['layer_2.bias'] + client['layer_3.weight'] + client['layer_3.bias'])/6
    weighted = (client['layer_1.weight']* client["layer1_weight_imp"] +
                client['layer_2.weight']* client["layer2_weight_imp"] +
                client['layer_3.weight']* client["layer3_weight_imp"] +
                client['layer_1.bias'] * client["layer1_bias_imp"] +
                client['layer_2.bias'] * client["layer2_bias_imp"] +
                client['layer_3.bias'] * client["layer3_bias_imp"]) / 100
    
    averages.append(round(avg, 4))
    weighted_avg.append(round(weighted, 4))
    
lrp_eccetricities["average"] = averages
lrp_eccetricities["weighted_avg"] = weighted_avg 


In [13]:
lrp_eccetricities.to_csv("insights/eccentricity_with_lrp_25_epoch.csv" , index=False)

In [17]:
sum(lrp_eccetricities.average.apply(lambda x: x/2))/2

0.5000249999999999