In [1]:
import argparse
import json
import logging
import random
from collections import defaultdict, OrderedDict
from pathlib import Path

import numpy as np
import torch
import torch.utils.data
import experiments
import model
from experiments.pfedhn_pc.utils import get_average_model, weighted_aggregate_model
from tqdm import trange

from experiments.pfedhn_pc.models import CNNHyperPC, CNNTargetPC, CNNTargetPC_M, LocalLayer
from experiments.pfedhn_pc.node import BaseNodesForLocal, BaseNodesForLocals_M
from experiments.utils import get_device, set_logger, set_seed, str2bool

import copy
from model import *

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
import imp
imp.reload(experiments)
imp.reload(model)
from model import *
from experiments.pfedhn_pc.utils import get_average_model, weighted_aggregate_model
from experiments.pfedhn_pc.node import BaseNodesForLocal, BaseNodesForLocals_M
from experiments.utils import get_device, set_logger, set_seed, str2bool
from experiments.pfedhn_pc.models import CNNHyperPC, CNNTargetPC, CNNTargetPC_M, LocalLayer

In [4]:
data_name = 'cifar10'

if data_name == 'cifar10':
    classes_per_node = 2
else:
    classes_per_node = 10

data_path = 'data'
num_nodes = 50
num_steps = 5000
inner_steps = 50
optim = 'sgd'
lr = 5e-2
inner_lr = 5e-3 
embed_lr = None
wd = 1e-3
inner_wd = 5e-5
embed_dim = -1
batch_size = 64
eval_every = 2
save_path = "pfedhn_pc_cifar_res"


In [5]:

nodes = BaseNodesForLocals_M(
    data_name=data_name,
    data_path=data_path,
    n_nodes=num_nodes,
    base_model=CNNTargetPC_M,
    layer_config={'in_channels':3},
    base_optimizer=torch.optim.SGD, optimizer_config=dict(lr=inner_lr, momentum=.9, weight_decay=inner_wd),
    device=device,
    batch_size=batch_size,
    classes_per_node=classes_per_node,
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
import logging
from tqdm import trange

In [7]:
embed_dim = embed_dim    
if embed_dim == -1:
    logging.info("auto embedding size")
    embed_dim = int(1 + num_nodes / 4)

inet = CNNTargetPC_M()
#net = CNNTargetPC(n_kernels=n_kernels)

inet = inet.to(device)
#net = net.to(device)#unnecessary

##################
# init optimizer #
##################
embed_lr = embed_lr if embed_lr is not None else lr
optimizers = {
    'sgd': torch.optim.SGD(
        [
            {'params': [p for n, p in inet.named_parameters() if 'embed' not in n]},
            {'params': [p for n, p in inet.named_parameters() if 'embed' in n], 'lr': embed_lr}
        ], lr=lr, momentum=0.9, weight_decay=wd
    ),
    'adam': torch.optim.Adam(params=inet.parameters(), lr=lr)
}
optimizer = optimizers[optim]
criteria = torch.nn.CrossEntropyLoss()

################
# init metrics #
################
last_eval = -1
best_step = -1
best_acc = -1
test_best_based_on_step, test_best_min_based_on_step = -1, -1
test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
step_iter = trange(num_steps)

results = defaultdict(list)

net_keys = [*inet.state_dict().keys()]
base_layer_keys = net_keys[:-2]
per_layer_keys = net_keys[-2:]

net_values = [*inet.state_dict().values()]
base_values = net_values[:-2]
per_values = net_values[-2:]



  0%|          | 0/5000 [00:00<?, ?it/s]

In [8]:
server = Server(base_values, client_list=nodes.models, num_client=100, 
                pre_update_eps=100, per_layer=per_values, num_cluster=10)

In [9]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(lineno)s - %(levelname)s - %(message)s')
fhandler = logging.FileHandler(filename='test_log.log', mode='a')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)
import sys
consoleHandler = logging.StreamHandler(sys.stdout)
consoleHandler.setFormatter(formatter)
logger.addHandler(consoleHandler)

In [10]:
def eval_model(nodes, num_nodes, net, criteria, device, split):
    curr_results = evaluate(nodes, num_nodes, net, criteria, device, split=split)
    total_correct = sum([val['correct'] for val in curr_results.values()])
    total_samples = sum([val['total'] for val in curr_results.values()])
    avg_loss = np.mean([val['loss'] for val in curr_results.values()])
    avg_acc = total_correct / total_samples

    all_acc = [val['correct'] / val['total'] for val in curr_results.values()]

    return curr_results, avg_loss, avg_acc, all_acc


@torch.no_grad()
def evaluate(nodes: BaseNodesForLocals_M, num_nodes, net, criteria, device, split='test'):
    net.eval()
    results = defaultdict(lambda: defaultdict(list))

    for node_id in range(num_nodes):  # iterating over nodes

        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]

        for batch_count, batch in enumerate(curr_data):
            img, label = tuple(t.to(device) for t in batch)
            pred = nodes.models[node_id](img)
            running_loss += criteria(pred, label).item()
            running_correct += pred.argmax(1).eq(label).sum().item()
            running_samples += len(label)

        results[node_id]['loss'] = running_loss / (batch_count + 1)
        results[node_id]['correct'] = running_correct
        results[node_id]['total'] = running_samples

    return results

