In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans
from scipy.sparse import csc_matrix, csr_matrix
import pdb
import datetime

import argparse
import datetime

import torch
import torch.optim as optim
from torchvision import datasets, transforms

from cifar_classifier import MaskedCifar
from classifier import Classifier
from mnist_classifier import MaskedMNist
from pruning.methods import weight_prune, prune_rate, get_all_weights
from pruning.utils import to_var
from resnet import MaskedResNet18, MaskedResNet34, MaskedResNet50, MaskedResNet101, MaskedResNet152
from classifier_utils import setup_default_args

from tensorboardX import SummaryWriter

from configurations import configurations


In [2]:
def get_all_weights(model):
    weights = []

    if len(list(model.children())) != 0:
        for l in model.children():
            weights += get_all_weights(l)
    else:
        for p in model.parameters():
            if len(p.data.size()) != 1: # Avoid bias parameters
                weights += list(p.cpu().data.abs().numpy().flatten())

    return weights

def gen_masks_for_layer(model, threshold):
    # generate mask
    for p in model.parameters():
        if len(p.data.size()) != 1:
            pruned_inds = p.data.abs() > threshold
            return pruned_inds.float()
    
def gen_masks_recursive(model, threshold):
    masks = []
    
    for module in model.children():
        if 'Masked' not in str(type(module)):
            print("Skipping masking of layer: ", module)
            continue
        if len(list(module.children())) != 0:
            masks.append(gen_masks_recursive(module, threshold))
        else:
            masks.append(gen_masks_for_layer(module, threshold))
    
    return masks

def weight_prune(model, pruning_perc):
    '''
    Prune pruning_perc% weights globally (not layer-wise)
    arXiv: 1606.09274
    '''    
    all_weights = get_all_weights(model)
    threshold = np.percentile(np.array(all_weights), pruning_perc)
    return gen_masks_recursive(model, threshold)

def prune_rate(model, verbose=True):
    """
    Print out prune rate for each layer and the whole network
    """
    total_nb_param = 0
    nb_zero_param = 0

    layer_id = 0

    for parameter in model.parameters():

        param_this_layer = 1
        for dim in parameter.data.size():
            param_this_layer *= dim
        total_nb_param += param_this_layer

        # only pruning linear and conv layers
        if len(parameter.data.size()) != 1:
            layer_id += 1
            zero_param_this_layer = \
                np.count_nonzero(parameter.cpu().data.numpy()==0)
            nb_zero_param += zero_param_this_layer

            if verbose:
                print("Layer {} | {} layer | {:.2f}% parameters pruned" \
                    .format(
                        layer_id,
                        'Conv' if len(parameter.data.size()) == 4 \
                            else 'Linear',
                        100.*zero_param_this_layer/param_this_layer,
                        ))
    pruning_perc = 100.*nb_zero_param/total_nb_param
    if verbose:
        print("Final pruning rate: {:.2f}%".format(pruning_perc))
    return pruning_perc


In [3]:
    config = [x for x in configurations if x['name'] == 'FCCifar10Classifier'][0]

    model = config['model']()

    device = 'cuda:2'

    train_data = test_data = config['dataset'](
        './data', train=True, download=True, transform=transforms.Compose(config['transforms'])
    )

    test_data = config['dataset'](
        './data', train=False, download=True, transform=transforms.Compose(config['transforms'])
    )

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True, num_workers=1, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=8, shuffle=True, num_workers=1, pin_memory=True)
    optimizer = config['optimizer'](model.parameters(), lr=0.01, momentum=0.5)
    
    wrapper = Classifier(model, device, train_loader, test_loader)

    model.load_state_dict(torch.load('./models/cifar_classifier.pt'))


Files already downloaded and verified
Files already downloaded and verified


