In [None]:
#!/usr/bin/env python
# -*-coding:utf-8 -*-
'''
@File    :   train.py
@Time    :   2023/03/05 16:14:14
@Author  :   Bo
'''

import mnist_utils as mnist_utils
import numpy as np
import torch
import os
import torch.nn as nn
from tqdm import tqdm
import sys
import argparse
import time
import scipy
import math
import copy
import pandas as pd
from random import sample


device=torch.device("cpu")

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def give_args():
    parser = argparse.ArgumentParser(description='VAE-Reconstruction')
    parser.add_argument("--sigma", type=float, default=0)
    parser.add_argument("--n_clients", type=int, default=10)
    parser.add_argument("--split", type=str, default="by_cls")
    parser.add_argument("--shuffle_percentage", type=float, default=0.0)
    parser.add_argument("--seed_use", type=int, default=1024)
    parser.add_argument("--num_local_epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--method", type=str, default="check_zeta")
    parser.add_argument("--version", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--num_rounds", type=int, default=50)
    return parser.parse_args()


def get_model_grads(model):
    return [p.grad.data for _, p in model.named_parameters() if \
            hasattr(p, 'grad') and (p.grad is not None)]

def get_model_params(model):
    return [p.data for _, p in model.named_parameters() if \
            hasattr(p, 'grad') and (p.grad is not None)]


class Train(object):
    def __init__(self, conf, data_group, num_local_epochs, sigma, exist_model, version=0):
        self.conf = conf
        self.data_group = data_group
        self.num_local_epochs = num_local_epochs
        self.sigma = sigma
        self.model_use = mnist_utils.CLSMultiLayerModel(28*28).to(device)

        if conf.round > 0:
            self.model_use.load_state_dict(exist_model)

        self.optimizer = mnist_utils.define_optimizer(self.model_use, lr=conf.lr)
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')

        self.tr_data_loader, self.tt_data_loader = data_group
        self.version = version

        print("Training data size", len(self.tr_data_loader))
        print("Testing data size", len(self.tt_data_loader))

        parameter_list = [p for p in self.model_use.parameters() if p.requires_grad == False]
        assert len(parameter_list) == 0


    def get_grad(self):
        grad_group = []
        for name, p in self.model_use.named_parameters():
            if p.requires_grad and "bias" not in name:
                grad_group.append(np.reshape(p.grad.data.detach().cpu().numpy(), [-1]))
        return grad_group[0]

    def _update_batch_tr(self, _image, _label, global_step):
        _image, _label = _image.to(device), _label.to(device)
        prev_model = copy.deepcopy(self.model_use)
        self.optimizer.zero_grad()
        _pred = self.model_use(_image)
        _loss = self.loss_fn(_pred, _label) / len(_image)
        _loss.backward()

        self.optimizer.step()

        accu = (_pred.argmax(axis=-1) == _label).sum().div(len(_image))
        print("Training loss: {:.4f} and Training accuracy {:.2f}".format(_loss.item(), accu.item()))
        return self.get_grad()

        # Evaluation of the client model
    def _eval(self, global_step, data_use, str_use):
        self.model_use.eval()
        val_loss, val_accu = 0.0, 0.0
        for i, (_image, _label) in enumerate(data_use):
            _image, _label = _image.to(device), _label.to(device)

            _pred = self.model_use(_image)
            #if i == 0:
            #    print(_pred)
            #    print(_pred.shape)
            _loss = self.loss_fn(_pred, _label)
            _accu = (_pred.argmax(axis=-1) == _label).sum()
            val_loss += _loss.detach().cpu().numpy()
            val_accu += _accu.detach().cpu().numpy()
        print("{} loss: {:.4f} and {} accuracy {:.2f}".format(str_use, val_loss / len(data_use)/len(_image),
                                                              str_use, val_accu / len(data_use) / len(_image)))
        return val_loss, val_accu / len(data_use) / len(_image)

    def run(self):
        global_step = 0
        for j in range(self.num_local_epochs):
            grad_group = []
            for i, (_im, _la) in enumerate(self.tr_data_loader):
                _grad_group = self._update_batch_tr(_im, _la, global_step)
                grad_group.append(_grad_group)
                global_step += 1
                if global_step >= self.num_local_epochs * len(self.tr_data_loader):
                    # client model evaluation
                    _val_loss, _val_accu = self._eval(global_step, self.tt_data_loader, "test")
                    return self.model_use.state_dict()


def run_train(conf, tr_im, tr_la, local_id, exist_model, version=0):
    print("===========================================================")
    print("                    Local ID %02d " % local_id)
    print("===========================================================")

    tt_im, tt_la = mnist_utils.load_tt_im()

    conf.batch_size = len(tr_im)
    print("The used batch size", conf.batch_size)
    tr_loader = mnist_utils.get_dataloader(tr_im, tr_la, True, conf.batch_size)
    tt_loader = mnist_utils.get_dataloader(tt_im, tt_la)

    train_obj = Train(conf, [tr_loader, tt_loader], conf.num_local_epochs, conf.sigma, exist_model, version)
    client_model = train_obj.run()
    print("Done Local ID %02d" % local_id )
    return client_model

    # Checking model accuracy
def check_test_accuracy(model_checkpoints):
    tt_im, tt_la = mnist_utils.load_tt_im()
    tt_loader = mnist_utils.get_dataloader(tt_im, tt_la, shuffle=False, batch_size=100)
    # Uses a linear output layer to test model_checkpoint
    model_use = mnist_utils.CLSMultiLayerModel(28*28).to(device)
    model_use.load_state_dict(model_checkpoints)
    loss, accu = 0.0, 0.0
    indices = []
    labels = []
    pred_ls = []
    for i, (_im, _la) in enumerate(tt_loader):
        _im, _la = _im.to(device), _la.to(device)
        _pred = model_use(_im)

        _loss = nn.CrossEntropyLoss(reduction='sum')(_pred, _la) / len(_im)
        #print(_loss)
        _index = _pred.argmax(axis=-1)
        labels.append(_la.detach().cpu().numpy())
        indices.append(_index.detach().cpu().numpy())
        _accu = (_pred.argmax(axis=-1) == _la).sum()
        loss += _loss.detach().cpu().numpy()
        accu += _accu.detach().cpu().numpy()
        _pred = _pred.detach().cpu().numpy()
        for j in range(_pred.shape[0]):
            pred_ls.append(softmax(_pred[j,:]))

    loss = loss / len(tt_loader)
    accu = accu / len(tt_loader) / 100
    flat_index = [item for sublist in indices for item in sublist]
    flat_labels = [item for sublist in labels for item in sublist]
    print("Server model loss: %.4f and accuracy: %.4f" % (loss, accu))
    del model_use
    del _im
    del _la
    return loss, accu, np.array(flat_index), np.array(flat_labels), pred_ls

def check_train_accuracy(model_checkpoints):
    tr_im, tr_la = mnist_utils.split_dataset_to_workers(1, 'iid')
    combine_im, combine_la = [tr_im["worker_%02d" % i] for i in range(1)], [tr_la["worker_%02d" % i] for i in range(1)]
    tr_loader = mnist_utils.get_dataloader(combine_im[0], combine_la[0], shuffle=False, batch_size=100)
    # Uses a linear output layer to test model_checkpoint
    model_use = mnist_utils.CLSMultiLayerModel(28*28).to(device)
    model_use.load_state_dict(model_checkpoints)
    loss, accu = 0.0, 0.0
    indices = []
    labels = []
    pred_ls = []
    for i, (_im, _la) in enumerate(tr_loader):
        _im, _la = _im.to(device), _la.to(device)
        _pred = model_use(_im)

        _loss = nn.CrossEntropyLoss(reduction='sum')(_pred, _la) / len(_im)
        #print(_loss)
        _index = _pred.argmax(axis=-1)
        labels.append(_la.detach().cpu().numpy())
        indices.append(_index.detach().cpu().numpy())
        _accu = (_pred.argmax(axis=-1) == _la).sum()
        loss += _loss.detach().cpu().numpy()
        accu += _accu.detach().cpu().numpy()
        _pred = _pred.detach().cpu().numpy()
        for j in range(_pred.shape[0]):
            pred_ls.append(softmax(_pred[j,:]))

    loss = loss / len(tr_loader)
    accu = accu / len(tr_loader) / 100
    flat_index = [item for sublist in indices for item in sublist]
    flat_labels = [item for sublist in labels for item in sublist]
    print("Server model loss: %.4f and accuracy: %.4f" % (loss, accu))
    del model_use
    del _im
    del _la
    return loss, accu, np.array(flat_index), np.array(flat_labels), pred_ls


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=0)


