In [None]:
!git clone http://github.com/wizard1203/FuseFL

In [3]:
!git clone https://github.com/Xtra-Computing/FedOV

Cloning into 'FedOV'...
remote: Enumerating objects: 175, done.[K
remote: Counting objects: 100% (158/158), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 175 (delta 90), reused 150 (delta 86), pack-reused 17 (from 1)[K
Receiving objects: 100% (175/175), 95.03 KiB | 5.94 MiB/s, done.
Resolving deltas: 100% (92/92), done.


In [4]:
!ls

FedOV  FuseFL


In [5]:
%%writefile FuseFL/alg_train.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import argparse
import copy
import os
import shutil
import sys
import warnings
import torchvision.models as models
import numpy as np
from tqdm import tqdm
import pdb
import logging
import time


from helpers.datasets import partition_data, load_data, get_image_size, get_num_of_labels
from helpers.utils import get_dataset, average_weights, DatasetSplit, BackdoorDS, KLDiv, setup_seed, test, progressive_test
from helpers.exp_path import ExpTool


from models.generator import Generator
from models.nets import CNNCifar, CNNMnist, CNNCifar100
from models.pnn import PNN
from models.pnn_cnn import PNN_CNN, pnn_resnet18, pnn_resnet50

from models.fl_pnn import Federated_PNN
from models.fl_pnn_cnn import Federated_PNN_CNN, fl_pnn_resnet18, fl_pnn_resnet50
from models.mlp import MLP
from models.fl_exnn import (MLP_Block, CNN_Block,
    merge_layer, Federated_EXNN, Federated_EXNNLayer_global, Federated_EXNNLayer_local,
    fl_exnn_resnet18, fl_exnn_resnet50, 
)
from models.seq_model import Sequential_SplitNN, ReconMIEstimator


import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from models.resnet import resnet18, resnet50, get_res18_out_channels
from models.vit import deit_tiny_patch16_224
import wandb
from models.configs import Split_Configs, EXNN_Split_Configs

warnings.filterwarnings('ignore')
upsample = torch.nn.Upsample(mode='nearest', scale_factor=7)

from locals.fedavg import LocalUpdate
from locals.fl_progressive import FedPnnLocalUpdate
from locals.progressive import PnnLocalUpdate
from locals.fl_expandable import FedEXNNLocalUpdate
from locals.ccvr import (compute_classes_mean_cov, generate_virtual_representation,
    calibrate_classifier, get_means_covs_from_client)

from utils import seq_map_values, batch, accuracy, show_model_layers



def obtain_projection_head(before_cls_feature_num, contrastive_projection_dim):
    projector = nn.Sequential(
        nn.Linear(before_cls_feature_num, before_cls_feature_num, bias=False),
        nn.ReLU(),
        nn.Linear(before_cls_feature_num, contrastive_projection_dim, bias=False),
    )
    return projector


class Ensemble(torch.nn.Module):
    def __init__(self, model_list):
        super(Ensemble, self).__init__()
        self.models = model_list

    def to(self, device):
        for model in self.models:
            model.to(device)


    def forward(self, x):
        logits_total = 0
        for i in range(len(self.models)):
            logits = self.models[i](x)
            logits_total += logits
        logits_e = logits_total / len(self.models)

        return logits_e


def pretrain(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels):

    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
    users = []
    locals = []

    before_cls_feature_num = out_channels[-1]
    backdoor_test_loader = None
    if args.backdoor_train: 
        backdoor_test_loader = DataLoader(BackdoorDS(test_dataset, args.backdoor_size, mode="random"),
                                    batch_size=256, shuffle=False, num_workers=4)

    # ===============================================
    for idx in range(args.num_users):
        logger.info("client {}".format(idx))
        users.append("client_{}".format(idx))
        if args.backdoor_train and idx < args.backdoor_n_clients:
            local_update = LocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                train_user_groups[idx], test_user_groups[idx], copy.deepcopy(global_model), backdoor_train=True)
        else:
            local_update = LocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                train_user_groups[idx], test_user_groups[idx], copy.deepcopy(global_model))
        locals.append(local_update)
        if args.contrastive_train:
            # We use a MLP with one hidden layer to obtain z_i = g(h_i) = W(2)σ(W(1)h_i) where σ is a ReLU non-linearity.
            projector = obtain_projection_head(before_cls_feature_num, args.contrastive_projection_dim)
            local_update.add_CL_head(projector)

    train_time = 0
    total_epoch = 0
    for epoch in range(args.local_ep):
        start_time = time.time()
        local_weights = []
        train_losses = []
        acc_list = []
        pfl_acc_list = []
        training_pfl_acc_list = []
        if epoch % 10 == 0 or epoch < 10 or epoch == args.local_ep - 1:
            if_test = True
        else:
            if_test = False
        if_test = True
        for idx in range(args.num_users):
            # not load global model, for one-shot communication...
            w, avg_train_loss, global_acc, pfl_acc, train_pfl_acc = locals[idx].update_weights(idx, 1, device, if_test=if_test)
            acc_list.append(global_acc)
            train_losses.append(avg_train_loss)
            pfl_acc_list.append(pfl_acc)
            training_pfl_acc_list.append(train_pfl_acc)
            # local_weights.append(copy.deepcopy(w))
            local_weights.append(w)

        total_epoch += args.local_ep

        avg_train_loss = np.mean(train_losses)
        train_time += time.time() - start_time

        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)
        model_list = []
        for i in range(len(local_weights)):
            net = copy.deepcopy(global_model)
            net.load_state_dict(local_weights[i])
            model_list.append(net)
        ensemble_model = Ensemble(model_list)
        if if_test:
            result_dict = {}
            for idx in range(args.num_users):
                result_dict["client_{}_acc".format(users[idx])] = acc_list[idx]
                result_dict["pfl_acc_on_{}".format(users[idx])] = pfl_acc_list[idx]
                result_dict["pfl_training_acc_on_{}".format(users[idx])] = training_pfl_acc_list[idx]

            ExpTool.record(result_dict)
            test_acc, test_loss = test(global_model, global_test_loader, device)
            logger.info(f"avg acc: {test_acc}")

            ensemble_acc, ensemble_loss = test(ensemble_model, global_test_loader, device)
            if args.backdoor_train:
                ensemble_backdoor_acc, ensemble_backdoor_loss = test(ensemble_model, backdoor_test_loader, device)
                logger.info(f"ensemble_backdoor_acc: {ensemble_backdoor_acc}")
                ExpTool.record({"ensemble_backdoor_acc": ensemble_backdoor_acc,
                                "ensemble_backdoor_loss": ensemble_backdoor_loss})

            logger.info(f"ensemble acc: {ensemble_acc}")
            ExpTool.record({"comm_round": 0, "local_epoch": total_epoch, "train_loss": avg_train_loss,
                            "test_acc": test_acc, "ensemble_acc": ensemble_acc, "train_time": train_time})
            ExpTool.upload()

    count_para = 0
    for local_weight in local_weights:
        # for key, value in local_weight.named_parameters():
        for key, value in local_weight.items():
            count_para += value.numel()
    summary_dict = {"count_paras": count_para}
    logger.info(f"summary_dict: {summary_dict}")
    ExpTool.summary(summary_dict)
    # ===============================================
    if not args.checkpoint == "no":
        ExpTool.save_pickle(local_weights, args.checkpoint, exp_dir=True)
    # ExpTool.load_pickle
    # torch.save(local_weights, '{}_{}clients_{}.pkl'.format(args.dataset, args.num_users, args.alpha))
    return global_model, global_weights, local_weights, model_list






def progressive(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels):

    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
    users = []
    locals = []

    # ===============================================
    for idx in range(args.num_users):
        logger.info("client {}".format(idx))
        users.append("client_{}".format(idx))
        local_update = PnnLocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                            train_user_groups[idx], test_user_groups[idx])
        locals.append(local_update)

    global_model.train()
    global_model.to(device)
    # Now, there is no local weights in progressive FL, because the model is increasing...
    training_pfl_acc_list = []
    train_losses = []
    train_time = 0
    for idx in range(args.num_users):
        start_time = time.time()
        # not load global model, for one-shot communication...
        _, avg_train_loss, _, train_pfl_acc = locals[idx].update_weights(idx, args.local_ep, global_model, device, if_test=True)
        training_pfl_acc_list.append(train_pfl_acc)
        train_losses.append(avg_train_loss)
        train_time += time.time() - start_time

    avg_train_loss = np.mean(train_losses)

    # Test global and ensemble model
    # NOTE: global weights need not to be averaged
    num_total_corrects = 0
    num_total = 0
    pfl_accs = []
    for idx in range(args.num_users):
        local = locals[idx]
        num_total += len(local.global_test_loader.dataset)
        pfl_acc, pfl_test_loss, correct = progressive_test(global_model, local.global_test_loader, idx, device)
        pfl_accs.append(pfl_acc)
        num_total_corrects += correct
    test_acc = 100. * num_total_corrects / num_total

    result_dict = {}
    for idx in range(args.num_users):
        result_dict["pfl_acc_on_{}".format(users[idx])] = pfl_accs[idx]
        result_dict["pfl_training_acc_on_{}".format(users[idx])] = training_pfl_acc_list[idx]

    logger.info(f"pfl_accs: {pfl_accs}")
    logger.info(f"training_pfl_acc_list:{training_pfl_acc_list}")
    logger.info(f"test_acc:{test_acc}")

    ExpTool.record(result_dict)
    logger.info("avg acc:")
    ExpTool.record({"comm_round": 0, "local_epoch": args.local_ep, "train_loss": avg_train_loss,
                    "test_acc": test_acc, "train_time": train_time})
    ExpTool.upload()

    count_para = 0
    for key, value in global_model.named_parameters():
        count_para += value.numel()
    summary_dict = {"count_paras": count_para}
    logger.info(f"summary_dict: {summary_dict}")
    ExpTool.summary(summary_dict)

    # ===============================================
    if not args.checkpoint == "no":
        ExpTool.save_pickle(global_model.cpu().state_dict(), args.checkpoint, exp_dir=True)
    # torch.save(global_model.cpu().state_dict(), '{}_{}_{}clients_{}.pkl'.format(args.type, args.dataset, args.num_users, args.alpha))


def fed_progressive(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels):

    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
    users = []
    locals = []
    for idx in range(args.num_users):
        logger.info("client {}".format(idx))
        users.append("client_{}".format(idx))
        local_update = FedPnnLocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                            train_user_groups[idx], test_user_groups[idx])
        locals.append(local_update)

    global_model.train()
    global_model.to(device)
    # Now, there is no local weights in progressive FL, because the model is increasing...
    training_pfl_acc_list = []
    train_losses = []
    train_time = 0
    for idx in range(args.num_users):
        # not load global model, for one-shot communication...
        start_time = time.time()
        _, avg_train_loss, _, train_pfl_acc = locals[idx].update_weights(idx, args.local_ep, global_model, device, if_test=True)
        training_pfl_acc_list.append(train_pfl_acc)
        train_losses.append(avg_train_loss)
        train_time += time.time() - start_time

    avg_train_loss = np.mean(train_losses)

    # Test global and ensemble model
    # NOTE: global weights need not to be averaged
    logger.info("avg acc:")
    test_acc, test_loss = test(global_model, global_test_loader, device)
    pfl_accs = []

    result_dict = {}
    for idx in range(args.num_users):
        result_dict["pfl_training_acc_on_{}".format(users[idx])] = training_pfl_acc_list[idx]
        local_test_acc, _ = test(global_model, locals[idx].test_loader, device)
        result_dict["pfl_acc_on_{}".format(users[idx])] = local_test_acc
        pfl_accs.append(local_test_acc)

    logger.info(f"pfl_accs: {pfl_accs}")
    logger.info(f"training_pfl_acc_list:{training_pfl_acc_list}")
    logger.info(f"test_acc:{test_acc}")

    ExpTool.record(result_dict)
    logger.info("avg acc:")
    ExpTool.record({"comm_round": 0, "local_epoch": args.local_ep, "train_loss": avg_train_loss,
                    "test_acc": test_acc, "train_time": train_time})
    ExpTool.upload()
    count_para = 0
    for key, value in global_model.named_parameters():
        count_para += value.numel()
    summary_dict = {"count_paras": count_para}
    logger.info(f"summary_dict: {summary_dict}")
    ExpTool.summary(summary_dict)

    # ===============================================
    if not args.checkpoint == "no":
        ExpTool.save_pickle(global_model.cpu().state_dict(), args.checkpoint, exp_dir=True)
    # torch.save(global_model.cpu().state_dict(), '{}_{}_{}clients_{}.pkl'.format(args.type, args.dataset, args.num_users, args.alpha))



def init_fedexnn_merged(args, split_modules, out_channels):
    users = []
    local_FedEXNN_models = {}
    split_config = EXNN_Split_Configs[args.model][args.fedexnn_split_num]
    num_of_classes = get_num_of_labels(args.dataset)

    for idx in range(args.num_users):
        split_local_layers = []
        for layer_idx, layer in enumerate(split_modules):
            EXNNLayer_local = Federated_EXNNLayer_local(layer_idx=layer_idx,
                local_layer=copy.deepcopy(layer),
                client_idx=idx,
                adapter=args.fedexnn_adapter,
                fedexnn_self_dropout=args.fedexnn_self_dropout)
            split_local_layers.append(EXNNLayer_local)
        init_model = Federated_EXNN(
            args,
            idx,
            split_local_layers=split_local_layers,
            num_of_classes=num_of_classes,
            fedexnn_classifer=args.fedexnn_classifer)
        local_FedEXNN_models[idx] = init_model

    for idx in range(args.fedexnn_split_num):
        layer_idx = idx
        federated_EXNNLayer_global = merge_layer(local_FedEXNN_models, layer_idx)
        federated_EXNNLayer_global.freeze()
        # split_local_layers[layer_idx] = federated_EXNNLayer_global
        for client_idx in range(args.num_users):
            local_FedEXNN_models[client_idx].adaptation(
                layer_idx, federated_EXNNLayer_global)
            if layer_idx < len(split_config):
                actual_layer_index = split_config[layer_idx]
                local_FedEXNN_models[client_idx].add_local_layer_adaptor(layer_idx+1,
                    in_channels=out_channels[actual_layer_index]*args.num_users,
                    out_channels=out_channels[actual_layer_index])
    return local_FedEXNN_models[0]



