In [3]:
from scipy import test
import torch
import torch.cuda
from torch import nn
from torch.nn import functional as F
import argparse
import gc
import itertools
import numpy as np
import os
import sys
import time
import pickle
from copy import deepcopy

from tqdm import tqdm
import warnings
import copy

import wandb

from datasets import get_dataset
from models.models import all_models

from client import Client
from utils import *

import fedsnip_obj

rng = np.random.default_rng()

def device_list(x):
    if x == 'cpu':
        return [x]
    return [int(y) for y in x.split(',')]

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument('--eta', type=float, help='learning rate', default=0.01)
parser.add_argument('--clients', type=int, help='number of clients per round', default=20)
parser.add_argument('--rounds', type=int, help='number of global rounds', default=400)
parser.add_argument('--epochs', type=int, help='number of local epochs', default=10)
parser.add_argument('--dataset', type=str, choices=('mnist', 'emnist', 'cifar10', 'cifar100'),
                    default='mnist', help='Dataset to use')
parser.add_argument('--distribution', type=str, choices=('dirichlet', 'lotteryfl', 'iid', 'classic_iid'), default='dirichlet',
                    help='how should the dataset be distributed?')
parser.add_argument('--beta', type=float, default=0.1, help='Beta parameter (unbalance rate) for Dirichlet distribution')
parser.add_argument('--total-clients', type=int, help='split the dataset between this many clients. Ignored for EMNIST.', default=400)
parser.add_argument('--min-samples', type=int, default=0, help='minimum number of samples required to allow a client to participate')
parser.add_argument('--samples-per-client', type=int, default=20, help='samples to allocate to each client (per class, for lotteryfl, or per client, for iid)')
parser.add_argument('--prox', type=float, default=0, help='coefficient to proximal term (i.e. in FedProx)')

parser.add_argument('--batch-size', type=int, default=32,
                    help='local client batch size')
parser.add_argument('--l2', default=0.001, type=float, help='L2 regularization strength')
parser.add_argument('--momentum', default=0.9, type=float, help='Local client SGD momentum parameter')
parser.add_argument('--cache-test-set', default=False, action='store_true', help='Load test sets into memory')
parser.add_argument('--cache-test-set-gpu', default=False, action='store_true', help='Load test sets into GPU memory')
parser.add_argument('--test-batches', default=0, type=int, help='Number of minibatches to test on, or 0 for all of them')
parser.add_argument('--eval-every', default=1, type=int, help='Evaluate on test set every N rounds')
parser.add_argument('--device', default='0', type=device_list, help='Device to use for compute. Use "cpu" to force CPU. Otherwise, separate with commas to allow multi-GPU.')
parser.add_argument('--no-eval', default=True, action='store_false', dest='eval')
parser.add_argument('-o', '--outfile', default='output.log', type=argparse.FileType('a', encoding='ascii'))


parser.add_argument('--model', type=str, choices=('VGG11_BN', 'VGG_SNIP', 'CNNNet'),
                    default='VGG11_BN', help='Dataset to use')

parser.add_argument('--prune_strategy', type=str, choices=('None', 'SNIP'),
                    default='None', help='Dataset to use')
parser.add_argument('--prune_at_first_round', default=False, action='store_true', dest='prune_at_first_round')
parser.add_argument('--keep_ratio', type=float, default=0.0,
                    help='local client batch size')    
parser.add_argument('--prune_vote', type=int, default=1,
                    help='local client batch size')

_StoreAction(option_strings=['--prune_vote'], dest='prune_vote', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='local client batch size', metavar=None)

In [5]:
args = parser.parse_args(args=['--dataset', 'cifar10', 
                               '--eta', '0.01', 
                               '--device', '3', 
                               '--distribution', 'classic_iid', 
                               '--total-clients', '10', 
                               '--clients', '10', 
                               '--batch-size', '64', 
                               '--rounds', '100', 
                               '--model', 'VGG11_BN', 
                               '--prune_strategy', 'SNIP',
                               '--epochs', '2',
                               '--keep_ratio', '0.1',
                               '--prune_vote', '1',
                               '--prune_at_first_round'])