In [4]:
def quantize_k_means(model, bits=5):
    for module in model.children():
        if 'weight' not in dir(module):
            continue

        dev = module.weight.device
        weight = module.weight.data.cpu().numpy()
        original_shape = weight.shape
        weight = np.reshape(weight, (2, -1))
        mat = csr_matrix(weight)
        min_ = min(mat.data)
        max_ = max(mat.data)
        space = np.linspace(min_, max_, num=2**bits)
        kmeans = KMeans(n_clusters=len(space), init=space.reshape(-1,1), n_init=1, precompute_distances=True, algorithm="full")
        kmeans.fit(mat.reshape(-1,1))

        weight = kmeans.cluster_centers_[kmeans.labels_].reshape(original_shape)
        weight_tensor = torch.tensor(weight, requires_grad=True).to(dev)

        # Use register_hooks to recalculate the gradients
        module.weight.data = weight_tensor
        module.weight.register_hook(gen_param_hook(torch.from_numpy(kmeans.labels_), dev))

def get_where_val_is_zero(array):
    for i, val in enumerate(array):
        if val == 0:
            return i
    
    return None

        
def gen_param_hook(c_labels, dev):
    
    def hook(grad):
        # print(f"Retraining start time {datetime.datetime.now()}")
        grad_original_shape = grad.shape
        reshape_start_time = datetime.datetime.now()
        grads = grad.reshape(-1, 1)
        reshape_end_time = datetime.datetime.now()

#         print(f"Reshape took: {reshape_end_time - reshape_start_time}")

        update_values = torch.tensor(np.zeros(shape=(len(c_labels), 1)), dtype=torch.float).to(dev)
        start_time = datetime.datetime.now()

        enumartion_start_time = datetime.datetime.now()

        for i, g in enumerate(grads):
            cluster_id = c_labels[i].item()
            update_values[cluster_id] += g
        
        
        enumeration_end_time = datetime.datetime.now()

        print(f"Enumeration time took: {enumeration_end_time - enumartion_start_time}")

        updated_grads = torch.tensor(np.zeros(len(grads)), dtype=torch.float).to(dev)
        
        # For each grad
        # Find the c_label
        # Find all the grads with the same c_
        
        for i in range(len(grads)):
            cluster_id = c_labels[i].item()
            grad = update_values[c_labels[i].item()]
            updated_grads[i] = grad
            
        updated_grads = updated_grads.reshape(grad_original_shape)

        # print(f"Retrain end time {datetime.datetime.now()}")
        end_time = datetime.datetime.now()
        print(f"Weight vector with {i} gradients took {end_time - start_time} to cluster gradient updates.")

        return updated_grads
    
    return hook

In [5]:
quantize_k_means(model)

In [6]:
optimizer = config['optimizer'](model.parameters(), lr=0.01, momentum=0.5)
    
wrapper.train(10, optimizer, 1, config['loss_fn'])








Test set: Average loss: 1.1928, Accuracy: 5764/10000 (58%)



(0.5764,
 OrderedDict([('conv1.weight',
               tensor([[[[-0.1792, -0.3007, -0.1634,  0.1001, -0.0307],
                         [-0.0394, -0.0261,  0.1543,  0.1128, -0.0070],
                         [ 0.0913,  0.1881,  0.1692, -0.0329, -0.0872],
                         [ 0.1587,  0.2333,  0.0790,  0.0762,  0.0477],
                         [ 0.1308,  0.0547, -0.0104, -0.0324,  0.1052]],
               
                        [[-0.0994, -0.1031, -0.0246, -0.0360,  0.1676],
                         [-0.0376,  0.0994,  0.1251, -0.0402, -0.1579],
                         [-0.0198,  0.2302,  0.0558, -0.0179, -0.2188],
                         [ 0.2220,  0.0784,  0.0726, -0.1091, -0.1700],
                         [ 0.0931, -0.0725, -0.0874, -0.1331, -0.0165]],
               
                        [[-0.1261, -0.0620,  0.1632,  0.1851,  0.2982],
                         [ 0.0472,  0.0953,  0.1070,  0.2150,  0.0227],
                         [-0.0247,  0.2783,  0.2085,  0.0753, 