def fed_expandable(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, split_modules, out_channels):

    bst_acc = -1
    description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
    users = []
    locals = []
    split_config = EXNN_Split_Configs[args.model][args.fedexnn_split_num]
    num_of_classes = get_num_of_labels(args.dataset) + 1 #over here

    before_cls_feature_num = out_channels[-1]
    backdoor_test_loader = None
    if args.backdoor_train: 
        backdoor_test_loader = DataLoader(BackdoorDS(test_dataset, args.backdoor_size, mode="random"),
                                    batch_size=256, shuffle=False, num_workers=4)

    local_FedEXNN_models = {}
    for idx in range(args.num_users):
        logger.info("client {}".format(idx))
        users.append("client_{}".format(idx))
        # split_local_layers = copy.deepcopy(split_modules)
        split_local_layers = []
        for layer_idx, layer in enumerate(split_modules):
            EXNNLayer_local = Federated_EXNNLayer_local(layer_idx=layer_idx,
                local_layer=copy.deepcopy(layer),
                client_idx=idx,
                adapter=args.fedexnn_adapter,
                fedexnn_self_dropout=args.fedexnn_self_dropout)
            split_local_layers.append(EXNNLayer_local)
        init_model = Federated_EXNN(
            args,
            idx,
            split_local_layers=split_local_layers,
            num_of_classes=num_of_classes,
            fedexnn_classifer=args.fedexnn_classifer)
        if args.debug_show_exnn_id:
            logging.info(f"==========Checking local layer IDs ================")
            logging.info(f"==========Client:{idx}, split_local_layers :{id(split_local_layers)} ================")
            for layer_idx, layer in enumerate(split_local_layers):
                logging.info(f"==========Client:{idx}, layer_idx{layer_idx} :{id(layer)} ================")
            logging.info(f"==========Client:{idx}, init_model :{id(init_model)} ================")

        local_FedEXNN_models[idx] = init_model
        if args.backdoor_train and idx < args.backdoor_n_clients:
            local_update = FedEXNNLocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                                train_user_groups[idx], test_user_groups[idx], backdoor_train=True)
        else:
            local_update = FedEXNNLocalUpdate(args, train_dataset, test_dataset, global_test_loader,
                                train_user_groups[idx], test_user_groups[idx])
        locals.append(local_update)
        if args.contrastive_train:
            # We use a MLP with one hidden layer to obtain z_i = g(h_i) = W(2)σ(W(1)h_i) where σ is a ReLU non-linearity.
            projector = obtain_projection_head(before_cls_feature_num, args.contrastive_projection_dim)
            local_update.add_CL_head(projector)


    # Train and fuse split layers
    count_para = 0
    # if args.debug:
    # show_model_layers(init_model)
    for key, value in init_model.named_parameters():
        count_para += value.numel()
    logger.info(f"init_model has count_para: {count_para}")

    for idx in range(args.fedexnn_split_num):
        pfl_acc_list = []
        training_pfl_acc_list = []
        train_losses = []
        result_dict = {}

        for client_idx in range(args.num_users):
            # not load global model, for one-shot communication...
            _, train_loss, _, pfl_acc, train_pfl_acc = locals[client_idx].update_weights(
                        client_idx, args.local_ep, local_FedEXNN_models[client_idx], device, if_test=True)
            pfl_acc_list.append(pfl_acc)
            training_pfl_acc_list.append(train_pfl_acc)
            train_losses.append(train_loss)
            result_dict["pfl_acc_on_{}".format(users[client_idx])] = pfl_acc
            result_dict["pfl_training_acc_on_{}".format(users[client_idx])] = training_pfl_acc_list[client_idx]
        avg_train_loss = np.mean(train_losses)
        logger.info(f"test_pfl_acc_list:{pfl_acc_list}")
        logger.info(f"training_pfl_acc_list:{training_pfl_acc_list}")
        # logger.info(f"test_acc:{test_acc}")
        # if idx == 0:
        #     pass
        # else:
        layer_idx = idx
        logger.info(f"=====Merging Layer : {layer_idx} =====")
        federated_EXNNLayer_global = merge_layer(local_FedEXNN_models, layer_idx)
        federated_EXNNLayer_global.freeze()
        # split_local_layers[layer_idx] = federated_EXNNLayer_global
        for client_idx in range(args.num_users):
            local_FedEXNN_models[client_idx].adaptation(
                layer_idx, federated_EXNNLayer_global)
            if args.debug_show_exnn_id:
                logging.info(f"==========Checking global layer IDs ================")
                model = local_FedEXNN_models[client_idx]
                logging.info(f"==========Client:{client_idx}, local_FedEXNN_models.layers[{layer_idx}] : \
                            \n ========== {id(model.layers[layer_idx])} ================")
                for sub_client_idx, local_layer, in federated_EXNNLayer_global.local_layers.items():
                    if hasattr(local_layer, "adapter_nn"):
                        logging.info(f"==========In global layer Client:{sub_client_idx},  \
                            \n ================In global layer  local_FedEXNN_models.layers[{layer_idx}].adapter_nn: {id(local_layer.adapter_nn)}")
            if layer_idx < len(split_config):
                actual_layer_index = split_config[layer_idx]
                local_FedEXNN_models[client_idx].add_local_layer_adaptor(layer_idx+1,
                    in_channels=out_channels[actual_layer_index]*args.num_users,
                    out_channels=out_channels[actual_layer_index])
                if args.debug_show_exnn_id:
                    if hasattr(model.layers[layer_idx+1], "adapter_nn"):
                        logging.info(f"==========Client:{client_idx},  \
                            \n ================local_FedEXNN_models.layers[{layer_idx+1}].adapter_nn: {id(model.layers[layer_idx+1].adapter_nn)}")
            if args.debug_show_exnn_id:
                measure_model = local_FedEXNN_models[client_idx]
                try:
                    if getattr(measure_model.layers[0], "is_global", False):
                        logger.info(f'client_idx: {client_idx} layer0 - model weight: {measure_model.layers[0].local_layers["0"].local_layer._layers[0][0].weight.data.norm()}')
                    if getattr(measure_model.layers[1], "is_global", False):
                        logger.info(f'client_idx: {client_idx} layer1 - model (is_global)  has attr adapter_nn : {hasattr(measure_model.layers[1].local_layers[str(client_idx)], "adapter_nn")}')
                        logger.info(f'client_idx: {client_idx} layer1 - model (is_global)  weight: {measure_model.layers[1].local_layers[str(client_idx)].adapter_nn.weight.data.norm()}')
                    else:
                        logger.info(f'client_idx: {client_idx} layer1 - model (isnot_global) has attr adapter_nn : {hasattr(measure_model.layers[1], "adapter_nn")}')
                        if hasattr(measure_model.layers[1], "adapter_nn"):
                            logger.info(f'client_idx: {client_idx} layer1 - model (isnot_global) weight: {measure_model.layers[1].adapter_nn.weight.data.norm()}')
                    if not getattr(measure_model.layers[2], "is_global", False):
                        logger.info(f'client_idx: {client_idx} local layer2 - model weight: {measure_model.layers[2].local_layer._layers[1].conv1.weight.data.norm()}')
                        logger.info(f'client_idx: {client_idx} local layer2 - model weight: {measure_model.layers[2].local_layer._layers[1].conv2.weight.data.norm()}')
                except:
                    pass
        logger.info(f"=====Finish Merging Layer : {layer_idx} =====")
        ExpTool.record(result_dict)
        logger.info(f"result_dict: {result_dict}")
        ExpTool.record({"comm_round": idx, "local_epoch": args.local_ep, "train_loss": avg_train_loss})
        ExpTool.upload()
    if args.debug_show_exnn_id:
        for client_idx in range(args.num_users):
            logging.info(f"==========Checking global layer IDs ================")
            model = local_FedEXNN_models[client_idx]
            for layer_index, layer in enumerate(model.layers):
                logging.info(f"==========Client:{client_idx}, local_FedEXNN_models.layers[{layer_idx}] : \
                                \n ========== {id(layer)} ================")

    for _, model in local_FedEXNN_models.items():
        model.to("cpu")
    global_model = local_FedEXNN_models[0]
    # if args.debug:
    #     show_model_layers(global_model)

    # Train and fuse classifier
    if args.fedexnn_classifer == "avg":
        new_classifier_weights = average_weights([
            local_FedEXNN_model.classifier.cpu().state_dict()  for local_FedEXNN_model in local_FedEXNN_models.values()])
        new_classifier = list(local_FedEXNN_models.values())[0].classifier
        new_classifier.load_state_dict(new_classifier_weights)

    elif args.fedexnn_classifer == "multihead":
        new_classifier = [
            local_FedEXNN_model.classifier  for local_FedEXNN_model in local_FedEXNN_models.values()]
    else:
        raise NotImplementedError

    global_model.adaptation_classifier(fedexnn_classifer=args.fedexnn_classifer, new_classifier=new_classifier)
    # global_model.train()
    global_model.to(device)
    # Now, there is no local weights in progressive FL, because the model is increasing...
    if args.fedexnn_classifer in ["avg"] :
        if args.contrastive_train:
            # Get the normal dataloader without n views.
            image_size = get_image_size(args.dataset)
            _, _, _, _, train_dataset, test_dataset = load_data(
                image_size, args.dataset, args.datadir)
            dataloaders = {}
            for i, local in enumerate(locals):
                dataloaders[i] = DataLoader(DatasetSplit(train_dataset, train_user_groups[i]),
                        batch_size=args.local_bs, shuffle=True, num_workers=4, drop_last=False)
        else:
            dataloaders = {}
            for i, local in enumerate(locals):
                dataloaders[i] = local.train_loader
        calibrate_classifier(
            global_model, None, dataloaders, args.num_classes, args.sample_per_class, args.lr, device)
    elif args.fedexnn_classifer == "multihead":
        pass
    else:
        raise NotImplementedError
    training_pfl_acc_list = []

    # Test global and ensemble model
    # NOTE: global weights need not to be averaged
    logger.info("avg acc:")
    test_acc, test_loss = test(global_model, global_test_loader, device)
    pfl_accs = []

    result_dict = {}
    for idx in range(args.num_users):
        local_test_acc, _ = test(global_model, locals[idx].test_loader, device)
        result_dict["pfl_acc_on_{}".format(users[idx])] = local_test_acc
        pfl_accs.append(local_test_acc)
    if args.backdoor_train:
        ensemble_backdoor_acc, ensemble_backdoor_loss = test(global_model, backdoor_test_loader, device)
        logger.info(f"ensemble_backdoor_acc: {ensemble_backdoor_acc}")
        ExpTool.record({"ensemble_backdoor_acc": ensemble_backdoor_acc,
                        "ensemble_backdoor_loss": ensemble_backdoor_loss})

    logger.info(f"pfl_accs: {pfl_accs}")
    logger.info(f"training_pfl_acc_list:{training_pfl_acc_list}")
    logger.info(f"test_acc:{test_acc}")

    ExpTool.record(result_dict)
    logger.info(f"result_dict: {result_dict}")
    ExpTool.record({"comm_round": args.fedexnn_split_num + 1, "local_epoch": args.local_ep, "train_loss": avg_train_loss, "test_acc": test_acc})
    ExpTool.upload()
    count_para = 0
    for key, value in global_model.named_parameters():
        count_para += value.numel()
    summary_dict = {"count_paras": count_para}
    logger.info(f"global_model's summary_dict: {summary_dict}")
    ExpTool.summary(summary_dict)

    # ===============================================
    if not args.checkpoint == "no":
        ExpTool.save_pickle(global_model.cpu().state_dict(), args.checkpoint, exp_dir=True)
    # torch.save(global_model.cpu().state_dict(), '{}_{}_{}clients_{}.pkl'.format(args.type, args.dataset, args.num_users, args.alpha))
    return global_model


Overwriting FuseFL/alg_train.py


In [6]:
# %%bash

# cluster_name=localhost
# dataset=cifar10

# source scripts/setup_env.sh
# source scripts/path.sh

# gpu=0

# debug=False
# enable_wandb=True



# num_users=5

# alpha=0.5
# checkpoint=weights
# res_base_width=64
# checkpoint=no
# fedexnn_adapter=cnn1x1
# res_base_width=20
# fedexnn_split_num=4
# local_ep=50
# wandb_entity=cabbagepatch-lahore-university-of-management-sciences
# model=resnet18


# type=fed-expandable
# source scripts/resetup_env.sh

# fedexnn_classifer=${fedexnn_classifer:-avg}

# python3 -u main.py --main_task=train --type=$type  --gpu $gpu  --debug $debug \
# --exp_name ${type}-${dataset}-${model}-nh${num_hidden_features}-c${num_users}-a${alpha}-ep${local_ep}-lr${lr}-clsf${fedexnn_classifer}-adp${fedexnn_adapter}-nxnn${fedexnn_split_num} \
# --checkpoint $checkpoint  \
# --split_measure_local_module_num 8 \
# --fedexnn_classifer  ${fedexnn_classifer} --fedexnn_adapter ${fedexnn_adapter}  --fedexnn_split_num ${fedexnn_split_num} \
# --fedexnn_self_dropout $fedexnn_self_dropout --fedexnn_adapter_constrain_beta $fedexnn_adapter_constrain_beta \
# --model=$model --mlp_hidden_features=$mlp_hidden_features --cnn_hidden_features $cnn_hidden_features --num_layers $num_layers --res_base_width $res_base_width \
# --iid=0 --lr=$lr \
# --dataset=${dataset} --datadir $datadir \
# --alpha=$alpha --seed=1 --num_users=${num_users} --local_ep=$local_ep \
# --wandb_entity ${wandb_entity} --project_name FuseFL --enable_wandb $enable_wandb --wandb_offline False \
# --wandb_key '80702ded3cdc00fb5532f8f21e2ebabb3d2b1b22'





In [7]:
%%writefile FuseFL/helpers/datasets.py
import os
import logging
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import datasets, transforms
import random

from .cl_dataset import _get_simclr_pipeline_transform, _get_cl_transform


def get_image_size(dataset):
    image_size = {
        "mnist": 28,
        "fmnist": 28,
        "SVHN": 32,
        "cifar10": 32,
        "cifar100": 32,
        "Tiny-ImageNet-200": 64,
    }[dataset]
    return image_size

def get_num_of_labels(dataset):
    num_of_labels = {
        "mnist": 10,
        "fmnist": 10,
        "SVHN": 10,
        "cifar10": 10,
        "cifar100": 100,
        "Tiny-ImageNet-200": 200,
    }[dataset]
    return num_of_labels


def load_data(image_size, dataset, datadir, contrastive_train=False, contrastive_n_views=2, **kwargs):
    # data_dir = '/dataset'
    data_dir = datadir
    if contrastive_train:
        contrastive_transform = _get_cl_transform(size=image_size, n_views=contrastive_n_views)

    if dataset == "mnist":
        if contrastive_train:
            train_dataset = datasets.MNIST(data_dir, train=True,
                                        transform=contrastive_transform)
        else:
            train_dataset = datasets.MNIST(data_dir, train=True,
                                        transform=transforms.Compose(
                                            [transforms.ToTensor()]))
        test_dataset = datasets.MNIST(data_dir, train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                      ]))
    elif dataset == "fmnist":
        if contrastive_train:
            train_dataset = datasets.FashionMNIST(data_dir, train=True,
                                                transform=contrastive_transform)
        else:
            train_dataset = datasets.FashionMNIST(data_dir, train=True,
                                                transform=transforms.Compose(
                                                    [transforms.ToTensor()]))
        test_dataset = datasets.FashionMNIST(data_dir, train=False,
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                             ]))
    elif dataset == "SVHN":
        if contrastive_train:
            train_dataset = datasets.SVHN(data_dir, split="train",
                                        transform=contrastive_transform)
        else:
            train_dataset = datasets.SVHN(data_dir, split="train",
                                        transform=transforms.Compose(
                                            [transforms.ToTensor()]))
        test_dataset = datasets.SVHN(data_dir, split="test",
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                     ]))
    elif dataset == "cifar10":
        if contrastive_train:
            train_dataset = datasets.CIFAR10(data_dir, train=True,download=True,
                                            transform=contrastive_transform)
        else:
            train_dataset = datasets.CIFAR10(data_dir, train=True,download=True,
                                            transform=transforms.Compose(
                                                [
                                                    transforms.RandomCrop(32, padding=4),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                ]))
        test_dataset = datasets.CIFAR10(data_dir, train=False,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
    elif dataset == "cifar100":
        if contrastive_train:
            train_dataset = datasets.CIFAR100(data_dir, train=True,
                                            transform=contrastive_transform)
        else:
            train_dataset = datasets.CIFAR100(data_dir, train=True,
                                            transform=transforms.Compose(
                                                [
                                                    transforms.RandomCrop(32, padding=4),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                ]))
        test_dataset = datasets.CIFAR100(data_dir, train=False,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                         ]))

    elif dataset == "Tiny-ImageNet-200":
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomRotation(20),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
            ]),
            'val': transforms.Compose([
                transforms.ToTensor(),
            ]),
            'test': transforms.Compose([
                transforms.ToTensor(),
            ])
        }
        if contrastive_train:
            data_transforms["train"] = contrastive_transform
        data_dir = "data/tiny-imagenet-200/"
        image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                          for x in ['train', 'val', 'test']}
        train_dataset = image_datasets['train']
        test_dataset = image_datasets['val']
    else:
        raise NotImplementedError
    if dataset == "SVHN":
        X_train, y_train = train_dataset.data, train_dataset.labels
        X_test, y_test = test_dataset.data, test_dataset.labels
    else:
        X_train, y_train = train_dataset.data, train_dataset.targets
        X_test, y_test = test_dataset.data, test_dataset.targets
    if "cifar10" in dataset or dataset == "SVHN":
        X_train = np.array(X_train)
        y_train = np.array(y_train)
        X_test = np.array(X_test)
        y_test = np.array(y_test)
    else:
        X_train = X_train.data.numpy()
        y_train = y_train.data.numpy()
        X_test = X_test.data.numpy()
        y_test = y_test.data.numpy()

    return X_train, y_train, X_test, y_test, train_dataset, test_dataset




def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts = {}
    global_unq = np.unique(y_train, return_counts=False)
    # print(f"global_unq: {global_unq}")
    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        # print(f"unq: {unq}, unq_cnt:{unq_cnt}")
        tmp = {}
        for label in global_unq:
            if label not in unq:
                tmp[label] = np.array([0])#this is the only change in this file
            else:
                # tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
                label_idx = np.where(unq == label)
                tmp[label] = unq_cnt[label_idx]
        net_cls_counts[net_i] = tmp


    # print('Data statistics: %s' % str(net_cls_counts))

    return net_cls_counts


