In [2]:
# Import necessary package
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Implement a naive multilayer perceptron
class MLP(nn.Module):
    def __init__(self):
        super().__init__()  
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3*32*32, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32,10) # the second value must be 10 because cifar10 dataset has 10 classes.
        )

    def forward(self, x):
        return self.linear_relu_stack(self.flatten(x))

In [4]:
# Prepare CIFAR-10 dataset
trainset = CIFAR10(root='cifar10', train="True", download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)
testset = CIFAR10(root='cifar10', train="False", download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Create OTO instance
import sys
sys.path.append('..')
from only_train_once import OTO

model = MLP() # Instantiate the model
dummy_input = torch.rand(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

OTO graph constructor
graph build


In [6]:
# Set up the Hesso optimizer
optimizer = oto.hesso(
    variant='sgd', 
    lr=0.1, 
    weight_decay=1e-4,
    target_group_sparsity=0.5,
    start_pruning_step=10 * len(trainloader), 
    pruning_periods=10,
    pruning_steps=10 * len(trainloader)
)

Setup HESSO
Target redundant groups per period:  [4, 4, 4, 4, 4, 4, 4, 4, 4, 12]


In [7]:
from utils.utils import check_accuracy

max_epoch = 100
model.cuda()
criterion = torch.nn.CrossEntropyLoss()

# Every 50 epochs, decay lr by 10.0
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 

for epoch in range(max_epoch):
    f_avg_val = 0.0
    lr_scheduler.step()
    for X, y in trainloader:
        X = X.cuda()
        y = y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        f_avg_val += f
        optimizer.step()
    opt_metrics = optimizer.compute_metrics()
    
    accuracy1, accuracy5 = check_accuracy(model, testloader)
    # accuracy1, accuracy5 = check_accuracy(model, trainloader)
    f_avg_val = f_avg_val.cpu().item() / len(trainloader)
    
    print("Ep: {ep}, loss: {f:.2f}, norm_all:{param_norm:.2f}, grp_sparsity: {gs:.2f}, acc1: {acc1:.4f}, norm_import: {norm_import:.2f}, norm_redund: {norm_redund:.2f}, num_grp_import: {num_grps_import}, num_grp_redund: {num_grps_redund}"\
         .format(ep=epoch, f=f_avg_val, param_norm=opt_metrics.norm_params, gs=opt_metrics.group_sparsity, acc1=accuracy1,\
         norm_import=opt_metrics.norm_important_groups, norm_redund=opt_metrics.norm_redundant_groups, \
         num_grps_import=opt_metrics.num_important_groups, num_grps_redund=opt_metrics.num_redundant_groups
        ))



Ep: 0, loss: 1.97, norm_all:68.85, grp_sparsity: 0.00, acc1: 0.2846, norm_import: 68.85, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 1, loss: 1.78, norm_all:79.18, grp_sparsity: 0.00, acc1: 0.2817, norm_import: 79.18, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 2, loss: 1.70, norm_all:88.29, grp_sparsity: 0.00, acc1: 0.3699, norm_import: 88.29, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 3, loss: 1.66, norm_all:96.56, grp_sparsity: 0.00, acc1: 0.3959, norm_import: 96.56, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 4, loss: 1.61, norm_all:103.96, grp_sparsity: 0.00, acc1: 0.3466, norm_import: 103.96, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 5, loss: 1.59, norm_all:110.85, grp_sparsity: 0.00, acc1: 0.4124, norm_import: 110.85, norm_redund: 0.00, num_grp_import: 96, num_grp_redund: 0
Ep: 6, loss: 1.56, norm_all:117.09, grp_sparsity: 0.00, acc1: 0.3604, norm_import: 117.09, norm_redund: 0.00, num_grp_im

In [8]:
import os

oto.construct_subnet(out_dir='./cache')

# Compare the full model size and compressed model size
full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model     : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")

# Both full and compressed model should return the exact same accuracy.
full_model = torch.load(oto.full_group_sparse_model_path)
compressed_model = torch.load(oto.compressed_model_path)

acc1_full, acc5_full = check_accuracy(full_model, testloader)
print("Full model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_full, acc5=acc5_full))

acc1_compressed, acc5_compressed = check_accuracy(compressed_model, testloader)
print("Compressed model: Acc 1: {acc1}, Acc 5: {acc5}".format(acc1=acc1_compressed, acc5=acc5_compressed))

Size of full model     :  0.000745970755815506 GBs
Size of compress model :  0.00021345727145671844 GBs


Full model: Acc 1: 0.55028, Acc 5: 0.94488
Compressed model: Acc 1: 0.55028, Acc 5: 0.94488
