# Load Homography Predictor

In [None]:
import os
import sys
sys.path.append("../..")
import torch
from lightning_modules.homography_imitation import ConvHomographyPredictorModule
from utils.io import load_yaml

def heichole(resnet: int=18):
    if resnet == 18:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet18/version_1"
        checkpoint = "checkpoints/epoch=32-step=11715.ckpt"
    elif resnet == 34:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet34/version_3"
        checkpoint = "checkpoints/epoch=48-step=17395.ckpt"
    elif resnet == 50:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/heichole/resnet50/version_1"
        checkpoint = "checkpoints/epoch=36-step=13135.ckpt"
    return checkpoint_prefix, checkpoint

def cholec80(resnet: int=18):
    if resnet == 18:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/cholec80/resnet18/version_2"
        checkpoint = "checkpoints/epoch=39-step=78400.ckpt"
    elif resnet == 34:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/cholec80/resnet34/version_2"
        checkpoint = "checkpoints/epoch=40-step=80360.ckpt"
    elif resnet == 50:
        checkpoint_prefix = "/media/martin/Samsung_T5/logs/miccai/final/cholec80/resnet50/version_2"
        checkpoint = "checkpoints/epoch=52-step=103880.ckpt"
    return checkpoint_prefix, checkpoint


checkpoint_prefix_predictor, checkpoint_predictor = cholec80(50)

predictor_config = load_yaml(os.path.join(checkpoint_prefix_predictor, "config.yml"))

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

predictor = ConvHomographyPredictorModule.load_from_checkpoint(
    os.path.join(checkpoint_prefix_predictor, checkpoint_predictor), **predictor_config["model"]
)
predictor.to(device)
predictor = predictor.eval()
predictor.freeze()

# Load Homography Estimator and Taylor

In [None]:
from utils.processing import TaylorHomographyPrediction
from lightning_modules.homography_regression import DeepImageHomographyEstimationModuleBackbone

def ae_cai():
    checkpoint_prefix = "/media/martin/Samsung_T5/logs/ae_cai/resnet/48/25/34/version_0"
    checkpoint = "checkpoints/epoch=99-step=47199.ckpt"
    return checkpoint_prefix, checkpoint


checkpoint_prefix_estimator, checkpoint_estimator = ae_cai()

estimator_config = load_yaml(os.path.join(checkpoint_prefix_estimator, "config.yml"))

estimator = DeepImageHomographyEstimationModuleBackbone.load_from_checkpoint(
    os.path.join(checkpoint_prefix_estimator, checkpoint_estimator), **estimator_config["model"]
)
estimator.to(device)
estimator = estimator.eval()
estimator.freeze()

taylor_predictor = TaylorHomographyPrediction(order=1)

# Load Dataset

In [None]:
import os
import sys
sys.path.append("../..")

import cv2
from torch.utils.data import DataLoader
from kornia import tensor_to_image
from kornia.geometry import resize
import numpy as np
import pandas as pd
import tqdm

from datasets import ImageSequenceDataset
from utils.processing import frame_pairs
from utils.io import load_yaml
from utils.viz import create_blend_from_four_point_homography

server = "local"
server = load_yaml("../../config/servers.yml")[server]
database = server["database"]["location"]

def pickle_path(name: str, window: int):
    prefix = ""
    motion_pickle = ""
    if name == "heichole":
        prefix = "heichole_single_frames_cropped"
        motion_pickle = f"23_03_07_motion_label_window_{window}_frame_increment_5_frames_between_clips_1_log_test_train.pkl"

    return prefix, motion_pickle

prefix, motion_pickle = pickle_path("heichole", 1)
df = pd.read_pickle(os.path.join(database, prefix, motion_pickle))

# get a single video, also check 23_03_01_dataset_sizes.ipynb
test_vid_idcs = df[df.train == False].vid.unique().tolist()
df = df[df.train == False].groupby("vid").get_group(test_vid_idcs[2])[20000:30000]
# print(df)

seq_len = 15
frame_increment = 5

ds = ImageSequenceDataset(
    df=df,
    prefix=os.path.join(database, prefix),
    seq_len=seq_len,
    frame_increment=frame_increment,
    frames_between_clips=frame_increment,
)

dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)

preview_horizon = 1
cnt = 0
for batch in tqdm.tqdm(dl):
    imgs, imgs_tf, frame_idcs, vid_idcs = batch
    B, T, C, H, W = imgs.shape
    imgs = imgs.to(device).float() / 255.

    recall_horizon_imgs = imgs[:, :-preview_horizon]
    recall_horizon_imgs = recall_horizon_imgs.reshape(B, -1, H, W)
    recall_horizon_imgs = recall_horizon_imgs

    # predict camera motion
    with torch.no_grad():
        duvs_predicted = predictor(recall_horizon_imgs)

    # estimate camera motion for taylor
    imgs_i, imgs_ip1 = frame_pairs(recall_horizon_imgs.view(B, -1, C, H, W))
    with torch.no_grad():
        duvs_estimated = estimator(imgs_i.view(-1, C, H, W), imgs_ip1.view(-1, C, H, W))
        duvs_estimated = duvs_estimated.view(B, T-1-preview_horizon, 4, 2)

    # taylor predict camera motion
    duvs_taylor_predicted = taylor_predictor(duvs_estimated.cpu())
    duvs_taylor_predicted = duvs_taylor_predicted[:, -preview_horizon:].to(device)

    # visualize predictor
    recall_horizon_imgs = recall_horizon_imgs.reshape(B, -1, C, H, W)
    blends = create_blend_from_four_point_homography(recall_horizon_imgs[:, -1], imgs[:, -1], duvs_predicted)
    blend = resize(blends[0], [480, 640]) #zeros batch
    blend = (tensor_to_image(blend, keepdim=False)*255.).astype(np.uint8)
    cv2.imwrite(f"/media/martin/Samsung_T5/23_02_20_miccai_measurements/eval/23_03_08_blends/deep/img_{cnt}.png", blend)

    # visualize taylor 1st order
    blends = create_blend_from_four_point_homography(recall_horizon_imgs[:, -1], imgs[:, -1], duvs_taylor_predicted[0])
    blend = resize(blends[0], [480, 640]) #zeros batch
    blend = (tensor_to_image(blend, keepdim=False)*255.).astype(np.uint8)
    cv2.imwrite(f"/media/martin/Samsung_T5/23_02_20_miccai_measurements/eval/23_03_08_blends/taylor/img_{cnt}.png", blend)

    # # visualize identity
    # blends = create_blend_from_four_point_homography(recall_horizon_imgs[:, -1], imgs[:, -1], torch.zeros_like(duvs_predicted))
    # blend = resize(blends[0], [480, 640]) #zeros batch
    # blend = (tensor_to_image(blend, keepdim=False)*255.).astype(np.uint8)

    # cv2.imwrite(f"/media/martin/Samsung_T5/23_02_20_miccai_measurements/eval/23_03_08_blends/identity/img_{cnt}.png", blend)
    
    cnt += 1
    # break



# Run Homography Prediction and Safe Images