In [None]:
from time import time 
from copy import deepcopy

import torch
import numpy as np 

from torch import nn

from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.nn.utils import prune
from torch.utils.data import DataLoader

from torchvision import models
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import ToTensor, Lambda

from tqdm import tqdm
from scipy.stats import ttest_ind

# Note: I ran this on a GPU instance
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Loading the data

In [None]:
def load_CIFAR10(batch_size: int = 128, shuffle: bool = True): 
    """Load CIFAR10 data from memory"""
    
    try: 
        train = CIFAR10(
            root = "data", 
            train = True,
            download=False,
            transform = ToTensor(), 
            target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
        )
        
        test = CIFAR10(
            root = "data", 
            train = False,
            download=False,
            transform = ToTensor(), 
            target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
        )
    except: 
        train = CIFAR10(
            root = "data", 
            train = True,
            download = True,
            transform = ToTensor(), 
            target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
        )
        
        test = CIFAR10(
            root = "data", 
            train = False,
            download =True,
            transform = ToTensor(), 
            target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
        )
        
    train_data = DataLoader(train, batch_size=batch_size, shuffle=shuffle)
    test_data  = DataLoader(test, batch_size=batch_size, shuffle=shuffle)
    
    return train_data, test_data

train_data, test_data = load_CIFAR10()

# Defining Train and Test Loop

In [None]:
def train(train_data, model, loss_func, optim): 
    """Apply one step of training"""
    size = len(train_data.dataset)
    for batch, (X, y) in enumerate(train_data):
        X, y = X.to(DEVICE), y.to(DEVICE)

        pred = model(X)
        loss = loss_func(pred, y)
        
        # Optimization 
        optim.zero_grad() 
        loss.backward()
        optim.step()

def test(test_data, model, loss_func, optim):
    """Evaluate the model on the test set"""
    size = len(test_data.dataset)
    n_batch = len(test_data)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in test_data:
            X, y = X.to(DEVICE), y.to(DEVICE)
            y_hat = model(X)
            test_loss += loss_func(y_hat, y).item()
            
            correct += (y_hat.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
            
    test_loss /= n_batch
    correct /= size 
    
    print(f"Accuracy: {100*correct:>0.1f}%, Avg loss: {test_loss:>8f}\n")

# Defining the pruning methods

In [None]:
LAYERS = ["conv", "linear"]

def prune_resnet(model, ratio): 
  """
  Iterate through layers of the model and prune linear and conv layers. This is
  a very basic approach corresponding to magnitude based pruning at initialization,
  which is used as a baseline in many papers. 
  """
  
  for name, layer in model.named_modules():
    if any(x in name for x in LAYERS):
      prune.l1_unstructured(layer, name = "weight", amount = ratio)
      if hasattr(layer, "bias") and layer.bias != None:
          prune.l1_unstructured(layer, name = "bias", amount = ratio)

  return model

# Evaluate pruning training efficiency

In [None]:
def eval_training(model, niter):
  """Train a model for a fixed number of iterations."""
  loss = CrossEntropyLoss()
  optim = Adam(model.parameters(), lr = 1e-3)

  exec_time = []

  for e in tqdm(range(niter)): 
    start = time()
    train(train_data, model, loss, optim)
    end = time()
    exec_time.append(end-start)
  
  test(test_data, model, loss, optim)
  
  return exec_time, model

def compare_training_time(model, niter = 10, ratio = 0.6):
  """
  Create a deepcopy of the ResNet model and prune it. Then train 
  both models for a fixed number of iterations and print the  
  """
  model_prune = prune_resnet(deepcopy(model), ratio = ratio)

  print("### Unpruned training ###")
  exec_time_t, model_t = eval_training(model, niter)
  
  print("### Pruned training ###")
  exec_time_pt, model_pt = eval_training(model_prune, niter)

  tval, pval = ttest_ind(exec_time_t, exec_time_pt)

  print(f"Unpruned: M = {np.round(np.mean(exec_time_t), 4)}, SD = {np.round(np.std(exec_time_t), 4)}")
  print(f"Pruned: M = {np.round(np.mean(exec_time_pt), 4)}, SD = {np.round(np.std(exec_time_pt), 4)}")
  
  if pval < .05: 
    print(f"Pruning leads to a statistical significance in training time: t = {np.round(tval, 4)}, p = {np.round(pval, 4)}")
  else: 
    print(f"Pruning does not lead to a statistical significance in training time: t = {np.round(tval, 4)}, p = {np.round(pval, 4)}")

  return model_t, model_pt

In [None]:
testnet = models.resnet50(num_classes = 10).to(DEVICE)

model_t, model_pt = compare_training_time(testnet, niter = 20)

### Unpruned training ###


100%|██████████| 20/20 [09:18<00:00, 27.90s/it]


Accuracy: 61.6%, Avg loss: 1.168145

### Pruned training ###


100%|██████████| 20/20 [09:34<00:00, 28.71s/it]


Accuracy: 71.3%, Avg loss: 0.871094

Unpruned: M = 27.8995, SD = 1.1458
Pruned: M = 28.7073, SD = 0.1553
Pruning leads to a statistical significance in training time: t = -3.0449, p = 0.0042


# Evaluate pruning inference efficiency

In [None]:
def eval_inference(model, niter = 10): 
  """
  Get the time required for inference on the full test set for a fixed 
  number of iterations.
  """
  duration = []
  for i in tqdm(range(niter)):
    start = time()
    for batch in train_data:
      with torch.no_grad():
        data = batch[0].to(DEVICE) 
        model(data)
    stop = time()
    duration.append(stop - start)
  return duration 

def compare_inference_time(model_t, model_pt, niter = 20):
  """
  Use the trained models without/with pruning for inference by running
  the full training set for a fixed number of iterations. 
  """

  print("### Unpruned inference ###")
  exec_time_t = eval_inference(model_t, niter)
  
  print("### Pruned inference ###")
  exec_time_pt = eval_inference(model_pt, niter)

  tval, pval = ttest_ind(exec_time_t, exec_time_pt)

  print(f"Unpruned: M = {np.round(np.mean(exec_time_t), 4)}, SD = {np.round(np.std(exec_time_t), 4)}")
  print(f"Pruned: M = {np.round(np.mean(exec_time_pt), 4)}, SD = {np.round(np.std(exec_time_pt), 4)}")

  if pval < .05: 
    print(f"Pruning leads to a statistical significance in inference time: t = {np.round(tval, 4)}, p = {np.round(pval, 4)}")
  else: 
    print(f"Pruning does not lead to a statistical significance in inferences time: t = {np.round(tval, 4)}, p = {np.round(pval, 4)}")

In [None]:
compare_inference_time(model_t, model_pt)

### Unpruned inference ###


100%|██████████| 20/20 [03:45<00:00, 11.27s/it]


### Pruned inference ###


100%|██████████| 20/20 [04:03<00:00, 12.18s/it]

Unpruned: M = 11.2734, SD = 0.374
Pruned: M = 12.178, SD = 0.1735
Pruning leads to a statistical significance in inference time: t = -9.5652, p = 0.0



