In [1]:
from scipy import test
import torch
import torch.cuda
from torch import nn
from torch.nn import functional as F
import argparse
import gc
import itertools
import numpy as np
import os
import sys
import time
import pickle
from copy import deepcopy

from tqdm import tqdm
import warnings
import copy

import wandb

from datasets import get_dataset
from models.models import all_models

from client import Client
from utils import *

import fedsnip as fedsnip_obj

rng = np.random.default_rng()

def device_list(x):
    if x == 'cpu':
        return [x]
    return [int(y) for y in x.split(',')]

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [7]:
parser = argparse.ArgumentParser()
parser.add_argument('--eta', type=float, help='learning rate', default=0.01)
parser.add_argument('--clients', type=int, help='number of clients per round', default=20)
parser.add_argument('--rounds', type=int, help='number of global rounds', default=400)
parser.add_argument('--epochs', type=int, help='number of local epochs', default=10)
parser.add_argument('--dataset', type=str, choices=('mnist', 'emnist', 'cifar10', 'cifar100'),
                    default='mnist', help='Dataset to use')
parser.add_argument('--distribution', type=str, choices=('dirichlet', 'lotteryfl', 'iid', 'classic_iid'), default='dirichlet',
                    help='how should the dataset be distributed?')
parser.add_argument('--beta', type=float, default=0.1, help='Beta parameter (unbalance rate) for Dirichlet distribution')
parser.add_argument('--total-clients', type=int, help='split the dataset between this many clients. Ignored for EMNIST.', default=400)
parser.add_argument('--min-samples', type=int, default=0, help='minimum number of samples required to allow a client to participate')
parser.add_argument('--samples-per-client', type=int, default=20, help='samples to allocate to each client (per class, for lotteryfl, or per client, for iid)')
parser.add_argument('--prox', type=float, default=0, help='coefficient to proximal term (i.e. in FedProx)')

parser.add_argument('--batch-size', type=int, default=32,
                    help='local client batch size')
parser.add_argument('--l2', default=1e-5, type=float, help='L2 regularization strength')
parser.add_argument('--momentum', default=0.9, type=float, help='Local client SGD momentum parameter')
parser.add_argument('--cache-test-set', default=False, action='store_true', help='Load test sets into memory')
parser.add_argument('--cache-test-set-gpu', default=False, action='store_true', help='Load test sets into GPU memory')
parser.add_argument('--test-batches', default=0, type=int, help='Number of minibatches to test on, or 0 for all of them')
parser.add_argument('--eval-every', default=1, type=int, help='Evaluate on test set every N rounds')
parser.add_argument('--device', default='0', type=device_list, help='Device to use for compute. Use "cpu" to force CPU. Otherwise, separate with commas to allow multi-GPU.')
parser.add_argument('--no-eval', default=True, action='store_false', dest='eval')
parser.add_argument('-o', '--outfile', default='output.log', type=argparse.FileType('a', encoding='ascii'))

parser.add_argument('--clip_grad', default=False, action='store_true', dest='clip_grad')

parser.add_argument('--model', type=str, choices=('VGG11_BN', 'VGG_SNIP', 'CNNNet', 'CIFAR10Net'),
                    default='VGG11_BN', help='Dataset to use')

parser.add_argument('--prune_strategy', type=str, choices=('None', 'SNIP', 'SNAP', 'random_masks', 'Iter-SNIP'),
                    default='None', help='Dataset to use')
parser.add_argument('--prune_at_first_round', default=False, action='store_true', dest='prune_at_first_round')
parser.add_argument('--keep_ratio', type=float, default=0.0,
                    help='local client batch size')         
parser.add_argument('--prune_vote', type=int, default=1,
                    help='local client batch size')

parser.add_argument('--single_shot_pruning', default=False, action='store_true', dest='single_shot_pruning')

parser.add_argument('--partition_method', type=str, default='homo', metavar='N',
                        help='how to partition the dataset on local workers')

parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                    help='partition alpha (default: 0.5)')

parser.add_argument('--target_keep_ratio', default=0.1, type=float, help='server target keep ratio')

parser.add_argument('--num_pruning_steps', type=int, help='total number of pruning steps')
parser.add_argument('--pruning_steps_decay_mode', type=str, default='linear', choices=('linear', 'exp'), help='pruning steps decay mode')

_StoreAction(option_strings=['--pruning_steps_decay_mode'], dest='pruning_steps_decay_mode', nargs=None, const=None, default='linear', type=<class 'str'>, choices=('linear', 'exp'), help='pruning steps decay mode', metavar=None)

