In [7]:
import numpy as np
import torch
import softsplat
import moviepy.editor
import cv2

def _flow_16bit_to_float(flow_16bit: np.ndarray):
    assert flow_16bit.dtype == np.uint16, flow_16bit.dtype
    assert flow_16bit.ndim == 3
    h, w, c = flow_16bit.shape
    assert c == 3
    # BGR 转 RGB
    flow_16bit = flow_16bit[..., ::-1]

    valid2D = flow_16bit[..., 2] == 1
    assert valid2D.shape == (h, w)

    assert np.all(flow_16bit[~valid2D, -1] == 0)
    valid_map = np.where(valid2D)
    flow_16bit = flow_16bit.astype('float32')
    flow_map = np.zeros((h, w, 2), dtype='float32')
    flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
    flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
    return flow_map, valid2D
##########################################################

def read_flo(strFile):
    flow_16bit = cv2.imread(str(strFile), cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
    strFlow, valid2D = _flow_16bit_to_float(flow_16bit)
    # assert(np.frombuffer(buffer=strFlow, dtype=np.float32, count=1, offset=0) == 202021.25)

    # intWidth = np.frombuffer(buffer=strFlow, dtype=np.int32, count=1, offset=4)[0]
    # intHeight = np.frombuffer(buffer=strFlow, dtype=np.int32, count=1, offset=8)[0]

    # return np.frombuffer(buffer=strFlow, dtype=np.float32, count=intHeight * intWidth * 2, offset=12).reshape(intHeight, intWidth, 2)
    return strFlow
# end

##########################################################

backwarp_tenGrid = {}

def backwarp(tenIn, tenFlow):
    if str(tenFlow.shape) not in backwarp_tenGrid:
        tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
        tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])

        backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda()
    # end

    tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1)

    return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)
# end

##########################################################

filename_first = '/home/xiaoshan/work/adap_v/my_proj/DSEC_remapped_images/train/zurich_city_01_a/images/left/distorted/000134.png'
filename_second = '/home/xiaoshan/work/adap_v/my_proj/DSEC_remapped_images/train/zurich_city_01_a/images/left/distorted/000135.png'
filename_flow = '/home/xiaoshan/work/adap_v/my_proj/DSEC/train/zurich_city_01_a/flow/forward/000134.png'

# 1 3 H W
tenFirst = torch.FloatTensor(np.ascontiguousarray(cv2.imread(filename=filename_first, flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(np.float32) * (1.0 / 255.0))).cuda()
tenSecond = torch.FloatTensor(np.ascontiguousarray(cv2.imread(filename=filename_second, flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(np.float32) * (1.0 / 255.0))).cuda()

# 1 2 H W
tenFlow = torch.FloatTensor(np.ascontiguousarray(read_flo(filename_flow).transpose(2, 0, 1)[None, :, :, :])).cuda()
# 1 1 H W
tenMetric_L1 = torch.nn.functional.l1_loss(input=tenFirst, target=backwarp(tenIn=tenSecond, tenFlow=tenFlow), reduction='none').mean(1, True)
tenMetric_flow_mag = torch.sqrt(torch.square(tenFlow[:, 0, :, :] + tenFlow[:, 1, :, :])).unsqueeze(1)

tenOutputs_L1 = [softsplat.softsplat(tenIn=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=-20.0 * tenMetric_L1, strMode='soft') for fltTime in np.linspace(0.0, 1.0, 11).tolist()]
npyOutputs_L1 = [(tenOutput_L1[0, :, :, :].cpu().numpy().transpose(1, 2, 0) * 255.0).clip(0.0, 255.0).astype(np.uint8) for tenOutput_L1 in tenOutputs_L1 + list(reversed(tenOutputs_L1[1:-1]))]

tenOutputs_flow_mag = [softsplat.softsplat(tenIn=tenFirst, tenFlow=tenFlow * fltTime, tenMetric=-tenMetric_flow_mag, strMode='soft') for fltTime in np.linspace(0.0, 1.0, 11).tolist()]
npyOutputs_flow_mag = [(tenOutput_flow_mag[0, :, :, :].cpu().numpy().transpose(1, 2, 0) * 255.0).clip(0.0, 255.0).astype(np.uint8) for tenOutput_flow_mag in tenOutputs_flow_mag + list(reversed(tenOutputs_flow_mag[1:-1]))]
# 两个video 串行展示
preds = [np.concatenate([npyOutputs_L1[i][:, :, ::-1], npyOutputs_flow_mag[i][:, :, ::-1]], axis=1) for i in range(len(npyOutputs_L1))]
video = moviepy.editor.ImageSequenceClip(sequence=preds, fps=15)
video.write_gif('./out.gif')
# end



MoviePy - Building file ./out.gif with imageio.


                                                            