In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

from functools import reduce
import numpy as np
import argparse
import math

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((64, 64)),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

batch_size = 1024

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(2, 2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(128, 256, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(2, 2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 512, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)

        x = torch.flatten(x, 1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = self.fc4(x)
        return x

Files already downloaded and verified
Files already downloaded and verified


In [2]:
net = Net()
PATH = './cifar_net.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load(PATH))
net.to(device)

Net(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (fc1

In [3]:
p = 1
total_W = 0
total_Z = 0

layers = ['layer1.0.weight','layer1.3.weight',
              'layer2.0.weight','layer2.3.weight',
              'layer3.0.weight','layer3.3.weight',
              'fc1.weight','fc2.weight','fc3.weight','fc4.weight']

for layer in layers:
    target = net.state_dict()[layer].data
    flatten = target.view(-1)
    boundary = abs(sorted(flatten, key=lambda a: torch.abs(a))[math.ceil(len(flatten) * p / 100)].item())
    lower = -boundary < target
    upper = target < boundary
    
    target = torch.where(torch.logical_not(torch.logical_and(lower, upper)), target, torch.cuda.FloatTensor([0]))
    net.state_dict()[layer].data.copy_(target)
    
    total_Z += len(net.state_dict()[layer][net.state_dict()[layer].data == 0.0])
    total_W += reduce(lambda x,y: x*y, net.state_dict()[layer].size())

layer1.0.weight
0.0029217947740107775
layer1.3.weight
0.0006999003235250711
layer2.0.weight
0.0005954368971288204
layer2.3.weight
0.0005173255340196192
layer3.0.weight
0.00035453165764920413
layer3.3.weight
0.00021319415827747434
fc1.weight
0.00040379914571531117
fc2.weight
0.0005659299204126
fc3.weight
0.0008856855565682054
fc4.weight
0.004978759214282036


In [4]:
train_correct, train_total = 0, 0
test_correct, test_total = 0, 0

with torch.no_grad():
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
        
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        
print(f'p = {p / 100}')
print(f'Trainset Acc: {100 * train_correct / train_total: .1f}%')
print(f'Testset Acc: {100 * test_correct / test_total: .1f}%')
print(f'Number of Zeros: {total_Z}')
print(f'Number of Weights: {total_W}')
print(f'Pruned Ratio: {total_Z / total_W * 100: .1f}%')

p = 0.01
Trainset Acc:  10.0%
Testset Acc:  10.0%
Number of Zeros: 6815969
Number of Weights: 6884832
Pruned Ratio:  99.0%


In [9]:
x = torch.tensor([[-1,2],
                [3,-4]])
print(x != 2)
print(x := torch.where(x != 2, x, 0))
print(x)

tensor([[ True, False],
        [ True,  True]])
tensor([[-1,  0],
        [ 3, -4]])
tensor([[-1,  0],
        [ 3, -4]])


In [15]:
x = x.view(-1)
print(x)
y = sorted(x, key=lambda a: torch.abs(a))[1].item()
print(abs(y))
# torch.where(-1 <= x, x, torch.FloatTensor([0.0]))
print(-1 <= x)
print(x <= 1)
print(torch.logical_not(torch.logical_and(-abs(y) <= x, x <= abs(y))))
torch.where(torch.logical_not(torch.logical_and(-abs(y) <= x, x <= abs(y))), x, 0)

tensor([-1,  0,  3, -4])
1
tensor([ True,  True,  True, False])
tensor([ True,  True, False,  True])
tensor([False, False,  True,  True])


tensor([ 0,  0,  3, -4])

In [52]:
abs(-3)

3

In [None]:
PATH = './cifar_net_' + str(p) + '.pth'
torch.save(net.state_dict(), PATH)