# Calculating the variance of the (client) model
def client_predictions(model_checkpoints):
    tt_im, tt_la = mnist_utils.load_tt_im()
    tt_loader = mnist_utils.get_dataloader(tt_im, tt_la, shuffle=False, batch_size=100)
    pred_ls = []
    # Uses a linear output layer to test model_checkpoint
    model_use = mnist_utils.CLSMultiLayerModel(28 * 28).to(device)
    model_use.load_state_dict(model_checkpoints)
    for i, (_im, _la) in enumerate(tt_loader):
        _im, _la = _im.to(device), _la.to(device)
        _pred = model_use(_im).detach().cpu().numpy()
        # VERY un-optimal way of calculating the softmax, replace later
        for j in range(_pred.shape[0]):
            pred_ls.append(softmax(_pred[j,:]))
        #pred_ls[i,:] = np.exp(_pred) / np.sum(np.exp(_pred), axis=1)
    return pred_ls

def var_per_image(pred_matrix, pred_indices):
    """
    This function calculates the variance of the client image predictions for the server truth
    input: pred_matrix: (clients, images, classes)
           pred_indices: (index per image)
    output: (variance per image)
    """
    variance = []
    for i in range(pred_matrix.shape[1]):
        variance.append(np.var(pred_matrix[:, i, pred_indices[i]]))
    return np.array(variance)

