In [2]:
#imports
import torch
import numpy as np
import pandas as pd
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset
import time
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import datetime
import copy
import os
import pickle
from tqdm import tqdm
from models import vgg

# setting

In [3]:
#mean and std of cifar100 dataset
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

BATCH_SIZE = 16
WARM = False # typically used in new training

DATE_FORMAT = '%A_%d_%B_%Y_%Hh_%Mm_%Ss'
#time of we run the script
TIME_NOW = datetime.datetime.now().strftime(DATE_FORMAT)
#data settings
subset = False #for local running
k = 10 #number of samples needed to each class in validation set, because we need to split train and validation

#model settings
USE_TENSORBOARD = False
if USE_TENSORBOARD:
    foo = SummaryWriter()
use_gpu = True

#lr scheduler
BASE_LR = 0.001
EPOCH_DECAY = 4
DECAY_WEIGHT = 0.5

DEVICE = 'cpu'
if use_gpu and torch.cuda.is_available():
    DEVICE = 'cuda'

In [4]:
#read files
def unpickle(file):
    
    with open(file, 'rb') as fo:
        dictionary = pickle.load(fo, encoding='bytes')
    return dictionary

In [5]:
def compute_mean_std(cifar100_dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = numpy.dstack([cifar100_dataset[i][1][:, :, 0] for i in range(len(cifar100_dataset))])
    data_g = numpy.dstack([cifar100_dataset[i][1][:, :, 1] for i in range(len(cifar100_dataset))])
    data_b = numpy.dstack([cifar100_dataset[i][1][:, :, 2] for i in range(len(cifar100_dataset))])
    mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
    std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

    return mean, std

In [6]:
#data processing
def reshape_images(data_dict):
    reshaped = data_dict.numpy().reshape(len(data_dict), 1024, 3, order = 'F').reshape(len(data_dict), 32,32,3)
    reshaped_processed = torch.from_numpy(reshaped).float().permute(0, 3, 1, 2)
    return reshaped_processed

In [26]:
def get_training_val_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #cifar100_training = CIFAR100Train(path, transform=transform_train)
    cifar100_training = torchvision.datasets.CIFAR100(root='', train=True, download=True, transform=transform_train)
    
    try:
        random_index = pickle.load(open("random_index.pkl", 'rb'))
    except:
        random_index = np.random.permutation([i for i in range(50000)])
        pickle.dump(random_index, open("random_index.pkl", 'wb'))
    
    train_index = random_index[:45000]
    validation_index = random_index[45000:]
    train_dataset = torch.utils.data.Subset(cifar100_training, train_index)
    validation_dataset = torch.utils.data.Subset(cifar100_training, validation_index)
    
    cifar100_training_loader = DataLoader(
        train_dataset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
    
    cifar100_validation_loader = DataLoader(
        validation_dataset, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader, cifar100_validation_loader

In [27]:
def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #cifar100_training = CIFAR100Train(path, transform=transform_train)
    cifar100_training = torchvision.datasets.CIFAR100(root='', train=True, download=True, transform=transform_train)
    cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader

In [8]:
def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #cifar100_test = CIFAR100Test(path, transform=transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(root='', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader

In [9]:
class ClipWeightCallBack():
    
    def __init__(self):
        self.zeros_mask = None
    
    # on batch begin
    def get_zeros_mask(self, model):
        
        self.zeros_mask= []

        for weights_matrix in model.parameters():
            self.zeros_mask.append(torch.where(weights_matrix == 0, \
                                     torch.zeros(weights_matrix.data.shape).to(DEVICE), \
                                     torch.ones(weights_matrix.data.shape).to(DEVICE)))
    # on batch end
    def apply_zeros_mask(self, model):
        
        for index, weights_matrix in enumerate(model.parameters()):
            weights_matrix.data = weights_matrix.data * self.zeros_mask[index].to(DEVICE)

In [38]:
def train(model, epoch, train_dataloader, optimizer, loss_function, callbacks = None):

    start = time.time()
    model.to(DEVICE)
    model.train()
    # keep track of the zero mask
    if callbacks != None:
        callbacks.get_zeros_mask(model)
    
    for batch_index, (images, labels) in enumerate(train_dataloader):
        '''
        if epoch <= WARM:
            warmup_scheduler.step()
        '''
            
        if use_gpu:
            labels = labels.to(DEVICE)
            images = images.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        if callbacks != None:
            callbacks.apply_zeros_mask(model)
            
        n_iter = (epoch - 1) * len(train_dataloader) + batch_index + 1

        last_layer = list(model.children())[-1]
        for name, para in last_layer.named_parameters():
            if 'weight' in name:
                writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter)
            if 'bias' in name:
                writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter)
                
        if batch_index % 100 == 0:
            print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
                loss.item(),
                optimizer.param_groups[0]['lr'],
                epoch=epoch,
                trained_samples=batch_index * BATCH_SIZE + len(images),
                total_samples=len(train_dataloader.dataset)
            ))
        

        #update training loss for each iteration
        writer.add_scalar('Train/loss', loss.item(), n_iter)

    for name, param in model.named_parameters():
        layer, attr = os.path.splitext(name)
        attr = attr[1:]
        writer.add_histogram("{}/{}".format(layer, attr), param, epoch)

    finish = time.time()

    print('epoch {} training time consumed: {:.2f}s'.format(epoch, finish - start))

