In [1]:
from scipy import test
import torch
import torch.cuda
from torch import nn
from torch.nn import functional as F
import argparse
import gc
import itertools
import numpy as np
import os
import sys
import time
import pickle
from copy import deepcopy

from tqdm import tqdm
import warnings
import copy

import wandb

from datasets import get_dataset
from models.models import all_models

from client import Client
from utils import *

import fedsnip as fedsnip_obj

rng = np.random.default_rng()

def device_list(x):
    if x == 'cpu':
        return [x]
    return [int(y) for y in x.split(',')]

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--eta', type=float, help='learning rate', default=0.01)
parser.add_argument('--clients', type=int, help='number of clients per round', default=20)
parser.add_argument('--rounds', type=int, help='number of global rounds', default=400)
parser.add_argument('--epochs', type=int, help='number of local epochs', default=10)
parser.add_argument('--dataset', type=str, choices=('mnist', 'emnist', 'cifar10', 'cifar100'),
                    default='mnist', help='Dataset to use')
parser.add_argument('--distribution', type=str, choices=('dirichlet', 'lotteryfl', 'iid', 'classic_iid'), default='dirichlet',
                    help='how should the dataset be distributed?')
parser.add_argument('--beta', type=float, default=0.1, help='Beta parameter (unbalance rate) for Dirichlet distribution')
parser.add_argument('--total-clients', type=int, help='split the dataset between this many clients. Ignored for EMNIST.', default=400)
parser.add_argument('--min-samples', type=int, default=0, help='minimum number of samples required to allow a client to participate')
parser.add_argument('--samples-per-client', type=int, default=20, help='samples to allocate to each client (per class, for lotteryfl, or per client, for iid)')
parser.add_argument('--prox', type=float, default=0, help='coefficient to proximal term (i.e. in FedProx)')

parser.add_argument('--batch-size', type=int, default=32,
                    help='local client batch size')
parser.add_argument('--l2', default=1e-5, type=float, help='L2 regularization strength')
parser.add_argument('--momentum', default=0.9, type=float, help='Local client SGD momentum parameter')
parser.add_argument('--cache-test-set', default=False, action='store_true', help='Load test sets into memory')
parser.add_argument('--cache-test-set-gpu', default=False, action='store_true', help='Load test sets into GPU memory')
parser.add_argument('--test-batches', default=0, type=int, help='Number of minibatches to test on, or 0 for all of them')
parser.add_argument('--eval-every', default=1, type=int, help='Evaluate on test set every N rounds')
parser.add_argument('--device', default='0', type=device_list, help='Device to use for compute. Use "cpu" to force CPU. Otherwise, separate with commas to allow multi-GPU.')
parser.add_argument('--no-eval', default=True, action='store_false', dest='eval')
parser.add_argument('-o', '--outfile', default='output.log', type=argparse.FileType('a', encoding='ascii'))

parser.add_argument('--clip_grad', default=False, action='store_true', dest='clip_grad')

parser.add_argument('--model', type=str, choices=('VGG11_BN', 'VGG_SNIP', 'CNNNet', 'CIFAR10Net'),
                    default='VGG11_BN', help='Dataset to use')

parser.add_argument('--prune_strategy', type=str, choices=('None', 'SNIP', 'SNAP', 'random_masks', 'Iter-SNIP'),
                    default='None', help='Dataset to use')
parser.add_argument('--prune_at_first_round', default=False, action='store_true', dest='prune_at_first_round')
parser.add_argument('--keep_ratio', type=float, default=0.0,
                    help='local client batch size')         
parser.add_argument('--prune_vote', type=int, default=1,
                    help='local client batch size')

parser.add_argument('--single_shot_pruning', default=False, action='store_true', dest='single_shot_pruning')

parser.add_argument('--partition_method', type=str, default='homo', metavar='N',
                        help='how to partition the dataset on local workers')

parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                    help='partition alpha (default: 0.5)')

parser.add_argument('--target_keep_ratio', default=0.1, type=float, help='server target keep ratio')

