# Imports 

In [None]:
!pip install wandb
!pip install kornia
!pip install optuna

In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.autograd import Variable
from torch import optim
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import math
import copy
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.models import resnet50, resnet34, resnet18, wide_resnet50_2, ResNet50_Weights, alexnet
import gc
import os
import pandas as pd
from torchvision.io import read_image
from flax.training import checkpoints
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

#Import from MTT code
from networks import ConvNet, AlexNet
from distill import ParamDiffAug
from utils import evaluate_synset, get_network
import argparse

import optuna
import kornia
import wandb

[0m

# Load Data

In [2]:
#Load distilled data from MTT, if not in downloads directory, download from MTT repository: https://github.com/GeorgeCazenavette/mtt-distillation/
labels_train = torch.load('./data/cifar100_50ipc_labels.pt')
images_train = torch.load('./data/cifar100_50ipc_images.pt')

In [3]:
#Load in real training data from pytorch
batch_size = 256
train_dataset = torchvision.datasets.CIFAR100(root = './data',
                                                    train = True,
                                                    transform = transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),]),
                                                    download=True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                                    batch_size = batch_size,
                                                    shuffle = True)


test_dataset = torchvision.datasets.CIFAR100(root = './data',
                                                    train = False,
                                                    transform = transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),]),
                                                    download=True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                                    batch_size = batch_size,
                                                    shuffle = True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


# Methods used for Distilled Pruning

