In [None]:
import os
import torch as tr
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from pathlib import Path

from vre.representations import build_representations_from_cfg
from vre.utils import get_project_root, vre_yaml_load, collage_fn, image_write
from vre_video import VREVideo
from vre_repository import get_vre_repository
from vre_repository.optical_flow.raft import FlowRaft

device = "cuda" if tr.cuda.is_available() else "cpu"


%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
class Ctx:
    def __init__(self, _repr):
        self.repr = _repr
    def __enter__(self):
        self.repr.vre_setup() if self.repr.setup_called is False else None
    def __exit__(self, type, value, traceback):
        self.repr.vre_free()

In [None]:
video = VREVideo(get_project_root() / "resources/test_video.mp4")
print(video.shape, video.fps)
h, w = video.shape[1:3]

In [None]:
os.environ["VRE_DEVICE"] = device = "cuda" if tr.cuda.is_available() else "cpu"
all_representations_dict = vre_yaml_load(Path.cwd() / "cfg.yaml")
device = "cuda" if tr.cuda.is_available() else "cpu"
representations = build_representations_from_cfg(all_representations_dict, representation_types=get_vre_repository())
name_to_repr = {r.name: r for r in representations}
print(representations)

## Run the representations for two particular frame

In [None]:
# inference setup (this is done inside VRE's main loop at run() as well)
depth, normals = name_to_repr["depth_marigold"], name_to_repr["normals_svd(depth_marigold)"]
depth2, normals2 = name_to_repr["depth_dpt"], name_to_repr["normals_svd(depth_dpt)"]

np.random.seed(43)
mb = 1
ixs = sorted([np.random.randint(0, len(video) - 1) for _ in range(mb)])
print(ixs)

with Ctx(depth):
    y_depth_img = depth.make_images(out_depth := depth.resize(depth.compute(video, ixs), (h, w)))
y_normals_img = normals.make_images(out_normals := normals.compute(video, ixs, [out_depth]))
with Ctx(depth2):
    y_depth2_img = depth2.make_images(out_depth := depth2.resize(depth2.compute(video, ixs), (h, w)))
y_normals2_img = normals2.make_images(out_normals := normals2.compute(video, ixs, [out_depth]))

for i in range(mb):
    titles = ["RGB", "Depth Marigold", "Normals SVD (Depth Marigold)", "", "Depth DPT", "Normals SVD (Depth DPT)"]
    imgs = [out_depth.frames[i], y_depth_img[i], y_normals_img[i], y_depth_img[i]*0, y_depth2_img[i], y_normals2_img[i]]
    img = collage_fn(imgs, titles=titles, rows_cols=(2, 3))
    plt.imshow(img)
    plt.show()


### Optical flow +/-1

In [None]:
h, w = video.shape[1:3]
# h, w = [540, 960]
print(h, w)
flow = FlowRaft(name="flow_raft", dependencies=[], inference_width=w, inference_height=h, iters=5, small=False, delta=5)
flow_l = FlowRaft(name="flow_raft", dependencies=[], inference_width=w, inference_height=h, iters=5,
                small=False, delta=-5)
flow.device = flow_l.device = device

# np.random.seed(43)
mb = 2
ixs = sorted([np.random.randint(0, len(video) - 1) for _ in range(mb)])
print(ixs)

with Ctx(flow):
    y_flow = flow.compute(video, ixs)
with Ctx(flow_l):
    y_flow_l = flow_l.compute(video, ixs)
print(y_flow.output.reshape(-1, 2).mean(0) * [h, w], y_flow.output.reshape(-1, 2).std(0))
print(y_flow_l.output.reshape(-1, 2).mean(0)  * [h, w], y_flow_l.output.reshape(-1, 2).std(0))
flow_img = flow.make_images(y_flow)
flow_l_img = flow_l.make_images(y_flow_l)
for i in range(mb):
    fig, ax = plt.subplots(1, 3, figsize=(20, 10))
    ax[0].imshow(video[ixs[i]])
    ax[1].imshow(flow_img[i])
    ax[2].imshow(flow_l_img[i])
plt.show()
