In [1]:
import math
import torch.nn as nn
import torch.nn.init as init

import argparse
import os
import shutil
import time
from tqdm import tqdm
import gc
import pickle, numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from cifar10tools import make_model, get_model_filters_config, test

In [2]:
class Individual():
    def __init__(self, bitstring) -> None:
        self.bitstring = bitstring
        self.metrics = None
    
    def mutate(self):
        mp_list = [0.0002]
        mutation_prob = mp_list[np.random.randint(len(mp_list))]
        self.metrics = None
        self.bitstring ^= np.random.binomial(1, mutation_prob, len(self.bitstring))
   
max_indiv = max(pickle.load(open('../PruningAlgo/checkpoints/pruned_population_ckpt/last_gen.p', 'rb')), key=lambda x: x.metrics['score'])

In [3]:
class VGG(nn.Module):
    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        inchannels = self.features.state_dict()['28.weight'].shape[0]
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(inchannels, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10),
        )
        # Initialize weights (kernels with normal randoms, bias with 0s)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
}

base_model = VGG(make_layers(
    cfg['D'] ## vgg16
))
base_model.load_state_dict(torch.load('../PruningAlgo/models/vgg16_cifar10_base_model'))
base_model = base_model.to('cuda')

In [4]:
pruned_model = make_model(max_indiv.bitstring, base_model)
pruned_model = pruned_model.to('cuda')

batch_size = 512
workers = 4

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ]), download=True),
    batch_size=batch_size, shuffle=True,
    num_workers=workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True)

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.001, momentum=0.9)
# optimizer = torch.optim.Adagrad(pruned_model.parameters(), lr=0.001)

Files already downloaded and verified


In [5]:
num_epochs = 400

start_time = time.time()
for epoch in range(num_epochs):

    pruned_model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):

        features = features.to('cuda')
        targets = targets.to('cuda')

        ## forward + backprop + loss
        logits = pruned_model(features)
        cost = criterion(logits, targets)
        optimizer.zero_grad()

        cost.backward()

        ### UPDATE MODEL PARAMETERS
        optimizer.step()

        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
                   %(epoch+1, num_epochs, batch_idx,
                     len(train_loader), cost))

    # model.eval()
    # with torch.set_grad_enabled(False): # save memory during inference
    #     print('Epoch: %03d/%03d | Train: %.3f%% | Loss: %.3f' % (
    #           epoch+1, num_epochs,
    #           compute_accuracy(model, train_loader),
    #           compute_epoch_loss(model, train_loader)))


    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

Epoch: 001/400 | Batch 0000/0098 | Cost: 0.7450
Epoch: 001/400 | Batch 0050/0098 | Cost: 0.3057
Time elapsed: 0.11 min
Epoch: 002/400 | Batch 0000/0098 | Cost: 0.2658
Epoch: 002/400 | Batch 0050/0098 | Cost: 0.2369
Time elapsed: 0.17 min
Epoch: 003/400 | Batch 0000/0098 | Cost: 0.2216
Epoch: 003/400 | Batch 0050/0098 | Cost: 0.1745
Time elapsed: 0.24 min
Epoch: 004/400 | Batch 0000/0098 | Cost: 0.2016
Epoch: 004/400 | Batch 0050/0098 | Cost: 0.1681
Time elapsed: 0.31 min
Epoch: 005/400 | Batch 0000/0098 | Cost: 0.1314
Epoch: 005/400 | Batch 0050/0098 | Cost: 0.1925
Time elapsed: 0.38 min
Epoch: 006/400 | Batch 0000/0098 | Cost: 0.1509
Epoch: 006/400 | Batch 0050/0098 | Cost: 0.2230
Time elapsed: 0.44 min
Epoch: 007/400 | Batch 0000/0098 | Cost: 0.1838
Epoch: 007/400 | Batch 0050/0098 | Cost: 0.1932
Time elapsed: 0.51 min
Epoch: 008/400 | Batch 0000/0098 | Cost: 0.1225
Epoch: 008/400 | Batch 0050/0098 | Cost: 0.1488
Time elapsed: 0.58 min
Epoch: 009/400 | Batch 0000/0098 | Cost: 0.1356


In [8]:
test(val_loader, [pruned_model, base_model])

[87.77000000000001, 86.87]

In [11]:
pickle.dump((pruned_model,max_indiv),open('../PruningAlgo/models/pruned_models/pruned_vgg16_cifar10_model','wb'))

In [7]:
import pickle
a,b,c,d = pickle.load(open('../PruningAlgo/checkpoints/pruned_plot_ckpt/last_plot.p','rb'))
print(len(c*2))

830