In [4]:
#Standard train function with hyperparameters used in paper set as default
def train(model,train_loader, num_epochs, lr = .0008, weight_decay = .0008, gamma = .15, milestones = [50,65,80]):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    cost = nn.CrossEntropyLoss()
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma= gamma)
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):  
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            
            loss = cost(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
    pass

#Standard test function, prints & returns test accuracy
def test(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    # Test the model
    model.eval()
    model.to(device)
    with torch.no_grad():
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(test_loader): 
            images, labels = images.to(device), labels.to(device)
            test_output = model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            correct += (pred_y == labels).sum().item()
            total += labels.size(0)
        accuracy = correct / total

    print('Test Accuracy:', accuracy)
    return accuracy

#Helper function for prunable all pruning modules to work pytorch global pruning. 
#See global pruning section of this: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
def get_parameters_to_prune(model):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    return tuple(parameters_to_prune)

#Returns number of zeros and total number of prunable parameters of a model. Global Sparsity measured as: zero / total
def sparsity_print(model):
    prune.global_unstructured(get_parameters_to_prune(model),pruning_method=prune.L1Unstructured,amount=0)
    zero = total = 0
    for module, _ in get_parameters_to_prune(model):
        zero += float(torch.sum(module.weight == 0))
        total += float(module.weight.nelement())
    print('Number of Zero Weights:', zero)
    print('Total Number of Weights:', total)
    print('Sparsity', zero/total)
    #TODO: Implement Node Sparsity
    return zero, total

#Standard IMP with Weight Rewinding to the kth epoch in training, 
#name: a string that allows us to save models/logs appropriately, 
#path: the location of folder we save to,
#start_iter: should normally be 0 but if a experiment stops halfway through it allows us to begin there,
#amount = % params pruned each pruning iteration, 
#save_model: boolean to decide if we download the model at every iter, 
#reinit: is boolean value if we want to test results on reinitialized weights
#reinit_model: is the specific model that holds the reinitialized weights, must be same type as model
#seed: seed for pruning sequence, SEED IS NOT FOR MODEL INITIALIZATION FROM PAPER. We instead set seed outside the function for that.
#num_epochs: number of epochs used for training
def LotteryTicketRewinding(model, name, path, train_loader, test_loader, start_iter = 0, end_iter = 30, num_epochs = 60, k = 1, amount = .2, save_model = True, seed = 0, reinit = False, reinit_model = None):
    torch.manual_seed(seed)
    zeros = [] #keeps track of zeros at each iteration
    totals = [] #keeps track of total parameters
    acc = [] #keeps track of model accuracy at each pruning iteration
    reinit_acc = [] #same as above but for reinitialized model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    
    #Create Rewind Weights after training K epochs
    train(model, train_loader,num_epochs = k)
    torch.save(model.state_dict(), path + name + '_RewindWeights' + '_' + str(k))
    model_rewind = copy.deepcopy(model).to(device)
    
    #Finish off the pretraining
    train(model, train_loader,num_epochs = num_epochs - k)

    #Lottery Ticket Rewinding: Prune, Rewind, Train
    for i in range(start_iter,end_iter):
        print('LTR Iteration:', i+1)
        #Prune
        prune.global_unstructured(get_parameters_to_prune(model),pruning_method=prune.L1Unstructured,amount=amount)
        #Rewind Weights
        for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
            with torch.no_grad():
                module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                module.weight_orig.copy_(module_rewind.weight)
        #Train Weights
        train(model, train_loader,num_epochs = num_epochs)
        
        #Log Results
        zero, total = sparsity_print(model)
        zeros.append(zero)
        totals.append(total)
        acc.append(test(model, test_loader))
        if save_model:
            torch.save(model.state_dict(), path + name + '_iter' + str(i+1))
            
        #Reinitialize the weights, train and validate on those new weights
        if reinit:
            #Rewind Weights
            for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
                with torch.no_grad():
                    module_reinit = get_parameters_to_prune(reinit_model)[idx][0]
                    module.weight_orig.copy_(module_reinit.weight)
                    
            train(model, train_loader,num_epochs = num_epochs)
            reinit_acc.append(test(model, test_loader))
            
            for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
                with torch.no_grad():
                    module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                    module.weight_orig.copy_(module_rewind.weight)
        else:
            reinit_acc.append(0)
            
        np.save(path + name + '_log', np.array([acc,zeros,totals,reinit_acc]))
    
    pass
  
#Generate full sparsity curve with retraining for random pruning, this is not a method to compute a single, high-sparsity model with random pruning. 
#This generates and trains models at all sparsities for comparison. 
#If you want to perform a single random pruning, then just use: prune.global_unstructured(get_parameters_to_prune(model),pruning_method=prune.RandomUnstructured,amount=amount)
def RandomPruning(model, name, path, train_loader, test_loader, start_iter = 0, end_iter = 30, num_epochs = 60, amount = .2, save_model = True, seed = 0):
    torch.manual_seed(seed)
    zeros = []
    totals = []
    acc = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    
    model_rewind = copy.deepcopy(model).to(device)
    
    for i in range(start_iter,end_iter):
        print('Random Pruning Iteration:', i+1)
        #Prune
        prune.global_unstructured(get_parameters_to_prune(model),pruning_method=prune.RandomUnstructured,amount=amount)
        #Train Weights
        train(model, train_loader,num_epochs = num_epochs)
        
        #Log Results
        zero, total = sparsity_print(model)
        zeros.append(zero)
        totals.append(total)
        acc.append(test(model, test_loader))

            
        #Rewind Weights to save them
        for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
            with torch.no_grad():
                module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                module.weight_orig.copy_(module_rewind.weight)
                
        if save_model:
            torch.save(model.state_dict(), path + name + '_iter' + str(i+1))
            
        np.save(path + name + '_log', np.array([acc,zeros,totals]))
    
#Distilled Pruning 
#See LotteryTicketRewinding for parameter description. 
#num_epochs_distilled: number of epochs used for distilled training
#distilled_lr = learning rate for distilled training
#validate = Retrains pruned model at each iteration on real data. This is used to generate a full sparsity curve, but is not efficient for finding a single distilled mask.
#use validate false, and set end_iter to desired final iteration of the mask. Remember sparsity = 1-(1-amount)^end_iter
def DistilledPruning(model, name, path, images_train, labels_train, train_loader, test_loader, start_iter = 0, end_iter = 30, num_epochs_distilled = 1000, num_epochs_real = 60, k = 0, amount = .2, save_model = True, validate = False, seed = 0, reinit = False, reinit_model = None, distilled_lr = .01):
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    accs = []
    zeros = []
    totals = []
    reinit_acc = []
    
    #Create rewind weights at initailization
    model_rewind = copy.deepcopy(model).to(device)
    torch.save(model.state_dict(), path + name + '_RewindWeights' + '_' + str(k))
    
    #Use if you want to try rewinding to an early point in training, this does not work well, so we suggest k=0 always.
    if k != 0:
        args = argparse.Namespace(lr_net=str(distilled_lr), device='cuda', epoch_eval_train=str(k),batch_train=512,dataset='cifar10',dsa=True,dsa_strategy='color_crop_cutout_flip_scale_rotate',dsa_param = ParamDiffAug(), dc_aug_param=None, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)) #, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
        model_rewind, acc_train_list, acc_test = evaluate_synset(0, model_rewind,images_train,labels_train,test_loader,args)
        
    
    for i in range(start_iter,end_iter):
        print('Distilled Pruning Iteration ', i)
        #Set distilled pruning training args for MTT eval
        args = argparse.Namespace(lr_net='.01', device='cuda', epoch_eval_train=str(num_epochs_distilled),batch_train=512,dataset='cifar10',dsa=True,dsa_strategy='color_crop_cutout_flip_scale_rotate',dsa_param = ParamDiffAug(), dc_aug_param=None, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)) #, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
        #MTT Training on Distilled Data
        model, acc_train_list, acc_test = evaluate_synset(i+1, model,images_train,labels_train,test_loader,args)
        prune.global_unstructured(get_parameters_to_prune(model),pruning_method=prune.L1Unstructured,amount=amount)
        #Rewind Weights
        for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
            with torch.no_grad():
                module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                module.weight_orig.copy_(module_rewind.weight)
    
        if save_model:
            torch.save(model.state_dict(), path + name + '_iter' + str(i+1))
            
        #Rewind weights back to initialization and train on real data to validate this sparsity mask
        if validate:
            train(model, train_loader,num_epochs = num_epochs_real)
            accs.append(test(model, test_loader))
            zero, total = sparsity_print(model)
            zeros.append(zero)
            totals.append(total)
            #Rewind Weights
            for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
                with torch.no_grad():
                    module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                    module.weight_orig.copy_(module_rewind.weight)
                    
            np.save(path + name + '_log', np.array([accs, zeros, totals, reinit_acc]))
        
        if reinit:
            #Rewind Weights to Reinit Model
            for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
                with torch.no_grad():
                    module_reinit = get_parameters_to_prune(reinit_model)[idx][0]
                    module.weight_orig.copy_(module_reinit.weight)
                    
            train(model, train_loader,num_epochs = num_epochs_real)
            reinit_acc.append(test(model, test_loader))
            
            for idx, (module, _) in enumerate(get_parameters_to_prune(model)):
                with torch.no_grad():
                    module_rewind = get_parameters_to_prune(model_rewind)[idx][0]
                    module.weight_orig.copy_(module_rewind.weight)
            np.save(path + name + '_log', np.array([accs, zeros, totals, reinit_acc]))
        else:
            reinit_acc.append(0)
    #If validate = False, then we still want to validate the final sparsity mask. just not all the masks.
    if not validate:
        train(model, train_loader,num_epochs = num_epochs_real)
        acc = (test(model, test_loader))
        zero, total = sparsity_print(model)
        np.save(path + name + '_log', np.array([acc, zero, total, reinit]))

