# Pose Estimation on Embedded GPUs

In [None]:
from BronchoTrack.BronchoTrack.models.offsetnet import OffsetNet
from torch import onnx as tonnx
import torch
import tensorrt as trt
from gpu.utils import to_GiB, return_pruning_params, DummyDataset, Calibrator
from pytorch_lightning.callbacks import ModelPruning
from torch.nn.utils.prune import is_pruned
import numpy as np

In [3]:
model = OffsetNet()



In [4]:
tonnx.export(model, torch.randn(1, 2, 3, 256, 256),  "broncho.onnx", verbose=True, opset_version=16)

In [None]:
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_LOGGER = trt.Logger(trt.Logger.INFO)

builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(EXPLICIT_BATCH)
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)

In [None]:
with open("broncho.onnx", "rb") as model:
    ok = parser.parse(model.read())

config.max_workspace_size = to_GiB(1)

In [None]:
plan = builder.build_serialized_network(network, config)
with open("broncho.trt", "wb") as f:
    f.write(plan)

## INT8 Quantization

In [None]:
config, network = configure_quantization_and_inputs(config, network, fp16=True, int8=True)

In [None]:
int8_calib_set = DummyDataset()

In [None]:
config.int8_calibrator = Calibrator(
    int8_calib_set, 1
)

In [None]:
plan = builder.build_serialized_network(network, config)
with open("broncho_int8.trt", "wb") as f:
    f.write(plan)

## Pruning

In [None]:
pruner = ModelPruning(
        pruning_fn="ln_structured",
        parameters_to_prune=return_pruning_params(model),
        amount=0.3,
        use_global_unstructured=False,
        pruning_norm=1,
        pruning_dim=0,
        parameter_names=['weight'],
        use_lottery_ticket_hypothesis=False,
        prune_on_train_epoch_end=True,
        make_pruning_permanent=True,
        verbose=1
    )
pruner.apply_pruning(0.3)

In [None]:
pruner.apply_pruning(0.3)
print("Pruning has been applied as pre-hooks. The network appear as pruned -> Pruned?", is_pruned(model))
pruner.make_pruning_permanent(model)
print("Now prune hooks are deleted, then the network appears as unpruned -> Pruned?", is_pruned(model))