parser.add_argument('--num_pruning_steps', type=int, help='total number of pruning steps')
parser.add_argument('--pruning_steps_decay_mode', type=str, default='linear', choices=('linear', 'exp'), help='pruning steps decay mode')

_StoreAction(option_strings=['--pruning_steps_decay_mode'], dest='pruning_steps_decay_mode', nargs=None, const=None, default='linear', type=<class 'str'>, choices=('linear', 'exp'), help='pruning steps decay mode', metavar=None)

In [3]:
args = parser.parse_args(args=['--dataset', 'cifar10', 
                               '--eta', '0.01', 
                               '--device', '3', 
                               '--distribution', 'classic_iid', 
                               '--total-clients', '10', 
                               '--clients', '10', 
                               '--batch-size', '64', 
                               '--rounds', '100', 
                               '--model', 'CNNNet', 
                               '--prune_strategy', 'Iter-SNIP',
                               '--epochs', '2',
                               '--keep_ratio', '0.1',
                               '--prune_vote', '1',
                               '--prune_at_first_round',
                               '--single_shot_pruning',
                               '--partition_method', 'hetero',
                               '--partition_alpha', '0.5',
                               '--target_keep_ratio', '0.2',
                               '--num_pruning_steps', '2',
                               '--pruning_steps_decay_mode', 'linear'])

In [4]:
run = fedsnip_obj.main(args)

Fetching dataset...
INFO:root:*********partition data***************


Files already downloaded and verified
Files already downloaded and verified


INFO:root:N = 50000
INFO:root:traindata_cls_counts = {0: {0: 953, 1: 142, 2: 141, 3: 75, 4: 695, 5: 819, 7: 2482}, 1: {0: 16, 1: 43, 2: 902, 3: 1650, 4: 86, 5: 182, 7: 693, 8: 110, 9: 6}, 2: {0: 9, 1: 8, 2: 290, 3: 769, 4: 841, 5: 283, 6: 119, 7: 1044, 8: 1014, 9: 58}, 3: {0: 395, 1: 1200, 2: 48, 3: 68, 4: 896, 5: 681, 6: 90, 7: 17, 8: 1351, 9: 301}, 4: {0: 504, 1: 2917, 2: 570, 3: 721, 4: 121, 5: 356}, 5: {0: 1262, 1: 71, 2: 325, 3: 119, 4: 1560, 5: 14, 6: 1, 7: 85, 8: 366}, 6: {0: 9, 1: 273, 2: 1657, 3: 40, 4: 1, 5: 130, 6: 1911, 7: 21, 8: 1160}, 7: {0: 722, 1: 3, 2: 281, 3: 738, 4: 22, 5: 974, 6: 624, 7: 1, 8: 82, 9: 878}, 8: {0: 1127, 1: 153, 2: 680, 3: 500, 4: 698, 5: 1139, 6: 92, 7: 23, 8: 1, 9: 3665}, 9: {0: 3, 1: 190, 2: 106, 3: 320, 4: 80, 5: 422, 6: 2163, 7: 634, 8: 916, 9: 92}}


