-
Notifications
You must be signed in to change notification settings - Fork 303
/
prune_resnet18_cifar10.py
129 lines (116 loc) · 4.93 KB
/
prune_resnet18_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from cifar_resnet import ResNet18
import cifar_resnet as resnet
import torch_pruning as tp
import argparse
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True, choices=['train', 'prune', 'test'])
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--verbose', action='store_true', default=False)
parser.add_argument('--total_epochs', type=int, default=100)
parser.add_argument('--step_size', type=int, default=70)
parser.add_argument('--round', type=int, default=1)
args = parser.parse_args()
def get_dataloader():
train_loader = torch.utils.data.DataLoader(
CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]), download=True),batch_size=args.batch_size, num_workers=2)
test_loader = torch.utils.data.DataLoader(
CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
]),download=True),batch_size=args.batch_size, num_workers=2)
return train_loader, test_loader
def eval(model, test_loader):
correct = 0
total = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
with torch.no_grad():
for i, (img, target) in enumerate(test_loader):
img = img.to(device)
out = model(img)
pred = out.max(1)[1].detach().cpu().numpy()
target = target.cpu().numpy()
correct += (pred==target).sum()
total += len(target)
return correct / total
def train_model(model, train_loader, test_loader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, 0.1)
model.to(device)
best_acc = -1
for epoch in range(args.total_epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img, target = img.to(device), target.to(device)
optimizer.zero_grad()
out = model(img)
loss = F.cross_entropy(out, target)
loss.backward()
optimizer.step()
if i%10==0 and args.verbose:
print("Epoch %d/%d, iter %d/%d, loss=%.4f"%(epoch, args.total_epochs, i, len(train_loader), loss.item()))
model.eval()
acc = eval(model, test_loader)
print("Epoch %d/%d, Acc=%.4f"%(epoch, args.total_epochs, acc))
if best_acc<acc:
torch.save( model, 'resnet18-round%d.pth'%(args.round) )
best_acc=acc
scheduler.step()
print("Best Acc=%.4f"%(best_acc))
def prune_model(model):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, pruned_prob):
weight = conv.weight.detach().cpu().numpy()
out_channels = weight.shape[0]
L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
num_pruned = int(out_channels * pruned_prob)
prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
plan = DG.get_pruning_plan(conv, tp.prune_conv, prune_index)
plan.exec()
block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, resnet.BasicBlock ):
prune_conv( m.conv1, block_prune_probs[blk_id] )
prune_conv( m.conv2, block_prune_probs[blk_id] )
blk_id+=1
return model
def main():
train_loader, test_loader = get_dataloader()
if args.mode=='train':
args.round=0
model = ResNet18(num_classes=10)
train_model(model, train_loader, test_loader)
elif args.mode=='prune':
previous_ckpt = 'resnet18-round%d.pth'%(args.round-1)
print("Pruning round %d, load model from %s"%( args.round, previous_ckpt ))
model = torch.load( previous_ckpt )
prune_model(model)
print(model)
params = sum([np.prod(p.size()) for p in model.parameters()])
print("Number of Parameters: %.1fM"%(params/1e6))
train_model(model, train_loader, test_loader)
elif args.mode=='test':
ckpt = 'resnet18-round%d.pth'%(args.round)
print("Load model from %s"%( ckpt ))
model = torch.load( previous_ckpt )
params = sum([np.prod(p.size()) for p in model.parameters()])
print("Number of Parameters: %.1fM"%(params/1e6))
acc = eval(model, test_loader)
print("Acc=%.4f\n"%(acc))
if __name__=='__main__':
main()