In [1]:
# Minimalistic inference demo

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import hflip
from dataloader.stereo import transforms
from glob import glob
import os
from evaluate_stereo import _log_time_usage
from PIL import Image
import cv2
from utils.file_io import write_pfm
from utils.visualization import vis_disparity
from unimatch.unimatch import UniMatch


IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def init_model(weights_path):
    model = UniMatch(feature_channels=128,
                     num_scales=2,
                     upsample_factor=4,
                     num_head=1,
                     ffn_dim_expansion=4,
                     num_transformer_layers=6,
                     reg_refine=True,
                     task='stereo').to(device)
    loc = 'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu'
    checkpoint = torch.load(weights_path, map_location=loc)
    model.load_state_dict(checkpoint['model'], strict=False)
    return model


@torch.no_grad()
def inference_stereo(model,
                     inference_dir=None,
                     inference_dir_left=None,
                     inference_dir_right=None,
                     output_path='output',
                     padding_factor=32,
                     inference_size=[1024, 1280],
                     attn_type='self_swin2d_cross_swin1d',
                     attn_splits_list=[2, 8],
                     corr_radius_list=[-1, 4],
                     prop_radius_list=[-1, 1],
                     num_reg_refine=3,
                     pred_bidir_disp=True,
                     pred_right_disp=False,
                     save_pfm_disp=False,
                     ):
    model.eval()

    val_transform_list = [transforms.ToTensor(),
                          transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
                          ]

    val_transform = transforms.Compose(val_transform_list)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    assert inference_dir or (inference_dir_left and inference_dir_right)

    if inference_dir is not None:
        filenames = sorted(glob(inference_dir + '/*.png') + glob(inference_dir + '/*.jpg'))

        left_filenames = filenames[::2]
        right_filenames = filenames[1::2]

    else:
        left_filenames = sorted(glob(inference_dir_left + '/*.png') + glob(inference_dir_left + '/*.jpg'))
        right_filenames = sorted(glob(inference_dir_right + '/*.png') + glob(inference_dir_right + '/*.jpg'))

    assert len(left_filenames) == len(right_filenames)

    num_samples = len(left_filenames)
    print('%d test samples found' % num_samples)

    fixed_inference_size = inference_size

    for i in range(num_samples):

        if (i + 1) % 50 == 0:
            print('predicting %d/%d' % (i + 1, num_samples))

        left_name = left_filenames[i]
        right_name = right_filenames[i]
        with _log_time_usage('Unimatch Inference Time in seconds: '):
            left = np.array(Image.open(left_name).convert('RGB')).astype(np.float32)
            right = np.array(Image.open(right_name).convert('RGB')).astype(np.float32)
            sample = {'left': left, 'right': right}

            sample = val_transform(sample)

            left = sample['left'].to(device).unsqueeze(0)  # [1, 3, H, W]
            right = sample['right'].to(device).unsqueeze(0)  # [1, 3, H, W]

            nearest_size = [int(np.ceil(left.size(-2) / padding_factor)) * padding_factor,
                            int(np.ceil(left.size(-1) / padding_factor)) * padding_factor]

            # resize to nearest size or specified size
            inference_size = nearest_size if fixed_inference_size is None else fixed_inference_size

            ori_size = left.shape[-2:]
            if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
                left = F.interpolate(left, size=inference_size,
                                    mode='bilinear',
                                    align_corners=True)
                right = F.interpolate(right, size=inference_size,
                                    mode='bilinear',
                                    align_corners=True)

            with torch.no_grad():
                if pred_bidir_disp:
                    new_left, new_right = hflip(right), hflip(left)
                    left = torch.cat((left, new_left), dim=0)
                    right = torch.cat((right, new_right), dim=0)

                if pred_right_disp:
                    left, right = hflip(right), hflip(left)

                pred_disp = model(left, right,
                                attn_type=attn_type,
                                attn_splits_list=attn_splits_list,
                                corr_radius_list=corr_radius_list,
                                prop_radius_list=prop_radius_list,
                                num_reg_refine=num_reg_refine,
                                task='stereo',
                                )['flow_preds'][-1]  # [1, H, W]

            if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
                # resize back
                pred_disp = F.interpolate(pred_disp.unsqueeze(1), size=ori_size,
                                        mode='bilinear',
                                        align_corners=True).squeeze(1)  # [1, H, W]
                pred_disp = pred_disp * ori_size[-1] / float(inference_size[-1])

            save_name = os.path.join(output_path, os.path.basename(left_name)[:-4] + '_disp.png')

            if pred_right_disp:
                pred_disp = hflip(pred_disp)

            disp = pred_disp[0].cpu().numpy()

        if save_pfm_disp:
            save_name_pfm = save_name[:-4] + '.pfm'
            write_pfm(save_name_pfm, disp)

        disp = vis_disparity(disp)
        cv2.imwrite(save_name, disp)

        if pred_bidir_disp:
            assert pred_disp.size(0) == 2  # [2, H, W]
            save_name = os.path.join(output_path, os.path.basename(left_name)[:-4] + '_disp_right.png')

            # flip back
            disp = hflip(pred_disp[1]).cpu().numpy()

            if save_pfm_disp:
                save_name_pfm = save_name[:-4] + '.pfm'
                write_pfm(save_name_pfm, disp)

            disp = vis_disparity(disp)
            cv2.imwrite(save_name, disp)

    print('Done!')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# test the above inference 
weights_path = '/home/aboggaram/projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth'
input_folder = '/opt/iunu_edge_vision/data/test_data_octiva'
output_folder = '/opt/iunu_edge_vision/data/test_data_octiva_gmstereo_output'
model = init_model(weights_path=weights_path)

In [3]:
inference_stereo(
    model=model,
    inference_dir=input_folder,
    output_path=output_folder,
    pred_bidir_disp=False,
)

1 test samples found


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


elapsed seconds Unimatch Inference Time in seconds:  1.36
Done!