In [None]:
#Save logs and models here
path = './model_results_cifar100/' 

For CIFAR 100, Distilled Pruning Hyperparameters = .085 distilled_lr and 1300 num_epochs_distilled. All else is in code and the same across experiments/datasets


# Example Distilled Training Snippet for MTT

In [14]:
model = get_network('ConvNetW128', 3, 100)
args = argparse.Namespace(lr_net='.09', device='cuda', epoch_eval_train=str(1300),batch_train=512,dataset='cifar100',dsa=True,dsa_strategy='color_crop_cutout_flip_scale_rotate',dsa_param = ParamDiffAug(), dc_aug_param=None, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)) #, zca_trans=kornia.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
model, acc_train_list, acc_test = evaluate_synset(1, model,images_train,labels_train,test_loader,args)

100%|██████████| 1301/1301 [01:33<00:00, 13.92it/s]

[2023-05-30 16:31:14] Evaluate_01: epoch = 1300 train time = 93 s train loss = 0.005592 train acc = 1.0000, test acc = 0.3993





# Example Distilled Pruning for computing a single sparsity mask

In [None]:
torch.manual_seed(0)
model = get_network('ConvNetW128', 3, 100)
DistilledPruning(model, 'DistilledPruning_CIFAR100_ipc50_seed0_iter5', path, images_train, labels_train, train_loader, test_loader, start_iter = 0, end_iter = 5, num_epochs_distilled = 1300, num_epochs_real = 120, k = 0, amount = .2, save_model = False, validate = True, seed = 0, reinit = False, reinit_model = None, distilled_lr = .09)