In [6]:
run = fedsnip_obj.main(args)

In [5]:
# import os
# print('...')
# name = input()
# print('++++')

In [7]:
# while True:
#     debug_info = next(run)
#     print(debug_info.msg)
    
    
    
#     input()
debug_info = next(run)
(model_list, server) = debug_info.obj
# global_model = debug_info.obj
# aggregated_masks = debug_info.obj[0]
# cl_mask_prarms = debug_info.obj[1]

Fetching dataset...
INFO:root:*********partition data***************


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:traindata_cls_counts = {0: {0: 503, 1: 483, 2: 487, 3: 527, 4: 473, 5: 548, 6: 491, 7: 502, 8: 482, 9: 504}, 1: {0: 501, 1: 503, 2: 529, 3: 503, 4: 494, 5: 468, 6: 462, 7: 497, 8: 507, 9: 536}, 2: {0: 510, 1: 540, 2: 498, 3: 461, 4: 516, 5: 483, 6: 502, 7: 515, 8: 492, 9: 483}, 3: {0: 487, 1: 487, 2: 493, 3: 498, 4: 523, 5: 482, 6: 500, 7: 487, 8: 495, 9: 548}, 4: {0: 505, 1: 480, 2: 488, 3: 488, 4: 480, 5: 509, 6: 537, 7: 490, 8: 490, 9: 533}, 5: {0: 504, 1: 480, 2: 511, 3: 494, 4: 530, 5: 509, 6: 521, 7: 489, 8: 490, 9: 472}, 6: {0: 512, 1: 509, 2: 512, 3: 488, 4: 488, 5: 465, 6: 490, 7: 538, 8: 506, 9: 492}, 7: {0: 472, 1: 521, 2: 463, 3: 549, 4: 502, 5: 515, 6: 518, 7: 496, 8: 479, 9: 485}, 8: {0: 505, 1: 485, 2: 487, 3: 505, 4: 497, 5: 512, 6: 508, 7: 501, 8: 531, 9: 469}, 9: {0: 501, 1: 512, 2: 532, 3: 487, 4: 497, 5: 509, 6: 471, 7: 485, 8: 528, 9: 478}}


