In [None]:
import os
os.environ["VRE_COLORIZE_SEMSEG_FAST"] = "1"
import torch as tr
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from pathlib import Path
from torch.nn import functional as F

from vre.utils import get_project_root, lo, colorize_semantic_segmentation, image_resize_batch
from vre import FFmpegVideo
from vre.representations import Representation, IORepresentationMixin
from vre_repository.optical_flow.raft import FlowRaft
from vre_repository.semantic_segmentation.safeuav import SafeUAV

device = "cpu"#"cuda" if tr.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
def warp_torch(image: tr.Tensor, flow: tr.Tensor) -> tr.Tensor:
    H, W = image.shape[-2:]
    # Create normalized meshgrid [-1,1] for grid_sample
    lsw, lsh =  tr.linspace(-1, 1, W, device=image.device), tr.linspace(-1, 1, H, device=image.device),
    grid_x, grid_y = tr.meshgrid(lsw, lsh, indexing="xy")
    grid = tr.stack((grid_x, grid_y), dim=-1)  # (H, W, 2), normalized [-1, 1]
    # Apply flow directly (since it's already in [-1, 1] range)
    new_grid = grid - flow
    # Warp image using grid_sample
    warped = F.grid_sample(image, new_grid, mode="bilinear", align_corners=True)
    return warped

def warp_image(rgb_t: np.ndarray, flow: np.ndarray) -> np.ndarray:
    """
    rgb_t :: (B, H, W, 3) uint8 [0:255]
    flow :: (B, H, W, 2) float32 [-1:-1]
    """
    image = (tr.tensor(rgb_t).permute(0, 3, 1, 2).float() / 255).to(device)
    flow = tr.tensor(flow).float().to(device)
    warped = warp_torch(image, flow)
    warped_numpy = (warped.permute(0, 2, 3, 1) * 255).cpu().numpy().astype(np.uint8)
    return warped_numpy