# Distilled Pruning Experiment for generating plots / validating all sparsity mask

In [None]:
for i in range(5):
    torch.manual_seed(i)
    model = get_network('ConvNetW128', 3, 100)
    DistilledPruning(model, 'DistilledPruning_CIFAR100_ipc50_seed' + str(i), path, images_train, labels_train, train_loader, test_loader, start_iter = 0, end_iter = 20, num_epochs_distilled = 1300, num_epochs_real = 120, k = 0, amount = .2, save_model = False, validate = True, seed = 0, reinit = False, reinit_model = None, distilled_lr = .09)

Distilled Pruning Iteration  0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 1251/1251 [00:55<00:00, 22.35it/s]


[2023-05-31 18:21:30] Evaluate_01: epoch = 1250 train time = 55 s train loss = 0.029523 train acc = 1.0000, test acc = 0.3745


  np.save(path + name + '_log', np.array([accs, zeros, totals, reinit_acc]))


Test Accuracy: 0.5176
Number of Zero Weights: 100634.0
Total Number of Weights: 503168.0
Sparsity 0.2000007949631137
Distilled Pruning Iteration  1


100%|██████████| 1251/1251 [00:54<00:00, 22.77it/s]


[2023-05-31 18:36:19] Evaluate_02: epoch = 1250 train time = 54 s train loss = 0.017830 train acc = 1.0000, test acc = 0.3730
Test Accuracy: 0.5025
Number of Zero Weights: 181141.0
Total Number of Weights: 503168.0
Sparsity 0.3600010334520478
Distilled Pruning Iteration  2


100%|██████████| 1251/1251 [00:54<00:00, 22.88it/s]


[2023-05-31 18:51:10] Evaluate_03: epoch = 1250 train time = 54 s train loss = 0.004405 train acc = 1.0000, test acc = 0.3734
Test Accuracy: 0.495
Number of Zero Weights: 245546.0
Total Number of Weights: 503168.0
Sparsity 0.4880000317985245
Distilled Pruning Iteration  3


100%|██████████| 1251/1251 [00:53<00:00, 23.30it/s]


[2023-05-31 19:06:01] Evaluate_04: epoch = 1250 train time = 53 s train loss = 0.003360 train acc = 1.0000, test acc = 0.3712
Test Accuracy: 0.4791
Number of Zero Weights: 297070.0
Total Number of Weights: 503168.0
Sparsity 0.590399230475706
Distilled Pruning Iteration  4


100%|██████████| 1251/1251 [00:54<00:00, 22.94it/s]