In [8]:
args = parser.parse_args(args=['--dataset', 'cifar10', 
                               '--eta', '0.01', 
                               '--device', '3', 
                               '--distribution', 'classic_iid', 
                               '--total-clients', '10', 
                               '--clients', '10', 
                               '--batch-size', '64', 
                               '--rounds', '100', 
                               '--model', 'CNNNet', 
                               '--prune_strategy', 'Iter-SNIP',
                               '--epochs', '2',
                               '--keep_ratio', '0.1',
                               '--prune_vote', '1',
                               '--prune_at_first_round',
                               '--single_shot_pruning',
                               '--partition_method', 'hetero',
                               '--partition_alpha', '0.5',
                               '--target_keep_ratio', '0.2',
                               '--num_pruning_steps', '2',
                               '--pruning_steps_decay_mode', 'linear'])

In [9]:
run = fedsnip_obj.main(args)

Fetching dataset...
INFO:root:*********partition data***************


Files already downloaded and verified
Files already downloaded and verified


INFO:root:N = 50000
INFO:root:traindata_cls_counts = {0: {1: 1088, 2: 6, 3: 250, 4: 298, 5: 2944, 6: 257, 7: 1, 8: 94, 9: 1033}, 1: {0: 314, 1: 128, 2: 594, 3: 44, 4: 977, 5: 650, 6: 16, 7: 1205, 8: 19, 9: 222}, 2: {0: 1059, 1: 1042, 2: 3, 3: 7, 4: 10, 5: 142, 6: 1428, 7: 958, 8: 432}, 3: {0: 2, 1: 683, 2: 423, 3: 111, 4: 61, 7: 43, 8: 849, 9: 2277}, 4: {0: 1062, 1: 34, 2: 1832, 3: 1, 4: 277, 5: 460, 6: 69, 7: 90, 8: 990, 9: 433}, 5: {0: 397, 1: 5, 2: 1416, 3: 1176, 4: 906, 5: 35, 6: 220, 7: 215, 8: 175, 9: 593}, 6: {0: 210, 1: 434, 2: 10, 3: 818, 4: 259, 6: 1312, 7: 718, 8: 2183}, 7: {0: 265, 1: 1329, 2: 11, 3: 18, 4: 521, 5: 426, 6: 22, 7: 1624, 8: 156, 9: 407}, 8: {0: 1380, 1: 64, 2: 206, 3: 213, 4: 1683, 5: 1, 6: 47, 7: 145, 8: 102, 9: 35}, 9: {0: 311, 1: 193, 2: 499, 3: 2362, 4: 8, 5: 342, 6: 1629, 7: 1}}


10000******************
Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:train_dl_global number = 782
INFO:root:test_dl_global number = 157
INFO:root:client_idx = 0, local_sample_number = 5971


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 0, batch_num_train_local = 94, batch_num_test_local = 16
INFO:root:client_idx = 1, local_sample_number = 4169


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 1, batch_num_train_local = 66, batch_num_test_local = 16
INFO:root:client_idx = 2, local_sample_number = 5081


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 2, batch_num_train_local = 80, batch_num_test_local = 16
INFO:root:client_idx = 3, local_sample_number = 4449


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 3, batch_num_train_local = 70, batch_num_test_local = 16
INFO:root:client_idx = 4, local_sample_number = 5248


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 4, batch_num_train_local = 82, batch_num_test_local = 16
INFO:root:client_idx = 5, local_sample_number = 5138


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 5, batch_num_train_local = 81, batch_num_test_local = 16
INFO:root:client_idx = 6, local_sample_number = 5944


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 6, batch_num_train_local = 93, batch_num_test_local = 16
INFO:root:client_idx = 7, local_sample_number = 4779


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 7, batch_num_train_local = 75, batch_num_test_local = 16
INFO:root:client_idx = 8, local_sample_number = 3876


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 8, batch_num_train_local = 61, batch_num_test_local = 16
INFO:root:client_idx = 9, local_sample_number = 5345


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 9, batch_num_train_local = 84, batch_num_test_local = 16
Initializing clients...


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


