In [2]:
import sys
import logging
import os
from pathlib import Path
from pprint import pprint as pp

import torch
import torchvision
import torchvision.transforms as transforms
import torch_tensorrt

machop_path = Path(".").resolve().parent.parent /"machop"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph import (
    save_node_meta_param_interface_pass,
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph

from chop.models import get_model_info, get_model

set_logging_verbosity("info")

# load dataset
batch_size = 8
model_name = "vgg7"
dataset_name = "cifar10"


data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

# load model
CHECKPOINT_PATH = "/home/qizhu/Desktop/Work/mase/mase_output/vgg7_classification_cifar10_2024-02-21/software/training_ckpts/best.ckpt"
model_info = get_model_info(model_name)
model = get_model(
    model_name,
    task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)

model = load_model(load_name=CHECKPOINT_PATH, load_type="pl", model=model)
model.eval()

# create input generator
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# a demonstration of how to feed an input value to the model
dummy_in = next(iter(input_generator))
_ = model(**dummy_in)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)
mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
mg, _ = add_software_metadata_analysis_pass(mg, None)

  from .autonotebook import tqdm as notebook_tqdm
INFO:datasets:PyTorch version 2.0.1+cu118 available.
[32mINFO    [0m [34mSet logging level to info[0m
INFO:chop:Set logging level to info


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /home/qizhu/Desktop/Work/mase/mase_output/vgg7_classification_cifar10_2024-02-21/software/training_ckpts/best.ckpt[0m
INFO:chop.tools.checkpoint_load:Loaded pytorch lightning checkpoint from /home/qizhu/Desktop/Work/mase/mase_output/vgg7_classification_cifar10_2024-02-21/software/training_ckpts/best.ckpt


In [2]:
dummy_in['x'].size()

torch.Size([8, 3, 32, 32])

In [3]:
testing_dataloader = data_module.test_dataloader()

calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(
    testing_dataloader,
    cache_file="./calibration.cache",
    use_cache=False,
    algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
    device=torch.device("cuda:0"),
)

trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((8, 3, 32, 32))],
                                    enabled_precisions={torch.float, torch.half, torch.int8},
                                    calibrator=calibrator,
                                    device={
                                         "device_type": torch_tensorrt.DeviceType.GPU,
                                         "gpu_id": 0,
                                         "dla_core": 0,
                                         "allow_gpu_fallback": False,
                                         "disable_tf32": False
                                     })

# optimized_model = torch.compile(model, backend="torch_tensorrt")


# trt_mod = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((1, 16), dtype=torch.float32)],
#     enabled_precisions = {torch.half}, # Run with FP32
#     workspace_size = 1 << 22
# )




In [82]:
inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 16],
        opt_shape=[50, 16],
        max_shape=[100, 16],
        dtype=torch.half,
    )
]
enabled_precisions = {torch.float, torch.half}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    model, inputs=inputs, enabled_precisions=enabled_precisions
)

input_data = torch.randn(100, 16).to('cuda')
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")



In [None]:
half_model = model.half()
half_model(input_data)

In [10]:
import time

inputs = torch.randn(8, 3, 32, 32).to('cuda')
timings = []

start_time = time.time()
features = model.to('cuda')(inputs)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)


start_time = time.time()
features = trt_mod(inputs)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)

timings


[0.002302885055541992, 0.0005021095275878906]

In [12]:
mg.model = trt_mod

In [3]:
mg.fx_graph

AttributeError: 'MaseGraph' object has no attribute 'onnx_model'

In [19]:
mg.fx_graph

'torch.fx.graph'