10000******************
Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:train_dl_global number = 782
INFO:root:test_dl_global number = 157
INFO:root:client_idx = 0, local_sample_number = 5307


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 0, batch_num_train_local = 83, batch_num_test_local = 16
INFO:root:client_idx = 1, local_sample_number = 3688


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 1, batch_num_train_local = 58, batch_num_test_local = 16
INFO:root:client_idx = 2, local_sample_number = 4435


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 2, batch_num_train_local = 70, batch_num_test_local = 16
INFO:root:client_idx = 3, local_sample_number = 5047


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 3, batch_num_train_local = 79, batch_num_test_local = 16
INFO:root:client_idx = 4, local_sample_number = 5189


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 4, batch_num_train_local = 82, batch_num_test_local = 16
INFO:root:client_idx = 5, local_sample_number = 3803


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 5, batch_num_train_local = 60, batch_num_test_local = 16
INFO:root:client_idx = 6, local_sample_number = 5202


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 6, batch_num_train_local = 82, batch_num_test_local = 16
INFO:root:client_idx = 7, local_sample_number = 4325


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 7, batch_num_train_local = 68, batch_num_test_local = 16
INFO:root:client_idx = 8, local_sample_number = 8078


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 8, batch_num_train_local = 127, batch_num_test_local = 16
INFO:root:client_idx = 9, local_sample_number = 4926


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 9, batch_num_train_local = 77, batch_num_test_local = 16
Initializing clients...
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mslimfun[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


server model param size: 69007680
keep_ratio_steps: [0.6, 0.19999999999999996]
client: 9 **************
No post init specified in SNIP
client: 0 **************
No post init specified in SNIP
client: 2 **************
No post init specified in SNIP
client: 8 **************
No post init specified in SNIP
client: 6 **************
No post init specified in SNIP
client: 5 **************
No post init specified in SNIP
client: 4 **************
No post init specified in SNIP
client: 7 **************
No post init specified in SNIP
client: 3 **************
No post init specified in SNIP
client: 1 **************
No post init specified in SNIP
merge local masks
server masked 40.000009276988614% params
client: 7 **************
No post init specified in SNIP
client: 6 **************
No post init specified in SNIP
client: 0 **************
No post init specified in SNIP
client: 3 **************
No post init specified in SNIP
client: 4 **************
No post init specified in SNIP
client: 1 ************

dl_cost: 13801523.196330769; ul_cost: 13801523.196330769
torch.Size([32, 3, 5, 5])
torch.Size([64, 32, 5, 5])
torch.Size([512, 4096])
torch.Size([10, 512])
layer.shape: torch.Size([32, 3, 5, 5]); mask.shape: torch.Size([32, 3, 5, 5])
layer.shape: torch.Size([64, 32, 5, 5]); mask.shape: torch.Size([64, 32, 5, 5])
layer.shape: torch.Size([512, 4096]); mask.shape: torch.Size([512, 4096])
layer.shape: torch.Size([10, 512]); mask.shape: torch.Size([10, 512])
**********test before train*************
global model zero params: 0.7997709240478741
Test client 1: Accuracy: 0.029826464131474495; Loss: 2.3090176582336426; Total: 3688.0;
accuracy: 0.029826464131474495; loss: 2.3090176582336426
**********test before train*************
running loss: 1.9983615443624299
running loss: 1.6217773001769493
total: 0.0; time: 1.9632079601287842
client 1 finish params zeros: 0.7997709240478741
dl_cost: 13801523.196330769; ul_cost: 13801523.196330769
torch.Size([32, 3, 5, 5])
torch.Size([64, 32, 5, 5])
torch.Si

  0%|          | 0/10 [00:00<?, ?it/s]

Test client 0: Accuracy: 0.1543244868516922; Loss: 2.307075262069702; Total: 5307.0;


 10%|█         | 1/10 [00:01<00:15,  1.69s/it]

Test client 0: Accuracy: 0.11300000548362732; Loss: 2.3043153285980225; Total: 1000.0;
Test client 1: Accuracy: 0.04934924095869064; Loss: 2.2871999740600586; Total: 3688.0;


 20%|██        | 2/10 [00:03<00:13,  1.67s/it]

Test client 1: Accuracy: 0.09100000560283661; Loss: 2.3086414337158203; Total: 1000.0;
Test client 2: Accuracy: 0.06381060183048248; Loss: 2.307614326477051; Total: 4435.0;


 30%|███       | 3/10 [00:05<00:11,  1.68s/it]

Test client 2: Accuracy: 0.09100000560283661; Loss: 2.3111298084259033; Total: 1000.0;
Test client 3: Accuracy: 0.13493163883686066; Loss: 2.301328182220459; Total: 5047.0;


 40%|████      | 4/10 [00:06<00:10,  1.72s/it]

Test client 3: Accuracy: 0.11100000888109207; Loss: 2.311053514480591; Total: 1000.0;
Test client 4: Accuracy: 0.0686066672205925; Loss: 2.337390184402466; Total: 5189.0;


 50%|█████     | 5/10 [00:08<00:08,  1.72s/it]

Test client 4: Accuracy: 0.0990000069141388; Loss: 2.305586099624634; Total: 1000.0;
Test client 5: Accuracy: 0.0036813043989241123; Loss: 2.238199472427368; Total: 3803.0;


 60%|██████    | 6/10 [00:10<00:06,  1.69s/it]

Test client 5: Accuracy: 0.10700000822544098; Loss: 2.3068132400512695; Total: 1000.0;
Test client 6: Accuracy: 0.02499038726091385; Loss: 2.3249242305755615; Total: 5202.0;


 70%|███████   | 7/10 [00:11<00:05,  1.69s/it]

Test client 6: Accuracy: 0.10100000351667404; Loss: 2.305304527282715; Total: 1000.0;
Test client 7: Accuracy: 0.2252023071050644; Loss: 2.278991460800171; Total: 4325.0;


 80%|████████  | 8/10 [00:13<00:03,  1.70s/it]

Test client 7: Accuracy: 0.10100000351667404; Loss: 2.3132882118225098; Total: 1000.0;
Test client 8: Accuracy: 0.14100024104118347; Loss: 2.310753107070923; Total: 8078.0;


 90%|█████████ | 9/10 [00:15<00:01,  1.77s/it]

Test client 8: Accuracy: 0.09200000762939453; Loss: 2.3092803955078125; Total: 1000.0;
Test client 9: Accuracy: 0.08566788583993912; Loss: 2.363237142562866; Total: 4926.0;


100%|██████████| 10/10 [00:17<00:00,  1.73s/it]

Test client 9: Accuracy: 0.09400000423192978; Loss: 2.309983491897583; Total: 1000.0;
round: 0
Train/Acc : 0.0951564759016037; Train/Loss: 2.305671453475952;
Test/Acc : 0.10000000149011612; Test/Loss: 2.308539628982544;
download cost: [13801523.19633077 13801523.19633077 13801523.19633077 13801523.19633077
 13801523.19633077 13801523.19633077 13801523.19633077 13801523.19633077
 13801523.19633077 13801523.19633077]
upload cost: [13801523.19633077 13801523.19633077 13801523.19633077 13801523.19633077
 13801523.19633077 13801523.19633077 13801523.19633077 13801523.19633077
 13801523.19633077 13801523.19633077]





torch.Size([32, 3, 5, 5])
torch.Size([64, 32, 5, 5])
torch.Size([512, 4096])
torch.Size([10, 512])
layer.shape: torch.Size([32, 3, 5, 5]); mask.shape: torch.Size([32, 3, 5, 5])
layer.shape: torch.Size([64, 32, 5, 5]); mask.shape: torch.Size([64, 32, 5, 5])
layer.shape: torch.Size([512, 4096]); mask.shape: torch.Size([512, 4096])
layer.shape: torch.Size([10, 512]); mask.shape: torch.Size([10, 512])
**********test before train*************
global model zero params: 0.7997709240478741
Test client 7: Accuracy: 0.2252023071050644; Loss: 2.278991460800171; Total: 4325.0;
accuracy: 0.2252023071050644; loss: 2.278991460800171
**********test before train*************
running loss: 2.147693043245989
running loss: 1.9724745242034687
total: 0.0; time: 2.0595459938049316
client 7 finish params zeros: 0.7997709240478741
dl_cost: 13801523.196330769; ul_cost: 13801523.196330769
torch.Size([32, 3, 5, 5])
torch.Size([64, 32, 5, 5])
torch.Size([512, 4096])
torch.Size([10, 512])
layer.shape: torch.Size([3

global model zero params: 0.7997709240478741
Test client 3: Accuracy: 0.13493163883686066; Loss: 2.301328182220459; Total: 5047.0;
accuracy: 0.13493163883686066; loss: 2.301328182220459
**********test before train*************
running loss: 2.1147086092188387
running loss: 1.9188398122787476
total: 0.0; time: 2.1696622371673584
client 3 finish params zeros: 0.7997709240478741
dl_cost: 13801523.196330769; ul_cost: 13801523.196330769
server masked 80.00001855397723% params


  0%|          | 0/10 [00:00<?, ?it/s]

Test client 0: Accuracy: 0.1543244868516922; Loss: 2.3113179206848145; Total: 5307.0;


 10%|█         | 1/10 [00:01<00:14,  1.62s/it]

Test client 0: Accuracy: 0.11300000548362732; Loss: 2.3070390224456787; Total: 1000.0;
Test client 1: Accuracy: 0.04934924095869064; Loss: 2.2817492485046387; Total: 3688.0;


 20%|██        | 2/10 [00:03<00:13,  1.63s/it]

Test client 1: Accuracy: 0.09100000560283661; Loss: 2.312185287475586; Total: 1000.0;
Test client 2: Accuracy: 0.06381060183048248; Loss: 2.311671733856201; Total: 4435.0;


 30%|███       | 3/10 [00:04<00:11,  1.65s/it]

Test client 2: Accuracy: 0.09100000560283661; Loss: 2.315244436264038; Total: 1000.0;
Test client 3: Accuracy: 0.13493163883686066; Loss: 2.3036937713623047; Total: 5047.0;


 40%|████      | 4/10 [00:06<00:10,  1.69s/it]

Test client 3: Accuracy: 0.11100000888109207; Loss: 2.3154802322387695; Total: 1000.0;
Test client 4: Accuracy: 0.0686066672205925; Loss: 2.3443939685821533; Total: 5189.0;


 50%|█████     | 5/10 [00:08<00:08,  1.73s/it]

Test client 4: Accuracy: 0.0990000069141388; Loss: 2.308368682861328; Total: 1000.0;
Test client 5: Accuracy: 0.0036813043989241123; Loss: 2.2221508026123047; Total: 3803.0;


 60%|██████    | 6/10 [00:10<00:06,  1.72s/it]

Test client 5: Accuracy: 0.10700000822544098; Loss: 2.309918165206909; Total: 1000.0;


 60%|██████    | 6/10 [00:10<00:06,  1.73s/it]


KeyboardInterrupt: 

In [None]:
# import os
# print('...')
# name = input()
# print('++++')

In [None]:
# while True:
#     debug_info = next(run)
#     print(debug_info.msg)
    
    
    
#     input()
debug_info = next(run)
masks = debug_info.obj
# global_model = debug_info.obj
# aggregated_masks = debug_info.obj[0]
# cl_mask_prarms = debug_info.obj[1]

In [None]:
masks = server.masks
torch.cat(masks)

In [None]:
server = masks
def count_vote(masks, vote):
    tc = 0.
    keeped = 0.
    for i in range(len(masks)):
        tc += masks[i].numel()
        m = masks[i].clone().detach()
        keeped += torch.sum(torch.where(m >= vote, 1, 0))

    print("keeped: {}, total: {}, pct: {}".format(keeped, tc, keeped/tc))

In [None]:
for i in range(10):
    print('vote {}'.format(i))
    count_vote(server.masks, i)

In [None]:
masks = server.masks
for i in range(len(masks)):
    print(masks[i].shape)
flat_masks = torch.cat([m.flatten() for m in masks])
keep_num = len(flat_masks) * 0.1
threshold, indices = flat_masks.topk(int(keep_num))
global_masks = torch.zeros_like(flat_masks)
global_masks[indices] = 1
print(indices.sort())
# len(indices)
print(global_masks[:25])
print(global_masks.sum())

idx = 0
ms = []
for i in range(len(masks)):
    m = global_masks[idx:idx+masks[i].numel()].reshape(masks[i].size())
    idx += masks[i].numel()
    print(m.shape)
    print(m)


In [None]:
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6,7,8,9])
ts = torch.cat([e.flatten() for e in [a,b,c]])
print(ts)


In [None]:
pruned_c = 0.0
total = 0.0

for name, param in server.model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
# for name in server.model.state_dict():
#     print(name)
server.model = server.model.to('cuda:3')
params = server.model.cpu().state_dict()
print(params['features.0.weight'])

print('***********'*3)

# params['features.0.weight'] = torch.zeros_like(params['features.0.weight'])
params['features.0.weight'][0][0][0] = 2.

print(server.model.state_dict()['features.0.weight'])


print(params['features.0.weight'])

In [None]:
pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
prune_c = 0.
total = 0.
for name,mask in global_model.mask.items():
#     print(mask)
    prune_c += sum(np.where(mask.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += mask.numel()
    
print(prune_c / total)

In [None]:
1 / 0

In [None]:
import copy

pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

gcp_model = copy.deepcopy(global_model)



prune_c = 0.
total = 0.
for name, params in gcp_model.state_dict().items():
#     print(name)
    prune_c += sum(np.where(params.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += params.numel()
    
print(prune_c / total)
    
# print('*'*10)
# for name, mask in global_model.mask.items():
#     print(name)

In [None]:
# for name in aggregated_masks:
#     print(name)
#     print(aggregated_masks[name].dtype)

def count_mask(state_dict):
    non_zero = 0.
    total = 0.
    for name in state_dict:
        non_zero += torch.count_nonzero(state_dict[name])
        total += state_dict[name].numel()
    return 1 - non_zero / total
# a = aggregated_masks['features.4.weight']
print(a.shape)
# print(torch.count_nonzero(a))
print(count_mask(aggregated_masks))
mask = torch.where(a>=1, 1, 0)
# print(torch.count_nonzero(mask))
print(count_mask(cl_mask_prarms))

In [None]:
import copy
import torch
keep_masks = {0: [torch.Tensor([0, 1, 0, 1, 1])], 1: [torch.Tensor([1, 1, 0, 0, 1])]}

model_list = [[2, torch.Tensor([1,2,3,4,5])], [3, torch.Tensor([1,2,3,4,5])]]

masks = copy.deepcopy(keep_masks)
for c, m in masks.items():
    for i in range(len(m)):
        m[i] *= model_list[c][0]
    
print(masks)

total_masks = copy.deepcopy(masks)
for c,m in total_masks.items():
    if c!= 0:
        for i in range(len(m)):
            total_masks[0][i] += m[i]
        
print(total_masks)

for c, m in masks.items():
    for i in range(len(m)):
        m[i] /= total_masks[0][i]
        masks[c][i] = torch.where(torch.isnan(m[i]), torch.full_like(m[i], 0), m[i])
#     print(torch.isnan(m))
print(masks)
#     m /= sum(model_list[i][0] for i in range(len(model_list)))

# print(masks[0] * model_list[0][1])
model_list[0][1] *= masks[0]
# print(model_list[0][1])

In [None]:
import numpy as np

list_tensor = [torch.Tensor([1,2]), torch.Tensor([2,3])]
t = torch.stack(list_tensor)
list_tensor *= 2
list_tensor
# print(t)
# a = torch.from_numpy(np.asarray(list_tensor))
# print(a)
# torch.cat(list_tensor, 1)

In [None]:
a = torch.Tensor([2,3]).type(torch.float64)
b = torch.Tensor([1,2]).type(torch.float64)
print(a.dtype == b.dtype)
b.dtype

In [None]:

model = nn.Sequential(
          nn.Linear(6, 2, bias=False),
          nn.Sigmoid(),
        )
input = torch.randn(6)
target = torch.randn(2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.1)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()



output = model(input)
loss = criterion(output, target)
loss.backward()

print(list(model.parameters())[0].grad)
optimizer.zero_grad() 
print(list(model.parameters())[0].grad)


print("before step: ", list(model.parameters()))
optimizer.step()
print("after step: ", list(model.parameters()))

In [None]:
target_keep_ratio = 0.2
num_pruning_steps = 2

keep_ratio_steps = [1 - ((x + 1) * (1 - target_keep_ratio) / num_pruning_steps) for x in range(num_pruning_steps)]
keep_ratio_steps