In [1]:
from scipy import test
import torch
import torch.cuda
from torch import nn
from torch.nn import functional as F
import argparse
import gc
import itertools
import numpy as np
import os
import sys
import time
import pickle
from copy import deepcopy

from tqdm import tqdm
import warnings
import copy

import wandb

from datasets import get_dataset
from models.models import all_models

from client import Client
from utils import *

import fedsnip_obj

rng = np.random.default_rng()

def device_list(x):
    if x == 'cpu':
        return [x]
    return [int(y) for y in x.split(',')]

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--eta', type=float, help='learning rate', default=0.01)
parser.add_argument('--clients', type=int, help='number of clients per round', default=20)
parser.add_argument('--rounds', type=int, help='number of global rounds', default=400)
parser.add_argument('--epochs', type=int, help='number of local epochs', default=10)
parser.add_argument('--dataset', type=str, choices=('mnist', 'emnist', 'cifar10', 'cifar100'),
                    default='mnist', help='Dataset to use')
parser.add_argument('--distribution', type=str, choices=('dirichlet', 'lotteryfl', 'iid', 'classic_iid'), default='dirichlet',
                    help='how should the dataset be distributed?')
parser.add_argument('--beta', type=float, default=0.1, help='Beta parameter (unbalance rate) for Dirichlet distribution')
parser.add_argument('--total-clients', type=int, help='split the dataset between this many clients. Ignored for EMNIST.', default=400)
parser.add_argument('--min-samples', type=int, default=0, help='minimum number of samples required to allow a client to participate')
parser.add_argument('--samples-per-client', type=int, default=20, help='samples to allocate to each client (per class, for lotteryfl, or per client, for iid)')
parser.add_argument('--prox', type=float, default=0, help='coefficient to proximal term (i.e. in FedProx)')

parser.add_argument('--batch-size', type=int, default=32,
                    help='local client batch size')
parser.add_argument('--l2', default=0.001, type=float, help='L2 regularization strength')
parser.add_argument('--momentum', default=0.9, type=float, help='Local client SGD momentum parameter')
parser.add_argument('--cache-test-set', default=False, action='store_true', help='Load test sets into memory')
parser.add_argument('--cache-test-set-gpu', default=False, action='store_true', help='Load test sets into GPU memory')
parser.add_argument('--test-batches', default=0, type=int, help='Number of minibatches to test on, or 0 for all of them')
parser.add_argument('--eval-every', default=1, type=int, help='Evaluate on test set every N rounds')
parser.add_argument('--device', default='0', type=device_list, help='Device to use for compute. Use "cpu" to force CPU. Otherwise, separate with commas to allow multi-GPU.')
parser.add_argument('--no-eval', default=True, action='store_false', dest='eval')
parser.add_argument('-o', '--outfile', default='output.log', type=argparse.FileType('a', encoding='ascii'))


parser.add_argument('--model', type=str, choices=('VGG11_BN', 'VGG_SNIP', 'CNNNet'),
                    default='VGG11_BN', help='Dataset to use')

parser.add_argument('--prune_strategy', type=str, choices=('None', 'SNIP'),
                    default='None', help='Dataset to use')
parser.add_argument('--prune_at_first_round', default=False, action='store_true', dest='prune_at_first_round')
parser.add_argument('--keep_ratio', type=float, default=0.0,
                    help='local client batch size')    
parser.add_argument('--prune_vote', type=int, default=1,
                    help='local client batch size')

parser.add_argument('--single_shot_pruning', default=False, action='store_true', dest='single_shot_pruning')

