In [3]:
import os
import time
import torch

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from FLTrack.models import ShallowNN
from FLTrack.utils import load_file, get_all_possible_pairs
from FLTrack.evals import evaluate, pairwise_euclidean_distance , influence, calculate_hessian_flattened
from FLTrack.evals import euclidean_distance, manhattan_distance, pairwise_euclidean_distance, accumulated_proximity, full_accumulated_proximity

features = 197
batch_size = 64
loss_fn = torch.nn.L1Loss()

In [4]:
client_ids = [f"{i}_{j}" for i in range(4) for j in range(6)]

## 500 Global Round and 1 Local Round

In [None]:
global_model = ShallowNN(features)
global_model.load_state_dict(torch.load('checkpt/epoch_500/_fedl_global_500.pth'))

In [None]:
def hessian(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
                                              
    fed_hessian_mat , _  = calculate_hessian_flattened(global_model,loss_fn,data_loader)
    
    return fed_hessian_mat

In [None]:
for client in client_ids:
    fed_hessian_mat = hessian(client)
    torch.save(fed_hessian_mat, "hessians/epoch_500/train_fed/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

## 1 Global Round and 25 Local Round

In [None]:
global_model = ShallowNN(features)
global_model.load_state_dict(torch.load('checkpt/epoch_25/1_rounds_25_epochs_per_round/_fedl_global_1_25.pth'))

In [None]:
def hessian(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
                                              
    fed_hessian_mat , _  = calculate_hessian_flattened(global_model,loss_fn,data_loader)
    
    return fed_hessian_mat

In [None]:
for client in client_ids:
    fed_hessian_mat = hessian(client)
    torch.save(fed_hessian_mat, "hessians/epoch_25/train_fed/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

### Hessian Matrix for isolated model

In [13]:
def hessian_iso(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
    
    client_model = ShallowNN(features)
    client_model.load_state_dict(torch.load("checkpt/saving/epoch_500/global_1/clients/client_model_" +str(client)+ ".pth" ))
           
    iso_hessian_mat = calculate_hessian_flattened(client_model,loss_fn,data_loader)
    
    return iso_hessian_mat

In [14]:
for client in client_ids:
    fed_hessian_mat = hessian_iso(client)
    torch.save(fed_hessian_mat, "hessians/epoch_25/iso/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

Calculation time of Gradients 0.01863384246826172
Calculation time of Hessian 9.779759017626445
Client 0_0 done.
Calculation time of Gradients 0.020355939865112305
Calculation time of Hessian 10.56281951268514
Client 0_1 done.
Calculation time of Gradients 0.01831197738647461
Calculation time of Hessian 10.428972816467285
Client 0_2 done.
Calculation time of Gradients 0.01817917823791504
Calculation time of Hessian 10.06779448588689
Client 0_3 done.
Calculation time of Gradients 0.0178830623626709
Calculation time of Hessian 9.845117966334024
Client 0_4 done.
Calculation time of Gradients 0.01825714111328125
Calculation time of Hessian 10.192091766993205
Client 0_5 done.
Calculation time of Gradients 0.02028799057006836
Calculation time of Hessian 10.894143664836884
Client 1_0 done.
Calculation time of Gradients 0.019181013107299805
Calculation time of Hessian 10.317794954776764
Client 1_1 done.
Calculation time of Gradients 0.01853799819946289
Calculation time of Hessian 10.3239126801

## Hessians in Each Iteration

In [3]:
def hessian(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)                                              
    fed_hessian_mat = calculate_hessian_flattened(fl_model_ckpt,loss_fn,data_loader)
    
    return fed_hessian_mat

In [4]:
#client_ids = ['0_2','0_3','0_4','0_5','1_0', '1_1', '1_2', '1_3', '1_4', '1_5', '2_0','2_1', '2_2','2_3','2_4','2_5','3_0','3_1','3_2','3_3','3_4','3_5']

In [5]:
global_round = 20
fl_model_ckpt = ShallowNN(features)
fl_model_ckpt.load_state_dict(torch.load('checkpt/saving/epoch_500/global_'+str(global_round)+'/global_model.pth'))
for client in client_ids:
    fed_hessian_mat = hessian(client)
    path = f"hessians/saving/{global_round}/{client}.pth"
    if os.path.exists(path):
        torch.save(fed_hessian_mat, path)
    else:
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(fed_hessian_mat, path)
    print("Client " +str(client)+ " done.")

Calculation time of Gradients 0.04049396514892578
Calculation time of Hessian 9.078031420707703
Client 0_0 done.
Calculation time of Gradients 0.017895221710205078
Calculation time of Hessian 9.412500635782878
Client 0_1 done.
Calculation time of Gradients 0.017027854919433594
Calculation time of Hessian 9.453154913584392
Client 0_2 done.
Calculation time of Gradients 0.016679763793945312
Calculation time of Hessian 37.06908006668091
Client 0_3 done.
Calculation time of Gradients 0.016553878784179688
Calculation time of Hessian 9.00583846171697
Client 0_4 done.
Calculation time of Gradients 0.016477108001708984
Calculation time of Hessian 8.999890299638112
Client 0_5 done.
Calculation time of Gradients 0.018182039260864258
Calculation time of Hessian 9.87904949982961
Client 1_0 done.
Calculation time of Gradients 0.017878055572509766
Calculation time of Hessian 10.864837551116944
Client 1_1 done.
Calculation time of Gradients 0.018316984176635742
Calculation time of Hessian 9.985280315