In [1]:
from scipy import test
import torch
import torch.cuda
from torch import nn
from torch.nn import functional as F
import argparse
import gc
import itertools
import numpy as np
import os
import sys
import time
import pickle
from copy import deepcopy

from tqdm import tqdm
import warnings
import copy

import wandb

from datasets import get_dataset
from models.models import all_models

from client import Client
from utils import *

import fedsnip as fedsnip_obj

rng = np.random.default_rng()

def device_list(x):
    if x == 'cpu':
        return [x]
    return [int(y) for y in x.split(',')]

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--eta', type=float, help='learning rate', default=0.01)
parser.add_argument('--clients', type=int, help='number of clients per round', default=20)
parser.add_argument('--rounds', type=int, help='number of global rounds', default=400)
parser.add_argument('--epochs', type=int, help='number of local epochs', default=10)
parser.add_argument('--dataset', type=str, choices=('mnist', 'emnist', 'cifar10', 'cifar100'),
                    default='mnist', help='Dataset to use')
parser.add_argument('--distribution', type=str, choices=('dirichlet', 'lotteryfl', 'iid', 'classic_iid'), default='dirichlet',
                    help='how should the dataset be distributed?')
parser.add_argument('--beta', type=float, default=0.1, help='Beta parameter (unbalance rate) for Dirichlet distribution')
parser.add_argument('--total-clients', type=int, help='split the dataset between this many clients. Ignored for EMNIST.', default=400)
parser.add_argument('--min-samples', type=int, default=0, help='minimum number of samples required to allow a client to participate')
parser.add_argument('--samples-per-client', type=int, default=20, help='samples to allocate to each client (per class, for lotteryfl, or per client, for iid)')
parser.add_argument('--prox', type=float, default=0, help='coefficient to proximal term (i.e. in FedProx)')

parser.add_argument('--batch-size', type=int, default=32,
                    help='local client batch size')
parser.add_argument('--l2', default=1e-5, type=float, help='L2 regularization strength')
parser.add_argument('--momentum', default=0.9, type=float, help='Local client SGD momentum parameter')
parser.add_argument('--cache-test-set', default=False, action='store_true', help='Load test sets into memory')
parser.add_argument('--cache-test-set-gpu', default=False, action='store_true', help='Load test sets into GPU memory')
parser.add_argument('--test-batches', default=0, type=int, help='Number of minibatches to test on, or 0 for all of them')
parser.add_argument('--eval-every', default=1, type=int, help='Evaluate on test set every N rounds')
parser.add_argument('--device', default='0', type=device_list, help='Device to use for compute. Use "cpu" to force CPU. Otherwise, separate with commas to allow multi-GPU.')
parser.add_argument('--no-eval', default=True, action='store_false', dest='eval')
parser.add_argument('-o', '--outfile', default='output.log', type=argparse.FileType('a', encoding='ascii'))

parser.add_argument('--clip_grad', default=False, action='store_true', dest='clip_grad')

parser.add_argument('--model', type=str, choices=('VGG11_BN', 'VGG_SNIP', 'CNNNet', 'CIFAR10Net'),
                    default='VGG11_BN', help='Dataset to use')

parser.add_argument('--prune_strategy', type=str, choices=('None', 'SNIP', 'SNAP', 'random_masks', 'Iter-SNIP', 'Grasp', 'PreCrop'),
                    default='None', help='Dataset to use')
parser.add_argument('--prune_at_first_round', default=False, action='store_true', dest='prune_at_first_round')
parser.add_argument('--keep_ratio', type=float, default=0.0,
                    help='local client batch size')         
parser.add_argument('--prune_vote', type=int, default=1,
                    help='local client batch size')

parser.add_argument('--single_shot_pruning', default=False, action='store_true', dest='single_shot_pruning')

parser.add_argument('--partition_method', type=str, default='homo', metavar='N',
                        help='how to partition the dataset on local workers')

parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                    help='partition alpha (default: 0.5)')

parser.add_argument('--target_keep_ratio', default=0.1, type=float, help='server target keep ratio')