def partition_data(image_size, dataset, datadir, partition, alpha=0.4, num_users=5, **kwargs):
    n_parties = num_users
    X_train, y_train, X_test, y_test, train_dataset, test_dataset = load_data(
        image_size, dataset, datadir, **kwargs)
    data_size = y_train.shape[0]

    if partition == "iid":
        idxs = np.random.permutation(data_size)
        batch_idxs = np.array_split(idxs, n_parties)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}

    elif partition == "dirichlet":
        min_size = 0
        min_require_size = 10
        label = np.unique(y_test).shape[0]
        net_dataidx_map = {}

        while min_size < min_require_size:
            idx_batch = [[] for _ in range(n_parties)]
            for k in range(label):
                idx_k = np.where(y_train == k)[0]
                np.random.shuffle(idx_k)  # shuffle the label
                # random [0.5963643 , 0.03712018, 0.04907753, 0.1115522 , 0.2058858 ]
                proportions = np.random.dirichlet(np.repeat(alpha, n_parties))
                proportions = np.array(   # 0 or x
                    [p * (len(idx_j) < data_size / n_parties) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(n_parties):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]
    train_data_cls_counts = record_net_data_stats(y_train, net_dataidx_map)

    test_net_dataidx_map = generate_personalized_data(X_test, y_test, train_data_cls_counts)
    test_data_cls_counts = record_net_data_stats(y_test, test_net_dataidx_map)

    np_test_cls_counts = np.array([list(item.values()) for item in test_data_cls_counts.values()])
    # print("test_data_cls_counts: \n", test_data_cls_counts)
    # print("np_test_cls_counts: \n", np_test_cls_counts)

    return train_dataset, test_dataset, net_dataidx_map, train_data_cls_counts, test_net_dataidx_map, test_data_cls_counts


def generate_personalized_data(X_test, y_test, train_data_cls_counts):
    # train_label = [i.dataset.targets for i in self.train_data_local_dict.values()]
    # label = np.unique(y_test).shape[0]

    # print(train_label[0].shape)
    # class_propotion=np.array([[np.sum(y==i) for i in range(num_classes)] for y in train_label])
    # print(class_propotion)
    # num_train=np.sum(class_propotion)
    # num_class=np.sum(class_propotion, axis=0, keepdims=False)

    num_classes = len(np.unique(y_test))
    n_parties = len(train_data_cls_counts)
    np_train_cls_counts = np.array([list(item.values()) for item in train_data_cls_counts.values()])
    num_class=np.sum(np_train_cls_counts, axis=0, keepdims=False)

    # num_class = [0] * n_parties
    # for i_class in range(num_classes):
    #     for idx in train_data_cls_counts.keys():
    #         num_class[i_class] += train_data_cls_counts[idx][i_class]

    # new_loader=list(zip(X, y))
    # num_test=len(y_test)
    # min_size=0

    # print("train_data_cls_counts: \n", train_data_cls_counts)
    # print("np_train_cls_counts: \n", np_train_cls_counts)
    idx_batch = [[] for _ in range(n_parties)]
    print("num_classes: \n", num_class)

    for k in range(num_classes):
        idx_k = np.where(y_test == k)[0]
        num=len(idx_k)
        np.random.shuffle(idx_k)
        k_num=(np.cumsum(np_train_cls_counts[:, k]*1.0/num_class[k])*num).astype(int)[:-1]
        # print("k_num:", k_num)
        # print("spilt result::::",np.split(idx_k, k_num))
        idx_batch=[idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, k_num))]
    #     print("len(idx_batch)", len(idx_batch))
    # print("[len(idx_j) for idx_j in idx_batch]", [len(idx_j) for idx_j in idx_batch])
    # print("sum of len(idx_j)", sum([len(idx_j) for idx_j in idx_batch]))
    net_dataidx_map = {}

    for j in range(n_parties):
        np.random.shuffle(idx_batch[j])
        net_dataidx_map[j] = idx_batch[j]
    # np.save(f"result/{self.args.dataset}_{self.args.partition_alpha}alpha_{self.args.client_num_in_total}client_testdata_cls_matrix", testdata_cls_matrix)
    return net_dataidx_map










Overwriting FuseFL/helpers/datasets.py


In [8]:
!cp FedOV/attack.py FuseFL/locals/
!cp FedOV/cutpaste.py FuseFL/locals/

# RUN FROM HERE

In [9]:
cd FuseFL

/kaggle/working/FuseFL


In [10]:
%%writefile main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import argparse
import copy
from copy import deepcopy
import os
import math
import shutil
import sys
import warnings
import torchvision.models as models
import numpy as np
from tqdm import tqdm
import pdb
import logging
import time

import torch.nn as nn

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))


from helpers.datasets import partition_data, get_image_size, get_num_of_labels
from helpers.utils import get_dataset, average_weights, DatasetSplit, KLDiv, setup_seed, test, progressive_test
from helpers.exp_path import ExpTool


from models.generator import Generator
from models.nets import (CNNCifar, CNNMnist, CNNCifar100, 
                        make_CNNCifar_seqs, make_CNNCifar_Head_seqs)
from models.pnn import PNN
from models.pnn_cnn import PNN_CNN, pnn_resnet18, pnn_resnet50

from models.fl_pnn import Federated_PNN
from models.fl_pnn_cnn import Federated_PNN_CNN, fl_pnn_resnet18, fl_pnn_resnet50
from models.mlp import MLP, make_MLP_seqs, make_MLP_Head_seqs, mlp2, mlp3
from models.fl_exnn import (MLP_Block, CNN_Block,
    merge_layer, Federated_EXNN, Federated_EXNNLayer_global, Federated_EXNNLayer_local,
    fl_exnn_resnet18, fl_exnn_resnet50, 
)
from models.seq_model import Sequential_SplitNN, ReconMIEstimator, LinearProbes
from models.configs import Split_Configs, EXNN_Split_Configs

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from models.resnet import (resnet18, resnet50, 
            resnet18_layers, resnet50_layers, 
            resnet18_head, resnet50_head, make_ResNetMIEstimator, get_res18_out_channels)
from models.vit import deit_tiny_patch16_224
import wandb

from models.auxiliary_nets import Decoder, AuxClassifier

warnings.filterwarnings('ignore')
upsample = torch.nn.Upsample(mode='nearest', scale_factor=7)

from locals.fedavg import LocalUpdate
from locals.fl_progressive import FedPnnLocalUpdate
from locals.progressive import PnnLocalUpdate
from locals.fl_expandable import FedEXNNLocalUpdate
from locals.ccvr import (compute_classes_mean_cov, generate_virtual_representation,
    calibrate_classifier, get_means_covs_from_client)

from alg_train import Ensemble, pretrain, progressive, fed_progressive, fed_expandable, init_fedexnn_merged
from utils import seq_map_values, batch, accuracy, show_model_layers

from helpers.meter import AverageMeter

from tsne_draw import draw_tsne




def str2bool(v):
    if isinstance(v, bool):
        return v
    # if v.lower() in ('yes', 'true', 't', 'y', '1'):
    if isinstance(v, str) and v.lower() in ('true', 'True'):
        return True
    elif isinstance(v, str) and v.lower() in ('false', 'False'):
        return False
    else:
        return v
        # raise argparse.ArgumentTypeError('Boolean value expected.')


def logging_config(args, process_id):
    # customize the log format
    while logging.getLogger().handlers:
        logging.getLogger().handlers.clear()
    log = logging.getLogger()  # root logger
    for hdlr in log.handlers[:]:  # remove all old handlers
        log.removeHandler(hdlr)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
    ch = logging.StreamHandler()
    ch.setLevel(logging.WARNING)

    formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
    ch.setFormatter(formatter)

    logger.addHandler(ch)

    logger.info(args)
    return logger



def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--epochs', type=int, default=10,
                        help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=5,
                        help="number of users: K")
    parser.add_argument('--frac', type=float, default=1,
                        help='the fraction of clients: C')
    parser.add_argument('--local_ep', type=int, default=100,
                        help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=128,
                        help="local batch size: B")
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='SGD momentum (default: 0.5)')
    # other arguments
    parser.add_argument('--dataset', type=str, default='cifar10', help="name \
                        of dataset")
    parser.add_argument('--datadir', type=str, required=False, default="./data/", help="Data directory")
    parser.add_argument('--iid', type=int, default=1,
                        help='Default set to IID. Set to 0 for non-IID.')
    parser.add_argument('--gpu', type=int, required=False, default=0)
    parser.add_argument('--num_classes', type=int, default=10, help='.')
    parser.add_argument('--sample_per_class', type=int, default=5000, help='.')
    parser.add_argument('--num_layers', type=int, default=2, help='.')
    parser.add_argument('--mlp_hidden_features', type=int, default=100, help='.')
    parser.add_argument('--cnn_hidden_features', type=int, default=128, help='.')
    parser.add_argument('--res_base_width', type=int, default=64, help='.')
    parser.add_argument('--res_group_norm', type=int, default=0, help='.')


    # Data Free
    parser.add_argument('--adv', default=0, type=float, help='scaling factor for adv loss')

    parser.add_argument('--bn', default=0, type=float, help='scaling factor for BN regularization')
    parser.add_argument('--oh', default=0, type=float, help='scaling factor for one hot loss (cross entropy)')
    parser.add_argument('--act', default=0, type=float, help='scaling factor for activation loss used in DAFL')
    parser.add_argument('--save_dir', default='run/synthesis', type=str)
    parser.add_argument('--partition', default='dirichlet', type=str)
    parser.add_argument('--alpha', default=0.5, type=float,
                        help=' If alpha is set to a smaller value, '
                            'then the partition is more unbalanced')

    # Basic
    parser.add_argument('--lr_g', default=1e-3, type=float,
                        help='initial learning rate for generation')
    parser.add_argument('--T', default=1, type=float)
    parser.add_argument('--g_steps', default=20, type=int, metavar='N',
                        help='number of iterations for generation')
    parser.add_argument('--batch_size', default=256, type=int, metavar='N',
                        help='number of total iterations in each epoch')
    parser.add_argument('--nz', default=256, type=int, metavar='N',
                        help='number of total iterations in each epoch')
    parser.add_argument('--synthesis_batch_size', default=256, type=int)

    # Misc
    parser.add_argument('--seed', default=None, type=int,
                        help='seed for initializing training.')
    parser.add_argument('--type', default="pretrain", type=str,
                        help='.')
    parser.add_argument('--main_task', default="train", type=str,
                        help='.')   # train, MI, 
    parser.add_argument('--model', default="", type=str,
                        help='.')
    parser.add_argument('--other', default="", type=str,
                        help='.')
    parser.add_argument('--logging_level', default="INFO", type=str,
                        help='.')
    parser.add_argument('--debug', default="False", type=str,
                        help='.')
    parser.add_argument('--debug_show_exnn_id', default="False", type=str,
                        help='.')
    # 'INFO' or 'DEBUG'

    # federated progressive
    parser.add_argument('--progressive_classifer', default="fixed", type=str,
                        help='.') # fixed, progressive

    # federated expandable NN
    parser.add_argument('--fedexnn_classifer', default="avg", type=str,
                        help='.') #   fixed   multihead
    parser.add_argument('--fedexnn_adapter', default="avg", type=str,
                        help='.') 
    parser.add_argument('--fedexnn_split_num', default=2, type=int,
                        help='.') 
    parser.add_argument('--fedexnn_hetero_layer_depth', default="False", type=str,
                        help='.') 
    parser.add_argument('--fedexnn_self_dropout', default=0.0, type=float,
                        help='.') 
    parser.add_argument('--fedexnn_adapter_constrain_beta', default=0.0, type=float,
                        help='.') 

    # split related 
    parser.add_argument('--split_train', default="False", type=str,
                        help='.') 
    parser.add_argument('--split_local_module_num', default=2, type=int,
                        help='.') 
    parser.add_argument('--split_measure_local_module_num', default=2, type=int,
                        help='.') 
    parser.add_argument('--infopro', default=2, type=int,
                        help='.') 
    parser.add_argument('--MI_cos_lr', default="False", type=str,
                        help='.') 

    # contrastive train
    parser.add_argument('--contrastive_train', default="False", type=str,
                        help='.')
    parser.add_argument('--contrastive_n_views', default=2, type=int,
                        help='.')
    parser.add_argument('--contrastive_weight', default=1.0, type=float,
                        help='.')
    parser.add_argument('--contrastive_projection_dim', default=64, type=int,
                        help='.')

    # backdoor train
    parser.add_argument('--backdoor_train', default="False", type=str,
                        help='.')
    parser.add_argument('--backdoor_n_clients', default=1, type=int,
                        help='.')
    parser.add_argument('--backdoor_size', default=10, type=int,
                        help='.')


    parser.add_argument('--checkpoint', default='no', type=str, metavar='PATH',
                        help='path to save checkpoint (default: checkpoint)')
    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')


    # spurious related 
    parser.add_argument('--spufeat', default="", type=str,
                        help='.') 
    parser.add_argument('--aux_net_config', default='1c2f', type=str,
                        help='architecture of auxiliary classifier / contrastive head '
                            '(default: 1c2f; 0c1f refers to greedy SL)'
                            '[0c1f|0c2f|1c1f|1c2f|1c3f|2c2f]')
    parser.add_argument('--local_loss_mode', default='contrast', type=str,
                        help='ways to estimate the task-relevant info I(x, y)'
                            '[contrast|cross_entropy]')
    parser.add_argument('--aux_net_widen', default=1.0, type=float,
                        help='widen factor of the two auxiliary nets (default: 1.0)')
    parser.add_argument('--aux_net_feature_dim', default=0, type=int,
                        help='number of hidden features in auxiliary classifier / contrastive head '
                            '(default: 128)')
    parser.add_argument('--ixx_1', default=0.0, type=float,)   # \lambda_1 for 1st local module
    parser.add_argument('--ixy_1', default=0.0, type=float,)   # \lambda_2 for 1st local module

    parser.add_argument('--ixx_2', default=0.0, type=float,)   # \lambda_1 for (K-1)th local module
    parser.add_argument('--ixy_2', default=0.0, type=float,)   # \lambda_2 for (K-1)th local module

    # EstMI
    parser.add_argument('--EstMI_method', default="infopro", type=str,
                        help='number of local modules (1 refers to end-to-end training)')
    parser.add_argument('--EstFeatNorm', default="no", type=str, help='')
    parser.add_argument('--SaveFeats', default="no", type=str, help='')
    parser.add_argument('--TSNE', default="no", type=str, help='')
    parser.add_argument('--TSNE_points', default=500, type=int, help='')


    # wandb, exp record related
    parser.add_argument("--wandb_offline", type=str, default="True")
    parser.add_argument("--wandb_console", type=str, default="False")
    parser.add_argument("--wandb_entity", type=str, default="your-wandb-entity")
    parser.add_argument("--wandb_key", type=str, default=None)

    parser.add_argument("--exp_abs_path", type=str, default=".")
    parser.add_argument("--project_name", type=str, default="your-wandb-project")
    parser.add_argument("--exp_name", type=str, default="OneShot-FL")
    parser.add_argument("--override_cmd_args", action="store_true")
    parser.add_argument("--tag", type=str, default="debug")
    parser.add_argument("--exp_tool_init_sub_dir", type=str, default="no")

    parser.add_argument("--enable_wandb", type=str, default="False")


    args = parser.parse_args()
    for key in args.__dict__.keys():
        args.__dict__[key] = str2bool(args.__dict__[key])
    return args



def kd_train(synthesizer, model, criterion, optimizer):
    student, teacher = model
    student.train()
    teacher.eval()
    description = "loss={:.4f} acc={:.2f}%"
    total_loss = 0.0
    correct = 0.0
    with tqdm(synthesizer.get_data()) as epochs:
        for idx, (images) in enumerate(epochs):
            optimizer.zero_grad()
            images = images
            with torch.no_grad():
                t_out = teacher(images)
            s_out = student(images.detach())
            loss_s = criterion(s_out, t_out.detach())

            loss_s.backward()
            optimizer.step()

            total_loss += loss_s.detach().item()
            avg_loss = total_loss / (idx + 1)
            pred = s_out.argmax(dim=1)
            target = t_out.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
            acc = correct / len(synthesizer.data_loader.dataset) * 100

            epochs.set_description(description.format(avg_loss, acc))


def save_checkpoint(state, is_best, filename='checkpoint.pth'):
    if is_best:
        torch.save(state, filename)