In [11]:
def evaluate_model(model, val_dataloader):
    # for validation set or testing set
    start = time.time()
    model.to(DEVICE)
    model.eval()
    
    total_preds = 0
    total_corrects = 0
    
    for batch_index, (images, labels) in enumerate(val_dataloader):
        if use_gpu:
            # labels = labels.to(DEVICE)
            images = images.to(DEVICE)
            
        outputs = model(images)
        _, preds = torch.max(outputs.data, 1)
        
        
        total_preds += len(labels)
        total_corrects += np.sum(preds.cpu().numpy() == labels.numpy())
    
    print("Accuracy is {:.5f}".format(total_corrects/total_preds))
    
    return total_corrects/total_preds

In [28]:
'''
cifar100_training_loader, cifar100_validation_loader = get_training_dataloader(
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers = 4,
    batch_size = BATCH_SIZE,
    shuffle = True
)
'''
cifar100_training_loader = get_training_dataloader(
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers = 4,
    batch_size = BATCH_SIZE,
    shuffle = True
)


cifar100_test_loader = get_test_dataloader(
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers = 4,
    batch_size = BATCH_SIZE,
    shuffle = True
)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
def load_vgg(path):
    model = vgg.vgg16_bn()
    weights = torch.load(path)
    model.load_state_dict(weights)
    model.to(DEVICE)
    
    return model

In [18]:
writer = SummaryWriter(log_dir=os.path.join(
            'logs', 'vgg', TIME_NOW))

INFO:tensorflow:Using local port 19268
INFO:tensorflow:Using local port 19410
INFO:tensorflow:Using local port 23955
INFO:tensorflow:Using local port 20259
INFO:tensorflow:Using local port 20492
INFO:tensorflow:Using local port 24257
INFO:tensorflow:Using local port 16502
INFO:tensorflow:Using local port 23239
INFO:tensorflow:Using local port 15395
INFO:tensorflow:Using local port 20986


In [20]:
# train block #
'''
sgd_optimizer = optim.SGD(model.parameters(), lr= 0.0001, momentum=0.9, weight_decay=5e-4)
crossEntropyLoss_function = nn.CrossEntropyLoss()
train(model, train_dataloader = cifar100_training_loader, epoch = 10, optimizer = sgd_optimizer, \
                                                          loss_function = crossEntropyLoss_function)
'''

'\nsgd_optimizer = optim.SGD(model.parameters(), lr= 0.0001, momentum=0.9, weight_decay=5e-4)\ncrossEntropyLoss_function = nn.CrossEntropyLoss()\ntrain(model, train_dataloader = cifar100_training_loader, epoch = 10, optimizer = sgd_optimizer,                                                           loss_function = crossEntropyLoss_function)\n'

# Pruning

In [22]:
def model_report(model, dataloader):
    # local final score on validation data
    num_zeros = sum([(i.detach().cpu().numpy() == 0).sum() for i in model.parameters()])
    total_parameters = sum([np.prod(i.shape) for i in model.parameters()])
    accuracy = evaluate_model(model, dataloader)
    result = (accuracy + num_zeros/total_parameters)/2
    print("num_zeros / total_parameters ratio is ", num_zeros/total_parameters)
    print("accuracy is ", accuracy)
    print("overall score is ", result)
    return num_zeros/total_parameters, accuracy

In [23]:
def prune_network(model, threshold = 0.01):
    
    # vgg has classifier and features
    for weights_matrix in model.parameters():
        weights_matrix.data = torch.where(torch.abs(weights_matrix.data) >= threshold, \
                                          weights_matrix.data, torch.zeros(weights_matrix.data.shape).to(DEVICE))

In [43]:
def finetune_prune(model, rounds, epoches, train_dataloader, test_dataloader, lr = 0.00001, threshold = 0.001):
    print("Model performance at the beginning...")
    model_report(model, test_dataloader)
    sgd_optimizer = optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = 5e-4)
    crossEntropyLoss_function = nn.CrossEntropyLoss()
    prune_callback = ClipWeightCallBack()
    
    sparsity = []
    accuracy = []
    
    print("Start pruning..")
    for i in range(rounds):
        print("Round {}/{}:".format(i + 1, rounds))
        for epoch in range(1 + epoches):
            prune_network(model)
            train(model, epoch, train_dataloader, sgd_optimizer, crossEntropyLoss_function, callbacks = prune_callback)
            zeros_percentage, test_acc = model_report(model, test_dataloader)
        sparsity.append(zeros_percentage)
        accuracy.append(test_acc)
        pickle.dump([sparsity, accuracy], open("prune_hist/sparsity_accuracy.pkl", "wb"))
    
    print("Done pruning.")
    
    return sparsity, accuracy

In [45]:
model = load_vgg("checkpoints/vgg72.pth")

In [52]:
finetune_prune(model, 10, 4, cifar100_training_loader, cifar100_test_loader, lr = 0.00001, threshold = 0.01)

Pruning threshold:

Sparsity:[0.8850553143641191,
  0.8854645702199087,
  0.8858486316019958,
  0.8862241968313407,
  0.8865842396778212,
  0.8869207049654809,
  0.8872260078936021,
  0.8874967382417068,
  0.8877426562959903,
  0.8879710528726462]
  
Accuracy: [0.7039,
  0.7166,
  0.7147,
  0.7208,
  0.7244,
  0.7241,
  0.7257,
  0.7267,
  0.7249,
  0.7296]

In [49]:
torch.save(model, open("checkpoints/vgg_sparsity8897_acc7296.pth", "wb"))

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
finetune_prune(model, 10, 4, cifar100_training_loader, cifar100_test_loader, lr = 0.00001, threshold = 0.01)