parser.add_argument('--num_pruning_steps', type=int, help='total number of pruning steps')
parser.add_argument('--pruning_steps_decay_mode', type=str, default='linear', choices=('linear', 'exp'), help='pruning steps decay mode')
parser.add_argument('--saliency_mode', type=str, choices=('saliency', 'mask'))
parser.add_argument('--structure', type=bool, help='learning rate', default=False)

_StoreAction(option_strings=['--structure'], dest='structure', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, help='learning rate', metavar=None)

In [3]:
args = parser.parse_args(args=['--dataset', 'cifar10', 
                               '--eta', '0.01', 
                               '--device', '2', 
                               '--distribution', 'classic_iid', 
                               '--total-clients', '3', 
                               '--clients', '3', 
                               '--batch-size', '64', 
                               '--rounds', '100', 
                               '--model', 'VGG11_BN', 
                               '--prune_strategy', 'PreCrop',
                               '--epochs', '10',
                               '--keep_ratio', '0.1',
                               '--prune_vote', '1',
                               '--prune_at_first_round',
                               '--single_shot_pruning',
                               '--partition_method', 'hetero',
                               '--partition_alpha', '0.5',
                               '--target_keep_ratio', '0.1',
                               '--num_pruning_steps', '1',
                               '--pruning_steps_decay_mode', 'linear',
                               '--structure', 'True',
                               '--saliency_mode', 'mask'])

In [4]:
run = fedsnip_obj.main(args)

In [5]:
# import os
# print('...')
# name = input()
# print('++++')

In [6]:
# while True:
#     debug_info = next(run)
#     print(debug_info.msg)

    
#     input()
debug_info = next(run)
# masks = debug_info.obj
# global_model = debug_info.obj
# aggregated_masks = debug_info.obj[0]
# cl_mask_prarms = debug_info.obj[1]

Fetching dataset...
INFO:root:*********partition data***************


Files already downloaded and verified
Files already downloaded and verified


INFO:root:N = 50000
  test_idx = np.array(np.array_split(test_idx, n_nets))
INFO:root:traindata_cls_counts = {0: {0: 4873, 1: 261, 2: 1161, 3: 4184, 4: 1683, 5: 230, 6: 106, 7: 1919, 8: 3849}, 1: {0: 82, 1: 906, 2: 423, 3: 672, 4: 2128, 5: 313, 6: 4072, 7: 370, 8: 213, 9: 5000}, 2: {0: 45, 1: 3833, 2: 3416, 3: 144, 4: 1189, 5: 4457, 6: 822, 7: 2711, 8: 938}}


10000******************
Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:train_dl_global number = 782
INFO:root:test_dl_global number = 157
INFO:root:client_idx = 0, local_sample_number = 18266


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 0, batch_num_train_local = 286, batch_num_test_local = 53
INFO:root:client_idx = 1, local_sample_number = 14179


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 1, batch_num_train_local = 222, batch_num_test_local = 53
INFO:root:client_idx = 2, local_sample_number = 17555


Files already downloaded and verified
Files already downloaded and verified


