In [1]:
import sys
sys.path.append('core')

import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
from torch.autograd import Variable

from raft.raft import RAFT
from raft.utils import flow_viz
from raft.utils.utils import InputPadder

In [6]:
transform_rgb = transforms.ToTensor()

transform_gray = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # 彩色图像转灰度图像num_output_channels默认1
    transforms.ToTensor()
])

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size=3

In [3]:
# rgb图片读取
def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    # img = np.array(Image.open(imfile))
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

In [4]:
# 图像可视化
def viz(flo):
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    # img_flo = np.concatenate([img, flo], axis=0)
 
    # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
    # cv2.waitKey()
    cv2.imwrite('res1.png', flo[:, :, [2,1,0]])

In [5]:
# 读取flo为tensor
def load_flow_to_numpy(path):
    with open(path, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)
        assert (202021.25 == magic), 'Magic number incorrect. Invalid .flo file'
        h = np.fromfile(f, np.int32, count=1)[0]
        w = np.fromfile(f, np.int32, count=1)[0]
        data = np.fromfile(f, np.float32, count=2 * w * h)
    data2D = np.resize(data, (w, h, 2))
    data2D = data2D.transpose(2,0,1)
    data2D_tensor = torch.from_numpy(data2D)
    return data2D_tensor

In [None]:
def dataload(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()
    
    data_path = '/Users/panding/code/ur/data/uniform'
    
    with torch.no_grad():
        # 读取路径内的成对图像和flow真值
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg')) + \
                 glob.glob(os.path.join(args.path, '*.ppm'))
        flow_truth = glob.glob(os.path.join(args.path, '*.flo'))
        
        images = sorted(images)
        images_num = len(images)
        images_loading_num = 1
        print('\n', '--------------images loading...-------------', '\n')
        for imfile1, imfile2 in zip(images[:-1], images[1:]):
            
            images_loading_num = images_loading_num + 1
            # torch.Size([3, 436, 1024])
            image1_rgb_tensor = load_image(imfile1)
            image2_rgb_tensor = load_image(imfile2)
            
            """
            torch.Size([1, 3, 440, 1024])
            这个pad操作会改变张量的尺寸, 后面灰度张量也需要pad一下才可以和光流张量拼接
            """
            padder = InputPadder(image1_rgb_tensor.shape)
            image1, image2 = padder.pad(image1_rgb_tensor, image2_rgb_tensor)
            
            # torch.Size([1, 2, 440, 1024])
            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            viz(flow_up)
            # torch.Size([2, 440, 1024])
            flow_up = torch.squeeze(flow_up)

            flow_up_u, flow_up_v = flow_up.split(1, 0)
            
            # torch.Size([2, 436, 1024])
            image1_gray_tensor = transform_gray(Image.open(imfile1)).to(DEVICE)
            image2_gray_tensor = transform_gray(Image.open(imfile2)).to(DEVICE)
            # torch.Size([2, 440, 1024])
            image1_gray_tensor, image2_gray_tensor = padder.pad(image1_gray_tensor, image2_gray_tensor)

            # image1_gray_tensor_remap = remap(image1_gray_tensor, flow_up_u, flow_up_v)
            
            # 读取flow的真值
            flow_path = '/home/panding/code/UR/chair/' + imfile1[6:-8] + 'flow.flo'
            flow_truth = load_flow_to_numpy(flow_path).to(DEVICE)

            """
            torch.Size([6, 440, 1024])
            六通道分别为 灰度后的i1, 灰度后的i2, u, v, u_t, v_t
            """
            result = torch.cat((image2_gray_tensor, image1_gray_tensor, flow_up, flow_truth), 0)
            result = result.cpu()
            result_np = result.numpy()
            data_path = data_path + '/' + imfile1[6:-4]
            np.save(data_path, result_np)
            data_path = '/home/panding/code/UR/data-chair'
            if images_loading_num % 5 == 0:
                print('\n', '--------------images loaded: ', images_loading_num, ' / ', images_num, '-------------', '\n')
