In [15]:
import sys
sys.path.append('core')

import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder


%matplotlib inline
#The line above is necesary to show Matplotlib's plots inside a Jupyter Notebook

import cv2
from matplotlib import pyplot as plt

#Import image
image = cv2.imread("input_path")

#Show the image with matplotlib


DEVICE = 'cuda:1'

def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)


def demo(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg'))
        
        images = sorted(images)

        for imfile1, imfile2 in zip(images[:-1], images[1:]):
            image1 = load_image(imfile1)
            image2 = load_image(imfile2)

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            
            
            # viz(image1, flow_up)




di = {
    'model': 'models/raft-kitti.pth',
    'small': False,
    'mixed_precision': True,
    'alternate_corr': False,
}
    
class DotDict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
args = DotDict(di)

In [2]:
import matplotlib.pyplot as plt
from time import time
from pathlib import Path
from glob import glob
import matplotlib

base = "/home/jonfrey/datasets/scannet"
image_pths = [str(p) for p in glob( base+'/**/*.jpg', recursive=True ) if str(p).find('color') != -1]
fun = lambda x : x.split('/')[-3][-7:] + '_'+ str( "0"*(6-len( x.split('/')[-1][:-4]))) + x.split('/')[-1][:-4]  
image_pths.sort(key=fun)
image_pths

['/home/jonfrey/datasets/scannet/scans/scene0000_00/color/0.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/1.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/2.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/3.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/4.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/5.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/6.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/7.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/8.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/9.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/10.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/11.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/12.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/13.jpg',
 '/home/jonfrey/datasets/scannet/scans/scene0000_00/color/14.jpg',
 '/ho

In [3]:
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))

model = model.module
model.to(DEVICE)
model.eval()

RAFT(
  (fnet): BasicEncoder(
    (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (relu1): ReLU(inplace=True)
    (layer1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): ResidualBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=Fa

In [79]:
from torchvision import transforms as tf 
import torch
import time 
from pathlib import Path
# PIL.Image.NEAREST
import imageio
tra = torch.nn.Sequential(
tf.Resize((484,648))
)
tra_up = torch.nn.Sequential(
tf.Resize((968, 1296))
)

def writeFlowKITTI(filename, uv):
    uv = 64.0 * uv + 2**15
    valid = np.ones([uv.shape[0], uv.shape[1], 1])
    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
    cv2.imwrite(filename, uv[..., ::-1])

def viz(img, flo):
    img = img[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    img_flo = np.concatenate([img, flo], axis=0)
    plt.imshow(img_flo/255 )
    plt.show()


def estimate_flow(f1, f2, name, model):
    image1 = tra(load_image(f1))
    image2 = tra(load_image(f2))
    padder = InputPadder(image1.shape)
    image1, image2 = padder.pad(image1, image2)

    flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
    
#     viz(image1, flow_up)
    
    ls = ['/'] + f1.split('/')[:-2]+[ name,"flow_low_" +f1.split('/')[-1][:-4]+'.png' ]
    p1 = os.path.join(*ls )
    ls = ['/'] + f1.split('/')[:-2]+[ name,"flow_up_" + f1.split('/')[-1][:-4]+'.png' ]
    p2 = os.path.join(*ls )
    
    direct =  os.path.join(*ls[:-1] )
    Path( str(direct) ).mkdir(parents=False, exist_ok=True)

    writeFlowKITTI(p1, flow_low.detach().cpu()[0].permute(1,2,0) )
    writeFlowKITTI(p2, flow_up.detach().cpu()[0].permute(1,2,0) )
    
    
sub= 10
scene = "xxx"
old_frame = None
base_name = f"flow_sub_{sub}"
st = time.time()
with torch.no_grad():
    for j, i in enumerate( image_pths):
        if i.split('/')[-3] == scene:
            if int( i.split('/')[-1][:-4] ) % sub == 0:
                #do flow
                estimate_flow(old_frame, i, base_name, model)            
                old_frame = i
        else:
            scene = i.split('/')[-3]
            old_frame = i
        if j > 15000:
            break
        if j % 100 == 0:
            print(j, '   ', time.time()-st)

0     0.00011038780212402344
100     8.984097242355347
200     18.05113124847412
300     27.127877473831177
400     36.171480655670166
500     45.207042932510376
600     54.221131563186646
700     63.2520751953125
800     72.24372053146362
900     81.31683683395386
1000     90.36558198928833
1100     99.36629748344421
1200     108.36039566993713
1300     117.33083009719849
1400     126.28037428855896
1500     135.25484585762024
1600     144.2170271873474
1700     153.19317841529846
1800     162.25288367271423
1900     171.31200242042542
2000     180.368825674057
2100     189.38590002059937
2200     198.36359214782715
2300     207.36214542388916
2400     216.4364733695984
2500     225.50802564620972
2600     234.57038927078247
2700     243.5932171344757
2800     252.67933797836304
2900     261.6879687309265
3000     270.6608304977417
3100     279.6450300216675
3200     288.61990547180176
3300     297.64190697669983
3400     306.6199269294739
3500     315.6156539916992
3600     324.59346

In [78]:
def readFlowKITTI(filename):
    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
    flow = flow[:,:,::-1].astype(np.float32)
    flow, valid = flow[:, :, :2], flow[:, :, 2]
    flow = (flow - 2**15) / 64.0
    return flow, valid

f,v = readFlowKITTI("/home/jonfrey/datasets/scannet/scans/scene0000_00/flow_sub_10/flow_up_0.png")
f.shape

(488, 648, 2)