def get_data_info(args):
    if args.dataset == "mnist":
        image_size = 28
        linear_in_feautres = image_size * image_size * 1
        channels = 1
    elif args.dataset == "fmnist":
        image_size = 28
        linear_in_feautres = image_size * image_size * 1
        channels = 1
    elif args.dataset == "SVHN":
        image_size = 32
        linear_in_feautres = image_size * image_size * 3
        channels = 3
    elif args.dataset == "cifar10":
        image_size = 32
        linear_in_feautres = image_size * image_size * 3
        channels = 3
    elif args.dataset == "cifar100":
        image_size = 32
        linear_in_feautres = image_size * image_size * 3
        channels = 3
    elif args.dataset == "Tiny-ImageNet-200":
        image_size = 64
        linear_in_feautres = image_size * image_size * 3
        channels = 3
    else:
        pass
    return image_size, linear_in_feautres, channels


def get_model(args, num_of_classes=10):
    linear_in_feautres = None
    dataset = args.dataset
    # split_config = Split_Configs[args.model][args.split_local_module_num]
    # split_measure_config = Split_Configs[args.model][args.split_measure_local_module_num]
  
    # split_measure_config = EXNN_Split_Configs[args.model][args.split_measure_local_module_num]
    #the only changes in this file were due to vgg-9 constraints so not changing it is also viable

    layers = None
    image_size, linear_in_feautres, channels = get_data_info(args)
    if args.type == "fed-expandable":
        small_layers = None
        large_layers = None
        if args.model == "mlp3":
            hidden_features = args.mlp_hidden_features
            layers = mlp3(linear_in_feautres, hidden_features, num_of_classes, init_classifier=False)
            if args.fedexnn_hetero_layer_depth:
                small_layers = mlp3(linear_in_feautres, hidden_features // 2, num_of_classes, init_classifier=False)
                large_layers = mlp3(linear_in_feautres, int(hidden_features * 1.5), num_of_classes, init_classifier=False)
        elif args.model == "cnn":
            hidden_features = args.cnn_hidden_features
            layers = make_CNNCifar_seqs(3, hidden_features, num_of_classes, init_classifier=False)
        elif args.model == "resnet18":
            split_local_layers = fl_exnn_resnet18(group_norm=args.res_group_norm,
                                            res_base_width=args.res_base_width, in_channels=channels, 
                                            hetero_layer_depth=args.fedexnn_hetero_layer_depth)
            layers, small_layers, large_layers = split_local_layers
        elif args.model == "resnet50":
            split_local_layers = fl_exnn_resnet50(group_norm=args.res_group_norm,
                                            res_base_width=args.res_base_width, in_channels=channels, 
                                            hetero_layer_depth=args.fedexnn_hetero_layer_depth)
            layers, small_layers, large_layers = split_local_layers
        else:
            raise NotImplementedError

        # split_config = Split_Configs[args.model][args.fedexnn_split_num]

        split_config = EXNN_Split_Configs[args.model][args.fedexnn_split_num]

        begin_index = 0
        split_modules = []
        for layer_index in split_config:
            split_module = Sequential_SplitNN(None, None, 
                                None, None,
                                layers[begin_index: layer_index+1])
            begin_index = layer_index + 1
            split_modules.append(split_module)
        split_module = Sequential_SplitNN(None, None, 
                            None, None,
                            layers[begin_index:])
        split_modules.append(split_module)
        assert len(split_modules) == args.fedexnn_split_num

        return layers, split_modules


    if args.type == "progressive":
        # if args.model == "pnn":
        if args.model == "mlp3":
            hidden_features = args.mlp_hidden_features
            model = PNN(num_layers=args.num_layers,
                            in_features=linear_in_feautres,
                            hidden_features_per_column=hidden_features,
                            num_of_classes=num_of_classes)
        # elif args.model == "pnn-cnn":
        elif args.model == "cnn":
            hidden_features = args.cnn_hidden_features
            model = PNN_CNN(num_layers=args.num_layers,
                        in_features=channels,
                        hidden_features_per_column=hidden_features,
                        num_of_classes=num_of_classes,
                        adapter="cnn",
                        )
        elif args.model == "resnet18":
            model = pnn_resnet18(num_classes=num_of_classes, group_norm=args.res_group_norm,
                                res_base_width=args.res_base_width, in_channels=channels, adapter="cnn")
        elif args.model == "resnet50":
            model = pnn_resnet50(num_classes=num_of_classes, group_norm=args.res_group_norm,
                                res_base_width=args.res_base_width, in_channels=channels, adapter="cnn")
        return model


    if args.type == "fed-progressive":
        # if args.model == "fl-pnn":
        if args.model == "mlp3":
            hidden_features = args.mlp_hidden_features
            model = Federated_PNN(num_layers=args.num_layers,
                            in_features=3,
                            hidden_features_per_column=hidden_features,
                            num_of_classes=num_of_classes,
                            classifier_name=args.progressive_classifer
                            )
        # elif args.model == "fl-pnn-cnn":
        elif args.model == "cnn":
            hidden_features = args.cnn_hidden_features
            model = Federated_PNN_CNN(num_layers=args.num_layers,
                        in_features=channels,
                        hidden_features_per_column=hidden_features,
                        num_of_classes=num_of_classes,
                        adapter="cnn",
                        classifier_name=args.progressive_classifer
                    )
        elif args.model == "resnet18":
            model = fl_pnn_resnet18(num_classes=num_of_classes, group_norm=args.res_group_norm,
                                res_base_width=args.res_base_width, in_channels=channels,
                                adapter="cnn", classifier_name=args.progressive_classifer)
        elif args.model == "resnet50":
            model = fl_pnn_resnet50(num_classes=num_of_classes, group_norm=args.res_group_norm,
                                res_base_width=args.res_base_width, in_channels=channels,
                                adapter="cnn", classifier_name=args.progressive_classifer)
        return model

    if args.model == "mnist_cnn":
        model = CNNMnist()
    elif args.model == "fmnist_cnn":
        model = CNNMnist()
    elif args.model == "cnn":
        hidden_features = args.cnn_hidden_features
        # model = CNNCifar(hidden_features, num_of_classes)
        layers = make_CNNCifar_seqs(3, hidden_features, num_of_classes, init_classifier=False)
        model = Sequential_SplitNN(args.split_train, split_config, 
                            split_measure_config, args.split_local_module_num,
                            layers)
    elif args.model == "mlp2":
        hidden_features = args.mlp_hidden_features
        layers = mlp2(linear_in_feautres, hidden_features, num_of_classes, init_classifier=True)
        model = Sequential_SplitNN(args.split_train, split_config, 
                            split_measure_config, args.split_local_module_num,
                            layers)
    elif args.model == "mlp3":
        hidden_features = args.mlp_hidden_features
        layers = mlp3(linear_in_feautres, hidden_features, num_of_classes, init_classifier=True)
        model = Sequential_SplitNN(args.split_train, split_config, 
                            split_measure_config, args.split_local_module_num,
                            layers)

    elif args.model == "svhn_cnn":
        hidden_features = args.cnn_hidden_features
        model = CNNCifar(hidden_features, num_of_classes)
    elif args.model == "cifar100_cnn":
        model = CNNCifar100()
    elif args.model == "resnet18":
        # model = resnet18(num_classes=num_of_classes, group_norm=args.res_group_norm, res_base_width=args.res_base_width, in_channels=channels)
        layers = resnet18_layers(init_classifier=True,
            num_classes=num_of_classes, group_norm=args.res_group_norm, res_base_width=args.res_base_width, in_channels=channels)
        model = Sequential_SplitNN(args.split_train, split_config, 
                            split_measure_config, args.split_local_module_num,
                            layers)
        # resnet18_head, resnet50_head
    elif args.model == "resnet50":
        layers = resnet50_layers(init_classifier=True,
            num_classes=num_of_classes, group_norm=args.res_group_norm, res_base_width=args.res_base_width, in_channels=channels)
        model = Sequential_SplitNN(args.split_train, split_config, 
                            split_measure_config, args.split_local_module_num,
                            layers)

    elif args.model == "vit":
        model = deit_tiny_patch16_224(num_classes=num_of_classes,
                                             drop_rate=0.,
                                             drop_path_rate=0.1)
        model.head = torch.nn.Linear(model.head.in_features, num_of_classes)
        model = torch.nn.DataParallel(model)

    return layers, model


def adjust_learning_rate(optimizer, epoch, training_configurations, args):
    """Sets the learning rate"""
    if not args.MI_cos_lr:
        if epoch in training_configurations[args.model]['changing_lr']:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= training_configurations[args.model]['lr_decay_rate']
        print('lr:')
        for param_group in optimizer.param_groups:
            print(param_group['lr'])

    else:
        for param_group in optimizer.param_groups:
            if epoch <= 10:
                param_group['lr'] = 0.5 * training_configurations[args.model]['initial_learning_rate']\
                                * (1 + math.cos(math.pi * epoch / training_configurations[args.model]['epochs'])) * (epoch - 1) / 10 + 0.01 * (11 - epoch) / 10
            else:
                param_group['lr'] = 0.5 * training_configurations[args.model]['initial_learning_rate']\
                                    * (1 + math.cos(math.pi * epoch / training_configurations[args.model]['epochs']))
        print('lr:')
        for param_group in optimizer.param_groups:
            print(param_group['lr'])




def measure_feautre(device, data_loader, model):
    """Eval for one epoch on the training set"""
    model.eval()
    layer_channel_norms = {}
    layer_total_norm = {}

    total_batches = 0
    with torch.no_grad():
        for i, (x, target) in enumerate(data_loader):
            target = target.to(device)
            x = x.to(device)
            output, hidden_xs = model.forward_measure(x)
            if args.type == "fed-expandable":
                hidden_xs = to_exnn_hidden_xs(hidden_xs)
            for layer_idx, features in hidden_xs.items():
                # norm on height and weight, output shape is [batch_size, num_channels]
                norms = torch.norm(features, p=2, dim=[2, 3])

                # average for mini-batch
                # shape is [num_channels]
                batch_mean_norms = torch.mean(norms, dim=0)
                if layer_idx not in layer_channel_norms:
                    layer_channel_norms[layer_idx] = batch_mean_norms
                else:
                    layer_channel_norms[layer_idx] += batch_mean_norms
            total_batches += 1

    for layer_idx in layer_channel_norms.keys():
        layer_channel_norms[layer_idx] = (layer_channel_norms[layer_idx] / total_batches)
        layer_total_norm[layer_idx] = torch.norm(layer_channel_norms[layer_idx], p=2).item()
    return layer_channel_norms, layer_total_norm


def get_all_feature(device, data_loader, model, num_points=1000):
    model.eval()

    layer_feats = {}
    labels = []
    loaded_num_points = 0

    with torch.no_grad():
        for i, (x, target) in enumerate(data_loader):
            x = x.to(device)
            loaded_num_points += x.shape[0]
            output, hidden_xs = model.forward_measure(x)
            if args.type == "fed-expandable":
                hidden_xs = to_exnn_hidden_xs(hidden_xs)
            labels.append(target)
            for layer_idx, features in hidden_xs.items():
                if layer_idx not in layer_feats:
                    layer_feats[layer_idx] = []
                layer_feats[layer_idx].append(features)
            if loaded_num_points > num_points:
                break
        for layer_idx in layer_feats.keys():
            layer_feats[layer_idx] = torch.cat(layer_feats[layer_idx], dim=0)[:num_points].to('cpu')
        labels = torch.cat(labels, dim=0)[:num_points]

    return layer_feats, labels






def estMI(device, train_loader, model, estimator, optimizer, epoch, num_layers):
    """Train for one epoch on the training set"""
    layer_top1s = [AverageMeter() for _ in range(num_layers)]

    record_file = ExpTool.get_file_name("EstiMI.txt", exp_dir=True)
    model.eval()

    loss_ixx_modules_iters = []
    loss_ixy_modules_iters = []

    local_iters = len(train_loader)

    for i, (x, target) in enumerate(train_loader):
        target = target.to(device)
        x = x.to(device)

        optimizer.zero_grad()
        output, hidden_xs = model.forward_measure(x)
        if args.type == "fed-expandable":
            hidden_xs = to_exnn_hidden_xs(hidden_xs)

        # show_model_layers(model, logger=None)
        # for k, decode in decoders.items():
        #     logger.info(f"====decoder {k}==============================")
        #     show_model_layers(decode, logger)
        #     logger.info(f"====aux_classifier {k}==============================")
        #     show_model_layers(aux_classifiers[k], logger)

        # for layer_index, hidden_x in hidden_xs.items():
        #     logging.info(f"layer: {layer_index}, has tensor shape: {hidden_x.shape}")

        h_logits, loss_ixx_modules, loss_ixy_modules = estimator(x, hidden_xs, target)

        loss_ixx_modules_iters.append(loss_ixx_modules)
        loss_ixy_modules_iters.append(loss_ixy_modules)
        optimizer.step()

        for layer_i, logits in enumerate(h_logits):
            prec1 = accuracy(logits.data, target, topk=(1,))[0]
            layer_top1s[layer_i].update(prec1.item(), x.size(0))

        if (i+1) % 10 == 0:
            # print(discriminate_weights)
            fd = open(record_file, 'a+')
            string = f"Training Epoch: [{epoch}][{i}/{local_iters}], loss_ixx: {[round(loss_ixx, 3) for loss_ixx in loss_ixx_modules]} " + \
                f"loss_ixy: {[round(loss_ixy, 3) for loss_ixy in loss_ixy_modules]} " + \
                f"top1s: {[round(top1s.val, 3) for top1s in layer_top1s]} "

            logging.info(string)
            # print(weights)
            fd.write(string + '\n')
            fd.close()

    loss_ixx_modules_iters = np.array(loss_ixx_modules_iters)
    loss_ixy_modules_iters = np.array(loss_ixy_modules_iters)
    loss_ixx_modules_iters = np.mean(loss_ixx_modules_iters, axis=0)
    loss_ixy_modules_iters = np.mean(loss_ixy_modules_iters, axis=0)
    fd = open(record_file, 'a+')
    string = f"Training Epoch: [{epoch}], loss_ixx avg: {[round(loss_ixx, 3) for loss_ixx in loss_ixx_modules_iters]} " + \
            f"loss_ixy avg: {[round(loss_ixy, 3) for loss_ixy in loss_ixy_modules_iters]} " + \
            f"top1s avg: {[round(top1s.avg, 3) for top1s in layer_top1s]} "
    logging.info(string)
    fd.write(string + '\n')
    fd.close()
    loss_ixxs = [round(loss_ixx, 3) for loss_ixx in loss_ixx_modules_iters]
    top1s_avg = [round(top1s.avg, 3) for top1s in layer_top1s]

    return loss_ixxs, top1s_avg



def train_linear_probe(device, train_loader, model, linear_probes, optimizer, epoch, num_layers):
    """Train for one epoch on the training set"""
    layer_top1s = [AverageMeter() for _ in range(num_layers)]

    record_file = ExpTool.get_file_name("EstiMI.txt", exp_dir=True)
    model.eval()

    loss_ixys_iters = []
    local_iters = len(train_loader)
    for i, (x, target) in enumerate(train_loader):
        target = target.to(device)
        x = x.to(device)

        optimizer.zero_grad()
        output, hidden_xs = model.forward_measure(x)
        if args.type == "fed-expandable":
            hidden_xs = to_exnn_hidden_xs(hidden_xs)
        h_logits, loss_ixys = linear_probes(x, hidden_xs, target)
        loss_ixys_iters.append(loss_ixys)
        optimizer.step()

        for layer_i, logits in enumerate(h_logits):
            prec1 = accuracy(logits.data, target, topk=(1,))[0]
            layer_top1s[layer_i].update(prec1.item(), x.size(0))

        if (i+1) % 10 == 0:
            # print(discriminate_weights)
            fd = open(record_file, 'a+')
            string = f"Training Epoch: [{epoch}][{i}/{local_iters}], " + \
                f"loss_ixy: {[round(loss_ixy, 3) for loss_ixy in loss_ixys]} " + \
                f"top1s: {[round(top1s.val, 3) for top1s in layer_top1s]} "

            logging.info(string)
            # print(weights)
            fd.write(string + '\n')
            fd.close()
    loss_ixys_iters = np.array(loss_ixys_iters)
    loss_ixys_iters = np.mean(loss_ixys_iters, axis=0)
    fd = open(record_file, 'a+')
    string = f"Training Epoch: [{epoch}]," + \
            f"loss_ixy avg: {[round(loss_ixy, 3) for loss_ixy in loss_ixys_iters]} " + \
            f"top1s avg: {[round(top1s.avg, 3) for top1s in layer_top1s]} "
    logging.info(string)
    fd.write(string + '\n')
    fd.close()
    top1s_avg = [round(top1s.avg, 3) for top1s in layer_top1s]

    return top1s_avg



def get_res_MIEstimator(split_measure_config, num_of_classes, group_norm, res_base_width, channels):
    layers = resnet18_layers(init_classifier=True,
        num_classes=num_of_classes, group_norm=group_norm, res_base_width=res_base_width, in_channels=channels)

    decoders, aux_classifiers = make_ResNetMIEstimator(
        layers, hidden_x_channels, image_size, aux_net_widen=1)

    mi_estimator = ReconMIEstimator(split_measure_config)

    for layer_index, decoder in decoders.items(): 
        mi_estimator.add_decoder(decoder, layer_index)
    for layer_index, aux_classifier in aux_classifiers.items(): 
        mi_estimator.add_aux_classifier(aux_classifier, layer_index)
    return mi_estimator



if __name__ == '__main__':

    args = args_parser()

    if args.main_task == "train":
        ExpTool.init(args)
    elif args.main_task in ["MI", "LinearProbe"]:
        if not args.exp_tool_init_sub_dir == "no":
            ExpTool.init_with_sub_dir(args, args.exp_tool_init_sub_dir)
        else:
            ExpTool.init(args)
    else:
        raise NotImplementedError

    logger = logging_config(args, 0)
    # wandb.init(config=args,
    #            project="ont-shot FL")

    device = torch.device(f"cuda:{args.gpu}")
    setup_seed(args.seed)
    # pdb.set_trace()
    image_size = get_image_size(args.dataset)
    num_of_classes = get_num_of_labels(args.dataset)
    train_dataset, test_dataset, train_user_groups, train_data_cls_counts, test_user_groups, test_data_cls_counts = partition_data(
        image_size, args.dataset, args.datadir, args.partition, alpha=args.alpha, num_users=args.num_users,
        contrastive_train=args.contrastive_train, contrastive_n_views=args.contrastive_n_views)

    logger.info(f"train_data_cls_counts: {train_data_cls_counts}")
    logger.info(f"test_data_cls_counts: {test_data_cls_counts}")

    global_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256,
                                              shuffle=False, num_workers=4)
    # BUILD MODEL

    mi_estimator_configurations = {
        'resnet18': {
            'epochs': 160,
            'batch_size': 128,
            'initial_learning_rate': 0.01,
            # 'batch_size': 1024 if args.dataset in ['cifar10', 'svhn'] else 128,
            # 'initial_learning_rate': 0.8 if args.dataset in ['cifar10', 'svhn'] else 0.1,
            'changing_lr': [80, 120],
            'lr_decay_rate': 0.1,
            'momentum': 0.9,
            'nesterov': True,
            'weight_decay': 1e-4,
        },
        'resnet50': {
            'epochs': 160,
            'batch_size': 1024 if args.dataset in ['cifar10', 'svhn'] else 128,
            'initial_learning_rate': 0.8 if args.dataset in ['cifar10', 'svhn'] else 0.1,
            'changing_lr': [80, 120],
            'lr_decay_rate': 0.1,
            'momentum': 0.9,
            'nesterov': True,
            'weight_decay': 1e-4,
        },
    }

    linear_probe_configurations = {
        'resnet18': {
            'epochs': 10,
            'batch_size': 128,
            'initial_learning_rate': 0.01,
            'momentum': 0.9,
            'nesterov': True,
            'weight_decay': 1e-4,
        },
    }

    layers, global_model = get_model(args, num_of_classes)
    # split_measure_config = EXNN_Split_Configs[args.model][args.split_measure_local_module_num]
    # split_measure_config = Split_Configs[args.model][args.split_measure_local_module_num]
    image_size, linear_in_feautres, channels = get_data_info(args)

    if args.model == "resnet18":
        out_channels = get_res18_out_channels(args.res_base_width)
    elif args.model == "mlp2":
        out_channels = [args.mlp_hidden_features for _ in range(2)]
    elif args.model == "mlp3":
        out_channels = [args.mlp_hidden_features for _ in range(3)]
    elif args.model == "cnn":
        out_channels = [args.cnn_hidden_features for _ in range(3)]
    else:
        raise NotImplementedError


    if args.main_task == "train":
        if args.type == "pretrain":
            global_model, global_weights, local_weights, model_list = pretrain(
                args, device, logger, train_dataset, test_dataset, 
                train_user_groups, train_data_cls_counts, 
                test_user_groups, test_data_cls_counts,
                global_test_loader, global_model, out_channels)
        elif args.type == "progressive":
            progressive(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels)
        elif args.type == "fed-progressive":
            fed_progressive(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels)
        elif args.type == "fed-expandable":
            fed_expandable(args, device, logger, train_dataset, test_dataset, 
            train_user_groups, train_data_cls_counts, 
            test_user_groups, test_data_cls_counts,
            global_test_loader, global_model, out_channels)
        else:
            raise RuntimeError

    elif args.main_task == "MI":
        if args.type == "pretrain":
            assert args.resume
            local_weights = ExpTool.load_pickle(args.resume, exp_dir=False)
            model_list = []
            for i in range(len(local_weights)):
                net = copy.deepcopy(global_model)
                net.load_state_dict(local_weights[i])
                model_list.append(net)
            ensemble_model = Ensemble(model_list)
            # global_model_test_acc, test_loss = test(global_model, global_test_loader, device)
            # logger.info(f"global_model acc: {global_model_test_acc}")

            local_model = model_list[0]
            # local_model_test_acc, test_loss = test(local_model, global_test_loader, device)
            # logger.info(f"local_model acc: {local_model_test_acc}")

            # ensemble_acc, ensemble_loss = test(ensemble_model, global_test_loader, device)
            # logger.info(f"ensemble acc: {ensemble_acc}")
            measure_model = local_model
            if not args.EstFeatNorm == "no":
                idx = 0
                local_train_loader = DataLoader(DatasetSplit(train_dataset, train_user_groups[idx]),
                                            batch_size=args.local_bs, shuffle=True, num_workers=4, drop_last=False)
                local_test_loader = DataLoader(DatasetSplit(test_dataset, test_user_groups[idx]),
                                            batch_size=args.local_bs, shuffle=False, num_workers=4, drop_last=False)
                EstFeatNorm_results = {}
                record_file = ExpTool.get_file_name("EstFeatNorm.txt", exp_dir=True)
                fd = open(record_file, 'a+')
                for client_idx, model in enumerate(model_list):
                    EstFeatNorm_results[client_idx] = {}
                    model.to(device)
                    layer_channel_norms, layer_total_norm = measure_feautre(device, local_train_loader, model)
                    model.to("cpu")
                    EstFeatNorm_results[client_idx]["layer_channel_norms"] = layer_channel_norms
                    EstFeatNorm_results[client_idx]["layer_total_norm"] = layer_total_norm
                    for layer_idx, channel_norms in layer_channel_norms.items():
                        ExpTool.logging_write(f"client_idx:{client_idx}, layer_idx:{layer_idx}: layer_total_norm = {layer_total_norm[layer_idx]}", fd)
                        # ExpTool.logging_write(f"channel_norms:{channel_norms} =============", fd)
                fd.close()
                ExpTool.save_pickle(EstFeatNorm_results, "EstFeatNorm_results", exp_dir=True)
                ExpTool.finish(args)
                exit()

            if not args.SaveFeats == "no":
                local_FeatLabels_results = {}
                for client_idx, model in enumerate(model_list):
                    if client_idx > 1:
                        break
                    logging.info(f"get client {client_idx} features")
                    model.to(device)
                    layer_feats, labels = get_all_feature(device, global_test_loader, model, num_points=1000)
                    model.to("cpu")
                    local_FeatLabels_results[client_idx] = {
                        "layer_feats": layer_feats,
                        "labels": labels}
                ExpTool.save_pickle(local_FeatLabels_results, "local_FeatLabels_results", exp_dir=True)
                # global_FeatLabels_results = {}
                # for client_idx, model in enumerate(model_list):
                #     logging.info(f"get client {client_idx} features")
                #     model.to(device)
                #     layer_feats, labels = get_all_feature(device, global_test_loader, model, num_points=1000)
                #     model.to("cpu")
                #     global_FeatLabels_results[client_idx] = {
                #         "layer_feats": layer_feats,
                #         "labels": labels}
                # ExpTool.save_pickle(global_FeatLabels_results, "global_FeatLabels_results", exp_dir=True)
                # if not args.TSNE == "no":
                ExpTool.load_pickle("local_FeatLabels_results", exp_dir=True)
                avg_pool = nn.AdaptiveAvgPool2d((1, 1))
                # avg_pool.to(device)
                for client_idx in local_FeatLabels_results.keys():
                    layer_feats = local_FeatLabels_results[client_idx]["layer_feats"]
                    labels = local_FeatLabels_results[client_idx]["labels"]
                    for layer_index, features in layer_feats.items():
                        logging.info(f"T-SNE on client {client_idx}, layer {layer_index} ...... ")
                        tSNE_save_path = ExpTool.get_file_name(f"local_c{client_idx}_l{layer_index}_TSNE.pdf", exp_dir=True)
                        if len(features.shape) > 2:
                            features = avg_pool(features[:args.TSNE_points])
                        features = features.view(features.shape[0], -1)
                        draw_tsne(device, num_of_classes, features, labels[:args.TSNE_points],
                            tSNE_save_path=tSNE_save_path)
                # ExpTool.load_pickle("global_FeatLabels_results", exp_dir=True)
                # for client_idx in global_FeatLabels_results.keys():
                #     layer_feats = global_FeatLabels_results[client_idx]["layer_feats"]
                #     labels = global_FeatLabels_results[client_idx]["labels"]
                #     for layer_index, features in layer_feats.items():
                #         tSNE_save_path = ExpTool.get_file_name(f"global_c{client_idx}_l{layer_index}_TSNE.pdf", exp_dir=True)
                #         draw_tsne(device, num_of_classes, layer_feats, labels,
                #             tSNE_save_path=tSNE_save_path)
                ExpTool.finish(args)
                exit()


        elif args.type == "fed-expandable":
            assert args.resume
            global_model = init_fedexnn_merged(args, global_model, out_channels)
            weights = ExpTool.load_pickle(args.resume, exp_dir=False)
            # show_model_layers(global_model, logger=None)
            # logger.info(f"================================")
            # for k, v in weights.items():
            #     logger.info(f"layer: {k}, Shape:{v.shape} No. Params: {v.numel()}")
            global_model.load_state_dict(weights)
            # global_model.load(weights)
            measure_model = global_model
        else:
            raise RuntimeError
        in_channels = []
        measure_model.eval()
        measure_model.to(device)

        for i, (x, target) in enumerate(global_test_loader):
            target = target.to(device)
            x = x.to(device)
            output, hidden_xs = measure_model.forward_measure(x)
            break

        def to_exnn_hidden_xs(hidden_xs):
            if args.type == "fed-expandable":
                # map to normal layer index
                split_config = EXNN_Split_Configs[args.model][args.fedexnn_split_num]
                new_hidden_xs = {}
                for module_idx, layer_idx in enumerate(split_config):
                    new_hidden_xs[layer_idx] = hidden_xs[module_idx]
            return new_hidden_xs

        if args.type == "fed-expandable":
            hidden_xs = to_exnn_hidden_xs(hidden_xs)

        hidden_x_channels = dict([(k, h.shape[1]) for k, h in hidden_xs.items()])
        logging.info(f"========== hidden_x_channels: {hidden_x_channels}")

        if args.model in ["resnet18"]:
            mi_estimator = get_res_MIEstimator(split_measure_config, num_of_classes, args.res_group_norm, args.res_base_width, channels)
            global_train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=mi_estimator_configurations[args.model]['batch_size'],
                shuffle=False, num_workers=4)
            num_layers = args.split_measure_local_module_num

            optimizer = torch.optim.SGD(
                mi_estimator.parameters(),
                lr=mi_estimator_configurations[args.model]['initial_learning_rate'],
                momentum=mi_estimator_configurations[args.model]['momentum'],
                nesterov=mi_estimator_configurations[args.model]['nesterov'],
                weight_decay=mi_estimator_configurations[args.model]['weight_decay'])

            # show_model_layers(global_model, logger)
            # for k, decode in decoders.items():
            #     logger.info(f"====decoder {k}==============================")
            #     show_model_layers(decode, logger)
            #     logger.info(f"====aux_classifier {k}==============================")
            #     show_model_layers(aux_classifiers[k], logger)


            mi_estimator.to(device)
            MI_results = {}

            for epoch in range(0, mi_estimator_configurations[args.model]['epochs']):
                # adjust_learning_rate(optimizer, epoch + 1)
                if args.debug and epoch == 1:
                    break
                adjust_learning_rate(optimizer, epoch, mi_estimator_configurations, args)
                train_loss_ixxs, train_top1s_avg = estMI(device, global_train_loader, measure_model, mi_estimator, optimizer, epoch, num_layers)
                if epoch % 10 == 0 or epoch == mi_estimator_configurations[args.model]['epochs'] - 1:
                    MI_results[epoch] = {}
                    loss_ixx_modules_iters = []
                    loss_ixy_modules_iters = []
                    layer_test_top1s = [AverageMeter() for _ in range(len(hidden_xs))]
                    for i, (x, target) in enumerate(global_test_loader):
                        target = target.to(device)
                        x = x.to(device)
                        output, hidden_xs = measure_model.forward_measure(x)
                        if args.type == "fed-expandable":
                            hidden_xs = to_exnn_hidden_xs(hidden_xs)
                        h_logits, loss_ixx_modules, loss_ixy_modules = mi_estimator(x, hidden_xs, target)
                        loss_ixx_modules_iters.append(loss_ixx_modules)
                        loss_ixy_modules_iters.append(loss_ixy_modules)
                        for layer_i, logits in enumerate(h_logits):
                            prec1 = accuracy(logits.data, target, topk=(1,))[0]
                            layer_test_top1s[layer_i].update(prec1.item(), x.size(0))

                    record_file = ExpTool.get_file_name("EstiMI.txt", exp_dir=True)
                    loss_ixx_modules_iters = np.array(loss_ixx_modules_iters)
                    loss_ixy_modules_iters = np.array(loss_ixy_modules_iters)
                    loss_ixx_modules_iters = np.mean(loss_ixx_modules_iters, axis=0)
                    loss_ixy_modules_iters = np.mean(loss_ixy_modules_iters, axis=0)
                    fd = open(record_file, 'a+')

                    string = f"Testing Epoch: [{epoch}], loss_ixx avg: {[round(loss_ixx, 3) for loss_ixx in loss_ixx_modules_iters]} " + \
                            f"loss_ixy avg: {[round(loss_ixy, 3) for loss_ixy in loss_ixy_modules_iters]} " + \
                            f"top1s avg: {[round(top1s.avg, 3) for top1s in layer_test_top1s]} "
                    test_loss_ixxs = [round(loss_ixx, 3) for loss_ixx in loss_ixx_modules_iters]
                    test_top1s_avg = [round(top1s.avg, 3) for top1s in layer_test_top1s]
                    print(string)
                    fd.write(string + '\n')
                    fd.close()
                    MI_results[epoch]["train_loss_ixxs"] = train_loss_ixxs
                    MI_results[epoch]["train_top1s_avg"] = train_top1s_avg
                    MI_results[epoch]["test_loss_ixxs"] = test_loss_ixxs
                    MI_results[epoch]["test_top1s_avg"] = test_top1s_avg
            ExpTool.save_pickle(MI_results, "MI_results", exp_dir=True)
        else:
            raise NotImplementedError

        ExpTool.finish(args)
    elif args.main_task == "LinearProbe":
        if args.type == "pretrain":
            assert args.resume
            local_weights = ExpTool.load_pickle(args.resume, exp_dir=False)
            model_list = []
            for i in range(len(local_weights)):
                net = copy.deepcopy(global_model)
                net.load_state_dict(local_weights[i])
                model_list.append(net)
            ensemble_model = Ensemble(model_list)
            local_model = model_list[0]
            measure_model = local_model
        elif args.type == "fed-expandable":
            assert args.resume
            global_model = init_fedexnn_merged(args, global_model, out_channels)
            weights = ExpTool.load_pickle(args.resume, exp_dir=False)
            global_model.load_state_dict(weights)
            measure_model = global_model
        else:
            raise RuntimeError
        in_channels = []
        measure_model.eval()
        measure_model.to(device)

        for i, (x, target) in enumerate(global_test_loader):
            target = target.to(device)
            x = x.to(device)
            output, hidden_xs = measure_model.forward_measure(x)
            break

        def to_exnn_hidden_xs(hidden_xs):
            if args.type == "fed-expandable":
                # map to normal layer index
                split_config = EXNN_Split_Configs[args.model][args.fedexnn_split_num]
                new_hidden_xs = {}
                for module_idx, layer_idx in enumerate(split_config):
                    new_hidden_xs[layer_idx] = hidden_xs[module_idx]
            return new_hidden_xs

        if args.type == "fed-expandable":
            hidden_xs = to_exnn_hidden_xs(hidden_xs)

        hidden_x_channels = dict([(k, h.shape[1]) for k, h in hidden_xs.items()])
        logging.info(f"========== hidden_x_channels: {hidden_x_channels}")
        if args.model in ["resnet18"]:
            linear_probes = LinearProbes()
            for layer_index, h in hidden_xs.items():
                linear_probes.add(layer_index, h[0].numel(), num_of_classes)
            global_train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=linear_probe_configurations[args.model]['batch_size'],
                shuffle=False, num_workers=4)
            num_layers = args.split_measure_local_module_num

            optimizer = torch.optim.SGD(
                linear_probes.parameters(),
                lr=linear_probe_configurations[args.model]['initial_learning_rate'],
                momentum=linear_probe_configurations[args.model]['momentum'],
                nesterov=linear_probe_configurations[args.model]['nesterov'],
                weight_decay=linear_probe_configurations[args.model]['weight_decay'])
            linear_probes.to(device)
            linear_probe_results = {}
            for epoch in range(0, linear_probe_configurations[args.model]['epochs']):
                # adjust_learning_rate(optimizer, epoch + 1)
                if args.debug and epoch == 1:
                    break
                top1s_avg = train_linear_probe(device, global_train_loader, measure_model, linear_probes, optimizer, epoch, num_layers)
                if epoch % 10 == 0 or epoch == linear_probe_configurations[args.model]['epochs'] - 1:
                    linear_probe_results[epoch] = {}
                    loss_ixy_modules_iters = []
                    layer_test_top1s = [AverageMeter() for _ in range(len(hidden_xs))]
                    for i, (x, target) in enumerate(global_test_loader):
                        target = target.to(device)
                        x = x.to(device)
                        output, hidden_xs = measure_model.forward_measure(x)
                        if args.type == "fed-expandable":
                            hidden_xs = to_exnn_hidden_xs(hidden_xs)
                        h_logits, loss_ixy_modules = linear_probes(x, hidden_xs, target)
                        loss_ixy_modules_iters.append(loss_ixy_modules)
                        for layer_i, logits in enumerate(h_logits):
                            prec1 = accuracy(logits.data, target, topk=(1,))[0]
                            layer_test_top1s[layer_i].update(prec1.item(), x.size(0))

                    record_file = ExpTool.get_file_name("LinearProbeResults.txt", exp_dir=True)
                    loss_ixy_modules_iters = np.array(loss_ixy_modules_iters)
                    loss_ixy_modules_iters = np.mean(loss_ixy_modules_iters, axis=0)
                    fd = open(record_file, 'a+')

                    string = f"Testing Epoch: [{epoch}], " + \
                            f"loss_ixy avg: {[round(loss_ixy, 3) for loss_ixy in loss_ixy_modules_iters]} " + \
                            f"top1s avg: {[round(top1s.avg, 3) for top1s in layer_test_top1s]} "
                    test_top1s_avg = [round(top1s.avg, 3) for top1s in layer_test_top1s]
                    print(string)
                    fd.write(string + '\n')
                    fd.close()
                    linear_probe_results[epoch]["test_top1s_avg"] = test_top1s_avg

    else:
        raise NotImplementedError



































