In [2]:
import glob
import os

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

import model_io
import utils
from models import UnetAdaptiveBins


def _is_pil_image(img):
    return isinstance(img, Image.Image)


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


class ToTensor(object):
    def __init__(self):
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __call__(self, image, target_size=(640, 480)):
        # image = image.resize(target_size)
        image = self.to_tensor(image)
        image = self.normalize(image)
        return image

    def to_tensor(self, pic):
        if not (_is_pil_image(pic) or _is_numpy_image(pic)):
            raise TypeError(
                'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

        if isinstance(pic, np.ndarray):
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            return img

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)

        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float()
        else:
            return img


class InferenceHelper:
    def __init__(self, dataset='nyu', device='cuda:0'):
        self.toTensor = ToTensor()
        self.device = device
        if dataset == 'nyu':
            self.min_depth = 1e-3
            self.max_depth = 10
            self.saving_factor = 1000  # used to save in 16 bit
            model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth)
            pretrained_path = "./pretrained/AdaBins_nyu.pt"
        elif dataset == 'kitti':
            self.min_depth = 1e-3
            self.max_depth = 80
            self.saving_factor = 256
            model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth)
            pretrained_path = "./pretrained/AdaBins_kitti.pt"
        else:
            raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset))

        model, _, _ = model_io.load_checkpoint(pretrained_path, model)
        model.eval()
        self.model = model.to(self.device)

    @torch.no_grad()
    def predict_pil(self, pil_image, visualized=False):
        # pil_image = pil_image.resize((640, 480))
        img = np.asarray(pil_image) / 255.

        img = self.toTensor(img).unsqueeze(0).float().to(self.device)
        bin_centers, pred = self.predict(img)

        if visualized:
            viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma')
            # pred = np.asarray(pred*1000, dtype='uint16')
            viz = Image.fromarray(viz)
            return bin_centers, pred, viz
        return bin_centers, pred

    @torch.no_grad()
    def predict(self, image):
        bins, pred = self.model(image)
        pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth)

        # Flip
        image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device)
        pred_lr = self.model(image)[-1]
        pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth)

        # Take average of original and mirror
        final = 0.5 * (pred + pred_lr)
        final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:],
                                          mode='bilinear', align_corners=True).cpu().numpy()

        final[final < self.min_depth] = self.min_depth
        final[final > self.max_depth] = self.max_depth
        final[np.isinf(final)] = self.max_depth
        final[np.isnan(final)] = self.min_depth

        centers = 0.5 * (bins[:, 1:] + bins[:, :-1])
        centers = centers.cpu().squeeze().numpy()
        centers = centers[centers > self.min_depth]
        centers = centers[centers < self.max_depth]

        return centers, final

    @torch.no_grad()
    def predict_dir(self, test_dir, out_dir):
        os.makedirs(out_dir, exist_ok=True)
        transform = ToTensor()
        all_files = glob.glob(os.path.join(test_dir, "*"))
        self.model.eval()
        for f in tqdm(all_files):
            image = np.asarray(Image.open(f), dtype='float32') / 255.
            image = transform(image).unsqueeze(0).to(self.device)

            centers, final = self.predict(image)
            # final = final.squeeze().cpu().numpy()

            final = (final * self.saving_factor).astype('uint16')
            basename = os.path.basename(f).split('.')[0]
            save_path = os.path.join(out_dir, basename + ".png")

            Image.fromarray(final).save(save_path)



In [26]:
import matplotlib.pyplot as plt
from time import time
from pathlib import Path
from glob import glob
import matplotlib


base = "/home/jonfrey/datasets/scannet"
image_pths = [str(p) for p in glob( base+'/**/*.jpg', recursive=True ) if str(p).find('color') != -1]
fun = lambda x : x.split('/')[-3][-7:] + '_'+ str( "0"*(6-len( x.split('/')[-1][:-4]))) + x.split('/')[-1][:-4]  
image_pths.sort(key=fun)
inferHelper = InferenceHelper(  dataset='nyu', device='cuda:1' )

from torchvision import transforms as tf 
import torch
tra = torch.nn.Sequential(
    tf.Resize((480,640))
)

from PIL import PIL.Image.NEAREST
tra_up = torch.nn.Sequential(
    tf.Resize((968, 1296))
)

Loading base model ()...

Using cache found in /home/jonfrey/.cache/torch/hub/rwightman_gen-efficientnet-pytorch_master


