In [15]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
import argparse
import math

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((64, 64)),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

batch_size = 1024

trainset = CIFAR10(root='./data', train=True, download=True, transform=test_transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

testset = CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)

# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [16]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(2, 2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(128, 256, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(2, 2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 512, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(),
            nn.Dropout(0.4),
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)

        x = torch.flatten(x, 1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = self.fc4(x)
        return x

net = Net()
PATH = './cifar_net.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load(PATH))
net.to(device)

Net(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): Dropout(p=0.4, inplace=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (fc1

In [17]:
import math

# list(net.layer1.parameters())[1].view(-1).sort()
# p = 10

# flatten = list(net.layer1.parameters())[0].view(-1)


# n = list(net.layer1.parameters())[0]
# sorted_flatten = flatten.sort()
# # print(sorted_flatten)

# flatten_length = len(flatten)
# num_of_prune = math.ceil(flatten_length * p / 100)
# boundary = sorted_flatten[0][num_of_prune].item()
# print(math.ceil(flatten_length * p / 100))
# list(net.named_parameters())

In [18]:
x = torch.tensor([[1,2],
                [3,4]])
print(x != 2)
print(x := torch.where(x != 2, x, 0))
print(x)

tensor([[ True, False],
        [ True,  True]])
tensor([[1, 0],
        [3, 4]])
tensor([[1, 0],
        [3, 4]])


In [19]:
# sorted_flatten[0][num_of_prune].item()

In [20]:
# print(torch.where(n >= boundary, n, torch.FloatTensor([0])))

In [21]:
# n

In [22]:
# list(net.layer1.parameters())[0]

In [23]:
p = 10

target = list(net.layer1.parameters())[0]
flatten = target.view(-1)
flatten_length = len(flatten)
# sorted_flatten = flatten.sort()
num_of_prune = math.ceil(flatten_length * p / 100)
boundary = flatten.sort()[0][num_of_prune].item()

# list(net.layer1.parameters())[0] = torch.where(list(net.layer1.parameters())[0] >= boundary, list(net.layer1.parameters())[0], torch.FloatTensor([0]))
# print(list(net.layer1.parameters())[0])

In [24]:
# print(sorted_flatten[0])
print(target.shape)
# print(len(sorted_flatten))
print(boundary)
target = torch.where(target >= boundary, target, torch.cuda.FloatTensor([0]))
# print(target)

torch.Size([32, 3, 3, 3])
-0.1635768711566925


In [25]:
net.state_dict()['layer1.0.weight'].data.copy_(target)

tensor([[[[ 0.0000e+00, -1.3643e-01,  0.0000e+00],
          [ 1.3870e-02,  2.6924e-02,  0.0000e+00],
          [ 2.3314e-01,  1.4097e-01,  1.9443e-01]],

         [[-5.2426e-02,  2.3522e-02, -3.7110e-02],
          [-1.8182e-02,  2.3792e-02,  5.1683e-02],
          [ 3.6065e-02, -1.3359e-02,  1.7108e-01]],

         [[ 3.6860e-02,  1.2776e-01, -5.0086e-02],
          [ 6.3916e-02,  2.1190e-02, -3.3019e-02],
          [-1.1552e-01,  1.3519e-01, -1.1535e-01]]],


        [[[-9.3218e-02,  4.8015e-02, -2.3696e-02],
          [ 8.6984e-03, -1.4886e-01, -1.4329e-01],
          [ 2.6592e-01,  2.2437e-01, -1.2396e-01]],

         [[ 6.8770e-02, -1.2386e-02,  0.0000e+00],
          [ 1.7918e-01, -7.4675e-02,  1.8475e-02],
          [ 1.7481e-01, -9.4195e-02, -7.6542e-03]],

         [[-1.1610e-01,  1.1347e-01,  4.0452e-02],
          [ 2.2952e-01, -1.3286e-01, -7.3434e-02],
          [ 6.4991e-02, -1.1266e-01, -6.8987e-02]]],


        [[[-1.3823e-02,  9.3859e-02,  1.0216e-01],
          [ 8.6

In [26]:
# print(net.state_dict()['layer1.0.weight'].data)
# target
# net.state_dict()

net.state_dict()

OrderedDict([('layer1.0.weight',
              tensor([[[[ 0.0000e+00, -1.3643e-01,  0.0000e+00],
                        [ 1.3870e-02,  2.6924e-02,  0.0000e+00],
                        [ 2.3314e-01,  1.4097e-01,  1.9443e-01]],
              
                       [[-5.2426e-02,  2.3522e-02, -3.7110e-02],
                        [-1.8182e-02,  2.3792e-02,  5.1683e-02],
                        [ 3.6065e-02, -1.3359e-02,  1.7108e-01]],
              
                       [[ 3.6860e-02,  1.2776e-01, -5.0086e-02],
                        [ 6.3916e-02,  2.1190e-02, -3.3019e-02],
                        [-1.1552e-01,  1.3519e-01, -1.1535e-01]]],
              
              
                      [[[-9.3218e-02,  4.8015e-02, -2.3696e-02],
                        [ 8.6984e-03, -1.4886e-01, -1.4329e-01],
                        [ 2.6592e-01,  2.2437e-01, -1.2396e-01]],
              
                       [[ 6.8770e-02, -1.2386e-02,  0.0000e+00],
                        [ 1.7918e-01, -7.4

In [311]:
correct = 0
total = 0

with torch.no_grad():
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Trainset Acc: %d %%' % (100 * correct / total))

Trainset Acc: 82 %


In [312]:
correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Testset Acc: %d %%' % (100 * correct / total))

Testset Acc: 78 %


In [31]:
layer_list = ['layer1.0.weight','layer1.3.weight',
              'layer2.0.weight','layer2.3.weight',
              'layer3.0.weight','layer3.3.weight',
              'fc1.weight','fc2.weight','fc3.weight','fc4.weight']

print(type(net.state_dict()[layer_list[0]].data))
print(type(net.state_dict()[layer_list[0]]))
# list(net.layer1.parameters())[0] == net.state_dict()[layer_list[0]]

<class 'torch.Tensor'>
<class 'torch.Tensor'>


tensor([[[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [Tru