In [None]:
import sys
sys.path.append('core')

import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
from torchvision import transforms as tf 
import torch
import time 
from pathlib import Path
import imageio

import cv2
from matplotlib import pyplot as plt

#Import image
image = cv2.imread("input_path")
#Show the image with matplotlib

DEVICE = 'cuda:0'

def demo(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg'))
        
        images = sorted(images)

        for imfile1, imfile2 in zip(images[:-1], images[1:]):
            image1 = load_image(imfile1)
            image2 = load_image(imfile2)

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            # viz(image1, flow_up)

di = {
    'model': 'models/raft-kitti.pth',
    'small': False,
    'mixed_precision': True,
    'alternate_corr': False,
}
    
class DotDict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
args = DotDict(di)

In [None]:
import matplotlib.pyplot as plt
from time import time
from pathlib import Path
from glob import glob
import matplotlib

base = "/home/jonfrey/datasets/scannet"
image_pths = [str(p) for p in glob( base+'/**/*.jpg', recursive=True ) if str(p).find('color') != -1]
fun = lambda x : x.split('/')[-3][-7:] + '_'+ str( "0"*(6-len( x.split('/')[-1][:-4]))) + x.split('/')[-1][:-4]  
image_pths.sort(key=fun)
image_pths = [i for i in  image_pths if i.find("scene0000_00") != -1]
image_pths

In [None]:
args['model'] = '/media/scratch1/jonfrey/results/rpose/models/raft-sintel.pth'
DEVICE = 'cuda:0'
image_pths[-1]

In [None]:
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))

model = model.module
model.to(DEVICE)
model.eval()

In [None]:
tra = torch.nn.Sequential(
tf.Resize((484,648))
)
tra_up = torch.nn.Sequential(
tf.Resize((968, 1296))
)

def writeFlowKITTI(filename, uv):
    uv = 64.0 * uv + 2**15
    valid = np.ones([uv.shape[0], uv.shape[1], 1])
    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
    cv2.imwrite(filename, uv[..., ::-1])

def viz(img, flo):
    img = img[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    img_flo = np.concatenate([img, flo], axis=0)
    plt.imshow(img_flo/255 )
    plt.show()

def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    print(img.shape)
    img = img[:, 12:-12, 12:-12]
    return img[None].to(DEVICE)


def estimate_flow(f1, f2, name, model):
    image1 = load_image(f1)
    image2 = load_image(f2)
    
    flow_low, flow_up = model(image1, image2, iters=12, test_mode=True)

    ls = ['/'] + f1.split('/')[:-2]+[ name,"flow_low_" +f1.split('/')[-1][:-4]+'.png' ]
    p1 = os.path.join(*ls )
    ls = ['/'] + f1.split('/')[:-2]+[ name,"flow_up_" + f1.split('/')[-1][:-4]+'.png' ]
    p2 = os.path.join(*ls )
    
    direct =  os.path.join(*ls[:-1] )
    Path( str(direct) ).mkdir(parents=False, exist_ok=True)
    print( p1 )
    writeFlowKITTI(p1, flow_low.detach().cpu()[0].permute(1,2,0) )
    writeFlowKITTI(p2, flow_up.detach().cpu()[0].permute(1,2,0) )
    
    
sub= 10
scene = "xxx"
old_frame = None
base_name = f"flow_sub_{sub}"
st = time.time()
with torch.no_grad():
    for j, i in enumerate( image_pths):
        if i.split('/')[-3] == scene:
            if int( i.split('/')[-1][:-4] ) % sub == 0:
                #do flow
                estimate_flow(old_frame, i, base_name, model)            
                old_frame = i
        else:
            scene = i.split('/')[-3]
            old_frame = i
        if j > 1000000:
            break
        if j % 100 == 0:
            print(j, '   ', time.time()-st)