Done.
Removing last two layers (global_pool & classifier).
Building Encoder-Decoder model..Done.


In [40]:
done_pths = [str(p) for p in glob( base+'/**/*.png', recursive=True ) if str(p).find('png') != -1 and str(p).find('preview') != -1]

In [52]:
done_pths[0],image_pths[0]

done_pths_idx = [ p.split('/')[-3]+'__'+p.split('/')[-1][:-12] for p in done_pths]
done_pths_idx

['scene0089_02__2',
 'scene0089_02__11',
 'scene0089_02__0',
 'scene0089_02__5',
 'scene0089_02__8',
 'scene0089_02__7',
 'scene0089_02__4',
 'scene0089_02__3',
 'scene0089_02__10',
 'scene0089_02__1',
 'scene0089_02__6',
 'scene0089_02__9',
 'scene0009_02__2',
 'scene0009_02__11',
 'scene0009_02__0',
 'scene0009_02__5',
 'scene0009_02__8',
 'scene0009_02__7',
 'scene0009_02__4',
 'scene0009_02__3',
 'scene0009_02__10',
 'scene0009_02__1',
 'scene0009_02__6',
 'scene0009_02__9',
 'scene0025_01__2',
 'scene0025_01__11',
 'scene0025_01__0',
 'scene0025_01__5',
 'scene0025_01__8',
 'scene0025_01__7',
 'scene0025_01__4',
 'scene0025_01__3',
 'scene0025_01__10',
 'scene0025_01__1',
 'scene0025_01__6',
 'scene0025_01__9',
 'scene0091_00__2',
 'scene0091_00__11',
 'scene0091_00__0',
 'scene0091_00__5',
 'scene0091_00__8',
 'scene0091_00__7',
 'scene0091_00__4',
 'scene0091_00__3',
 'scene0091_00__10',
 'scene0091_00__1',
 'scene0091_00__6',
 'scene0091_00__9',
 'scene0074_00__2',
 'scene0074_

In [None]:
start = time()
import imageio

for j, i in enumerate( image_pths ):
    idx = i.split('/')[-3]+'__'+i.split('/')[-1][:-4]
    if idx in done_pths_idx:
        print("Already done ", idx)
#          continue
    else:
        print(idx)
    
    img = tra( Image.open( i ) )
    
    
    centers, pred = inferHelper.predict_pil(img)
    
    pred = torch.from_numpy(pred).numpy()
    
    Path( os.path.join( str(Path(i).parent.parent),'depth_estimate') ).mkdir(exist_ok=True)
    save_path = os.path.join( str(Path(i).parent.parent),'depth_estimate/'+ i.split('/')[-1] )
    store = (pred[0,0,:,:]) * 1000
    store = store.astype(np.uint16)
    save_path = save_path[:-4]+'.png'
    
    imageio.imwrite( save_path ,store) 
    
#     save_path = save_path[:-4]+'_preview.png'
#     plt.imshow(pred.squeeze(), cmap='magma_r')
#     plt.axis('off')
#     plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
#     plt.close()
    
    if j % 100 == 0:
        print(j, '/',len(image_pths), '  ',time()-start,'s')
        start = time()

Already done  scene0000_00__0
0 / 362886    1.181058406829834 s
Already done  scene0000_00__1
Already done  scene0000_00__2
Already done  scene0000_00__3
Already done  scene0000_00__4
Already done  scene0000_00__5
Already done  scene0000_00__6
Already done  scene0000_00__7
Already done  scene0000_00__8
Already done  scene0000_00__9
Already done  scene0000_00__10
Already done  scene0000_00__11
scene0000_00__12
scene0000_00__13
scene0000_00__14
scene0000_00__15
scene0000_00__16
scene0000_00__17
scene0000_00__18
scene0000_00__19
scene0000_00__20
scene0000_00__21
scene0000_00__22
scene0000_00__23
scene0000_00__24
scene0000_00__25
scene0000_00__26
scene0000_00__27
scene0000_00__28
scene0000_00__29
scene0000_00__30
scene0000_00__31
scene0000_00__32
scene0000_00__33
scene0000_00__34
scene0000_00__35
scene0000_00__36
scene0000_00__37
scene0000_00__38
scene0000_00__39
scene0000_00__40
scene0000_00__41
scene0000_00__42
scene0000_00__43
scene0000_00__44
scene0000_00__45
scene0000_00__46
scene0000