10000******************
download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:train_dl_global number = 782
INFO:root:test_dl_global number = 157
INFO:root:client_idx = 0, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 0, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 1, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 1, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 2, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 2, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 3, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 3, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 4, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 4, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 5, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 5, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 6, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 6, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 7, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 7, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 8, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 8, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 9, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 9, batch_num_train_local = 79, batch_num_test_local = 16
Initializing clients...
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mslimfun[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


client: 3 **************
all params num: 9747136; num_params_to_keep: 8772422


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor(8772422, device='cuda:3')
client.net: 0.10096911377349008
**********test before train*************
test global model: 0.10096911377349008
Test client 3: Accuracy: 0.09779999405145645; Loss: 2.30259370803833; Total: 5000.0;
accuracy: 0.09779999405145645; loss: 2.30259370803833
**********test before train*************
running loss: 2.0603887551947486
running loss: 1.8130085649369638
total:  10000.0
client: 6 **************
all params num: 9747136; num_params_to_keep: 8772422
tensor(8772422, device='cuda:3')
client.net: 0.10096921616918061
**********test before train*************
test global model: 0.10096921616918061
Test client 6: Accuracy: 0.09299999475479126; Loss: 2.3026037216186523; Total: 5000.0;
accuracy: 0.09299999475479126; loss: 2.3026037216186523
**********test before train*************
running loss: 2.0807727424404288
running loss: 1.817297027080874
total:  10000.0
client: 4 **************
all params num: 9747136; num_params_to_keep: 8772422
tensor(8772422, device='cud

In [8]:
for name in model_list[0][1]:
    print(name)

features.0.weight
features.0.bias
features.1.weight
features.1.bias
features.1.running_mean
features.1.running_var
features.1.num_batches_tracked
features.4.weight
features.4.bias
features.5.weight
features.5.bias
features.5.running_mean
features.5.running_var
features.5.num_batches_tracked
features.8.weight
features.8.bias
features.9.weight
features.9.bias
features.9.running_mean
features.9.running_var
features.9.num_batches_tracked
features.11.weight
features.11.bias
features.12.weight
features.12.bias
features.12.running_mean
features.12.running_var
features.12.num_batches_tracked
features.15.weight
features.15.bias
features.16.weight
features.16.bias
features.16.running_mean
features.16.running_var
features.16.num_batches_tracked
features.18.weight
features.18.bias
features.19.weight
features.19.bias
features.19.running_mean
features.19.running_var
features.19.num_batches_tracked
features.22.weight
features.22.bias
features.23.weight
features.23.bias
features.23.running_mean
featur

In [20]:
# for name in server.model.state_dict():
#     print(name)
server.model = server.model.to('cuda:3')
params = server.model.cpu().state_dict()
print(params['features.0.weight'])

print('***********'*3)

# params['features.0.weight'] = torch.zeros_like(params['features.0.weight'])
params['features.0.weight'][0][0][0] = 2.

print(server.model.state_dict()['features.0.weight'])


print(params['features.0.weight'])

tensor([[[[ 1.0000,  1.0000,  1.0000],
          [ 0.1049, -0.1265, -0.1936],
          [-0.0625, -0.0272, -0.0316]],

         [[ 0.0058,  0.1236, -0.0714],
          [-0.0846, -0.1164, -0.0325],
          [ 0.0175,  0.0043, -0.0027]],

         [[ 0.0998,  0.0625,  0.1206],
          [-0.0531,  0.0409,  0.0927],
          [ 0.0288, -0.0765, -0.0134]]],


        [[[ 0.0344, -0.0587,  0.0657],
          [ 0.0144,  0.0120, -0.0016],
          [-0.0919,  0.0136, -0.1024]],

         [[-0.0150, -0.0258,  0.0599],
          [ 0.0326,  0.0198, -0.1006],
          [-0.0442, -0.0429,  0.0281]],

         [[ 0.0212,  0.0103,  0.0838],
          [ 0.0230,  0.0054,  0.0017],
          [ 0.0207, -0.0553,  0.0938]]],


        [[[ 0.0452, -0.0017,  0.0452],
          [-0.1510, -0.0091,  0.0583],
          [-0.0034, -0.0743,  0.0394]],

         [[ 0.0248, -0.0313,  0.0188],
          [-0.0356,  0.0159,  0.0717],
          [ 0.0557, -0.0710, -0.0214]],

         [[ 0.0652,  0.0156,  0.1431],
     

In [7]:
pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

global model zero params: 0.6850283984208128


In [8]:
prune_c = 0.
total = 0.
for name,mask in global_model.mask.items():
#     print(mask)
    prune_c += sum(np.where(mask.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += mask.numel()
    
print(prune_c / total)

0.6863566898009836


In [9]:
import copy

pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

gcp_model = copy.deepcopy(global_model)



prune_c = 0.
total = 0.
for name, params in gcp_model.state_dict().items():
#     print(name)
    prune_c += sum(np.where(params.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += params.numel()
    
print(prune_c / total)
    
# print('*'*10)
# for name, mask in global_model.mask.items():
#     print(name)

global model zero params: 0.6850283984208128
0.6850283984208128


In [10]:
# for name in aggregated_masks:
#     print(name)
#     print(aggregated_masks[name].dtype)

def count_mask(state_dict):
    non_zero = 0.
    total = 0.
    for name in state_dict:
        non_zero += torch.count_nonzero(state_dict[name])
        total += state_dict[name].numel()
    return 1 - non_zero / total
# a = aggregated_masks['features.4.weight']
print(a.shape)
# print(torch.count_nonzero(a))
print(count_mask(aggregated_masks))
mask = torch.where(a>=1, 1, 0)
# print(torch.count_nonzero(mask))
print(count_mask(cl_mask_prarms))

(10,)


NameError: name 'aggregated_masks' is not defined