_StoreTrueAction(option_strings=['--single_shot_pruning'], dest='single_shot_pruning', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [3]:
args = parser.parse_args(args=['--dataset', 'cifar10', 
                               '--eta', '0.01', 
                               '--device', '3', 
                               '--distribution', 'classic_iid', 
                               '--total-clients', '10', 
                               '--clients', '10', 
                               '--batch-size', '64', 
                               '--rounds', '100', 
                               '--model', 'VGG11_BN', 
                               '--prune_strategy', 'SNIP',
                               '--epochs', '2',
                               '--keep_ratio', '0.9',
                               '--prune_vote', '1',
                               '--prune_at_first_round',
                               '--single_shot_pruning'])

In [4]:
run = fedsnip_obj.main(args)

In [5]:
# import os
# print('...')
# name = input()
# print('++++')

In [6]:
# while True:
#     debug_info = next(run)
#     print(debug_info.msg)
    
    
    
#     input()
debug_info = next(run)
(model_list, server) = debug_info.obj
# global_model = debug_info.obj
# aggregated_masks = debug_info.obj[0]
# cl_mask_prarms = debug_info.obj[1]

Fetching dataset...
INFO:root:*********partition data***************


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:traindata_cls_counts = {0: {0: 503, 1: 483, 2: 487, 3: 527, 4: 473, 5: 548, 6: 491, 7: 502, 8: 482, 9: 504}, 1: {0: 501, 1: 503, 2: 529, 3: 503, 4: 494, 5: 468, 6: 462, 7: 497, 8: 507, 9: 536}, 2: {0: 510, 1: 540, 2: 498, 3: 461, 4: 516, 5: 483, 6: 502, 7: 515, 8: 492, 9: 483}, 3: {0: 487, 1: 487, 2: 493, 3: 498, 4: 523, 5: 482, 6: 500, 7: 487, 8: 495, 9: 548}, 4: {0: 505, 1: 480, 2: 488, 3: 488, 4: 480, 5: 509, 6: 537, 7: 490, 8: 490, 9: 533}, 5: {0: 504, 1: 480, 2: 511, 3: 494, 4: 530, 5: 509, 6: 521, 7: 489, 8: 490, 9: 472}, 6: {0: 512, 1: 509, 2: 512, 3: 488, 4: 488, 5: 465, 6: 490, 7: 538, 8: 506, 9: 492}, 7: {0: 472, 1: 521, 2: 463, 3: 549, 4: 502, 5: 515, 6: 518, 7: 496, 8: 479, 9: 485}, 8: {0: 505, 1: 485, 2: 487, 3: 505, 4: 497, 5: 512, 6: 508, 7: 501, 8: 531, 9: 469}, 9: {0: 501, 1: 512, 2: 532, 3: 487, 4: 497, 5: 509, 6: 471, 7: 485, 8: 528, 9: 478}}


10000******************
download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:train_dl_global number = 782
INFO:root:test_dl_global number = 157
INFO:root:client_idx = 0, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 0, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 1, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 1, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 2, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 2, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 3, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 3, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 4, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 4, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 5, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 5, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 6, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 6, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 7, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 7, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 8, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 8, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 9, local_sample_number = 5000


download = True
Files already downloaded and verified
download = True
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 9, batch_num_train_local = 79, batch_num_test_local = 16
Initializing clients...
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mslimfun[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
It is me!!!
client: 6 **************
all params num: 9747136; num_params_to_keep: 974713


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor(974713, device='cuda:3')
client: 9 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 5 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974714, device='cuda:3')
client: 8 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 2 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974714, device='cuda:3')
client: 7 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 3 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 0 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 1 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974713, device='cuda:3')
client: 4 **************
all params num: 9747136; num_params_to_keep: 974713
tensor(974714, devi

In [10]:
import json

with open(f'applyed_masks.txt', 'r') as f:
    masks_str = f.readline()
#     print(masks_str)
    masks = json.loads(masks_str)
    for i in range(len(masks)):
        masks[i] = torch.from_numpy(np.asarray(masks[i]))
        
print(masks[0])

tensor([[[[10., 10., 10.],
          [ 7.,  9.,  9.],
          [ 9., 10., 10.]],

         [[10., 10., 10.],
          [ 9.,  9.,  9.],
          [10., 10., 10.]],

         [[10., 10.,  9.],
          [10., 10., 10.],
          [ 9.,  0., 10.]]],


        [[[10.,  9., 10.],
          [10., 10., 10.],
          [10., 10., 10.]],

         [[10.,  9., 10.],
          [10.,  5., 10.],
          [10., 10., 10.]],

         [[10., 10., 10.],
          [ 9., 10., 10.],
          [ 9., 10.,  9.]]],


        [[[ 6., 10., 10.],
          [10.,  9., 10.],
          [10.,  8., 10.]],

         [[10.,  1., 10.],
          [10., 10., 10.],
          [ 7., 10., 10.]],

         [[10., 10., 10.],
          [10.,  9.,  5.],
          [10., 10.,  9.]]],


        ...,


        [[[ 7., 10.,  8.],
          [ 7.,  5.,  9.],
          [ 8.,  8., 10.]],

         [[10.,  5., 10.],
          [ 8., 10., 10.],
          [10., 10.,  9.]],

         [[10.,  9., 10.],
          [ 0.,  8.,  9.],
          [ 

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.


In [None]:
for name in model_list[0][1]:
    print(name)
    print(model_list[0][1][name].dtype)
#     print(model_list[0][1][name])

In [None]:
pruned_c = 0.0
total = 0.0

for name, param in server.model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
# for name in server.model.state_dict():
#     print(name)
server.model = server.model.to('cuda:3')
params = server.model.cpu().state_dict()
print(params['features.0.weight'])

print('***********'*3)

# params['features.0.weight'] = torch.zeros_like(params['features.0.weight'])
params['features.0.weight'][0][0][0] = 2.

print(server.model.state_dict()['features.0.weight'])


print(params['features.0.weight'])

In [None]:
pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
prune_c = 0.
total = 0.
for name,mask in global_model.mask.items():
#     print(mask)
    prune_c += sum(np.where(mask.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += mask.numel()
    
print(prune_c / total)

In [None]:
import copy

pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

gcp_model = copy.deepcopy(global_model)



prune_c = 0.
total = 0.
for name, params in gcp_model.state_dict().items():
#     print(name)
    prune_c += sum(np.where(params.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += params.numel()
    
print(prune_c / total)
    
# print('*'*10)
# for name, mask in global_model.mask.items():
#     print(name)

In [None]:
# for name in aggregated_masks:
#     print(name)
#     print(aggregated_masks[name].dtype)

def count_mask(state_dict):
    non_zero = 0.
    total = 0.
    for name in state_dict:
        non_zero += torch.count_nonzero(state_dict[name])
        total += state_dict[name].numel()
    return 1 - non_zero / total
# a = aggregated_masks['features.4.weight']
print(a.shape)
# print(torch.count_nonzero(a))
print(count_mask(aggregated_masks))
mask = torch.where(a>=1, 1, 0)
# print(torch.count_nonzero(mask))
print(count_mask(cl_mask_prarms))

In [None]:
import copy
import torch
keep_masks = {0: [torch.Tensor([0, 1, 0, 1, 1])], 1: [torch.Tensor([1, 1, 0, 0, 1])]}

model_list = [[2, torch.Tensor([1,2,3,4,5])], [3, torch.Tensor([1,2,3,4,5])]]

masks = copy.deepcopy(keep_masks)
for c, m in masks.items():
    for i in range(len(m)):
        m[i] *= model_list[c][0]
    
print(masks)

total_masks = copy.deepcopy(masks)
for c,m in total_masks.items():
    if c!= 0:
        for i in range(len(m)):
            total_masks[0][i] += m[i]
        
print(total_masks)

for c, m in masks.items():
    for i in range(len(m)):
        m[i] /= total_masks[0][i]
        masks[c][i] = torch.where(torch.isnan(m[i]), torch.full_like(m[i], 0), m[i])
#     print(torch.isnan(m))
print(masks)
#     m /= sum(model_list[i][0] for i in range(len(model_list)))

# print(masks[0] * model_list[0][1])
model_list[0][1] *= masks[0]
# print(model_list[0][1])

In [None]:
import numpy as np

list_tensor = [torch.Tensor([1,2]), torch.Tensor([2,3])]
t = torch.stack(list_tensor)
list_tensor *= 2
list_tensor
# print(t)
# a = torch.from_numpy(np.asarray(list_tensor))
# print(a)
# torch.cat(list_tensor, 1)

In [None]:
a = torch.Tensor([2,3]).type(torch.float64)
b = torch.Tensor([1,2]).type(torch.float64)
print(a.dtype == b.dtype)
b.dtype

In [26]:

model = nn.Sequential(
          nn.Linear(6, 2, bias=False),
          nn.Sigmoid(),
        )
input = torch.randn(6)
target = torch.randn(2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.1)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()



output = model(input)
loss = criterion(output, target)
loss.backward()

print(list(model.parameters())[0].grad)
optimizer.zero_grad() 
print(list(model.parameters())[0].grad)


print("before step: ", list(model.parameters()))
optimizer.step()
print("after step: ", list(model.parameters()))

tensor([[ 0.0406,  0.3143, -0.1377,  0.1393, -0.3008,  0.2192],
        [ 0.0291,  0.2250, -0.0986,  0.0997, -0.2153,  0.1569]])
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
before step:  [Parameter containing:
tensor([[-0.2029, -0.3128, -0.3821, -0.3446, -0.0828,  0.2239],
        [ 0.2207, -0.3937,  0.2547, -0.3195, -0.0863, -0.1655]],
       requires_grad=True)]
after step:  [Parameter containing:
tensor([[-0.2027, -0.3125, -0.3817, -0.3442, -0.0827,  0.2237],
        [ 0.2205, -0.3933,  0.2544, -0.3191, -0.0862, -0.1654]],
       requires_grad=True)]