def train_with_conf(conf):
    #stdoutOrigin = sys.stdout
    model_dir = "/content/drive/MyDrive/DL_Proj/Model_/"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_dir += "version_%02d" % conf.version
    conf.model_dir = model_dir

    # File locations
    p_write = conf.model_dir + ".txt"
    loss_write = conf.model_dir + "_loss.txt"
    acc_write = conf.model_dir + "_acc.txt"
    ind_write = conf.model_dir + "_ind.txt"
    pred_write = conf.model_dir + "_pred.txt"
    var_write = conf.model_dir + "_var.txt"

    print("The directory", p_write)
    #sys.stdout = open(p_write, 'w')

    # Recording accuracy, loss, indices, summed variance and general variance for easy plots
    lw = open(loss_write, 'w')
    aw = open(acc_write, 'w')
    iw = open(ind_write, 'w')
    pw = open(pred_write, 'w')
    vw = open(var_write, "w")


    tr_im, tr_la = mnist_utils.split_dataset_to_workers(conf.n_clients, conf.split)
    n_cl = conf.n_clients
    if conf.split != "iid" and conf.shuffle_percentage != 0:
        tr_im_sync, tr_la_sync, combine_im, combine_la = mnist_utils.shuffle_dataset(conf.n_clients, conf.shuffle_percentage,
                                                                               )
    else:
        combine_im, combine_la = [tr_im["worker_%02d" % i] for i in range(n_cl)], [tr_la["worker_%02d" % i] for i in range(n_cl)]

    print("The updated number of local epochs", conf.num_local_epochs, conf.shuffle_percentage)
    print("The class frequency per client")
    for i in range(conf.n_clients):
        print("worker-%02d" % i, np.unique(combine_la[i], return_counts=True))
    init_time = time.time()
    print("Initial time: ", init_time)

    seed_use = np.random.randint(0,100000, 1)[0]
    print("The used learning rate", conf.lr)
    print("The seed", seed_use)

 # This is the actual client - server algorithm structure
    # The algorithm runs for the NUMBER of communication ROUNDS
    for i in range(conf.num_rounds):
        weights_flag = True
        pred_matrix = []
        round_clients = sample(range(0, conf.n_clients), conf.participation)
        # In the first round we seed all the clients to the SAME random initialization (for optimization accuracy)
        if i == 0:
            conf.random_state = np.random.RandomState(seed_use)
            mnist_utils.seed_everything(seed_use)
        conf.round = i
        if i == 0:
            # specifies that there is no server model yet, so one is created by the first client
            exist_model = None
        for j in range(conf.n_clients):
            if i>0:
              if j not in round_clients:
                continue
            # Now, for every CLIENT in the ROUND, we train the model for a number of EPOCHS
            c_tr_im, c_tr_la = combine_im[j], combine_la[j]
            _model = run_train(conf, c_tr_im, c_tr_la, j, exist_model, conf.version)

            pred_matrix.append(client_predictions(_model))

            ### THIS IS THE IMPORTANT STEP ###

            # The client model weights are summed, weighted by the number of clients
            #   THis is the --FedAvg-- algorithm in the Google paper

            if weights_flag:
                weights_flag = False
                model_group = {}
                for k in _model.keys():
                    # In the first pass the new server weights need to be defined
                    model_group[k] = _model[k] * (1 / conf.participation)
            else:
                for k in _model.keys():
                    # Then for all the other ones it is just summed
                    model_group[k] += _model[k] * (1 / conf.participation)

            ######################################

        pred_matrix = np.array(pred_matrix)
        #print(pred_matrix)
        print('Prediction Matrix Shape: ', pred_matrix.shape)
        # The server model is tested
        exist_model = model_group
        lossk, accuk, indexx, labels, pred_soft_max = check_test_accuracy(exist_model)
        lossk_tr, accuk_tr, indexx_tr, labels_tr, pred_soft_max_tr = check_train_accuracy(exist_model)
        #print(indexx)
        #indexx = np.array(indexx.cpu().detach())
        #indexx = indexx.ravel()
        print('Indices Matrix shape: ', indexx.shape)
        #print(indexx)

        var = var_per_image(pred_matrix, indexx)
        print(var.shape)
        #print(var)
        #lw.write(str(lossk)+'\n')
        #aw.write(str(accuk)+'\n')
        #iw.write(str(indexx) + '\n')
        # Sum of the variance to see the reduction per round
        #pw.write(str(np.sum(var)) + '\n')

        saves_dir = "/content/drive/MyDrive/DL_Proj_temp/Results_/Version_%02d/Round_%02d/" % (conf.version, i)

        if not os.path.exists(saves_dir):
            os.makedirs(saves_dir)
        e = 0
        for i in range(conf.n_clients):
            if i not in round_clients:
              continue
            df = pd.DataFrame(pred_matrix[e,:,:])
            e = e + 1
            df.to_csv(saves_dir+('_client_%02d.csv' % i), index=False)

        df2 = pd.DataFrame(indexx)
        df2.to_csv(saves_dir+'_Index_Matrix.csv', index=False)
        df3 = pd.DataFrame(labels)
        df3.to_csv(saves_dir+'_Ground_Truth_Indices.csv', index=False)
        df4 = pd.DataFrame(pred_soft_max)
        df4.to_csv(saves_dir+'_server_.csv', index=False)


        df5 = pd.DataFrame(indexx_tr)
        df5.to_csv(saves_dir+'_Index_Matrix_train.csv', index=False)
        df6 = pd.DataFrame(labels_tr)
        df6.to_csv(saves_dir+'_Ground_Truth_train_Indices.csv', index=False)
        df7 = pd.DataFrame(pred_soft_max_tr)
        df7.to_csv(saves_dir+'_server_train_acc_.csv', index=False)
        #shutil.make_archive('Results_Round_%02d' % i, 'zip', saves_dir)

        with open(saves_dir+'accuracy.txt', 'w') as f:
          f.write('%f' % accuk_tr)






    end_time = time.time()
    print("End time", end_time - init_time)
    #sys.stdout.close()
    #sys.stdout = stdoutOrigin