In [11]:

list_node_ids = random.sample(range(num_nodes), 20)
print(list_node_ids)

[14, 45, 15, 33, 10, 18, 40, 42, 17, 11, 19, 35, 32, 8, 46, 2, 38, 47, 48, 5]


In [12]:
logger.debug('This is a debug message')
logger.info('This is an info message')
logger.warning('This is a warning message')
logger.error('This is an error message')
logger.critical('This is a critical message')


INFO:root:This is an info message


2022-05-13 14:26:39,685 - root - 2 - INFO - This is an info message






ERROR:root:This is an error message


2022-05-13 14:26:39,687 - root - 4 - ERROR - This is an error message


CRITICAL:root:This is a critical message


2022-05-13 14:26:39,688 - root - 5 - CRITICAL - This is a critical message


In [13]:


for step in range(10):
    inet.train()

    # each client load global weights
    nodes.client_load_weights(inet.state_dict())
    # select client at random
    list_node_ids = random.sample(range(num_nodes), 20)

    # NOTE: evaluation on sent model
    with torch.no_grad():
        pred_list, prvs_acc_list, prvs_loss_list = [],[],[]

        for node_id in list_node_ids:
            nodes.models[node_id].eval()
            batch = next(iter(nodes.test_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)

            pred = nodes.models[node_id](img)
            pred_list.append(pred)

            prvs_loss = criteria(pred, label)
            prvs_loss_list.append(prvs_loss)
            prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
            prvs_acc_list.append(prvs_acc)

    # inner updates -> obtaining theta_tilda
    for i in range(inner_steps):
        for node_id in list_node_ids:
            nodes.models[node_id].train()
            nodes.local_optimizers[node_id].zero_grad()

            batch = next(iter(nodes.train_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)

            pred = nodes.models[node_id](img)

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(nodes.models[node_id].parameters(), 50)
            nodes.local_optimizers[node_id].step()

    inet.load_state_dict(get_average_model(nodes.models))


    for i in range(20):
        node_id = list_node_ids[i]
        prvs_loss = prvs_loss_list[i]
        prvs_acc = prvs_acc_list[i]
        step_iter.set_description(
            f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
        )

    if step % eval_every == 0:
        last_eval = step
        step_results, avg_loss, avg_acc, all_acc = eval_model(
            nodes, num_nodes, inet, criteria, device, split="test"
        )
        logging.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")

        results['test_avg_loss'].append(avg_loss)
        results['test_avg_acc'].append(avg_acc)

        _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, inet, criteria, device, split="val")
        if best_acc < val_avg_acc:
            best_acc = val_avg_acc
            best_step = step
            test_best_based_on_step = avg_acc
            test_best_min_based_on_step = np.min(all_acc)
            test_best_max_based_on_step = np.max(all_acc)
            test_best_std_based_on_step = np.std(all_acc)

        results['val_avg_loss'].append(val_avg_loss)
        results['val_avg_acc'].append(val_avg_acc)
        results['best_step'].append(best_step)
        results['best_val_acc'].append(best_acc)
        results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
        results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
        results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
        results['test_best_std_based_on_step'].append(test_best_std_based_on_step)



Step: 1, Node ID: 20, Loss: 2.3901,  Acc: 0.0000:   0%|          | 0/5000 [00:16<?, ?it/s]INFO:root:
Step: 1, AVG Loss: 1033224.3371,  AVG Acc: 0.3015


2022-05-13 14:26:58,012 - root - 58 - INFO - 
Step: 1, AVG Loss: 1033224.3371,  AVG Acc: 0.3015


Step: 2, Node ID: 37, Loss: 92277080.0000,  Acc: 0.0000:   0%|          | 0/5000 [00:33<?, ?it/s]

KeyboardInterrupt: 

In [None]:
print(nodes.models[])
logging.info('Python info')

In [None]:

if step != last_eval:
    _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, inet, net, criteria, device, split="val")
    step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, inet, net, criteria, device, split="test")
    logging.info(f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")

    results['test_avg_loss'].append(avg_loss)
    results['test_avg_acc'].append(avg_acc)

    if best_acc < val_avg_acc:
        best_acc = val_avg_acc
        best_step = step
        test_best_based_on_step = avg_acc
        test_best_min_based_on_step = np.min(all_acc)
        test_best_max_based_on_step = np.max(all_acc)
        test_best_std_based_on_step = np.std(all_acc)

    results['val_avg_loss'].append(val_avg_loss)
    results['val_avg_acc'].append(val_avg_acc)
    results['best_step'].append(best_step)
    results['best_val_acc'].append(best_acc)
    results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
    results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
    results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
    results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
