In [1]:
import torch
import torch.nn.functional as F
from torchvision.io.video import read_video

occ_mask_thres = 0.05

def flow_warp(x,
              flow,
              interpolation='bilinear',
              padding_mode='zeros',
              align_corners=True):
    """Warp an image or a feature map with optical flow.
    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
            a two-channel, denoting the width and height relative offsets.
            Note that the values are not normalized to [-1, 1].
        interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
            Default: 'bilinear'.
        padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Whether align corners. Default: True.
    Returns:
        Tensor: Warped image or feature map.
    """
    if x.size()[-2:] != flow.size()[1:3]:
        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                         f'flow ({flow.size()[1:3]}) are not the same.')
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
    grid = torch.stack((grid_x, grid_y), 2).type_as(x)
    grid.requires_grad = False

    grid_flow = grid + flow
    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
    output = F.grid_sample(
        x,
        grid_flow,
        mode=interpolation,
        padding_mode=padding_mode,
        align_corners=align_corners)

    return output

def get_occlusion_mask(flow, threshold=1.5):
    # mean = torch.mean(flow, dim=(2,3), keepdim=True)
    # std = torch.std(flow, dim=(2,3), keepdim=True)
    # flow (n, h, w, 2)
    mask = torch.mean(flow.abs(), dim=-1, keepdim=True) - threshold
    mask = torch.clamp(mask, 0, 1)
    mask = torch.sign(mask) # (n, h, w, 1)
    mask = mask.permute(0, 3, 1, 2)
    return mask

def load_flow_tensor(path):
    flow = torch.load(path)
    flow = flow.permute(0,2,3,1)
    return flow

def pair_error(occ_mask_frm, frm1, frm2, flow_frm):
    flow_warp_frm = flow_warp(frm2, flow_frm)
    # print(occ_mask_frm.mean(), occ_mask_frm.max(), occ_mask_frm.min())
    # assert 1==2
    # error = occ_mask_frm * torch.abs((frm1 - flow_warp_frm)) #F.mse_loss(frm1, flow_warp_frm)
    # error = F.mse_loss(occ_mask_frm*frm1, occ_mask_frm *flow_warp_frm) #, reduction='sum')
    error = torch.sum((occ_mask_frm*frm1 - occ_mask_frm *flow_warp_frm)**2)
    # error_sum = error.sum(dim=(1,2,3))
    error = error/torch.sum(occ_mask_frm, dim=(1,2,3))
    return error

def warping_error_frames(frames, flow):
    frm1 = frames[:1]
    occ_masks = get_occlusion_mask(flow)
    errors = []
    pre_frm = frm1
    for i in range(1, len(frames)):
        frm = frames[i:i+1]
        occ_mask = occ_masks[i-1:i]
        flow_frm = flow[i-1:i]
        frm_error1 = pair_error(occ_mask, frm, frm1, flow_frm)
        frm_error2 = pair_error(occ_mask, frm, pre_frm, flow_frm)
        # print('i=',i,'error1: ', frm_error1, 'error2: ', frm_error2)
        errors.append(frm_error1+frm_error2)
        pre_frm = frm
    errors = torch.cat(errors, dim=0)
    return errors

def avg_warping_errors(errors):
    m = errors.isnan()==False
    errors = errors.nan_to_num()
    avg = errors.sum()/m.sum()
    return avg.item()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
back_flow_stem = load_flow_tensor('./out_flow/woman_running2_flow_stem_512_back.pt')
back_flow_glfm = load_flow_tensor('./out_flow/woman_running2_flow_glfm_1.0_1.0_64_448_back.pt')
back_flow_glfm2 = load_flow_tensor('./out_flow/woman_running2_flow_glfm_0.6_1.0_64_448_back.pt')
back_flow_glfm3 = load_flow_tensor('./out_flow/woman_running2_flow_glfm_0.6_1.0_128_384_back.pt')