# if __name__ == "__main__":
#     conf = give_args()
#     train_with_conf(conf)


In [None]:
import sys
import argparse

parser = argparse.ArgumentParser(description='VAE-Reconstruction')
parser.add_argument("--sigma", type=float, default=0)
parser.add_argument("--n_clients", type=int, default=10)
parser.add_argument("--participation", type=int, default=10)
parser.add_argument("--split", type=str, default="by_cls")
parser.add_argument("--shuffle_percentage", type=float, default=0.0)
parser.add_argument("--seed_use", type=int, default=1024)
parser.add_argument("--num_local_epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--method", type=str, default="check_zeta")
parser.add_argument("--version", type=int, default=0)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--num_rounds", type=int, default=50)

_StoreAction(option_strings=['--num_rounds'], dest='num_rounds', nargs=None, const=None, default=50, type=<class 'int'>, choices=None, required=False, help=None, metavar=None)

In [None]:
import warnings
warnings.filterwarnings("ignore")

conf = parser.parse_args(['--sigma', '0', '--n_clients', '10', '--participation', '9', '--split', 'by_class', '--shuffle_percentage', '0',
                          '--method', 'check_zeta','--version', '0', '--lr', '0.1', '--num_rounds', '80',
                          '--num_local_epochs', '10'])


train_with_conf(conf)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Training data size 1
Testing data size 1
Training loss: 1.6055 and Training accuracy 0.26
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
Training loss: 0.0000 and Training accuracy 1.00
test loss: 15.6090 and test accuracy 0.10
Done Local ID 08
The shape of the dataset (5120, 28, 28, 1) (5120,)
The length of the dataset 5120
                    Local ID 09 
The used batch size 1024
The shape of the dataset (1024, 28, 28, 1) (1024,)
The length of the dataset 1024
The shape of the dataset (5120, 28, 28, 1) (5120,)
The length of the dataset 5120
Training data size 1
Testing data size 1