Overwriting main.py


In [27]:
%%writefile locals/fl_expandable.py

from tqdm import tqdm
import numpy as np
from itertools import chain
import logging

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from helpers.datasets import partition_data
from helpers.utils import get_dataset, average_weights, DatasetSplit, BackdoorDS, KLDiv, setup_seed, test, progressive_test
from helpers.exp_path import ExpTool
from locals.cl_loss.info_nce import INFONCE
from models.losses import conv_balance_regularization

from .ccvr import compute_classes_mean_cov, generate_virtual_representation, calibrate_classifier, get_means_covs_from_client
from locals.attack import *
from locals.cutpaste import *
import copy

def cut(x):
    x_gen = copy.deepcopy(x.cpu().numpy())
    half = int(x_gen.shape[2] / 2)
    rnd = random.randint(0,5)
    pl = random.randint(0,half-1)
    pl2 = random.randint(0,half-1)
    while (abs(pl-pl2)<half/2):
        pl2 = random.randint(0,half-1)
    if rnd <= 1:
        x_gen[:,:,pl:pl+half] = x_gen[:,:,pl2:pl2+half]
    elif rnd == 2:
        x_gen[:,:,half:] = x_gen[:,:,:half]
        x_gen[:,:,:half] = copy.deepcopy(x.cpu().numpy())[:,:,half:]
    elif rnd <= 4:
        x_gen[:,pl:pl+half,:] = x_gen[:,pl2:pl2+half,:]
    else:
        x_gen[:,half:,:] = x_gen[:,:half,:]
        x_gen[:,:half,:] = copy.deepcopy(x.cpu().numpy())[:,half:,:]
    x_gen = torch.Tensor(x_gen)

    return x_gen

