In [1]:
from mpl.models import *

In [2]:
model = vgg16()



In [3]:
print("Before pruning:")
model.calc_num_prunable_params(display=True)

Before pruning:
Layer name: features.0. remaining/all: 1792/1792 = 1.0
Layer name: features.2. remaining/all: 36928/36928 = 1.0
Layer name: features.5. remaining/all: 73856/73856 = 1.0
Layer name: features.7. remaining/all: 147584/147584 = 1.0
Layer name: features.10. remaining/all: 295168/295168 = 1.0
Layer name: features.12. remaining/all: 590080/590080 = 1.0
Layer name: features.14. remaining/all: 590080/590080 = 1.0
Layer name: features.17. remaining/all: 1180160/1180160 = 1.0
Layer name: features.19. remaining/all: 2359808/2359808 = 1.0
Layer name: features.21. remaining/all: 2359808/2359808 = 1.0
Layer name: features.24. remaining/all: 2359808/2359808 = 1.0
Layer name: features.26. remaining/all: 2359808/2359808 = 1.0
Layer name: features.28. remaining/all: 2359808/2359808 = 1.0
Layer name: classifier.0. remaining/all: 102764544/102764544 = 1.0
Layer name: classifier.3. remaining/all: 16781312/16781312 = 1.0
Layer name: classifier.6. remaining/all: 4097000/4097000 = 1.0
Total: re

(138357544, 138357544)

In [4]:
def calcScaleZeroPoint(min_val, max_val, num_bits=8):
    qmin = 0.
    qmax = 2. ** num_bits - 1.

    # Prevent errors caused by Nan and Inf
    if min_val==max_val and min_val==0:
        scale =(max_val - min_val)+0.001 / (qmax - qmin)
    elif min_val==max_val and min_val!=0:
        scale=max_val/(qmax-qmin)
    else:
        scale =(max_val - min_val) / (qmax - qmin)

    initial_zero_point = (qmin - min_val) / scale

    if initial_zero_point < qmin:
        zero_point = qmin
    elif initial_zero_point > qmax:
        zero_point = qmax
    else:
        zero_point = initial_zero_point

    zero_point = int(zero_point)


    return scale, zero_point

In [5]:
def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    if not min_val and not max_val:
        min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2. ** num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    return q_x.int(),scale,zero_point

In [6]:
def dequantize_tensor(scale,x,zero_point):
    scale=float(scale)
    return scale * (x.float() - zero_point)

In [7]:
print("After pruning:")
model.prune_by_pct([0.5 for _ in model.prunable_layers])
model.calc_num_prunable_params(display=True)

After pruning:
Layer name: features.0. remaining/all: 928/1792 = 0.5178571428571429
Layer name: features.2. remaining/all: 18496/36928 = 0.5008665511265165
Layer name: features.5. remaining/all: 36992/73856 = 0.5008665511265165
Layer name: features.7. remaining/all: 73856/147584 = 0.5004336513443192
Layer name: features.10. remaining/all: 147712/295168 = 0.5004336513443192
Layer name: features.12. remaining/all: 295168/590080 = 0.5002169197396963
Layer name: features.14. remaining/all: 295168/590080 = 0.5002169197396963
Layer name: features.17. remaining/all: 590336/1180160 = 0.5002169197396963
Layer name: features.19. remaining/all: 1180160/2359808 = 0.5001084834020395
Layer name: features.21. remaining/all: 1180160/2359808 = 0.5001084834020395
Layer name: features.24. remaining/all: 1180160/2359808 = 0.5001084834020395
Layer name: features.26. remaining/all: 1180160/2359808 = 0.5001084834020395
Layer name: features.28. remaining/all: 1180160/2359808 = 0.5001084834020395
Layer name: c

(69185487, 138357544)

In [8]:
model.state_dict()

OrderedDict([('features.0.weight',
              tensor([[[[ 0.0321,  0.0757,  0.0288],
                        [-0.0592, -0.0297, -0.1027],
                        [-0.0511,  0.0442,  0.0798]],
              
                       [[-0.0335,  0.0814,  0.0667],
                        [-0.0398,  0.0456,  0.0231],
                        [-0.0987,  0.1050, -0.0408]],
              
                       [[-0.0067,  0.0493,  0.0285],
                        [-0.0273, -0.0525, -0.0165],
                        [-0.0656, -0.0154, -0.0588]]],
              
              
                      [[[ 0.0885,  0.0688, -0.0723],
                        [ 0.0110, -0.0409, -0.0338],
                        [ 0.0590, -0.1153, -0.0557]],
              
                       [[-0.0038, -0.0066, -0.0516],
                        [-0.1402, -0.0308, -0.0270],
                        [-0.0018, -0.0148, -0.0532]],
              
                       [[-0.0641, -0.0330,  0.0097],
                     