# Parameter Histograms

This notebook loads a model and draws the histograms of the parameters tensors.

In [None]:
import torch
import torch.nn as nn
import scipy.stats as ss
import numpy as np
import matplotlib.pyplot as plt
import distiller
import distiller.models as models

plt.style.use('seaborn') # pretty matplotlib plots

## Load your model

In [None]:
# It is interesting to compare the distribution of non-pretrained model (Normally-distributed)
# vs. the distribution of the pretrained model.
model = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50',
                            device_ids=-1)  # load to CPU

# Optionally load your compressed model 
# distiller.apputils.load_checkpoint(model, <path-to-your-checkpoint-file>)

## Plot the distributions

We plot the distributions of the weights of each convolution layer, and we also plot the fitted Gaussian and Laplacian distributions.

In [None]:
def getSparsity(x):
    return 1 - (x[x.nonzero()].size / x.size)

REMOVE_ZEROS = False
nbins = 500
for name, weights in model.named_parameters():
    if weights.dim() != 4:
        # not convolution layer
        continue

    shape_str = "x".join(map(str, weights.shape))
    weights = weights.cpu().detach().numpy().flatten()
    sparsity = getSparsity(weights)

    if REMOVE_ZEROS:
        # Optionally remove zeros (lots of zeros will dominate the histogram and the 
        # other data will be hard to see
        weights = weights[weights.nonzero()]

    # Fit the data to the Normal distribution
    (mean_fitted, std_fitted) = ss.norm.fit(weights)
    x = np.linspace(min(weights), max(weights), nbins)
    weights_gauss_fitted = ss.norm.pdf(x, loc=mean_fitted, scale=std_fitted)

    # Fit the data to the Laplacian distribution
    (mean_fitted, std_fitted) = ss.laplace.fit(weights)
    weights_laplace_fitted = ss.laplace.pdf(x, loc=mean_fitted, scale=std_fitted)

    n, bins, patches = plt.hist(weights, histtype='stepfilled', 
                                cumulative=False, bins=nbins, density=True)
    plt.plot(x, weights_gauss_fitted, label='gauss')
    plt.plot(x, weights_laplace_fitted, label='laplace')

    plt.title(name + " - " + shape_str + (
        ' - sparsity: {:.0%}'.format(sparsity) if REMOVE_ZEROS else ''))
    #plt.figure(figsize=(10,5))
    plt.legend()
    plt.show()