def rot(x):
    #rnd = random.randint(0,20)
    #if rnd < 21:
    x_gen = copy.deepcopy(x.cpu().numpy())
    half = int(x_gen.shape[2] / 2)
    pl = random.randint(0,half-1)
    rnd = random.randint(1,3)

    x_gen[:,pl:pl+half,half:] = np.rot90(x_gen[:,pl:pl+half,half:],k=rnd,axes=(1,2))
    x_gen[:,pl:pl+half,:half] = np.rot90(x_gen[:,pl:pl+half,:half],k=rnd,axes=(1,2))
    x_gen = torch.Tensor(x_gen)
    #else:
    #    x_gen = op(copy.deepcopy(x))
    #    if rnd < 20:
    #        x_gen = torch.max(x_gen, x)
    #    else:
    #        x_gen = torch.min(x_gen, x)

    return x_gen

def paint(x):
    x_gen = copy.deepcopy(x.cpu().numpy())
    size = int(x_gen.shape[2])
    sq = 4
    pl = random.randint(sq,size-sq*2)
    pl2 = random.randint(sq,size-sq-1)
    rnd = random.randint(0,1)
    if rnd == 0:
        for i in range(sq,size-sq):
            x_gen[:,i,pl:pl+sq] = x_gen[:,pl2,pl:pl+sq]
    elif rnd == 1:
        for i in range(sq,size-sq):
            x_gen[:,pl:pl+sq,i] = x_gen[:,pl:pl+sq,pl2]
    x_gen = torch.Tensor(x_gen)

    return x_gen

def blur(x):
    rnd = random.randint(0,1)
    sz = random.randint(1,4)*2+1
    sz2 = random.randint(0,2)*2+1
    if rnd == 0:
        func = transforms.GaussianBlur(kernel_size=(sz, sz2), sigma=(10, 100))
    else:
        func = transforms.GaussianBlur(kernel_size=(sz2, sz), sigma=(10, 100))
    
    return func(x)



class FedEXNNLocalUpdate(object):
    def __init__(self, args, train_dataset, test_dataset, global_test_loader, train_idxs, test_idxs,
                 backdoor_train=False):
        self.args = args
        self.backdoor_train = backdoor_train
        if backdoor_train:
            self.train_loader = DataLoader(BackdoorDS(DatasetSplit(train_dataset, train_idxs), args.backdoor_size),
                                        batch_size=self.args.local_bs, shuffle=True, num_workers=4, drop_last=False)
        else:
            self.train_loader = DataLoader(DatasetSplit(train_dataset, train_idxs),
                                        batch_size=self.args.local_bs, shuffle=True, num_workers=4, drop_last=False)
        self.test_loader = DataLoader(DatasetSplit(test_dataset, test_idxs),
                                       batch_size=self.args.local_bs, shuffle=False, num_workers=4, drop_last=False)
        
        self.global_test_loader = global_test_loader
        self.info_nce = INFONCE(args.contrastive_n_views)

    def add_CL_head(self, CL_head):
        self.CL_head = CL_head

    def load_global_model(self):
        pass

    def update_weights(self, client_id, epochs, model, device, if_test):
        num_class = 11
        sz=32        
        model.train()
        model.to(device)

        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                    momentum=0.9)
        criterion = torch.nn.CrossEntropyLoss()
        attack = FastGradientSignUntargeted(model, 
                                            epsilon=0.5, 
                                            alpha=0.002, 
                                            min_val=0, 
                                            max_val=1, 
                                            max_iters=5,
                                            device=device)
        toImg = transforms.ToPILImage()
        toTensor = transforms.ToTensor()
    
        op = transforms.RandomChoice( [
            transforms.RandomRotation(degrees=(15,75)),
            transforms.RandomRotation(degrees=(-75,-15)),
            transforms.RandomRotation(degrees=(85,90)),
            transforms.RandomRotation(degrees=(-90,-85)),
            transforms.RandomRotation(degrees=(175,180)),
        ])
    
        aug = transforms.Compose([
            toImg,
            op,
            toTensor
        ])

        aug_crop =  transforms.RandomChoice( [
            transforms.RandomResizedCrop(sz, scale=(0.1, 0.33)), 
            transforms.Lambda(lambda img: blur(img)), 
            transforms.RandomErasing(p=1, scale=(0.33, 0.5)), 
            transforms.Lambda(lambda img: cut(img)), 
            transforms.Lambda(lambda img: rot(img)),
            transforms.Lambda(lambda img: cut(img)),
            transforms.Lambda(lambda img: rot(img)),
        ])

        cp = CutPasteUnion()

        aug_final = transforms.RandomChoice( [
            transforms.Lambda(lambda img: aug_crop(img)),
        ])


        for epoch in tqdm(range(epochs)):
            train_losses = []
            correct = 0
            for batch_idx, (images, targets) in enumerate(self.train_loader):
                if self.args.debug and batch_idx > 3:
                    break
                optimizer.zero_grad()

                B,C,H,W = images.shape
                
                images, targets = images.to(device), targets.to(device)
                # y_gen = torch.full((B,), num_class-1, dtype=torch.long, device=device)

                # x_gen11 = [aug_final(img.cpu()) for img in images]
                # x_gen11 = torch.stack(x_gen11).to(device)

                # adv_data = attack.perturb(x_gen11, y_gen)
               
                optimizer.zero_grad()
                # combined_batch = torch.cat([images, x_gen11, adv_data], dim=0)
                # combined_batch = torch.cat([images, x_gen11], dim=0)

                # combined_out, combined_mid = model(combined_batch, True)
                # out, mid = combined_out[:B], combined_mid[:B]
                # out_gen11 = combined_out[B:]
                # out_gen11 = combined_out[B:2*B]
                # out_adv = combined_out[2*B:]
                out, mid = model(images, True)
                
                # one_hot = torch.zeros(B, num_class, device=device)
                # one_hot.scatter_(1, targets.reshape(-1, 1), 1)
                # out_second = out - one_hot * 10000

                # ind = torch.randperm(targets.size(0), device=device)
                # y_mask = torch.where(targets == targets[ind], targets, torch.tensor(10, device=device))

                # phi=torch.distributions.beta.Beta(1, 1).sample([]).item()
                # mixed_embeddings = phi * mid + (1-phi) * mid[ind]
                # mixed_out = model.classifier(mixed_embeddings.to(device)) 

                
                # alpha = 1
                # beta = 1
                # gamma=0.01
                # delta=1
                # alpha = 1
                # beta = 1
                # gamma=0.01
                # delta=1
                def anneal(start, end, epoch, total):
                    return start - (start - end) * (min(epoch, total) / total)
                
                alpha = anneal(1.0, 0.1, epoch, epochs)  
                # delta = anneal(1.0, 0.1, epoch, epochs)  
                delta = 0.0
             
                beta = 1 
                gamma = 0.01
           
                loss = criterion(out, targets) #+ alpha*criterion(out_gen11, y_gen) + beta*criterion(out_second, y_gen) + gamma*criterion(mixed_out, y_mask)# + delta*criterion(out_adv, y_gen)
                 

                pred = torch.max(out, 1)[1]
                correct += pred.eq(targets.view_as(pred)).sum().item()

                # if self.args.fedexnn_adapter_constrain_beta > 0:
                #     adapter = model.get_last_training_adapter()
                #     if adapter is not None:
                #         balance_loss = conv_balance_regularization(adapter.weight, self.args.fedexnn_adapter_constrain_beta)
                #         loss += balance_loss

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_losses.append(loss.item())
    
            train_pfl_acc = 100. * correct / len(self.train_loader.dataset)
            avg_train_loss = np.mean(train_losses)
            logging.info(f"client_id:{client_id} epoch:[{epoch}/{epochs}] loss:{loss}, train_loss:{avg_train_loss} train_pfl_acc: {train_pfl_acc}")
            if if_test:
                acc, test_loss = test(model, self.global_test_loader, device)
                pfl_acc, pfl_test_loss = test(model, self.test_loader, device)
            else:
                acc = 0.0
                pfl_acc = 0.0
        model.to("cpu")
        return model.state_dict(), avg_train_loss, acc, pfl_acc, train_pfl_acc


Overwriting locals/fl_expandable.py


In [28]:
%%writefile locals/attack.py

import os
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

#from utils import tensor2cuda

def project(x, original_x, epsilon, _type='linf'):

    if _type == 'linf':
        max_x = original_x + epsilon
        min_x = original_x - epsilon

        x = torch.max(torch.min(x, max_x), min_x)

    elif _type == 'l2':
        dist = (x - original_x)

        dist = dist.view(x.shape[0], -1)

        dist_norm = torch.norm(dist, dim=1, keepdim=True)

        mask = (dist_norm > epsilon).unsqueeze(2).unsqueeze(3)

        # dist = F.normalize(dist, p=2, dim=1)

        dist = dist / dist_norm

        dist *= epsilon

        dist = dist.view(x.shape)

        x = (original_x + dist) * mask.float() + x * (1 - mask.float())

    else:
        raise NotImplementedError

    return x

class FastGradientSignUntargeted():
    b"""
        Fast gradient sign untargeted adversarial attack, minimizes the initial class activation
        with iterative grad sign updates
    """
    def __init__(self, model, epsilon, alpha, min_val, max_val, max_iters, device='cpu', _type='linf'):
        self.model = model
        self.epsilon = epsilon
        self.alpha = alpha
        self.min_val = min_val
        self.max_val = max_val
        self.max_iters = max_iters
        self._type = _type
        self.device = device
        
    def perturb(self, original_images, labels, reduction4loss='mean', random_start=False):
        # original_images: values are within self.min_val and self.max_val

        # The adversaries created from random close points to the original data
        '''
        if random_start:
            rand_perturb = torch.FloatTensor(original_images.shape).uniform_(
                -self.epsilon, self.epsilon)
            rand_perturb = tensor2cuda(rand_perturb)
            x = original_images + rand_perturb
            x.clamp_(self.min_val, self.max_val)
        else:
        '''

        x = original_images.to(self.device)

        x.requires_grad = True 

        # max_x = original_images + self.epsilon
        # min_x = original_images - self.epsilon

        with torch.enable_grad():
            for _iter in range(self.max_iters):
                outputs = self.model(x)#there was a bug over here

                loss = F.cross_entropy(outputs, labels).to(self.device)

                #if reduction4loss == 'none':
                #    grad_outputs = tensor2cuda(torch.ones(loss.shape))
                    
                #else:
                grad_outputs = None

                grads = torch.autograd.grad(loss, x, grad_outputs=grad_outputs, 
                        only_inputs=True)[0]

                x.data += self.alpha * torch.sign(grads.data) 

                # the adversaries' pixel value should within max_x and min_x due 
                # to the l_infinity / l2 restriction
                x = project(x, original_images, self.epsilon, self._type)
                # the adversaries' value should be valid pixel value
                # x.clamp_(self.min_val, self.max_val)

        return x

Overwriting locals/attack.py


In [29]:
%%writefile locals/ccvr.py
import logging

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


