In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import torch
from torchvision.models import MobileNet_V2_Weights
from torchvision import datasets


In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
preprocess = MobileNet_V2_Weights.IMAGENET1K_V1.transforms()

test_dataset = datasets.ImageNet(root='../data/imagenet',
                                 split='val',
                                transform=preprocess)

In [57]:
import os
from torch.utils.data import Dataset
from PIL import Image
import json
syn_to_class = {}
with open(os.path.join("../data/imagenet", "imagenet_class_index.json"), "rb") as f:
    json_file = json.load(f)
    for class_id, v in json_file.items():
        syn_to_class[v[0]] = class_id
                
def get_class_name(entry):        
    target = syn_to_class[int(entry)]
    return target
        
class ImageNetKaggle(Dataset):
    def __init__(self, root, transform=None, test_size=0.0, num_imgs=50):
        with open(os.path.join("../data/imagenet", "imagenet_class_index.json"), "rb") as f:
            json_file = json.load(f)
            for class_id, v in json_file.items():
                syn_to_class[v[0]] = class_id
                
        self.samples = []
        self.targets = []
        self.transform = transform
        samples_dir = os.path.join(root, "val")
        
        test_split_size = int(test_size * num_imgs)
        train_split_size = num_imgs - test_split_size
        
        for entry in os.listdir(samples_dir):
                sample_path = os.path.join(samples_dir, entry)
                for file in os.listdir(sample_path):                    
                    self.samples.append(os.path.join(sample_path, file))
                    self.targets.append(int(syn_to_class[entry]))
        
    def __len__(self):
            return len(self.samples)
        
    def __getitem__(self, idx):
            x = Image.open(self.samples[idx]).convert("RGB")
            if self.transform:
                x = self.transform(x)
            return x, self.targets[idx]

In [59]:
root = '../data/imagenet'
val_transform = None
dataset = ImageNetKaggle(root, val_transform)


In [53]:
num_imgs = 50
test_size = 0.2
seed = 42  # For reproducibility
transform = None
shuffle = True
random_state = np.random.RandomState(seed)
root = '../data/imagenet'
with open(os.path.join("../data/imagenet", "imagenet_class_index.json"), "rb") as f:
    json_file = json.load(f)
    for class_id, v in json_file.items():
        syn_to_class[v[0]] = class_id
        
train_samples = []
train_targets = []
test_samples = []
test_targets = []

transform = transform
samples_dir = os.path.join(root, "val")

test_split_size = int(test_size * num_imgs)
train_split_size = num_imgs - test_split_size

for entry in os.listdir(samples_dir):
    sample_path = os.path.join(samples_dir, entry)
    samples = []
    targets = []
    for file in os.listdir(sample_path):                    
        samples.append(os.path.join(sample_path, file))
        targets.append(int(syn_to_class[entry]))
        
    if shuffle:
        pairs = list(zip(samples, targets))
        random_state.shuffle(pairs)
        samples = [p[0] for p in pairs]
        targets = [p[1] for p in pairs]
    
    train_samples.extend(samples[:train_split_size])
    train_targets.extend(targets[:train_split_size])
    test_samples.extend(samples[train_split_size:])
    test_targets.extend(targets[train_split_size:])
    


In [65]:
class ImageNet(Dataset):
    def __init__(self, samples, targets, transform=None):
        self.transform = transform
        self.samples = samples
        self.targets = targets
        
    def __len__(self):
            return len(self.samples)
        
    def __getitem__(self, idx):
            x = Image.open(self.samples[idx]).convert("RGB")
            if self.transform:
                x = self.transform(x)
            return x, self.targets[idx]

In [67]:
train_dataset = ImageNet(test_samples, test_targets)
len(train_dataset)

10

