In [2]:
import os
import time
import torch

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from FLTrack.models import ShallowNN
from FLTrack.utils import load_file, get_all_possible_pairs
from FLTrack.evals import influence, calculate_hessian_flattened, calculate_hessian_gnewtons

features = 169
batch_size = 64
loss_fn = torch.nn.HuberLoss()

In [3]:
client_ids = [f"c{i}" for i in range(1, 25)]

## 500 Global Round and 1 Local Round

In [None]:
global_model = ShallowNN(features)
global_model.load_state_dict(torch.load('FLTrack/checkpt/fedl/epoch_250/25_rounds_10_epochs_per_round/global_model.pth'))

In [None]:
def hessian(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, 64, shuffle=True)
                                              
    fed_hessian_mat = calculate_hessian_flattened(global_model,loss_fn,data_loader)
    
    return fed_hessian_mat

In [None]:
h = hessian("c1")

In [None]:
h1 = hessian("c1")

In [None]:
h1

In [None]:
for client in client_ids:
    fed_hessian_mat = hessian(client)
    torch.save(fed_hessian_mat, "hessians/epoch_500/train_fed/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

## 1 Global Round and 25 Local Round

In [None]:
global_model = ShallowNN(features)
global_model.load_state_dict(torch.load('checkpt/epoch_25/1_rounds_25_epochs_per_round/_fedl_global_1_25.pth'))

In [None]:
def hessian(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
                                              
    fed_hessian_mat , _  = calculate_hessian_flattened(global_model,loss_fn,data_loader)
    
    return fed_hessian_mat

In [None]:
for client in client_ids:
    fed_hessian_mat = hessian(client)
    torch.save(fed_hessian_mat, "hessians/epoch_25/train_fed/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

### Hessian Matrix for isolated model

In [None]:
def hessian_iso(client):
    
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
    
    client_model = ShallowNN(features)
    client_model.load_state_dict(torch.load("checkpt/saving/epoch_500/global_1/clients/client_model_" +str(client)+ ".pth" ))
           
    iso_hessian_mat = calculate_hessian_flattened(client_model,loss_fn,data_loader)
    
    return iso_hessian_mat

In [None]:
for client in client_ids:
    fed_hessian_mat = hessian_iso(client)
    torch.save(fed_hessian_mat, "hessians/epoch_25/iso/"+str(client)+".pth")
    print("Client " +str(client)+ " done.")

## Hessians in Each Iteration

In [2]:
def hessian(client):
    dataset = torch.load("trainpt/"+ str(client)+".pt")
    data_loader = DataLoader(dataset, batch_size, shuffle=True)                                              
    fed_hessian_mat = calculate_hessian_gnewtons(fl_model_ckpt,loss_fn,data_loader)
    
    return fed_hessian_mat

In [3]:
client_ids = [f"c{i}" for i in range(1, 25)]

In [4]:
for i in range(1,26):
    global_round = i
    fl_model_ckpt = ShallowNN(features)
    fl_model_ckpt.load_state_dict(torch.load('FLTrack/checkpt/fedl/epoch_250/25_rounds_10_epochs_per_round/global_'+str(global_round)+'/global_model.pth'))
    
    for client in client_ids:
        fed_hessian_mat = hessian(client)
        path = f"hessians/saving/gn/{global_round}/{client}.pth"
        if os.path.exists(path):
            torch.save(fed_hessian_mat, path)
        else:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            torch.save(fed_hessian_mat, path)
        print("Client " +str(client)+ " done.")

Calculation time of Gradients 0.06399679183959961
Calculation time of Hessian 0.001777644952138265
Client c1 done.
Calculation time of Gradients 0.03069901466369629
Calculation time of Hessian 0.001176166534423828
Client c2 done.
Calculation time of Gradients 0.021516084671020508
Calculation time of Hessian 0.000678865114847819
Client c3 done.
Calculation time of Gradients 0.01970672607421875
Calculation time of Hessian 0.0006606340408325196
Client c4 done.
Calculation time of Gradients 0.019054174423217773
Calculation time of Hessian 0.0006354848543802897
Client c5 done.
Calculation time of Gradients 0.020772218704223633
Calculation time of Hessian 0.0006634354591369629
Client c6 done.
Calculation time of Gradients 0.021872758865356445
Calculation time of Hessian 0.0006889859835306803
Client c7 done.
Calculation time of Gradients 0.01848888397216797
Calculation time of Hessian 0.0006191492080688477
Client c8 done.
Calculation time of Gradients 0.008349895477294922
Calculation time of 