# compression.onnx

> Fill in a module description here

In [None]:
#| default_exp compression.onnx

In [None]:
#| export
import torch
import torch.nn as nn
from onnxruntime.quantization import quantize_dynamic, QuantType

In [None]:
#| export
def script_model(model, dummy_input, path='scripted_model.pt'):
    scripted_model = torch.jit.trace(model, dummy_input)
    scripted_model.save(path)
    return scripted_model

In [None]:
#| export
def quantize_onnx(model, dummy_input, onnx_path="model.onnx", quant_onnx_path="model_quantized.onnx"):
    torch.onnx.export(
    model,              
    dummy_input,        
    onnx_path, 
    input_names=["features", "year", "month", "day", "hour"],   
    output_names=["output"], 
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 
    opset_version=11
    )

    quantize_dynamic(
        onnx_path,
        quant_onnx_path,
        weight_type=QuantType.QUInt8
    )

In [None]:
from TRAIL24.models.nn import *

In [None]:
# Example usage:
input_size = 40  # Length of input time series
output_size = 1  # Length of output time series (forecast)
num_blocks = 12
num_hidden = 512
num_layers = 8
embedding_dim = 10
final_hidden = 512

nbeats_params = {
    'input_size': input_size,
    'output_size': output_size,
    'num_blocks': num_blocks,
    'num_hidden': num_hidden,
    'num_layers': num_layers
}

model_cfg = {
    'model_type': 'nbeats', 
    'model_params': nbeats_params, 
    'embedding_dim': 10, 
    'final_hidden': 256
}

In [None]:
net = create_model(**model_cfg)

In [None]:
batch_size = 5
num_features = 40

features = torch.randn(batch_size, num_features)
month = torch.randint(0, 12, (batch_size, 1))      # Random months between 1 and 12
day = torch.randint(0, 31, (batch_size, 1))        # Random days between 1 and 31
hour = torch.randint(0, 24, (batch_size, 1))       # Random hours between 0 and 23

example_input = features, month, day, hour

In [None]:
scripted_model = script_model(net, example_input)

torch.Size([5, 40])
torch.Size([5, 40])
torch.Size([5, 40])


In [None]:
quantize_onnx(net, example_input)

torch.Size([5, 40])