['../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00000293.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00002138.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00003014.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00006697.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00007197.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00009111.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00009191.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00009346.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00009379.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00009396.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00010306.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00011233.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00011993.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00012503.JPEG',
 '../data/imagenet\\val\\n01440764\\ILSVRC2012_val_00013716.JP

In [36]:
t

0

In [26]:
import numpy as np
seed = 42  # For reproducibility
random_state = np.random.RandomState(seed)
random_state.shuffle(pairs)

In [28]:
pairs

[('../data/imagenet\\val\\n03792782\\ILSVRC2012_val_00004147.JPEG', 671),
 ('../data/imagenet\\val\\n02095314\\ILSVRC2012_val_00033834.JPEG', 188),
 ('../data/imagenet\\val\\n01491361\\ILSVRC2012_val_00049000.JPEG', 3),
 ('../data/imagenet\\val\\n02109961\\ILSVRC2012_val_00047924.JPEG', 248),
 ('../data/imagenet\\val\\n04201297\\ILSVRC2012_val_00043448.JPEG', 789),
 ('../data/imagenet\\val\\n04418357\\ILSVRC2012_val_00029189.JPEG', 854),
 ('../data/imagenet\\val\\n02101556\\ILSVRC2012_val_00025471.JPEG', 216),
 ('../data/imagenet\\val\\n12620546\\ILSVRC2012_val_00047816.JPEG', 989),
 ('../data/imagenet\\val\\n01797886\\ILSVRC2012_val_00041909.JPEG', 82),
 ('../data/imagenet\\val\\n03992509\\ILSVRC2012_val_00006589.JPEG', 739),
 ('../data/imagenet\\val\\n04456115\\ILSVRC2012_val_00003713.JPEG', 862),
 ('../data/imagenet\\val\\n04131690\\ILSVRC2012_val_00043799.JPEG', 773),
 ('../data/imagenet\\val\\n01984695\\ILSVRC2012_val_00039918.JPEG', 123),
 ('../data/imagenet\\val\\n01632458\\ILSV

In [14]:
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import torchvision

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
val_transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
dataset = ImageNetKaggle("../data/imagenet", val_transform)
dataloader = DataLoader(
            dataset,
            batch_size=64, # may need to reduce this depending on your GPU 
            shuffle=False,
        )

In [15]:
def evaluate(model, device, test_loader):
    model.eval()
    
    losses = 0.0
    total_predictions = 0
    true_predictions_top1 = 0
    true_predictions_top5 = 0
    criterion = torch.nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            # Compute loss
            loss = criterion(outputs, targets) / inputs.size(0)
            losses += loss.item()
            
            # Top-1 predictions
            _, predicted_top1 = torch.max(outputs, 1)
            batch_true_predictions_top1 = (predicted_top1 == targets).sum().item()
            true_predictions_top1 += batch_true_predictions_top1
            
            # Top-5 predictions
            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            batch_true_predictions_top5 = sum(
                [targets[i].item() in predicted_top5[i].tolist() for i in range(targets.size(0))]
            )
            true_predictions_top5 += batch_true_predictions_top5
            
            # Update total predictions
            batch_total_predictions = outputs.size(0)
            total_predictions += batch_total_predictions
            
            # Print batch metrics
            # print(
            #     f'Batch {batch_idx}, Loss: {loss:.4f}, '
            #     f'Accuracy@1: {batch_true_predictions_top1/batch_total_predictions*100:.2f}%, '
            #     f'Accuracy@5: {batch_true_predictions_top5/batch_total_predictions*100:.2f}%'
            # )
    
    # Compute overall accuracies
    accuracy_top1 = true_predictions_top1 / total_predictions
    accuracy_top5 = true_predictions_top5 / total_predictions
    
    return accuracy_top1, accuracy_top5, losses


In [16]:
model = torchvision.models.mobilenet_v2(weights='MobileNet_V2_Weights.IMAGENET1K_V1')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [17]:
model.eval().to(DEVICE)
accuracy_top1, accuracy_top5, losses = evaluate(model, DEVICE, dataloader)
print(f"acc@1: {accuracy_top1*100}%, acc@5: {accuracy_top5*100}%, loss: {losses}")

acc@1: 71.87%, acc@5: 90.31%, loss: 14.188950040435884


# Evaluation of Global Pruning w/o Retraining

In [3]:
from torchvision.models import MobileNet_V2_Weights
from utils.datasets import train_test_split, ImageNet
from models.pruning import Pruning
from torch.utils.data import DataLoader
from torchvision import models
import torch

In [4]:
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### Load Data

In [5]:
preprocess = MobileNet_V2_Weights.IMAGENET1K_V1.transforms()


In [6]:
_, _, test_X, test_Y = train_test_split(test_size=0.2, shuffle=False, num_imgs=50, root = "../data/imagenet")
test_dataset = ImageNet(test_X, test_Y)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Prune Model

In [14]:
model = models.mobilenet_v2(weights='MobileNet_V2_Weights.IMAGENET1K_V1')
model.to(DEVICE)
model.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [15]:
batch_norms = []
for i in range(1, 18):
    if i == 1:
        batch_norms.append(f'features.{i}.conv.0.1')
        continue    
    batch_norms.append(f'features.{i}.conv.1.1')

In [16]:
pruning = Pruning(model, DEVICE)
model = pruning.scaling_based_pruning(batch_norms=batch_norms, pruning_ratio=0.1, level='global', scale_threshold=False)

features.1.conv.0.1 features.2.conv.1.1 features.3.conv.1.1 features.4.conv.1.1 features.5.conv.1.1 features.6.conv.1.1 features.7.conv.1.1 features.8.conv.1.1 features.9.conv.1.1 features.10.conv.1.1 features.11.conv.1.1 features.12.conv.1.1 features.13.conv.1.1 features.14.conv.1.1 features.15.conv.1.1 features.16.conv.1.1 features.17.conv.1.1
32    96    144   144   192   192   192   384   384   384   384   576   576   576   960   960   960  
23    94    137   144   181   178   191   344   317   313   378   514   501   519   896   875   818  
----------------------------------------------------------------------------------------------------
9     2     7     0     11    14    1     40    67    71    6     62    75    57    64    85    142  


In [19]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--foo', nargs='+', default=[1,2,3,4,5,6,7,8,9,10])
parser.add_argument('--bar', nargs='*')
parser.add_argument('baz', nargs='*')
parser.parse_args('a b --bar 1 2'.split())

Namespace(foo=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], bar=['1', '2'], baz=['a', 'b'])