In [1]:
# some basic imports and setups
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import matplotlib.pyplot as plt
%matplotlib inline
import urllib
from PIL import Image
from torchvision import transforms

  warn(f"Failed to load image Python extension: {e}")


In [54]:
#load model - AlexNet
from torch.utils.model_zoo import load_url as load_state_dict_from_url
model_url = 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'

class AlexNet_custom(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet_custom, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) # in case input image is larger
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

# load pretrained weights
def alexnet_custom(path = "model_weights/pretrained_AlexNet.pt", pretrained=True, progress=True, **kwargs):
    model = AlexNet_custom(**kwargs)
    net = torch.jit.load(path)
    params = net.state_dict()
#     state_dict = load_state_dict_from_url(model_url, progress=progress)
#     model.load_state_dict(state_dict)
    model.load_state_dict(params)
    return model

net = alexnet_custom(path = "model_weights/finetuned_AlexNet.pt") #pretrained on ImageNet - 1000 classes


In [37]:
def get_min_max(net):
    state_dict = net.state_dict()
    weights = list(state_dict.values())
    mini=0; maxi=0
    for i in weights:
        if torch.min(i)<mini:
            mini=torch.min(i)
        if torch.max(i)>maxi:
            maxi=torch.max(i)
    return (mini, maxi)

In [45]:
def get_threshold(p,mini,maxi):
    k = np.arange(mini, maxi, 0.001)
    return np.quantile(k, p)
    

In [47]:
def magnitude_sparsity(net, thr):
    sd = net.state_dict()
    thr = -1.506125
    for k in sd.keys():
        w = sd[k]
        sd[k] = w * (w > thr)
    net.load_state_dict(sd)
    return net

In [62]:
def random_sparsity(net, prob):
    sd = net.state_dict()
    val = np.random.choice(2, 1, p=[prob, 1-prob])
    for k in sd.keys():
        w = sd[k]
        sd[k] = w * val
    net.load_state_dict(sd)
    return net

In [63]:
get_min_max(net)
get_threshold(0.15,mini,maxi)
net = magnitude_sparsity(net, thr)

In [66]:
net = random_sparsity(net, 0.15)

In [67]:
#Evaluate
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=True, num_workers=2)

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

correct = 0
total = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):    #test on 1000 images
        outputs = net(data[0])
        _, predicted = torch.max(outputs.data, 1)
        total += data[1].size(0)
        correct += (predicted == data[1]).sum().item()
        if i==1000:
            break
print('Accuracy of the network on the 1000 test images: %.2f %%' % (100 * correct / total))

Files already downloaded and verified
Accuracy of the network on the 1000 test images: 58.22 %


In [None]:
model_scripted = torch.jit.script(net) # Export to TorchScript
model_scripted.save('./model_weights/sparse_AlexNet.pt')