# Load Model

In [None]:
import os
import sys
sys.path.append("../..")
import torch
from lightning_modules.homography_imitation import ConvHomographyPredictorModule
from utils.io import load_yaml

def heichole(resnet: int=18):
    if resnet == 18:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet18/version_1"
        checkpoint = "checkpoints/epoch=32-step=11715.ckpt"
    elif resnet == 34:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet34/version_3"
        checkpoint = "checkpoints/epoch=48-step=17395.ckpt"
    elif resnet == 50:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet50/version_1"
        checkpoint = "checkpoints/epoch=36-step=13135.ckpt"
    return checkpoint_prefix, checkpoint

checkpoint_prefix, checkpoint = heichole(34)

config = load_yaml(os.path.join(checkpoint_prefix, "config.yml"))

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

module = ConvHomographyPredictorModule.load_from_checkpoint(
    os.path.join(checkpoint_prefix, checkpoint), **config["model"]
)
module.to(device)
module = module.eval()
module.freeze()

# Load Dataset

In [None]:
import os
import sys
sys.path.append("../..")

import cv2
from torch.utils.data import DataLoader
from kornia import tensor_to_image
from kornia.geometry import resize
import numpy as np
import pandas as pd
import tqdm

from datasets import ImageSequenceDataset
from utils.io import load_yaml
from utils.viz import create_blend_from_four_point_homography

server = "local"
server = load_yaml("../../config/servers.yml")[server]
database = server["database"]["location"]

def pickle_path(name: str, window: int):
    prefix = ""
    motion_pickle = ""
    if name == "heichole":
        prefix = "heichole_single_frames_cropped"
        motion_pickle = f"23_03_07_motion_label_window_{window}_frame_increment_5_frames_between_clips_1_log_test_train.pkl"

    return prefix, motion_pickle

prefix, motion_pickle = pickle_path("heichole", 1)
df = pd.read_pickle(os.path.join(database, prefix, motion_pickle))

# get a single video, also check 23_03_01_dataset_sizes.ipynb
test_vid_idcs = df[df.train == False].vid.unique().tolist()
df = df[df.train == False].groupby("vid").get_group(test_vid_idcs[7])

seq_len = 15
frame_increment = 5

ds = ImageSequenceDataset(
    df=df,
    prefix=os.path.join(database, prefix),
    seq_len=seq_len,
    frame_increment=frame_increment,
    frames_between_clips=frame_increment,
)

dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)

preview_horizon = 1
cnt = 0
for batch in tqdm.tqdm(dl):
    imgs, imgs_tf, frame_idcs, vid_idcs = batch
    B, T, C, H, W = imgs.shape
    imgs = imgs.to(device).float() / 255.

    recall_horizon_imgs = imgs[:, :-preview_horizon]
    recall_horizon_imgs = recall_horizon_imgs.reshape(B, -1, H, W)
    recall_horizon_imgs = recall_horizon_imgs

    # inference
    duvs = module(recall_horizon_imgs)

    # visualize
    recall_horizon_imgs = recall_horizon_imgs.reshape(B, -1, C, H, W)
    blends = create_blend_from_four_point_homography(recall_horizon_imgs[:, -1], imgs[:, -1], duvs)
    blend = resize(blends[0], [480, 640]) #zeros batch
    blend = (tensor_to_image(blend, keepdim=False)*255.).astype(np.uint8)

    cv2.imwrite(f"/media/martin/Samsung_T5/23_02_20_miccai_measurements/eval/23_03_08_blends/img_{cnt}.png", blend)
    cnt += 1



# Run Homography Prediction and Safe Images