def mm(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
video = FFmpegVideo(get_project_root() / "resources/test_video.mp4")
print(video)

In [None]:
h, w = video.shape[1:3]
safeuav = SafeUAV(name="safeuav", dependencies=[], disk_data_argmax=True, variant="model_4M")
raft_r = FlowRaft(name="flow_raft", dependencies=[], inference_width=w, inference_height=h, iters=5,
                  small=False, delta=1)
raft_l = FlowRaft(name="flow_raft", dependencies=[], inference_width=w, inference_height=h, iters=5,
                  small=False, delta=-1)
raft_r.device = raft_l.device = safeuav.device = device
raft_r.vre_setup() if raft_r.setup_called is False else None
raft_l.vre_setup() if raft_l.setup_called is False else None
safeuav.vre_setup() if safeuav.setup_called is False else None


In [None]:
mb = 1
delta = 10
raft_r.delta = delta
raft_l.delta = -delta

ixs = sorted([np.random.randint(delta, len(video) - delta - 1) for _ in range(mb)])
# ixs = [4000]
ixs_l = [ix + raft_l.delta for ix in ixs]
ixs_r = [ix + raft_r.delta for ix in ixs]
print(f"{ixs=}, {ixs_l=}, {ixs_r=}")

rgbs, rgbs_l, rgbs_r = video[ixs], video[ixs_l], video[ixs_r] # (B, H, W, 3)

flow_l_out = raft_l.resize(raft_l.compute(video, ixs_l), video.shape[1:3]) # (B, H, W, 2)
flow_l_img = raft_l.make_images(flow_l_out) # (B, H, W, 3)
flow_r_out = raft_r.resize(raft_r.compute(video, ixs_r), video.shape[1:3]) # (B, H, W, 2)
flow_r_img = raft_r.make_images(flow_r_out) # (B, H, W, 3)

rgb_warp_l = warp_image(rgbs, flow_l_out.output) # (B, H, W, 3)
rgb_warp_r = warp_image(rgbs, flow_r_out.output) # (B, H, W, 3)
mask_l = rgb_warp_l.sum(-1, keepdims=True) != 0 # (B, H, W, 1)
mask_r = rgb_warp_r.sum(-1, keepdims=True) != 0 # (B, H, W, 1)

diffs_l = ((rgbs_l.astype(np.float32) - rgb_warp_l).__abs__() * mask_l).sum(-1) # (B, H, W)
diffs_r = ((rgbs_r.astype(np.float32) - rgb_warp_r).__abs__() * mask_r).sum(-1) # (B, H, W)

sema_out = safeuav.resize(safeuav.compute(video, ixs), video.shape[1:3]) # (B, H, W, C)
sema_img = safeuav.make_images(sema_out) # (B, H, W, 3)
sema_out_l = safeuav.resize(safeuav.compute(video, ixs_l), video.shape[1:3]) # (B, H, W, C)
sema_img_l = safeuav.make_images(sema_out_l) # (B, H, W, 3)
sema_out_r = safeuav.resize(safeuav.compute(video, ixs_r), video.shape[1:3]) # (B, H, W, C)
sema_img_r = safeuav.make_images(sema_out_r) # (B, H, W, 3)

sema_warp_l = warp_torch(tr.from_numpy(sema_out.output).permute(0, 3, 1, 2), tr.from_numpy(flow_l_out.output)).permute(0, 2, 3, 1).numpy().argmax(-1) # (B, H, W) argmax
sema_warp_r = warp_torch(tr.from_numpy(sema_out.output).permute(0, 3, 1, 2), tr.from_numpy(flow_r_out.output)).permute(0, 2, 3, 1).numpy().argmax(-1) # (B, H, W) argmax
sema_warp_l_img = colorize_semantic_segmentation(sema_warp_l, safeuav.classes, safeuav.color_map) * mask_l # (B, H, W, 3)
sema_warp_r_img = colorize_semantic_segmentation(sema_warp_r, safeuav.classes, safeuav.color_map) * mask_r # (B, H, W, 3)
red_cm = np.array([[0, 0, 0], [255, 0, 0]])
black = rgbs[0] * 0
diff_sema_l = ((sema_out_l.output.argmax(-1) != sema_warp_l) * mask_l[..., 0]).astype(int) # (B, H, W)
diff_sema_r = ((sema_out_r.output.argmax(-1) != sema_warp_r) * mask_r[..., 0]).astype(int) # (B, H, W)
score = 1 - (diff_sema_l + diff_sema_r) / 2 # (B, H, W)

for i in range(mb):
    fig, ax = plt.subplots(3, 6, figsize=(20, 8))
    fig.suptitle(f"Consistency score: {score[i].mean() * 100:.2f}%", fontsize=14, fontweight="bold")
    ax[0, 0].set_title(f"T={ixs_l[i]}", fontsize=14, fontweight="bold")
    ax[0, 0].imshow(rgbs_l[i])
    ax[0, 1].set_title(f"warp {ixs[i]}->{ixs_l[i]}", fontsize=14, fontweight="bold")
    ax[0, 1].imshow(rgb_warp_l[i].round().astype(np.uint8))
    ax[0, 2].set_title(f"Diff: {diffs_l.mean().item():.2f}", fontsize=14, fontweight="bold")
    ax[0, 2].imshow(mm(diffs_l[i]))
    ax[0, 3].set_title(f"Sema T={ixs_l[i]}", fontsize=14, fontweight="bold")
    ax[0, 3].imshow(sema_img_l[i])
    ax[0, 4].set_title(f"warp {ixs[i]}->{ixs_l[i]}", fontsize=14, fontweight="bold")
    ax[0, 4].imshow(sema_warp_l_img[i])
    ax[0, 5].set_title(f"Diff: {diff_sema_l[i].mean()*100:.2f}%", fontsize=14, fontweight="bold")
    ax[0, 5].imshow(red_cm[diff_sema_l[i]])

    ax[1, 0].set_title(f"T={ixs[i]}", fontsize=14, fontweight="bold")
    ax[1, 0].imshow(video[ixs[i]])
    ax[1, 1].set_title(f"flow {ixs[i]}->{ixs_l[i]}", fontsize=14, fontweight="bold")
    ax[1, 1].imshow(flow_l_img[i])
    ax[1, 2].set_title(f"flow {ixs[i]}->{ixs_r[i]}", fontsize=14, fontweight="bold")
    ax[1, 2].imshow(flow_r_img[i])
    ax[1, 3].imshow(black)
    ax[1, 4].set_title(f"sema {ixs[i]}", fontsize=14, fontweight="bold")
    ax[1, 4].imshow(sema_img[i])
    ax[1, 5].imshow(black)

    ax[2, 0].set_title(f"T={ixs_r[i]}", fontsize=14, fontweight="bold")
    ax[2, 0].imshow(rgbs_r[i])
    ax[2, 1].set_title(f"warp {ixs[i]}->{ixs_r[i]}", fontsize=14, fontweight="bold")
    ax[2, 1].imshow(rgb_warp_r[i].round().astype(np.uint8))
    ax[2, 2].set_title(f"Diff: {diffs_r.mean().item():.2f}", fontsize=14, fontweight="bold")
    ax[2, 2].imshow(mm(diffs_r[i]))
    ax[2, 3].set_title(f"Sema T={ixs_r[i]}", fontsize=14, fontweight="bold")
    ax[2, 3].imshow(sema_img_r[i])
    ax[2, 4].set_title(f"warp {ixs[i]}->{ixs_r[i]}", fontsize=14, fontweight="bold")
    ax[2, 4].imshow(sema_warp_r_img[i])
    ax[2, 5].set_title(f"Diff: {diff_sema_r[i].mean() * 100:.2f}%", fontsize=14, fontweight="bold")
    ax[2, 5].imshow(red_cm[diff_sema_r[i]])

fig.tight_layout()
plt.show()