In [5]:
video_glfm, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k1_64_k2_448_bs100_400frms_all_scale_1.0_type7.mp4', output_format='TCHW')
video_stem, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k512_bs100_400frms_type0.mp4', output_format='TCHW')
video_glfm2, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k1_64_k2_448_bs100_400frms_all_scale_0.6_type7.mp4', output_format='TCHW')
video_glfm3, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k1_128_k2_384_bs100_400frms_all_scale_0.6_type7.mp4', output_format='TCHW')

video_glfm = video_glfm.float()/255.
video_glfm2 = video_glfm2.float()/255.
video_glfm2 = video_glfm3.float()/255.
video_stem = video_stem.float()/255.



In [6]:
stem_warping_errors = warping_error_frames(video_stem, back_flow_stem)
glfm_warping_errors = warping_error_frames(video_glfm, back_flow_glfm)
glfm_warping_errors2 = warping_error_frames(video_glfm2, back_flow_glfm2)
glfm_warping_errors3 = warping_error_frames(video_glfm2, back_flow_glfm3)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [7]:
stem_warping_errors = stem_warping_errors.reshape(len(stem_warping_errors))
glfm_warping_errors = glfm_warping_errors.reshape(len(glfm_warping_errors))
glfm_warping_errors2 = glfm_warping_errors2.reshape(len(glfm_warping_errors2))
glfm_warping_errors3 = glfm_warping_errors3.reshape(len(glfm_warping_errors3))

In [11]:
print('glfm << k1=64,k2=448,α=1.0,β=1.0, average warping errors= ', avg_warping_errors(glfm_warping_errors))
print('glfm << k1=64,k2=448,α=0.6,β=1.0, average warping errors= ', avg_warping_errors(glfm_warping_errors2))
print('glfm << k1=128,k2=384,α=0.6,β=1.0, average warping errors= ', avg_warping_errors(glfm_warping_errors3))
print('stem << k=512, average warping errors= ', avg_warping_errors(stem_warping_errors))

del back_flow_stem
del back_flow_glfm
del back_flow_glfm2
del back_flow_glfm3

glfm << k1=64,k2=448,α=1.0,β=1.0, average warping errors=  0.4913294017314911
glfm << k1=64,k2=448,α=0.6,β=1.0, average warping errors=  0.4922697842121124
glfm << k1=128,k2=384,α=0.6,β=1.0, average warping errors=  0.48882412910461426
stem << k=512, average warping errors=  0.5161299109458923


In [3]:
back_flow_glfm5 = load_flow_tensor('./out_flow/woman_running2_flow_glfm_0.6_1.0_64_192_back.pt')
back_flow_glfm6 = load_flow_tensor('./out_flow/woman_running2_flow_glfm_1.0_0.6_64_448_back.pt')

video_glfm5, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k1_64_k2_192_bs100_400frms_all_scale_0.6_type7.mp4', output_format='TCHW')
video_glfm6, _, _ = read_video('./src_videos/tokenflow_PnP_fps_10_k1_64_scale_1.0_k2_448_scale_0.6_bs100_400frms_type7.mp4', output_format='TCHW')
video_glfm5 = video_glfm5.float()/255.
video_glfm6 = video_glfm6.float()/255.

glfm_warping_errors5 = warping_error_frames(video_glfm5, back_flow_glfm5)
glfm_warping_errors6 = warping_error_frames(video_glfm6, back_flow_glfm6)
del back_flow_glfm5
del back_flow_glfm6
glfm_warping_errors5.reshape(len(glfm_warping_errors5))
glfm_warping_errors6.reshape(len(glfm_warping_errors6))

print('glfm << k1=64,k2=192,α=0.6,β=1.0, average warping errors= ', avg_warping_errors(glfm_warping_errors5))
print('glfm << k1=64,k2=448,α=1.0,β=0.6, average warping errors= ', avg_warping_errors(glfm_warping_errors6))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


glfm << k1=64,k2=192,α=0.6,β=1.0, average warping errors=  0.5476954579353333
glfm << k1=64,k2=448,α=1.0,β=0.6, average warping errors=  0.6425002813339233
