In [None]:
import visualize
from run import estimate
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
%matplotlib inline


def compute_mask(fwd_flow, bwd_flow):
    fwd_flow_norm = torch.linalg.vector_norm(fwd_flow, dim=0)
    mask = fwd_flow_norm > 10
    h, w = fwd_flow.shape[1:]
    x = torch.arange(w, device='cuda')
    y = torch.arange(h, device='cuda')
    X, Y = torch.meshgrid(x, y, indexing='xy')
    XY = torch.dstack((X, Y)) + fwd_flow.permute(1, 2, 0)
    XY = XY * 2 / torch.tensor([w - 1, h - 1], device='cuda') - 1
    bwd2fwd_flow = torch.nn.functional.grid_sample(bwd_flow.unsqueeze(0), XY.unsqueeze(0)).squeeze(0)
    mask = torch.logical_and(mask, torch.any(bwd2fwd_flow, dim=0))
    fwd_flow_diff = torch.linalg.vector_norm(fwd_flow + bwd2fwd_flow, dim=0)
    fwd_consist_th = torch.maximum(0.05 * fwd_flow_norm, torch.tensor(1))
    mask = torch.logical_and(mask, fwd_flow_diff < fwd_consist_th)
    return mask


i = 3
j = 23
k = 26
fname = f'/home/slin/data/al/camera_motion/microwave/{i}_{j}_{k}.npz'
# if os.path.isfile(fname):
if False:
    data = np.load(fname)
    pt1 = data['pt1']
    pt2 = data['pt2']
    F = data['F']
else:
    fname1 = f'/home/slin/data/al/articulated_motion/microwave/img{i:02}_{j:04}.png'
    fname2 = f'/home/slin/data/al/articulated_motion/microwave/img{i:02}_{k:04}.png'
    img1 = plt.imread(fname1)[..., 2::-1]  # shape=(1080, 1920, 3), dtype=float32
    img2 = plt.imread(fname2)[..., 2::-1]
    ten1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
    ten2 = torch.from_numpy(img2.copy()).permute(2, 0, 1)
    flow12 = estimate(ten1, ten2)
    flow21 = estimate(ten2, ten1)

    # mask1, pts1, pts2 = compute_mask(flow12, flow21)
    mask1 = compute_mask(flow12, flow21).cpu().numpy()
    mask2 = compute_mask(flow21, flow12).cpu().numpy()
    flow12 = flow12.permute(1, 2, 0).cpu().numpy()
    flow21 = flow21.permute(1, 2, 0).cpu().numpy()
    viz12 = visualize.flow_to_color(flow12)
    viz21 = visualize.flow_to_color(flow21)

    plt.figure(figsize=(10, 10))
    plt.subplot(3, 2, 1)
    plt.imshow(img1)
    plt.axis('off')
    plt.title('img1')
    plt.subplot(3, 2, 2)
    plt.imshow(img2)
    plt.axis('off')
    plt.title('img2')
    plt.subplot(3, 2, 3)
    plt.imshow(viz12)
    plt.axis('off')
    plt.title('flow12')
    plt.subplot(3, 2, 4)
    plt.imshow(viz21)
    plt.axis('off')
    plt.title('flow21')
    plt.subplot(3, 2, 5)
    plt.imshow(mask1)
    plt.axis('off')
    plt.title('mask1')
    plt.subplot(3, 2, 6)
    plt.imshow(mask2)
    plt.axis('off')
    plt.title('mask2')