INFO:root:seed: 0!!!!!!!
INFO:root:client_idx = 2, batch_num_train_local = 275, batch_num_test_local = 53
Initializing clients...
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mslimfun[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


None
None
None
None
params.data.shape: torch.Size([64, 3, 3, 3]); mask.shape: torch.Size([64, 3, 3, 3])
params.data.shape: torch.Size([128, 64, 3, 3]); mask.shape: torch.Size([128, 64, 3, 3])
params.data.shape: torch.Size([256, 128, 3, 3]); mask.shape: torch.Size([256, 128, 3, 3])
params.data.shape: torch.Size([256, 256, 3, 3]); mask.shape: torch.Size([256, 256, 3, 3])
params.data.shape: torch.Size([512, 256, 3, 3]); mask.shape: torch.Size([512, 256, 3, 3])
params.data.shape: torch.Size([512, 512, 3, 3]); mask.shape: torch.Size([512, 512, 3, 3])
params.data.shape: torch.Size([512, 512, 3, 3]); mask.shape: torch.Size([512, 512, 3, 3])
params.data.shape: torch.Size([512, 512, 3, 3]); mask.shape: torch.Size([512, 512, 3, 3])
params.data.shape: torch.Size([10, 512]); mask.shape: torch.Size([10, 512])
server model param size: 295395648
No post init specified in PreCrop
client 0 output_channels: [64, 127, 94, 203, 59, 322, 37, 512]
client 0 layer_ratio [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0

In [8]:
# ['features.0.weight', 'features.0.bias', 'features.1.weight', 'features.1.bias', 'features.1.running_mean', 'features.1.running_var', 'features.4.weight', 'features.4.bias', 'features.5.weight', 'features.5.bias', 'features.5.running_mean', 'features.5.running_var', 'features.8.weight', 'features.8.bias', 'features.9.weight', 'features.9.bias', 'features.9.running_mean', 'features.9.running_var', 'features.11.weight', 'features.11.bias', 'features.12.weight', 'features.12.bias', 'features.12.running_mean', 'features.12.running_var', 'features.15.weight', 'features.15.bias', 'features.16.weight', 'features.16.bias', 'features.16.running_mean', 'features.16.running_var', 'features.18.weight', 'features.18.bias', 'features.19.weight', 'features.19.bias', 'features.19.running_mean', 'features.19.running_var', 'features.22.weight', 'features.22.bias', 'features.23.weight', 'features.23.bias', 'features.23.running_mean', 'features.23.running_var', 'features.25.weight', 'features.25.bias', 'features.26.weight', 'features.26.bias', 'features.26.running_mean', 'features.26.running_var', 'classifier.fc1.weight', 'classifier.fc1.bias']
# 28211 50000
# debug_info['features.11.weight'][:204,:95,:,:]
# model_list, client_idx, server_params = debug_info
training_nums = debug_info



In [11]:
for k, v in training_nums.items():
    print(f'{k}: {v.shape}')

features.0.weight: torch.Size([64, 3, 3, 3])
features.0.bias: torch.Size([64])
features.1.weight: torch.Size([64])
features.1.bias: torch.Size([64])
features.1.running_mean: torch.Size([64])
features.1.running_var: torch.Size([64])
features.4.weight: torch.Size([128, 64, 3, 3])
features.4.bias: torch.Size([128])
features.5.weight: torch.Size([128])
features.5.bias: torch.Size([128])
features.5.running_mean: torch.Size([128])
features.5.running_var: torch.Size([128])
features.8.weight: torch.Size([205, 128, 3, 3])
features.8.bias: torch.Size([205])
features.9.weight: torch.Size([205])
features.9.bias: torch.Size([205])
features.9.running_mean: torch.Size([205])
features.9.running_var: torch.Size([205])
features.11.weight: torch.Size([203, 205, 3, 3])
features.11.bias: torch.Size([203])
features.12.weight: torch.Size([203])
features.12.bias: torch.Size([203])
features.12.running_mean: torch.Size([203])
features.12.running_var: torch.Size([203])
features.15.weight: torch.Size([304, 203, 3

In [22]:
training_nums['features.22.weight'][37,0,3,3]

Error in callback <function _WandbInit._resume_backend at 0x7f57b3883050> (for pre_run_cell):


Exception: The wandb backend process has shutdown

IndexError: index 3 is out of bounds for dimension 2 with size 3

Error in callback <function _WandbInit._pause_backend at 0x7f57c6ab3c20> (for post_run_cell):


Exception: The wandb backend process has shutdown

In [None]:
# model_list[3][2]
client_idx

# 1/0

In [None]:
def weight_scale(model_params):
    for k, p in model_params.items():
        if len(p.shape) == 0:
            continue
        print(f'{k}: {p.shape}, {p.mean()}')
        

model = torch.load('client_model_9.pt')

weight_scale(model.state_dict())
# weight_scale(model_list[3])

In [None]:
def flatten_model(model_params):
    return torch.cat([p.flatten() for p in model_params.values()])
i = 1
for i in range(len(model_list)):
    print(torch.cosine_similarity(flatten_model(model_list[i-1]), flatten_model(model_list[i])))

In [None]:
def exchange_dim(model_params):
    prio_exchange_idx = [0,1,2]
    for k, p in model_params.items():
        if len(p.shape) == 4:
#             exchange_idx = []
#             for i in range(p.shape[0])
            exchange_idx = np.argsort([p[i,:,:,:].mean().cpu() for i in range(p.shape[0])])
#             print(exchange_idx)
            model_params[k] = model_params[k][:,prio_exchange_idx,:,:]
            model_params[k] = model_params[k][exchange_idx,:,:,:]
            prio_exchange_idx = exchange_idx
        elif len(p.shape) == 1:
#             print(p.shape)
#             print(prio_exchange_idx)
            model_params[k] = model_params[k][prio_exchange_idx]
        elif len(p.shape) == 2:
            model_params[k] = model_params[k][:,prio_exchange_idx]
            prio_exchange_idx = [i for i in range(10)]
            
exchange_dim(vgg11_pruned.state_dict())

In [None]:
vgg11_pruned = all_models['VGG11_BN']('cuda:3', output_channels=[64, 128, 95, 204, 60, 322, 38, 512]).to('cuda:3')
# vgg11_pruned.load_state_dict(model_list[0])

In [None]:
weight_scale(model_list[-1])

In [None]:
weight_scale(model_list[5])

In [None]:
vgg11_full = all_models['VGG11_BN']('cuda:3', output_channels=[64, 128, 256, 256, 512, 512, 512, 512]).to('cuda:3')
vgg11_full.load_state_dict(model_list[3])

In [None]:
server_model = all_models['VGG11_BN']('cuda:3', output_channels=[64, 128, 256, 256, 512, 512, 512, 512]).to('cuda:3')
server_model.load_state_dict(server_params)


In [None]:
from data_loader import load_partition_data_cifar10, get_dataloader_test_CIFAR10, get_dataloader
import os

path = os.path.join('..', 'data', args.dataset)
_, test_data_global = get_dataloader(args.dataset, path, 64, 128)

In [None]:
path = os.path.join('..', 'data', args.dataset)
train_data_num, test_data_num, train_data_global, test_data_global, \
    train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
    class_num = load_partition_data_cifar10(args.dataset, path, args.partition_method, args.partition_alpha, args.total_clients, args.batch_size)
    

In [None]:
sum([5971,5081,5248,5944,3876])

In [None]:
from torch import nn

def test(model=None, n_batches=0, test_loader=None):
        '''Evaluate the local model on the local test set.

        model - model to evaluate, or this client's model if None
        n_batches - number of minibatches to test on, or 0 for all of them
        '''

        correct = 0.
        total = 0.
        loss = 0.

        criterion = nn.CrossEntropyLoss().to(model.device)
        model.eval()
        
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(test_loader):
                if i > n_batches and n_batches > 0:
                    break
                inputs = inputs.to(model.device)
                labels = labels.to(model.device)
                outputs = model(inputs)
#                 print(outputs)
                loss += criterion(outputs, labels) * len(labels)
                outputs = torch.argmax(outputs, dim=-1)
                correct += sum(labels == outputs)
                total += len(labels)

        # remove copies if needed
        # if model is not _model:
        #     del _model

        print(f'Test : Accuracy: {correct / total}; Loss: {loss / total}; Total: {total};')

        return correct / total, loss / total

# Test : Accuracy: 0.42100003361701965; Loss: 2.1450095176696777; Total: 1000.0;
# Test : Accuracy: 0.44600000977516174; Loss: 2.139613389968872; Total: 1000.0;
# Test : Accuracy: 0.39000001549720764; Loss: 2.363382577896118; Total: 1000.0;
# Test : Accuracy: 0.40800002217292786; Loss: 2.3720834255218506; Total: 1000.0;
# Test : Accuracy: 0.3930000066757202; Loss: 2.366488456726074; Total: 1000.0;
# Test : Accuracy: 0.4000000059604645; Loss: 2.2676665782928467; Total: 1000.0;
# Test : Accuracy: 0.4180000126361847; Loss: 2.3040504455566406; Total: 1000.0;
# Test : Accuracy: 0.44700002670288086; Loss: 2.065028667449951; Total: 1000.0;
# Test : Accuracy: 0.3760000169277191; Loss: 2.3444175720214844; Total: 1000.0;
# Test : Accuracy: 0.3890000283718109; Loss: 2.352550506591797; Total: 1000.0;
for i in range(10):
    test(vgg11_pruned, test_loader=test_data_local_dict[i])

In [None]:
recons_model = all_models['VGG11_BN']('cuda:3', output_channels=[64, 128, 95, 204, 60, 322, 38, 512]).to('cuda:3')
def reset_weights(net, global_state_dict, output_channels):
    print('===================')
    prio_channel = 3
    local_params = net.state_dict()
    idx = -1
    for k in local_params.keys():
        shape_dim = len(local_params[k].shape)
        # print(local_params[k].shape)
        if shape_dim == 4:
            idx += 1
            # print(f'{local_params[k].shape} == {self.output_channels[idx]}')
            local_params[k].copy_(global_state_dict[k][:output_channels[idx],:prio_channel,:,:])
            prio_channel = output_channels[idx]
        elif shape_dim == 1:
            # print(local_params[k].shape)
            local_params[k].copy_(global_state_dict[k][:output_channels[idx]])
        elif shape_dim == 2:
            # print(f'{local_params[k].shape} == {self.output_channels[idx]}')
            local_params[k].copy_(global_state_dict[k][:,:output_channels[idx]])

reset_weights(recons_model, server_params, output_channels=[64, 128, 95, 204, 60, 322, 38, 512])

for i in range(10):
    test(recons_model, test_loader=test_data_local_dict[i])
    

In [None]:
1/0
def check_masks(masks):
    for m in masks:
        print(m.mean())
def statistic_masks(masks):
    M = torch.cat([m.flatten() for m in masks]).to(torch.int)
    print(M.shape)
    print(torch.bincount(M))
for k, v in debug_info.items():
    print('*' * 10)
    print(k)
    check_masks(v)

In [None]:
masks = server.masks
torch.cat(masks)

In [None]:
server = masks
def count_vote(masks, vote):
    tc = 0.
    keeped = 0.
    for i in range(len(masks)):
        tc += masks[i].numel()
        m = masks[i].clone().detach()
        keeped += torch.sum(torch.where(m >= vote, 1, 0))

    print("keeped: {}, total: {}, pct: {}".format(keeped, tc, keeped/tc))

In [None]:
for i in range(10):
    print('vote {}'.format(i))
    count_vote(server.masks, i)

In [None]:
masks = server.masks
for i in range(len(masks)):
    print(masks[i].shape)
flat_masks = torch.cat([m.flatten() for m in masks])
keep_num = len(flat_masks) * 0.1
threshold, indices = flat_masks.topk(int(keep_num))
global_masks = torch.zeros_like(flat_masks)
global_masks[indices] = 1
print(indices.sort())
# len(indices)
print(global_masks[:25])
print(global_masks.sum())

idx = 0
ms = []
for i in range(len(masks)):
    m = global_masks[idx:idx+masks[i].numel()].reshape(masks[i].size())
    idx += masks[i].numel()
    print(m.shape)
    print(m)


In [None]:
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6,7,8,9])
ts = torch.cat([e.flatten() for e in [a,b,c]])
print(ts)


In [None]:
pruned_c = 0.0
total = 0.0

for name, param in server.model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
# for name in server.model.state_dict():
#     print(name)
server.model = server.model.to('cuda:3')
params = server.model.cpu().state_dict()
print(params['features.0.weight'])

print('***********'*3)

# params['features.0.weight'] = torch.zeros_like(params['features.0.weight'])
params['features.0.weight'][0][0][0] = 2.

print(server.model.state_dict()['features.0.weight'])


print(params['features.0.weight'])

In [None]:
pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

In [None]:
prune_c = 0.
total = 0.
for name,mask in global_model.mask.items():
#     print(mask)
    prune_c += sum(np.where(mask.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += mask.numel()
    
print(prune_c / total)

In [None]:
1 / 0

In [None]:
import copy

pruned_c = 0.0
total = 0.0
for name, param in global_model.state_dict().items():
    a = param.view(-1).to(device='cpu', copy=True).numpy()
    pruned_c +=sum(np.where(a, 0, 1))
    total += param.numel()
print(f'global model zero params: {pruned_c / total}')

gcp_model = copy.deepcopy(global_model)



prune_c = 0.
total = 0.
for name, params in gcp_model.state_dict().items():
#     print(name)
    prune_c += sum(np.where(params.to('cpu', copy=True).view(-1).numpy(), 0, 1))
    total += params.numel()
    
print(prune_c / total)
    
# print('*'*10)
# for name, mask in global_model.mask.items():
#     print(name)

In [None]:
# for name in aggregated_masks:
#     print(name)
#     print(aggregated_masks[name].dtype)

def count_mask(state_dict):
    non_zero = 0.
    total = 0.
    for name in state_dict:
        non_zero += torch.count_nonzero(state_dict[name])
        total += state_dict[name].numel()
    return 1 - non_zero / total
# a = aggregated_masks['features.4.weight']
print(a.shape)
# print(torch.count_nonzero(a))
print(count_mask(aggregated_masks))
mask = torch.where(a>=1, 1, 0)
# print(torch.count_nonzero(mask))
print(count_mask(cl_mask_prarms))

In [None]:
import copy
import torch
keep_masks = {0: [torch.Tensor([0, 1, 0, 1, 1])], 1: [torch.Tensor([1, 1, 0, 0, 1])]}

model_list = [[2, torch.Tensor([1,2,3,4,5])], [3, torch.Tensor([1,2,3,4,5])]]

masks = copy.deepcopy(keep_masks)
for c, m in masks.items():
    for i in range(len(m)):
        m[i] *= model_list[c][0]
    
print(masks)

total_masks = copy.deepcopy(masks)
for c,m in total_masks.items():
    if c!= 0:
        for i in range(len(m)):
            total_masks[0][i] += m[i]
        
print(total_masks)

for c, m in masks.items():
    for i in range(len(m)):
        m[i] /= total_masks[0][i]
        masks[c][i] = torch.where(torch.isnan(m[i]), torch.full_like(m[i], 0), m[i])
#     print(torch.isnan(m))
print(masks)
#     m /= sum(model_list[i][0] for i in range(len(model_list)))

# print(masks[0] * model_list[0][1])
model_list[0][1] *= masks[0]
# print(model_list[0][1])

In [None]:
import numpy as np

list_tensor = [torch.Tensor([1,2]), torch.Tensor([2,3])]
t = torch.stack(list_tensor)
list_tensor *= 2
list_tensor
# print(t)
# a = torch.from_numpy(np.asarray(list_tensor))
# print(a)
# torch.cat(list_tensor, 1)

In [None]:
a = torch.Tensor([2,3]).type(torch.float64)
b = torch.Tensor([1,2]).type(torch.float64)
print(a.dtype == b.dtype)
b.dtype

In [None]:

model = nn.Sequential(
          nn.Linear(6, 2, bias=False),
          nn.Sigmoid(),
        )
input = torch.randn(6)
target = torch.randn(2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.1)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()



output = model(input)
loss = criterion(output, target)
loss.backward()

print(list(model.parameters())[0].grad)
optimizer.zero_grad() 
print(list(model.parameters())[0].grad)


print("before step: ", list(model.parameters()))
optimizer.step()
print("after step: ", list(model.parameters()))

In [None]:
target_keep_ratio = 0.2
num_pruning_steps = 2

keep_ratio_steps = [1 - ((x + 1) * (1 - target_keep_ratio) / num_pruning_steps) for x in range(num_pruning_steps)]
keep_ratio_steps

In [None]:
learning_rate = 0.01
for r in range(100):
    lr = learning_rate * 1/(1 + 0.1 * r)
    print(lr)

In [None]:
import random
bandwidths = []
for i in range(10):
    bandwidths.append(random.uniform(1.5, 10))
print(bandwidths)