In [1]:
import os
import time
import torch
import collections

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from models import ShallowNN
from utils import load_file, get_all_possible_pairs
from evals import evaluate, pairwise_euclidean_distance , influence, layer_importance, layer_importance_bias, layerwise_full_accumulated_proximity
from evals import euclidean_distance, pairwise_euclidean_distance, accumulated_proximity

features = 197
batch_size = 64
loss_fn = torch.nn.L1Loss()

In [2]:
client_ids = ["0_0","0_1","0_2","0_3","0_4","0_5","1_0","1_1","1_2","1_3","1_4","1_5","2_0","2_1","2_2","2_3","2_4","2_5","3_0","3_1","3_2","3_3","3_4","3_5"]

In [3]:
dummy = torch.load("checkpt/isolated/batch64_client_0_0.pth")
critarians = [item for item in dummy]
state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in client_ids
    }

In [4]:
def layerwise_proximity(x:collections.OrderedDict,y:collections.OrderedDict, critarian:str , distance_matrix):
    if critarian.split(".")[-1] == "bias":
        proximity = accumulated_proximity(x[critarian].view(1, -1),y[critarian].view(1, -1),distance_matrix)
    else:
        proximity = accumulated_proximity(x[critarian],y[critarian],distance_matrix)
    
    return proximity

In [5]:
def layerwise_full_accumulated_proximity(
    clients: list, criterian: str, distance_matrix) -> tuple:
    """
    Calculate the layer-wise full accumulated proximity between all clients.

    Parameters:
    -------------
    clients: list; list of clients
    criterian: str; criterian to be evaluated, basically the layer name.
    distance_matrix: Callable; distance matrix

    Returns:
    -------------
    total_weight_proximity: float; total weight proximity
    total_bias_proximity: float; total bias proximity
    """
    state_dicts = {
        key: torch.load("checkpt/isolated/batch64_client_" + str(key) + ".pth")
        for key in clients
    }

    #possible_pairs = get_all_possible_pairs(clients)

    total_proximity = 0.0

    for l in clients:
        for i in clients:
            prox = layerwise_proximity(state_dicts[l],state_dicts[i],criterian, distance_matrix)
            total_proximity += prox

    return total_proximity

In [None]:
full_prox_dict = {}
for item in critarians:
    full_prox_dict[item] = layerwise_full_accumulated_proximity(client_ids, item ,euclidean_distance)

In [None]:
eccentricities = {}
for c in critarians:
    client_ecc = {}
    for client in client_ids:
        client_matrix = torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth")
        acc_proximity = 0.0
        for key in state_dicts:
            distance = layerwise_proximity(client_matrix,state_dicts[key],c, euclidean_distance)
            acc_proximity += distance
        eccentricity = 2*acc_proximity/full_prox_dict[c]
        client_ecc[client] = eccentricity.item()
    eccentricities[c] = client_ecc
    #print(client,acc_proximity, 2*acc_proximity/layer1_fullweight_prox)

In [None]:
lrp_eccetricities = pd.DataFrame.from_dict(eccentricities)
lrp_eccetricities.head(5)

In [None]:
layer_importance_scores = {}
  
for client in client_ids:
    iso_model = ShallowNN(features)
    iso_model.load_state_dict(torch.load("checkpt/isolated/batch64_client_" + str(client) + ".pth"))
    validation_data = torch.load("trainpt/" + str(client) + ".pt")
    validation_data_loader = DataLoader(validation_data, batch_size, shuffle=True)
    iso_layer_importance = layer_importance_bias(iso_model, loss_fn, validation_data_loader)
    
    layer_importance_scores[client] = iso_layer_importance
layer_importance = pd.DataFrame.from_dict(layer_importance_scores).T
layer_importance.head(5)

In [None]:
for item in critarians:
    print(item, round(lrp_eccetricities[item].sum(),4))
    print("layer_im",layer_importance[item].mean())

In [None]:
layer_1_weight_ls = 81.60
layer_1_bias_ls = 0.76
layer_2_weight_ls = 16.57
layer_2_bias_ls = 0.54
layer_3_weight_ls = 0.47
layer_3_bias_ls = 0.06

In [None]:
averages = [] 
weighted_avg = []
for index, client in lrp_eccetricities.iterrows():
    avg = (client['layer_1.weight'] + client['layer_1.bias'] + client['layer_2.weight'] + client['layer_2.bias'] + client['layer_3.weight'] + client['layer_3.bias'])/6
    weighted = (client['layer_1.weight']*layer_1_weight_ls + client['layer_1.bias']*layer_1_bias_ls + 
                client['layer_2.weight']*layer_2_weight_ls + client['layer_2.bias']*layer_2_bias_ls + 
                client['layer_3.weight']*layer_3_weight_ls + client['layer_3.bias']*layer_3_bias_ls)/(100)
    averages.append(round(avg,4))
    weighted_avg.append(round(weighted,4))
lrp_eccetricities["average"] = averages
lrp_eccetricities["weighted_avg"] = weighted_avg

In [None]:
lrp_eccetricities#.to_csv("insights/eccentricity_with_lrp.csv" , index=False)