def compute_classes_mean_cov(global_model, models, dataloaders, num_classes, device):
    features_means, features_covs, features_count = [], [], []
    for client_idx in dataloaders.keys():
        if global_model is None:
            means, covs, sizes = get_means_covs_from_client(
                models[client_idx], dataloaders[client_idx], device, num_classes
            )
        else:
            means, covs, sizes = get_means_covs_from_client(
                global_model, dataloaders[client_idx], device, num_classes
            )

        features_means.append(means)
        features_covs.append(covs)
        features_count.append(sizes)

    num_classes = len(features_count[0])-1#to account for unknown class only change
    labels_count = [sum(cnts) for cnts in zip(*features_count)]
    classes_mean = []
    for c, (means, sizes) in enumerate(
        zip(zip(*features_means), zip(*features_count))
    ):
        weights = torch.tensor(sizes, device=device) / labels_count[c]
        means_ = torch.stack(means, dim=-1)
        classes_mean.append(torch.sum(means_ * weights, dim=-1))
    classes_cov = [None for _ in range(num_classes)]
    for c in range(num_classes):
        # for k in self.train_clients:
        for client_idx in dataloaders.keys():
            if classes_cov[c] is None:
                classes_cov[c] = torch.zeros_like(features_covs[client_idx][c])

            classes_cov[c] += (features_count[client_idx][c] - 1) / (
                labels_count[c] - 1
            ) * features_covs[client_idx][c] + (
                features_count[client_idx][c] / (labels_count[c] - 1)
            ) * (
                features_means[client_idx][c].unsqueeze(1)
                @ features_means[client_idx][c].unsqueeze(0)
            )

        classes_cov[c] -= (labels_count[c] / labels_count[c] - 1) * (
            classes_mean[c].unsqueeze(1) @ classes_mean[c].unsqueeze(0)
        )

    return classes_mean, classes_cov

def generate_virtual_representation(
    classes_mean, classes_cov, sample_per_class, device
):
    data, targets = [], []
    for c, (mean, cov) in enumerate(zip(classes_mean, classes_cov)):
        samples = np.random.multivariate_normal(
            mean.cpu().numpy(), cov.cpu().numpy(), sample_per_class
        )
        data.append(torch.tensor(samples, dtype=torch.float, device=device))
        targets.append(
            torch.ones(
                sample_per_class, dtype=torch.long, device=device
            )
            * c
        )

    data = torch.cat(data)
    targets = torch.cat(targets)
    return data, targets

def calibrate_classifier(global_model, models, dataloaders, num_classes, sample_per_class, local_lr, device):
    classes_mean, classes_cov = compute_classes_mean_cov(global_model, models, dataloaders, num_classes, device)
    data, targets = generate_virtual_representation(classes_mean, classes_cov, sample_per_class, device)

    class RepresentationDataset(Dataset):
        def __init__(self, data, targets):
            self.data = data
            self.targets = targets

        def __getitem__(self, idx):
            return self.data[idx], self.targets[idx]

        def __len__(self):
            return len(self.targets)

    dataset = RepresentationDataset(data, targets)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(global_model.parameters(), lr=local_lr)

    for x, y in dataloader:
        logits = global_model.classifier(x)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()




def get_means_covs_from_client(
    model, dataloader, device, num_classes
):
    features = []
    targets = []
    feature_length = None
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        features.append(model.get_final_features(x))
        targets.append(y)
        # features.append(model.get_final_features(x).to("cpu"))
        # targets.append(y.to("cpu"))
        # logging.info(f"features shape: {features[-1].shape}")

    targets = torch.cat(targets)
    features = torch.cat(features)
    feature_length = features.shape[-1]
    # indices = [
    #     torch.where(targets == i)[0]
    #     for i in range(len(num_classes))
    # ]
    indices = [
        torch.where(targets == i)[0]
        for i in range(num_classes)
    ]
    classes_features = [features[idxs] for idxs in indices]
    classes_means, classes_covs = [], []
    for fea in classes_features:
        if fea.shape[0] > 0:
            classes_means.append(fea.mean(dim=0))
            # classes_covs.append(fea.t().cov(correction=0))
            classes_covs.append(torch.cov(fea.t(), correction=0, fweights=None, aweights=None))
        else:
            classes_means.append(torch.zeros(feature_length, device=device))
            classes_covs.append(
                torch.zeros(feature_length, feature_length, device=device)
            )
    return classes_means, classes_covs, [len(idxs) for idxs in indices]


Overwriting locals/ccvr.py


In [None]:
%%writefile models/fl_exnn.py
from copy import deepcopy
import logging
logger = logging.getLogger()

import torch
import torch.nn.functional as F
from torch import nn

from .resnet import norm2d, BasicBlock, Bottleneck


class MLP_Block(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, actv="relu", num_layers=2):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.layers = nn.ModuleList()
        self.num_layers = num_layers
        layer = nn.Linear(in_features, hidden_features)
        self.layers.append(layer)
        for _ in range(self.num_layers - 1):
            layer = nn.Linear(hidden_features, hidden_features)
            self.layers.append(layer)


    def forward(self, x):
        for lay in self.layers:
            x = F.relu(lay(x))
        return x


class CNN_Block(nn.Module):
    def __init__(self, in_features, hidden_features, is_pool, actv="relu"):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        # self.out_features = out_features
        self.is_pool = is_pool
        self.layers = nn.ModuleList()
        # self.num_layers = num_layers
        # layer = nn.Linear(in_features, hidden_features)
        # self.layers.append(layer)
        # for _ in range(self.num_layers - 1):
        #     layer = nn.Linear(hidden_features, hidden_features)
        #     self.layers.append(layer)
        if self.is_pool:
            self.layer = nn.Sequential(
                nn.Conv2d(in_features, hidden_features, 3),
                nn.BatchNorm2d(hidden_features),
                nn.ReLU(),
                nn.MaxPool2d(2, 2))
        else:
            self.layer = nn.Sequential(
                nn.Conv2d(in_features, hidden_features, 3),
                nn.BatchNorm2d(hidden_features),
                nn.ReLU())

    def forward(self, x):
        x = self.layer(x)
        return x