[34m[1mwandb[0m: wandb version 0.13.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


server model param size: 69007680
keep_ratio_steps: [0.55, 0.09999999999999998]


ZeroDivisionError: division by zero

In [None]:
# import os
# print('...')
# name = input()
# print('++++')

In [None]:
# while True:
#     debug_info = next(run)
#     print(debug_info.msg)
    
    
    
#     input()
debug_info = next(run)
masks = debug_info.obj
# global_model = debug_info.obj
# aggregated_masks = debug_info.obj[0]
# cl_mask_prarms = debug_info.obj[1]

In [None]:
masks = server.masks
torch.cat(masks)

In [None]:
server = masks
def count_vote(masks, vote):
    tc = 0.
    keeped = 0.
    for i in range(len(masks)):
        tc += masks[i].numel()
        m = masks[i].clone().detach()
        keeped += torch.sum(torch.where(m >= vote, 1, 0))

    print("keeped: {}, total: {}, pct: {}".format(keeped, tc, keeped/tc))

In [None]:
for i in range(10):
    print('vote {}'.format(i))
    count_vote(server.masks, i)

In [None]:
masks = server.masks
for i in range(len(masks)):
    print(masks[i].shape)
flat_masks = torch.cat([m.flatten() for m in masks])
keep_num = len(flat_masks) * 0.1
threshold, indices = flat_masks.topk(int(keep_num))
global_masks = torch.zeros_like(flat_masks)
global_masks[indices] = 1
print(indices.sort())
# len(indices)
print(global_masks[:25])
print(global_masks.sum())

idx = 0
ms = []
for i in range(len(masks)):
    m = global_masks[idx:idx+masks[i].numel()].reshape(masks[i].size())
    idx += masks[i].numel()
    print(m.shape)
    print(m)


In [None]:
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6,7,8,9])
ts = torch.cat([e.flatten() for e in [a,b,c]])
print(ts)


In [None]:
pruned_c = 0.0
total = 0.0

for name, param in server.model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
# for name in server.model.state_dict():
#     print(name)
server.model = server.model.to('cuda:3')
params = server.model.cpu().state_dict()
print(params['features.0.weight'])

print('***********'*3)

# params['features.0.weight'] = torch.zeros_like(params['features.0.weight'])
params['features.0.weight'][0][0][0] = 2.

print(server.model.state_dict()['features.0.weight'])


print(params['features.0.weight'])

In [None]:
pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
prune_c = 0.
total = 0.
for name,mask in global_model.mask.items():
#     print(mask)
    prune_c += sum(np.where(mask.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += mask.numel()
    
print(prune_c / total)

In [None]:
1 / 0

In [None]:
import copy

pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

gcp_model = copy.deepcopy(global_model)



prune_c = 0.
total = 0.
for name, params in gcp_model.state_dict().items():
#     print(name)
    prune_c += sum(np.where(params.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += params.numel()
    
print(prune_c / total)
    
# print('*'*10)
# for name, mask in global_model.mask.items():
#     print(name)

In [None]:
# for name in aggregated_masks:
#     print(name)
#     print(aggregated_masks[name].dtype)

def count_mask(state_dict):
    non_zero = 0.
    total = 0.
    for name in state_dict:
        non_zero += torch.count_nonzero(state_dict[name])
        total += state_dict[name].numel()
    return 1 - non_zero / total
# a = aggregated_masks['features.4.weight']
print(a.shape)
# print(torch.count_nonzero(a))
print(count_mask(aggregated_masks))
mask = torch.where(a>=1, 1, 0)
# print(torch.count_nonzero(mask))
print(count_mask(cl_mask_prarms))

In [None]:
import copy
import torch
keep_masks = {0: [torch.Tensor([0, 1, 0, 1, 1])], 1: [torch.Tensor([1, 1, 0, 0, 1])]}

model_list = [[2, torch.Tensor([1,2,3,4,5])], [3, torch.Tensor([1,2,3,4,5])]]

masks = copy.deepcopy(keep_masks)
for c, m in masks.items():
    for i in range(len(m)):
        m[i] *= model_list[c][0]
    
print(masks)

total_masks = copy.deepcopy(masks)
for c,m in total_masks.items():
    if c!= 0:
        for i in range(len(m)):
            total_masks[0][i] += m[i]
        
print(total_masks)

for c, m in masks.items():
    for i in range(len(m)):
        m[i] /= total_masks[0][i]
        masks[c][i] = torch.where(torch.isnan(m[i]), torch.full_like(m[i], 0), m[i])
#     print(torch.isnan(m))
print(masks)
#     m /= sum(model_list[i][0] for i in range(len(model_list)))

# print(masks[0] * model_list[0][1])
model_list[0][1] *= masks[0]
# print(model_list[0][1])

In [None]:
import numpy as np

list_tensor = [torch.Tensor([1,2]), torch.Tensor([2,3])]
t = torch.stack(list_tensor)
list_tensor *= 2
list_tensor
# print(t)
# a = torch.from_numpy(np.asarray(list_tensor))
# print(a)
# torch.cat(list_tensor, 1)

In [None]:
a = torch.Tensor([2,3]).type(torch.float64)
b = torch.Tensor([1,2]).type(torch.float64)
print(a.dtype == b.dtype)
b.dtype

In [None]:

model = nn.Sequential(
          nn.Linear(6, 2, bias=False),
          nn.Sigmoid(),
        )
input = torch.randn(6)
target = torch.randn(2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.1)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()



output = model(input)
loss = criterion(output, target)
loss.backward()

print(list(model.parameters())[0].grad)
optimizer.zero_grad() 
print(list(model.parameters())[0].grad)


print("before step: ", list(model.parameters()))
optimizer.step()
print("after step: ", list(model.parameters()))

In [None]:
target_keep_ratio = 0.2
num_pruning_steps = 2

keep_ratio_steps = [1 - ((x + 1) * (1 - target_keep_ratio) / num_pruning_steps) for x in range(num_pruning_steps)]
keep_ratio_steps