# compression.pruning

> Fill in a module description here

In [None]:
#| default_exp compression.pruning

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function
from fasterai.core.all import *

import onnx
import onnxruntime as ort
from onnxruntime import quantization

import numpy as np
import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterbench.benchmark import *

In [None]:
#| export
def get_ignored_layers(model):
    ignored_layers = []

    # Check and process trend blocks if they exist
    if hasattr(model.ts_model, 'trend_blocks'):
        for block in model.ts_model.trend_blocks:
            if hasattr(block, 'backcast_fc'):
                ignored_layers.append(block.backcast_fc)
            if hasattr(block, 'forecast_fc'):
                ignored_layers.append(block.forecast_fc)

    # Check and process seasonality blocks if they exist
    if hasattr(model.ts_model, 'seasonality_blocks'):
        for block in model.ts_model.seasonality_blocks:
            if hasattr(block, 'backcast_fc'):
                ignored_layers.append(block.backcast_fc)
            if hasattr(block, 'forecast_fc'):
                ignored_layers.append(block.forecast_fc)

    return ignored_layers

In [None]:
#| export
def adjust_layer_features(layer, pruning_ratio):
    if hasattr(layer, 'in_features') and hasattr(layer, 'weight'):
        in_features = layer.in_features
        layer.in_features = int(in_features * (1-pruning_ratio))
        local_scores = large_final(layer, 'column')
        threshold = torch.quantile(local_scores.view(-1), pruning_ratio)
        mask = local_scores.ge(threshold).to(dtype=local_scores.dtype)
        ixs = torch.nonzero(mask[0] == 1, as_tuple=True)[0]
        layer.weight.data = layer.weight[:, ixs]

In [None]:
#| export
def prune_model(model, pruning_ratio, dummy_input):
    imp = tp.importance.GroupNormImportance(p=2)
    ignored_layers = get_ignored_layers(model)

    pruner = tp.pruner.MetaPruner(
        model.ts_model,
        dummy_input,
        importance=imp,
        pruning_ratio=pruning_ratio, 
        ignored_layers=ignored_layers
    )
    pruner.step()
    
    if hasattr(model.ts_model, 'trend_blocks'):
        for block in model.ts_model.trend_blocks:
            for layer in [block.backcast_fc, block.forecast_fc]:
                if layer is not None and layer.in_features != int(num_hidden*(1-pruning_ratio)):
                    adjust_layer_features(layer, pruning_ratio)

    if hasattr(model.ts_model, 'seasonality_blocks'):
        for block in model.ts_model.seasonality_blocks:
            for layer in [block.backcast_fc, block.forecast_fc]:
                if layer is not None and layer.in_features != int(num_hidden*(1-pruning_ratio)):
                    adjust_layer_features(layer, pruning_ratio)

In [None]:
from TRAIL24.models.nn import *

In [None]:
# Example usage:
input_size = 40  # Length of input time series
output_size = 1  # Length of output time series (forecast)
num_blocks = 12
num_hidden = 512
num_layers = 8
embedding_dim = 10
final_hidden = 512

nbeats_params = {
    'input_size': input_size,
    'output_size': output_size,
    'num_blocks': num_blocks,
    'num_hidden': num_hidden,
    'num_layers': num_layers
}

model_cfg = {
    'model_type': 'nbeats', 
    'model_params': nbeats_params, 
    'embedding_dim': 10, 
    'final_hidden': 256
}

In [None]:
net = create_model(**model_cfg)

In [None]:
num_parameters = get_num_parameters(net)
disk_size = get_model_size(net)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")

Model Size: 90.39 MB (disk), 22576523 parameters


In [None]:
batch_size = 5
num_features = 40

features = torch.randn(batch_size, num_features)

In [None]:
prune_model(net, 0.3, features)

torch.Size([5, 40])


In [None]:
num_parameters = get_num_parameters(net)
disk_size = get_model_size(net)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")

Model Size: 44.71 MB (disk), 11157731 parameters
