In [None]:
© 2021-present Neuralmagic, Inc. // Neural Magic Legal

Torchvision Classification Model Pruning using SparseML
This notebook provides a step-by-step walkthrough for pruning a torchvision model using SparseML. You will:

Download a pre-trained torchvision model and generic dataset
Define a generic torchvision finetuning flow
Integrate the torchvision flow with SparseML
Prune the model using the torchvision+SparseML flow
Save the model and export to ONNX
Reading through this notebook will be reasonably quick to gain an intuition for how to integrate SparseML with torchvision or more generically a PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:

15 minutes on a GPU
45 minutes on a laptop CPU

In [None]:
Step 1 - Requirements
To run this notebook, you will need the following packages already installed:

SparseML and SparseZoo
PyTorch and torchvision
You can install any package that is not already present via pip.

In [None]:
"""
Creating Config for BERT-QA GMP training/pruning with Neural Magic
"""

notebook_name = "BERTQA-NM"
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import collections
import json
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional, Union, Tuple, Iterable, Dict, Any, List
import re
import random
import time
import math
from dataclasses import dataclass, field
import wandb
from glob import glob
from tqdm import auto
from tqdm.auto import tqdm
import numpy as np
from onnx import ModelProto

import torch
from torch import nn
from torch.nn import Module, Linear, Parameter

import transformers
from transformers import (AutoConfig, AutoModelForQuestionAnswering,
                          AutoTokenizer, DataCollatorWithPadding,
                          PreTrainedTokenizerFast)

In [None]:
from neuralmagicML.pytorch.recal import approx_ks_loss_sensitivity
from neuralmagicML.utilsnb import check_pytorch_notebook_setup
from neuralmagicML.pytorch.utils import CrossEntropyLossWrapper
from neuralmagicML.utils import create_unique_dir, clean_path
from neuralmagicML.pytorch.utils import (
    CrossEntropyLossWrapper,
    TopKAccuracy,
    ModuleTrainer,
    ModuleTester,
    TensorBoardLogger,
)

from neuralmagicML.pytorch.utils import ModuleExporter
from neuralmagicML.utilsnb import (
    KSWidgetContainer,
    PruningEpochWidget,
    PruningParamsWidget,
)
from neuralmagicML.pytorch.utils import get_named_layers_and_params_by_regex

from neuralmagicML.pytorch.recal import ScheduledModifierManager, ScheduledOptimizer
from neuralmagicML.recal import (
    default_check_sparsities_loss,
    default_check_sparsities_perf,
    KSLossSensitivityAnalysis,
    KSPerfSensitivityAnalysis,
    KSSensitivityResult,
)
logger = logging.getLogger(__name__)
check_pytorch_notebook_setup()

In [None]:
cache_dir = 'cache'
model_name = 'bert-base-uncased'
config = AutoConfig.from_pretrained(model_name,cache_dir=cache_dir,)
tokenizer = AutoTokenizer.from_pretrained(model_name,cache_dir=cache_dir,use_fast=True,)
model = AutoModelForQuestionAnswering.from_pretrained(model_name,config=config,cache_dir=cache_dir,)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("running approximate ks loss sensitivity analysis for model on {}".format(device))

loss_analysis = approx_ks_loss_sensitivity(model)

save_path = clean_path(
    os.path.join(".", notebook_name, model_name, "ks-loss-sensitivity.json")
)
loss_analysis.save_json(save_path)
print("saved analysis to {}".format(save_path))
print("plotting...")
fig, axes = loss_analysis.plot(path=None, plot_integral=True, normalize=False)

In [None]:
if "loss_analysis" not in globals():
    loss_analysis = None
    
param_in_scope_regex = ["re:.*key\.weight", "re:.*value\.weight","re:.*query\.weight","re:.*dense\.weight"]
# match all key, value, query, and dense
prune_layers_and_params = get_named_layers_and_params_by_regex(
    model, param_in_scope_regex
)
# format to full parameter names
param_names = [
    "{}.{}".format(param.layer_name, param.param_name) for param in prune_layers_and_params
]

print("There are {} prunable parameters in the current selection".format(get_n_params_by_regex(model,param_in_scope_regex)))
widget_container = KSWidgetContainer(
    PruningEpochWidget(start_epoch=0, end_epoch=1, total_epochs=1, max_epochs=1),
    PruningParamsWidget(
        param_names=param_names,
        param_descs=[str(param.layer) for param in prune_layers_and_params],
        param_enables=None,
        param_sparsities=None,
        loss_sens_analysis=loss_analysis,
    ),
)
print("Creating ui...")
display(widget_container.create())

In [None]:
config_path = clean_path("prune-config.yaml"))
print("Saving config to {}".format(config_path)) 
widget_container.get_manager("pytorch").save(config_path)