In [1]:
import torchvision
import numpy as np
from Torch_Pruning.examples.cifar_minimal.cifar_resnet import ResNet18
import Torch_Pruning.examples.cifar_minimal.cifar_resnet as resnet

import Torch_Pruning.torch_pruning as tp
import argparse
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn 
import numpy as np 
import os

In [2]:
def get_dataloader():
    train_loader = torch.utils.data.DataLoader(
        CIFAR10('./data', train=True, transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]), download=True),batch_size=256, num_workers=2)
    test_loader = torch.utils.data.DataLoader(
        CIFAR10('./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]),download=True),batch_size=256, num_workers=2)
    return train_loader, test_loader

def eval(model, test_loader):
    correct = 0
    total = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        for i, (img, target) in enumerate(test_loader):
            img = img.to(device)
            out = model(img)
            pred = out.max(1)[1].detach().cpu().numpy()
            target = target.cpu().numpy()
            correct += (pred==target).sum()
            total += len(target)
    return correct / total

def train_model(model, train_loader, test_loader, epochs, round_num):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 70, 0.1)
    model.to(device)

    acc_s = []
    best_acc = -1
    for epoch in range(epochs):
        model.train()
        for i, (img, target) in enumerate(train_loader):
            img, target = img.to(device), target.to(device)
            optimizer.zero_grad()
            out = model(img)
            loss = F.cross_entropy(out, target)
            loss.backward()
            optimizer.step()
#             if i%10==0 and args.verbose:
#                 print("Epoch %d/%d, iter %d/%d, loss=%.4f"%(epoch, args.total_epochs, i, len(train_loader), \
# loss.item()))
        model.eval()
        acc = eval(model, test_loader)
#         print("Epoch %d/%d, Acc=%.4f"%(epoch, epochs, acc))
        acc_s.append(acc)
        if best_acc<acc:
            torch.save( model, 'resnet18-round%d.pth'%(round_num) )
            best_acc=acc
        scheduler.step()
#     print("Best Acc=%.4f"%(best_acc))
    return best_acc, np.mean(acc_s)

def prune_model(model, strategy_name='l1'):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, amount=0.2, strategy_name='l1'):
        strategy = tp.strategy.L1Strategy() if strategy_name == 'l1' else tp.strategy.RandomStrategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv_out_channel, pruning_index)
        plan.exec()
    
    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance(m, resnet.BasicBlock):
            prune_conv(m.conv1, block_prune_probs[blk_id], strategy_name)
            prune_conv(m.conv2, block_prune_probs[blk_id], strategy_name)
            blk_id+=1
    print(model)
    return model 

In [3]:
def count_params(model):
    params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Number of Parameters: %.1fM"%(params/1e6))
    return params

In [4]:
train_loader, test_loader = get_dataloader()

Files already downloaded and verified
Files already downloaded and verified


In [5]:
model = ResNet18(num_classes=10)

In [6]:
model_dict = {
    'best': [],
    'mean': [],
    'params': [],
    'size': [],
    'out_ch': []
}

In [None]:
for i in range(5):
    print(f'------ Round {i} ------')
    if i > 0:
        prune_model(model, 'l1')
    param = count_params(model)
    best, mean = train_model(model, train_loader, test_loader, 30 if i == 0 else 30, i)
    model_dict['best'].append(best)
    model_dict['mean'].append(mean)
    model_dict['params'].append(param)
    model_dict['size'].append(os.path.getsize(f'resnet18-round{i}.pth') / 1024 / 1024)    

    ch = []
    for a in model.modules():
        if isinstance(a, torch.nn.modules.conv.Conv2d):
            ch.append(a.weight.shape[0])
    model_dict['out_ch'].append(ch)

------ Round 0 ------
Number of Parameters: 11.2M
------ Round 1 ------
ResNet(
  (conv1): Conv2d(3, 53, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(53, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(58, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(53, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(53, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(58, 53, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), b

In [None]:
model_dict