In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd import grad as Grad
from torchvision import transforms
import skimage.io
import os
import copy
from collections import OrderedDict
from tqdm import tqdm, trange
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import argparse


In [7]:
from models import __models__
from datasets import __datasets__

import os
from dataclasses import dataclass, field
from utils import *
from torch.utils.data import DataLoader


In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
parser = argparse.ArgumentParser(description='Accurate and Real-Time Stereo Matching via Context and Geometry Interaction (CGI-Stereo)')
parser.add_argument('--model', default='CGI_Stereo', help='select a model structure', choices=__models__.keys())
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
parser.add_argument('--datapath', default="/data/KITTI/KITTI_2015/training/", help='data path')
parser.add_argument('--kitti', type=str, default='2015')
parser.add_argument('--loadckpt', default='./pretrained_models/CGI_Stereo/sceneflow.ckpt',help='load the weights from a specific checkpoint')

args = parser.parse_args()


In [6]:
from pathlib import Path
@dataclass
class Args:
    model: str           = 'CGI_Stereo'
    # maxdisp: int         = 192
    dataset: str         = 'scape_pipes'
    # datapath: str        = r"C:\Users\"
    # trainlist: str       = "full_aug_combined_rectified_scape_dataset_train.txt"
    testlist: str        = "full_aug_combined_rectified_scape_dataset_test.txt"
    # lr: float            = 0.001
    # batch_size: int      = 8
    # test_batch_size: int = 2
    # epochs: int          = 100
    # lrepochs: str        = "10,14,16,18:2"
    # logdir: str          = r"C:\Users\"
    # # loadckpt: str        = r"C:\Users\"
    # loadckpt: str        = r""
    # resume: bool         = False
    # seed: int            = 1
    # summary_freq: int    = 5
    # save_freq: int       = 1
args = Args()


In [None]:

scan_folder        = os.path.join(BASE_PATH, 'data_3/ainstec')
image_folder_left  = os.path.join(BASE_PATH, 'data_3/multicast_rect_reso/camera1')
image_folder_right = os.path.join(BASE_PATH, 'data_3/multicast_rect_reso/camera2')
output_path        = os.path.join(BASE_PATH, 'data_3/ground_truth_pfm_rect')

os.makedirs(output_path, exist_ok=True)

list_of_scans_folders: List[str] = sorted(os.listdir(scan_folder))
list_of_left_images:   List[str] = sorted([os.path.join(image_folder_left , x) for x in os.listdir(image_folder_left ) if x != '.DS_Store'])
list_of_right_images:  List[str] = sorted([os.path.join(image_folder_right, x) for x in os.listdir(image_folder_right) if x != '.DS_Store'])

list_of_scans_paths: List[str] = []
for subfolder in list_of_scans_folders:
    if str(subfolder) == ".DS_Store":
        continue
    folder_path = os.path.join(scan_folder, subfolder)
    pcd_files = [os.path.join(folder_path, x) for x in os.listdir(folder_path) if x.endswith('.pcd')]

In [None]:
StereoDataset = __datasets__[args.dataset]
test_dataset = StereoDataset(args.datapath, args.testlist, False)
TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=2, drop_last=False)

In [None]:
if args.kitti == '2015':
    all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt2015.kt2015_loader(args.datapath)
else:
    all_limg, all_rimg, all_ldisp, test_limg, test_rimg, test_ldisp = kt2012.kt2012_loader(args.datapath)

In [None]:
for batch_idx, sample in enumerate(TestImgLoader):
    print(len(sample))

In [None]:
for batch_idx, sample in enumerate(TestImgLoader):
    model = __models__[args.model](args.maxdisp)
    model = nn.DataParallel(model)
    model.cuda()
    model.eval()

    if args.loadckpt is not None:
        state_dict = torch.load(args.loadckpt)
        model.load_state_dict(state_dict['state_dict'], strict=False)

    imgL = sample['left'].cuda()
    imgR = sample['right'].cuda()

    with torch.no_grad():
        disp = model(imgL, imgR)

    disp = torch.squeeze(disp)
    disp = disp.data.cpu().numpy()

    for i in range(disp.shape[0]):
        disp_map = disp[i]
        disp_map = (disp_map * 256).astype(np.uint16)

        img_name = test_limg[batch_idx * args.test_batch_size + i].split('/')[-1]
        skimage.io.imsave(os.path.join(output_path, img_name[:-4] + '.png'), disp_map)

    if batch_idx % 10 == 0:
        print('Iter %d' % batch_idx)


In [None]:

test_limg = all_limg + test_limg
test_rimg = all_rimg + test_rimg
test_ldisp = all_ldisp + test_ldisp

model = __models__[args.model](args.maxdisp)
model = nn.DataParallel(model)
model.cuda()
model.eval()

state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'])

pred_mae = 0
pred_op = 0
for i in trange(len(test_limg)):
    limg = Image.open(test_limg[i]).convert('RGB')
    rimg = Image.open(test_rimg[i]).convert('RGB')

    w, h = limg.size
    m = 32
    wi, hi = (w // m + 1) * m, (h // m + 1) * m
    limg = limg.crop((w - wi, h - hi, w, h))
    rimg = rimg.crop((w - wi, h - hi, w, h))

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    limg_tensor = transform(limg)
    rimg_tensor = transform(rimg)
    limg_tensor = limg_tensor.unsqueeze(0).cuda()
    rimg_tensor = rimg_tensor.unsqueeze(0).cuda()

    disp_gt = Image.open(test_ldisp[i])
    disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32) / 256
    gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda()

    with torch.no_grad():
        pred_disp  = model(limg_tensor, rimg_tensor)[-1]
        pred_disp = pred_disp[:, hi - h:, wi - w:]

    predict_np = pred_disp.squeeze().cpu().numpy()

    op_thresh = 3
    mask = (disp_gt > 0) & (disp_gt < args.maxdisp)
    error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))

    pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
    pred_op += np.sum((pred_error > op_thresh)) / np.sum(mask)
    pred_mae += np.mean(pred_error[mask])

    # print("#### >3.0", np.sum((pred_error > op_thresh)) / np.sum(mask))
    # print("#### EPE", np.mean(pred_error[mask]))

print("#### EPE", pred_mae / len(test_limg))
print("#### >3.0", pred_op / len(test_limg))