[2023-05-31 19:20:54] Evaluate_05: epoch = 1250 train time = 54 s train loss = 0.003594 train acc = 1.0000, test acc = 0.3687
Test Accuracy: 0.4738
Number of Zero Weights: 338290.0
Total Number of Weights: 503168.0
Sparsity 0.6723201793436785
Distilled Pruning Iteration  5


100%|██████████| 1251/1251 [00:54<00:00, 22.92it/s]


[2023-05-31 19:35:49] Evaluate_06: epoch = 1250 train time = 54 s train loss = 0.014486 train acc = 1.0000, test acc = 0.3672


# IMP Pruning Experiment for generating plots / validating all sparsity mask

In [9]:
for i in range(5):
    torch.manual_seed(i)
    model = get_network('ConvNetW128', 3,10)
    LotteryTicketRewinding(model, 'IMP_CIFAR100_seed' + str(i), path, train_loader, test_loader, start_iter = 0, end_iter = 30, num_epochs = 120, k = 1, amount = .2, save_model = False, seed = i, reinit = False)

LTR Iteration: 1
Number of Zero Weights: 63770.0
Total Number of Weights: 318848.0
Sparsity 0.20000125451625853
Test Accuracy: 0.8143
LTR Iteration: 2
Number of Zero Weights: 114786.0
Total Number of Weights: 318848.0
Sparsity 0.36000225812926534
Test Accuracy: 0.8132
LTR Iteration: 3
Number of Zero Weights: 155598.0
Total Number of Weights: 318848.0
Sparsity 0.48800055198715375
Test Accuracy: 0.8118
LTR Iteration: 4
Number of Zero Weights: 188248.0
Total Number of Weights: 318848.0
Sparsity 0.590400441589723
Test Accuracy: 0.8094
LTR Iteration: 5
Number of Zero Weights: 214368.0
Total Number of Weights: 318848.0
Sparsity 0.6723203532717784
Test Accuracy: 0.8101
LTR Iteration: 6
Number of Zero Weights: 235264.0
Total Number of Weights: 318848.0
Sparsity 0.7378562826174228
Test Accuracy: 0.806
LTR Iteration: 7
Number of Zero Weights: 251981.0
Total Number of Weights: 318848.0
Sparsity 0.7902856533520675
Test Accuracy: 0.8042
LTR Iteration: 8
Number of Zero Weights: 265354.0
Total Number

# Random Pruning Experiment

In [6]:
for i in range(5):
    torch.manual_seed(i)
    model = get_network('ConvNetW128', 3,100)
    RandomPruning(model, 'RandomPruning_CIFAR100_seed' + str(i), path, train_loader, test_loader, start_iter = 0, end_iter = 30, num_epochs = 120, amount = .2, save_model = False, seed = i)

Random Pruning Iteration: 1
Number of Zero Weights: 100634.0
Total Number of Weights: 503168.0
Sparsity 0.2000007949631137
Test Accuracy: 0.8126
Random Pruning Iteration: 2
Number of Zero Weights: 181141.0
Total Number of Weights: 503168.0
Sparsity 0.3600010334520478
Test Accuracy: 0.7985
Random Pruning Iteration: 3
Number of Zero Weights: 245547.0
Total Number of Weights: 503168.0
Sparsity 0.4880020192063088
Test Accuracy: 0.7913
Random Pruning Iteration: 4
Number of Zero Weights: 297070.0
Total Number of Weights: 503168.0
Sparsity 0.590399230475706
Test Accuracy: 0.781
Random Pruning Iteration: 5
Number of Zero Weights: 338290.0
Total Number of Weights: 503168.0
Sparsity 0.6723201793436785
Test Accuracy: 0.7751
Random Pruning Iteration: 6
Number of Zero Weights: 371266.0
Total Number of Weights: 503168.0
Sparsity 0.7378569384380564
Test Accuracy: 0.7669
Random Pruning Iteration: 7
Number of Zero Weights: 397646.0
Total Number of Weights: 503168.0
Sparsity 0.7902847557873315
Test Accu