# Export Models

In [None]:
import sys

sys.path.append("/home/martin/Dev/homography_imitation_learning")

import torch
import os
import importlib
from typing import List
from utils import load_yaml, generate_path


def export_pipeline(
    devices: List[str],
    name: str,
    package: str,
    prefix: str,
    checkpoint_prefix: str,
    checkpoint: str,
    output_path: str,
    output_name: str,
    example_inputs: List[torch.Tensor],
) -> None:
    generate_path(output_path)

    config = load_yaml(os.path.join(prefix, checkpoint_prefix, "config.yml"))
    model = getattr(importlib.import_module(package), name).load_from_checkpoint(
        os.path.join(prefix, checkpoint_prefix, checkpoint), **config["model"]
    )

    for device in devices:
        model = model.to(device)
        model.freeze()

        script = model.to_torchscript(method="trace", example_inputs=example_inputs)
        script = torch.jit.freeze(script)
        torch.jit.save(script, f"{output_path}/{output_name}_{device}.pt")

## Export Bounding Circle Detection

### Big Model

In [None]:
export_pipeline(
    devices=["cpu", "cuda"],
    name="ImageSegmentationModule",
    package="lightning_modules",
    prefix="/home/martin/Tresors/homography_imitation_learning_logs/boundary_image_segmentation/unet/resnet/34",
    checkpoint_prefix="version_4",
    checkpoint="checkpoints/epoch=288-step=2311.ckpt",
    output_path="/tmp/models",
    output_name="seg_unet_resnet_34",
    example_inputs=[torch.rand([1, 3, 240, 320])],
)

### Tiny Model

In [None]:
export_pipeline(
    devices=["cpu", "cuda"],
    name="ImageSegmentationModule",
    package="lightning_modules",
    prefix="/home/martin/Tresors/homography_imitation_learning_logs/boundary_image_segmentation/unet/resnet/34/tiny",
    checkpoint_prefix="version_5",
    checkpoint="checkpoints/epoch=374-step=1499.ckpt",
    output_path="/tmp/models",
    output_name="seg_unet_resnet_34_tiny",
    example_inputs=[torch.rand([1, 3, 240, 320])],
)

## Export Deep Homography Estimation

### 48 Pixel Augmentation

In [None]:
export_pipeline(
    devices=["cpu", "cuda"],
    name="DeepImageHomographyEstimationModuleBackbone",
    package="lightning_modules",
    prefix="/home/martin/Tresors/homography_imitation_learning_logs/ae_cai/resnet/48/25/34",
    checkpoint_prefix="version_0",
    checkpoint="checkpoints/epoch=99-step=47199.ckpt",
    output_path="/tmp/models",
    output_name="h_est_48_resnet_34",
    example_inputs=[torch.rand([1, 3, 240, 320]), torch.rand([1, 3, 240, 320])],
)

### 64 Pixel Augmentation

In [None]:
export_pipeline(
    devices=["cpu", "cuda"],
    name="DeepImageHomographyEstimationModuleBackbone",
    package="lightning_modules",
    prefix="/home/martin/Tresors/homography_imitation_learning_logs/ae_cai/resnet/64/25/34",
    checkpoint_prefix="version_0",
    checkpoint="checkpoints/epoch=99-step=47199.ckpt",
    output_path="/tmp/models",
    output_name="h_est_64_resnet_34",
    example_inputs=[torch.rand([1, 3, 240, 320]), torch.rand([1, 3, 240, 320])],
)

## Export Homography Imitation

### Incremental Feature LSTM

In [None]:
prefix = "/home/martin/Tresors/homography_imitation_learning_logs/miccai/feature_lstm/phantom/resnet34/pairwise_distance"
checkpoint_prefix = "version_0"
checkpoint = "checkpoints/epoch=25-step=104.ckpt"

config = load_yaml(os.path.join(prefix, checkpoint_prefix, "config.yml"))
hidden_features = config["model"]["lstm"]["kwargs"]["hidden_size"]

export_pipeline(
    devices=["cpu", "cuda"],
    name="FeatureLSTMIncrementalModule",
    package="lightning_modules",
    prefix=prefix,
    checkpoint_prefix=checkpoint_prefix,
    checkpoint=checkpoint,
    output_path="/tmp/models",
    output_name=f"h_pred_{hidden_features}_feature_lstm_incremental",
    example_inputs=[
        torch.rand([1, 1, 3, 240, 320]),
        torch.rand([1, 1, 4, 2]),
        torch.rand([1, 1, 4, 2]),
        (torch.zeros([1, 1, hidden_features]), torch.zeros(1, 1, hidden_features)),
    ],
)

### Feature LSTM

In [None]:
prefix = "/home/martin/Tresors/homography_imitation_learning_logs/miccai/feature_lstm/cholec80/resnet34/no_motion_prior/pairwise_distance"
checkpoint_prefix = "version_0"
checkpoint = "checkpoints/epoch=41-step=12264.ckpt"

config = load_yaml(os.path.join(prefix, checkpoint_prefix, "config.yml"))
hidden_features = config["model"]["lstm"]["kwargs"]["hidden_size"]

export_pipeline(
    devices=["cpu", "cuda"],
    name="FeatureLSTMModule",
    package="lightning_modules",
    prefix=prefix,
    checkpoint_prefix=checkpoint_prefix,
    checkpoint=checkpoint,
    output_path="/tmp/models",
    output_name=f"h_pred_{hidden_features}_feature_lstm",
    example_inputs=[
        torch.rand([1, 1, 3, 240, 320]),
        (torch.zeros([1, 1, hidden_features]), torch.zeros(1, 1, hidden_features)),
    ],
)

### ResNet 34

In [None]:
prefix = "/home/martin/Tresors/homography_imitation_learning_logs/miccai/conv_homography_predictor/phantom/resnet34"
checkpoint_prefix = "version_0"
checkpoint = "checkpoints/epoch=497-step=3486.ckpt"

config = load_yaml(os.path.join(prefix, checkpoint_prefix, "config.yml"))
in_channels = config["model"]["predictor"]["kwargs"]["in_channels"]

export_pipeline(
    devices=["cpu", "cuda"],
    name="ConvHomographyPredictorModule",
    package="lightning_modules",
    prefix=prefix,
    checkpoint_prefix=checkpoint_prefix,
    checkpoint=checkpoint,
    output_path="/tmp/models",
    output_name=f"h_pred_resnet_34_in_channels_{in_channels}",
    example_inputs=[
        torch.rand([1, in_channels, 240, 320]),
    ],
)

### ResNet 50

In [None]:
prefix = "/home/martin/Tresors/homography_imitation_learning_logs/miccai/conv_homography_predictor/phantom/resnet50"
checkpoint_prefix = "version_0"
checkpoint = "checkpoints/epoch=492-step=3451.ckpt"

config = load_yaml(os.path.join(prefix, checkpoint_prefix, "config.yml"))
in_channels = config["model"]["predictor"]["kwargs"]["in_channels"]

export_pipeline(
    devices=["cpu", "cuda"],
    name="ConvHomographyPredictorModule",
    package="lightning_modules",
    prefix=prefix,
    checkpoint_prefix=checkpoint_prefix,
    checkpoint=checkpoint,
    output_path="/tmp/models",
    output_name=f"h_pred_resnet_50_in_channels_{in_channels}",
    example_inputs=[
        torch.rand([1, in_channels, 240, 320]),
    ],
)