In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

import warnings
warnings.filterwarnings("ignore")

# CIFAR10

In [5]:
batch_size = 50
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ])

In [16]:
training_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [17]:
batch = next(iter(train_loader))
print(f'Images shape: {batch[0].shape}')
print(f'Labels shape: {batch[1].shape}')

Images shape: torch.Size([50, 3, 32, 32])
Labels shape: torch.Size([50])


In [18]:
#testing
def test_model(model, test_loader):
    t = time.time()
    batch_size = 50
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for images, labels in tqdm(test_loader):
            images = images
            labels = labels
            outputs = model(images).view(batch_size,10)

            _, predictions = torch.max(outputs, 1)
            n_samples += labels.shape[0]
            n_correct += (predictions==labels).sum().item()

    acc = 100.0*n_correct/n_samples
    print(f'accuracy: {acc}')
    print(f'testing time: {time.time()-t:.2f}.sec')

In [19]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    print(f'number of model params: {pp}')
    return pp

# ResNet20

In [20]:
from ResNet import ResNet20
model = ResNet20()
model.load_state_dict(torch.load('model.pth', map_location='cpu'))

<All keys matched successfully>

In [21]:
test_model(model, test_loader)

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:23<00:00,  8.63it/s]

accuracy: 85.81
testing time: 23.17.sec





# First method

In [22]:
from first import FirstPruning

In [23]:
compression_ratio = 0.8
first_model = FirstPruning(model, compression_ratio)

In [13]:
TestModel(first_model, test_loader)

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [01:08<00:00,  2.91it/s]

accuracy: 10.64
testing time: 68.65.sec





# Second method

In [24]:
from second import SecondPruning

In [25]:
compression_ratio = 0.8
second_model = SecondPruning(model, compression_ratio)

In [27]:
test_model(second_model, test_loader)

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [00:15<00:00, 13.24it/s]

accuracy: 13.53
testing time: 15.11.sec