def define_fl_exnn_res_layers(
    fedexnn_split_num, block, num_blocks, group_norm=0, res_base_width=64, in_channels=3,
):
    in_planes = res_base_width
    group_norm = group_norm
    # in_planes = in_planes * block.expansion

    def _make_layer(block, planes, num_blocks, stride, group_norm):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        nonlocal in_planes
        for stride in strides:
            layers.append(block(in_planes=in_planes, planes=planes,
                stride=stride, group_norm=group_norm,
            ))
            in_planes = planes * block.expansion

        return layers

    if fedexnn_split_num == 2:
        all_layers = []
        local_layer = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(res_base_width, group_norm),
                nn.ReLU(),
            ),
            nn.Sequential(*_make_layer(block, res_base_width, num_blocks[0], stride=1, group_norm=group_norm)),
            nn.Sequential(*_make_layer(block, res_base_width*2, num_blocks[1], stride=2, group_norm=group_norm)),
        )
        all_layers.append(local_layer)

        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width*4, num_blocks[2], stride=2, group_norm=group_norm)),
            nn.Sequential(*_make_layer(block, res_base_width*8, num_blocks[3], stride=2, group_norm=group_norm)),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        all_layers.append(local_layer)
    elif fedexnn_split_num == 3:
        all_layers = []
        local_layer = nn.Sequential(
            nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(res_base_width, group_norm),
            nn.ReLU(),
        )
        all_layers.append(local_layer)

        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width, num_blocks[0], stride=1, group_norm=group_norm)),
            nn.Sequential(*_make_layer(block, res_base_width*2, num_blocks[1], stride=2, group_norm=group_norm)),
        )
        all_layers.append(local_layer)

        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width*4, num_blocks[2], stride=2, group_norm=group_norm)),
            nn.Sequential(*_make_layer(block, res_base_width*8, num_blocks[3], stride=2, group_norm=group_norm)),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        all_layers.append(local_layer)
    elif fedexnn_split_num == 4:
        all_layers = []
        local_layer = nn.Sequential(
            nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(res_base_width, group_norm),
            nn.ReLU(),
        )
        all_layers.append(local_layer)

        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width, num_blocks[0], stride=1, group_norm=group_norm)),
            nn.Sequential(*_make_layer(block, res_base_width*2, num_blocks[1], stride=2, group_norm=group_norm)),
        )
        all_layers.append(local_layer)
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width*4, num_blocks[2], stride=2, group_norm=group_norm)))
        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width*8, num_blocks[3], stride=2, group_norm=group_norm)),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        all_layers.append(local_layer)
    elif fedexnn_split_num == 5:
        all_layers = []
        local_layer = nn.Sequential(
            nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(res_base_width, group_norm),
            nn.ReLU(),
        )
        all_layers.append(local_layer)
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width, num_blocks[0], stride=1, group_norm=group_norm)))
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width*2, num_blocks[1], stride=2, group_norm=group_norm)))
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width*4, num_blocks[2], stride=2, group_norm=group_norm)))
        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width*8, num_blocks[3], stride=2, group_norm=group_norm)),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        all_layers.append(local_layer)
    elif fedexnn_split_num > 6:
        all_layers = []
        local_layer = nn.Sequential(
            nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(res_base_width, group_norm),
            nn.ReLU(),
        )
        all_layers.append(local_layer)
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width, num_blocks[0], stride=1, group_norm=group_norm)))
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width*2, num_blocks[1], stride=2, group_norm=group_norm)))
        all_layers.append(nn.Sequential(*_make_layer(block, res_base_width*4, num_blocks[2], stride=2, group_norm=group_norm)))
        local_layer = nn.Sequential(
            nn.Sequential(*_make_layer(block, res_base_width*8, num_blocks[3], stride=2, group_norm=group_norm)),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        all_layers.append(local_layer)
    else:
        raise NotImplementedError
    return all_layers



def define_fl_exnn_res_layers(
    block, num_blocks, group_norm=0, res_base_width=64, in_channels=3,
    hetero_layer_depth=False,
):
    in_planes = res_base_width
    group_norm = group_norm
    # in_planes = in_planes * block.expansion
    layers = []
    def _make_layer(block, planes, num_blocks, stride):
        nonlocal in_planes
        nonlocal layers
        strides = [stride] + [1] * (num_blocks - 1)
        for stride in strides:
            layers.append(block(in_planes, planes, stride))
            in_planes = planes * block.expansion

    layers.append(
        nn.Sequential(nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
                    norm2d(res_base_width, group_norm),
                    nn.ReLU())
            )
    _make_layer(block, res_base_width, num_blocks[0], stride=1)
    _make_layer(block, res_base_width*2, num_blocks[1], stride=2)
    _make_layer(block, res_base_width*4, num_blocks[2], stride=2)
    _make_layer(block, res_base_width*8, num_blocks[3], stride=2)
    layers.append(nn.AdaptiveAvgPool2d((1, 1)))
    if hetero_layer_depth:
        normal_layers = deepcopy(layers)
        layers = []
        layers.append(
            nn.Sequential(nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
                        norm2d(res_base_width, group_norm),
                        nn.ReLU())
                )
        _make_layer(block, res_base_width, num_blocks[0]-1, stride=1)
        _make_layer(block, res_base_width*2, num_blocks[1]-1, stride=2)
        _make_layer(block, res_base_width*4, num_blocks[2]-1, stride=2)
        _make_layer(block, res_base_width*8, num_blocks[3]-1, stride=2)
        layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        small_layers = deepcopy(layers)
        layers = []
        layers.append(
            nn.Sequential(nn.Conv2d(in_channels, res_base_width, kernel_size=3, stride=1, padding=1, bias=False),
                        norm2d(res_base_width, group_norm),
                        nn.ReLU())
                )
        _make_layer(block, res_base_width, num_blocks[0]+1, stride=1)
        _make_layer(block, res_base_width*2, num_blocks[1]+1, stride=2)
        _make_layer(block, res_base_width*4, num_blocks[2]+1, stride=2)
        _make_layer(block, res_base_width*8, num_blocks[3]+1, stride=2)
        layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        large_layers = deepcopy(layers)
    else:
        small_layers, large_layers = None, None
    return layers, small_layers, large_layers



def fl_exnn_resnet18(**kwargs):
    return define_fl_exnn_res_layers( 
                BasicBlock, [2, 2, 2, 2], **kwargs)


def fl_exnn_resnet34(**kwargs):
    return define_fl_exnn_res_layers(
                BasicBlock, [3, 4, 6, 3], **kwargs)


def fl_exnn_resnet50(**kwargs):
    return define_fl_exnn_res_layers(
                Bottleneck, [3, 4, 6, 3], **kwargs)


def fl_exnn_resnet101(**kwargs):
    return define_fl_exnn_res_layers(
                Bottleneck, [3, 4, 23, 3], **kwargs)


def fl_exnn_resnet152(**kwargs):
    return define_fl_exnn_res_layers(
                Bottleneck, [3, 8, 36, 3], **kwargs)


class AttentionAdapter(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(AttentionAdapter, self).__init__()
        self.align = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.channel_fc = nn.Sequential(
            nn.Linear(out_channels, out_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels // reduction, out_channels, bias=False),
            nn.Sigmoid()
        )#added a more comlex adaptor

    def forward(self, x):
        x = self.bn(self.align(x))
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.channel_fc(y).view(b, c, 1, 1)
     
        return x * y.expand_as(x)


class Federated_EXNNLayer_local(nn.Module):

    def __init__(self, layer_idx,
                local_layer,
                client_idx=0,
                adapter="avg", 
                fedexnn_self_dropout=0.0
            ):
        super().__init__()
        self.layer_idx = layer_idx
        self.is_global = False
        self.local_layer = local_layer
        self.client_idx = client_idx
        # self.in_features = local_layer.in_features
        self.adapter = adapter
        self.fedexnn_self_dropout = fedexnn_self_dropout
        if self.fedexnn_self_dropout > 0:
            self.dropout = nn.Dropout(p=self.fedexnn_self_dropout)


    def adaptation(self, in_channels=1, out_channels=1):
        if self.adapter in ["avg", "sum"]:
            pass
        elif self.adapter == "cnn1x1":
            # self.adapter_nn = nn.Conv2d(
            #     in_channels=in_channels,
            #     out_channels=out_channels, kernel_size=1, stride=1, padding=0)
            self.adapter_nn = AttentionAdapter(in_channels, out_channels)
        elif self.adapter == "mlp":
            raise NotImplementedError
        else:
            raise NotImplementedError

    def get_module(self):
        return self.local_layer

    def forward(self, x, is_global):
        if is_global:
            if self.fedexnn_self_dropout > 0:
                x[str(self.client_idx)] = self.dropout(x[str(self.client_idx)])
            if self.adapter == "avg":
                xs = list(x.values())
                x = sum(xs) / len(xs)
            elif self.adapter == "sum":
                x = sum(list(x.values()))
            elif self.adapter == "cnn1x1":
                # x = self.adapter_layers[str(i)](  torch.concat(list(x.values()), dim=1) )
                x = self.adapter_nn(torch.concat(list(x.values()), dim=1))

            elif self.adapter == "mlp":
                raise NotImplementedError
            else:
                raise NotImplementedError
        else:
            if self.fedexnn_self_dropout > 0:
                x = self.dropout(x)
        return self.local_layer(x)


class Federated_EXNNLayer_global(nn.Module):

    def __init__(self, layer_idx,
                local_layers, fedexnn_self_dropout=0.0
            ):
        super().__init__()
        self.layer_idx = layer_idx
        self.is_global = True
        self.fedexnn_self_dropout = fedexnn_self_dropout

        self.local_layers = torch.nn.ModuleDict()
        for client_idx, local_layer in local_layers.items():
            self.local_layers[client_idx] = local_layer

        # logger.info(f"list(local_layers.values())[0]: {list(local_layers.values())[0]}")

        # if hasattr(list(local_layers.values())[0], "in_features"):
        #     self.in_features = list(local_layers.values())[0].in_features
        # else:
        #     pass

    def get_module(self):
        return self.local_layers

    def freeze(self):
        for client_idx, local_layer in self.local_layers.items():
            for param in local_layer.parameters():
                param.requires_grad = False


    def forward(self, x, is_global):
        xs = {}
        for client_idx, local_layer in self.local_layers.items():
            xs[client_idx] = local_layer(x, is_global)
        return xs


# def merge_layer(Federated_EXNNs, layer_idx):
#     horizon_layers = {}
#     # for client_idx, Federated_EXNN in enumerate(Federated_EXNNs):
#     for client_idx, Federated_EXNN in Federated_EXNNs.items():
#         logger.info(f"Merging layer {layer_idx}, model has num of layers:{len(Federated_EXNN.layers)}")
#         horizon_layers[client_idx] = Federated_EXNN.layers[layer_idx].get_module()
#     federated_EXNNLayer_global = Federated_EXNNLayer_global(layer_idx, horizon_layers)
#     return federated_EXNNLayer_global



def merge_layer(Federated_EXNNs, layer_idx):
    horizon_layers = {}
    # for client_idx, Federated_EXNN in enumerate(Federated_EXNNs):
    for client_idx, Federated_EXNN in Federated_EXNNs.items():
        logger.info(f"Merging layer {layer_idx}, model has num of layers:{len(Federated_EXNN.layers)}")
        horizon_layers[str(client_idx)] = Federated_EXNN.layers[layer_idx]
    federated_EXNNLayer_global = Federated_EXNNLayer_global(layer_idx, horizon_layers)
    return federated_EXNNLayer_global



class Federated_EXNN(nn.Module):
    def __init__(
        self,
        args,
        client_idx,
        split_local_layers=[],
        num_of_classes=10,
        fedexnn_classifer="avg",
    ):
        super().__init__()
        self.args = args
        self.client_idx = client_idx
        self.num_layers = len(split_local_layers)
        if args.model == "cnn":
            self.hidden_features = args.cnn_hidden_features * 4 * 4
            self.flatten_at_classifier = True
        elif args.model == "mlp3":
            self.hidden_features = args.mlp_hidden_features
            self.flatten_at_classifier = False
        elif args.model in ["resnet18",]:
            self.hidden_features = args.res_base_width * 8 * 1
            self.flatten_at_classifier = True
        elif args.model in ["resnet50", ]:
            self.hidden_features = args.res_base_width * 8 * 4
            self.flatten_at_classifier = True
        else:
            raise NotImplementedError
        self.num_of_classes = num_of_classes
        # self.adapter = adapter
        # self.adapter_layers = torch.nn.ModuleDict()

        self.layers = nn.ModuleList()
        for i, local_layer in enumerate(split_local_layers):
            # lay = Federated_EXNNLayer_local(
            #     i,
            #     local_layer,
            # )
            # self.layers.append(lay)
            self.layers.append(local_layer)

        self.classifier = torch.nn.Linear(self.hidden_features, num_of_classes)
        self.fedexnn_classifer = fedexnn_classifer

    # def adaptation(self, layer_idx, federated_EXNNLayer_global, in_channels=1, out_channels=1):
    def adaptation(self, layer_idx, federated_EXNNLayer_global):
        federated_EXNNLayer_global.freeze()
        del self.layers[layer_idx]
        self.layers.insert(layer_idx, federated_EXNNLayer_global)
        # self.layers[layer_idx] = federated_EXNNLayer_global

    def add_local_layer_adaptor(self, layer_idx, **kwargs):
        assert not self.layers[layer_idx].is_global
        self.layers[layer_idx].adaptation(**kwargs)

    def get_last_training_adapter(self):
        for layer_idx, layer in enumerate(self.layers):
            if not layer.is_global and hasattr(layer, "adapter_nn"):
                return layer.adapter_nn
        return None

    # def adaptation(self, layer_idx, federated_EXNNLayer_global, in_channels=1, out_channels=1):
    #     federated_EXNNLayer_global.freeze()
    #     # self.layers[layer_idx] = federated_EXNNLayer_global
    #     # self.layers.pop(layer_idx)
    #     del self.layers[layer_idx]
    #     self.layers.insert(layer_idx, federated_EXNNLayer_global)
    #     # self.layers[layer_idx] = federated_EXNNLayer_global
    #     if layer_idx < self.num_layers - 1:
    #         if self.adapter in ["avg", "sum"]:
    #             pass
    #         elif self.adapter == "cnn1x1":
    #             self.adapter_layers[str(layer_idx)] = nn.Conv2d(
    #                 in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
    #         elif self.adapter == "mlp":
    #             raise NotImplementedError
    #         else:
    #             raise NotImplementedError
    #     else:
    #         pass


    def adaptation_classifier(self, fedexnn_classifer, new_classifier=None):
        self.fedexnn_classifer = fedexnn_classifer
        self.classifier = new_classifier


    def forward(self, x, get_logits=False):
        # x = x.contiguous()
        # if self.flatten_at_classifier:
        #     pass
        # else:
        #     x = x.contiguous()
        #     x = x.view(x.size(0), self.layers[0].in_features)
        #     # logger.info(f"type(x): {type(x)}")
        prev_is_global = False
        for i, lay in enumerate(self.layers):
            # first layer is raw data, no need to adapt
            x = lay(x, prev_is_global)
            prev_is_global = lay.is_global
            # if getattr(lay, "is_global", False) and i < self.num_layers - 1:

        logits = None
        if getattr(self.layers[-1], "is_global", False):
            if self.flatten_at_classifier:
                for k, v in x.items():
                    x[k] = v.view(v.size(0), -1)
            else:
                pass
            if self.fedexnn_classifer in ["avg"] :
                x = sum(list(x.values()))
                if get_logits:
                    logits = x
                outputs = self.classifier(x)
            elif self.fedexnn_classifer == "multihead":
                outputs = self.classifier(x)
                outputs = sum(x)
            else:
                raise NotImplementedError
        else:
            if self.flatten_at_classifier:
                x = x.view(x.size(0), -1)
                if get_logits:
                    logits = x
            else:
                pass
            outputs = self.classifier(x)
        if get_logits:
            return outputs, logits
        else:
            return outputs



    def forward_measure(self, x):
        hidden_xs = {}
        # for layer_index, module in enumerate(self._layers.values()):
        # for layer_index, module in enumerate(self._layers):
        prev_is_global = False
        for i, lay in enumerate(self.layers):
            x = lay(x, prev_is_global)
            prev_is_global = lay.is_global
            # The outputed x is globally, thus average it for measuring MI. 
            xs = list(x.values())
            x_avg = sum(xs) / len(xs)
            x_avg = x_avg.detach()
            hidden_xs[i] = x_avg
        current_i = i

        if getattr(self.layers[-1], "is_global", False):
            if self.flatten_at_classifier:
                for k, v in x.items():
                    x[k] = v.view(v.size(0), -1)
            else:
                pass
            if self.fedexnn_classifer in ["avg"] :
                x = sum(list(x.values()))
                outputs = self.classifier(x)
            elif self.fedexnn_classifer == "multihead":
                outputs = self.classifier(x)
                outputs = sum(x)
            else:
                raise NotImplementedError
        else:
            if self.flatten_at_classifier:
                x = x.view(x.size(0), -1)
            else:
                pass
            outputs = self.classifier(x)
        # hidden_xs[current_i+1] = outputs
        return outputs, hidden_xs



    def get_final_features(self, x):
        prev_is_global = False
        for i, lay in enumerate(self.layers):
            # first layer is raw data, no need to adapt
            x = lay(x, prev_is_global)
            prev_is_global = lay.is_global


        if getattr(self.layers[-1], "is_global", False):
            if self.flatten_at_classifier:
                for k, v in x.items():
                    x[k] = v.view(v.size(0), -1)
            else:
                pass
            if self.fedexnn_classifer in ["avg"] :
                x = sum(list(x.values()))
            elif self.fedexnn_classifer == "multihead":
                pass
            else:
                raise NotImplementedError
        else:
            pass

        return x



In [31]:
%%writefile models/configs.py

Split_Configs = {
    'mlp2': {
        1: [],  # End-to-end
        2: [0],
        3: [0, 1],
    },
    'mlp3': {
        1: [],  # End-to-end
        2: [1],
        3: [0, 1],
        4: [0, 1, 2],
    },
    'resnet18': { # 0-base conv, 1-2, 3-4, 5-6, 7-8, 9: Avg-linear
        1: [],  # End-to-end
        2: [4],
        3: [2, 6],
        4: [2, 4, 6],
        8: [2, 3, 4, 5, 6, 7, 8],
    },
    'resnet34': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
    'resnet50': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
}

EXNN_Split_Configs = {
    'mlp2': {
        1: [],  # End-to-end
        2: [0],
        3: [0, 1],
    },
    'mlp3': {
        1: [],  # End-to-end
        2: [1],
        3: [0, 1],
        4: [0, 1, 2],
    },
    'cnn': {
        1: [],  # End-to-end
        2: [1],
        3: [0, 1],
        4: [0, 1, 2],#cnn config
    },
    'resnet18': { # 0-base conv, 1-2, 3-4, 5-6, 7-8, 9: Avg-linear
        1: [],  # End-to-end
        2: [4],
        3: [0, 4],
        4: [0, 4, 6],
        8: [0, 1, 2, 3, 5, 7, 8],
    },
}


InfoPro = {
    'mlp2': {
        1: [],  # End-to-end
        2: [0],
        3: [0, 1],
    },
    'mlp3': {
        1: [],  # End-to-end
        2: [1],
        3: [0, 1],
        4: [0, 1, 2],
    },
    'resnet18': { # 0-base conv, 1-2, 3-4, 5-6, 7-8, 9: Avg-linear
        1: [],  # End-to-end
        2: [4],
        3: [2, 6],
        4: [2, 4, 6],
        8: [2, 3, 4, 5, 6, 7, 8],
    },
    'resnet34': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
    'resnet50': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
}


InfoPro_balanced_memory = {
    'mlp2': {
        1: [],  # End-to-end
        2: [0],
        3: [0, 1],
    },
    'mlp3': {
        1: [],  # End-to-end
        2: [1],
        3: [0, 1],
        4: [0, 1, 2],
    },
    'resnet18': { # 0-base conv, 1-2, 3-4, 5-6, 7-8, 9: Avg-linear
        1: [],  # End-to-end
        2: [4],
        3: [2, 6],
        4: [2, 4, 6],
        8: [2, 3, 4, 5, 6, 7, 8],
    },
    'resnet34': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
    'resnet50': { # 0-base conv, 1-3, 4-7, 8-13, 14-16, 17: Avg-linear
        1: [],  # End-to-end
        2: [8],
        3: [5, 12],
        4: [4, 8, 12],
        8: [2, 4, 6, 8, 10, 12, 14],
        16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    },
}

Overwriting models/configs.py


In [32]:
%%writefile models/nets.py

import torch
from torch import nn
import torch.nn.functional as F

from .basics import View


class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


class CNNCifar(nn.Module):
    def __init__(self, hidden_features, num_of_classes):
        super(CNNCifar, self).__init__()
        self.hidden_features = hidden_features
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, self.hidden_features, 3),
            nn.BatchNorm2d(self.hidden_features),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(self.hidden_features, self.hidden_features, 3),
            nn.BatchNorm2d(self.hidden_features),
            nn.ReLU(),
            nn.MaxPool2d(2, 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(self.hidden_features, self.hidden_features, 3),
            nn.BatchNorm2d(self.hidden_features),
            nn.ReLU())
        # self.conv1 = nn.Conv2d(3, self.hidden_features, 3)
        # self.pool = nn.MaxPool2d(2, 2)
        # self.conv2 = nn.Conv2d(self.hidden_features, self.hidden_features, 3)
        # self.conv3 = nn.Conv2d(self.hidden_features, self.hidden_features, 3)

        self.fc1 = nn.Linear(self.hidden_features * 4 * 4, num_of_classes)

    def forward(self, x):
        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))
        # x = F.relu(self.conv3(x))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = x.view(-1, self.hidden_features * 4 * 4)
        x = self.fc1(x)
        return x



import torch
import torch.nn as nn


def make_CNNCifar_seqs(in_features, hidden_features, out_features, init_classifier):
    layers = []
    
    # Block 1: 32x32 -> 16x16
    layers.append(nn.Sequential(
        nn.Conv2d(in_features, hidden_features, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features),
        nn.ReLU(),
        nn.Conv2d(hidden_features, hidden_features, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    ))
#cnns
    # Block 2: 16x16 -> 8x8
    layers.append(nn.Sequential(
        nn.Conv2d(hidden_features, hidden_features * 2, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features * 2),
        nn.ReLU(),
        nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features * 2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    ))

    # Block 3: 8x8 -> 4x4
    layers.append(nn.Sequential(
        nn.Conv2d(hidden_features * 2, hidden_features * 4, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features * 4),
        nn.ReLU(),
        nn.Conv2d(hidden_features * 4, hidden_features * 4, kernel_size=3, padding=1),
        nn.BatchNorm2d(hidden_features * 4),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    ))

    # Block 4: Flattening and the 2048 Linear layer
    flattened_input_dim = (hidden_features * 4) * 4 * 4
    layers.append(nn.Sequential(
        View([flattened_input_dim]), # Flatten FIRST
        nn.Linear(flattened_input_dim, 2048), # Apply Linear SECOND
        nn.ReLU(),
        nn.Dropout(0.5)
    ))

    if init_classifier:
        # Final output layer
        # Since the last layer in the list above is 2048, this starts at 2048
        classifier = nn.Linear(2048, out_features)
        layers.append(classifier)

    return layers


def make_CNNCifar_Head_seqs(in_features, hidden_features, out_features, init_classifier, split_layer_index):
    origin_res_layer_index = 0
    layers = []
    if origin_res_layer_index > split_layer_index:
        layers.append(nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 3),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)))
    origin_res_layer_index += 1
    if origin_res_layer_index > split_layer_index:
        layers.append(nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)))
    origin_res_layer_index += 1
    if origin_res_layer_index > split_layer_index:
        layers.append(nn.Sequential(
            nn.Conv2d(hidden_features, hidden_features, 3),
            nn.BatchNorm2d(hidden_features),
            nn.ReLU(),
            View([hidden_features * 4 * 4])))
    origin_res_layer_index += 1

    if init_classifier:
        if origin_res_layer_index > split_layer_index:
            classifier = torch.nn.Linear(hidden_features, out_features)
            layers.append(classifier)
        origin_res_layer_index += 1
    return layers





class CNNCifar100(nn.Module):
    def __init__(self):
        super(CNNCifar100, self).__init__()
        self.conv1 = nn.Conv2d(3, 256, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(256, 256, 3)
        self.conv3 = nn.Conv2d(256, 128, 3)
        self.fc1 = nn.Linear(128 * 4 * 4, 100)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc1(x)
        return x


class CNNCifar2(nn.Module):  # 重新搭建CNN
    def __init__(self):
        super(CNNCifar2, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 64, 3)
        self.fc1 = nn.Linear(64 * 4 * 4, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Overwriting models/nets.py


In [33]:
%%bash

cluster_name=localhost
dataset=cifar10

source scripts/setup_env.sh
source scripts/path.sh

gpu=0

debug=False
enable_wandb=True



num_users=5


alpha=0.5
checkpoint=weights
res_base_width=64
checkpoint=no
fedexnn_adapter=cnn1x1
fedexnn_split_num=4
local_ep=50
wandb_entity=cabbagepatch-lahore-university-of-management-sciences
model=resnet18
num_classes=11
lr=0.01


type=fed-expandable
source scripts/resetup_env.sh

fedexnn_classifer=${fedexnn_classifer:-avg}

python3 -u main.py --main_task=train --type=$type  --gpu $gpu  --debug $debug \
--exp_name ${type}-${dataset}-${model}-nh${num_hidden_features}-c${num_users}-a${alpha}-ep${local_ep}-lr${lr}-clsf${fedexnn_classifer}-adp${fedexnn_adapter}-nxnn${fedexnn_split_num} \
--checkpoint $checkpoint  \
--split_measure_local_module_num 8 \
--fedexnn_classifer  ${fedexnn_classifer} --fedexnn_adapter ${fedexnn_adapter}  --fedexnn_split_num ${fedexnn_split_num} \
--fedexnn_self_dropout $fedexnn_self_dropout --fedexnn_adapter_constrain_beta $fedexnn_adapter_constrain_beta \
--model=$model --mlp_hidden_features=$mlp_hidden_features --cnn_hidden_features $cnn_hidden_features --num_layers $num_layers --res_base_width $res_base_width \
--iid=0 --lr=$lr \
--dataset=${dataset} --datadir $datadir \
--alpha=$alpha --seed=1 --num_users=${num_users} --local_ep=$local_ep \
--wandb_entity ${wandb_entity} --project_name FuseFL --enable_wandb $enable_wandb --wandb_offline False \
--wandb_key '80702ded3cdc00fb5532f8f21e2ebabb3d2b1b22' --num_classes=${num_classes}




Process is terminated.
