In [8]:

from __future__ import absolute_import, division, print_function

import os
import sys
import glob
import argparse
import numpy as np
import PIL.Image as pil
import matplotlib as mpl
import matplotlib.cm as cm
import cv2

import torch
from torchvision import transforms, datasets
from haze_networks_niantic.depth_decoder import DepthDecoder
from haze_networks_niantic.resnet_encoder import ResnetEncoder

In [16]:
def gen_haze(clean_img, depth_img, beta=1.0, A = 150):

    depth_img_3c = np.zeros_like(clean_img)
    depth_img_3c[:,:,0] = depth_img
    depth_img_3c[:,:,1] = depth_img
    depth_img_3c[:,:,2] = depth_img

    norm_depth_img = depth_img_3c/255
    trans = np.exp(-norm_depth_img*beta)

    hazy = clean_img*trans + A*(1-trans)
    hazy = np.array(hazy, dtype=np.uint8)

    return hazy


def test_simple(model_path = "./models", image_path = "./images", output_image_path = "./depth_images", beta=0.5, airlight=150, no_cuda = False):



    if torch.cuda.is_available() and not no_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")


    print("-> Loading model from ", model_path)
    encoder_path = os.path.join(model_path, "encoder.pth")
    depth_decoder_path = os.path.join(model_path, "depth.pth")

    # LOADING PRETRAINED MODEL
    print("   Loading pretrained encoder")
    encoder = ResnetEncoder(18, False)
    loaded_dict_enc = torch.load(encoder_path, map_location=device)

    # EXTRACT THE HEIGHT AND WIDTH OF IMAGE THAT THIS MODEL WAS TRAINED WITH
    feed_height = loaded_dict_enc['height']
    feed_width = loaded_dict_enc['width']
    filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
    encoder.load_state_dict(filtered_dict_enc)
    encoder.to(device)
    encoder.eval()

    print("   Loading pretrained decoder")
    depth_decoder = DepthDecoder(
        num_ch_enc=encoder.num_ch_enc, scales=range(4))

    loaded_dict = torch.load(depth_decoder_path, map_location=device)
    depth_decoder.load_state_dict(loaded_dict)

    depth_decoder.to(device)
    depth_decoder.eval()



    # CHECK IF OUTPUT FOLDER EXISTS
    if not os.path.isdir(output_image_path):
        os.makedirs(output_image_path)

    output_dir = output_image_path

    # PREDICTING ON EACH IMAGE IN TURN
    with torch.no_grad():
        # LOAD IMAGE AND PREPROCESS
        input_image = pil.open(image_path).convert('RGB')
        clean_img = input_image.copy()
        original_width, original_height = input_image.size
        input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
        input_image = transforms.ToTensor()(input_image).unsqueeze(0)

        # PREDICTION
        input_image = input_image.to(device)
        features = encoder(input_image)
        outputs = depth_decoder(features)

        disp = outputs[("disp", 0)]
        disp_resized = torch.nn.functional.interpolate(
            disp, (original_height, original_width), mode="bilinear", align_corners=False)

        # EXTRACT DEPTH IMAGE
        disp_resized_np = disp_resized.squeeze().cpu().numpy()
        vmax = np.percentile(disp_resized_np, 95)
        normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
        mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
        colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
        im = pil.fromarray(colormapped_im)
        gray_colormapped_im = cv2.cvtColor(colormapped_im, cv2.COLOR_RGB2GRAY)
        inv_gray_colormapped_im = 255 - gray_colormapped_im

        # MAKE HAZY IMAGE:
        # Change degree of haze by changing 'beta' (recommended value of beta: 0.5 - 3.0)
        # High beta -> Thick haze
        # Low beta -> Sparse haze
        hazy = gen_haze(clean_img, inv_gray_colormapped_im, beta=beta, A=airlight)

        # SAVE FILES
        output_name = os.path.splitext(os.path.basename(image_path))[0]
        cv2.imwrite(f"{output_dir}/{output_name}_synt.jpg", cv2.cvtColor(hazy, cv2.COLOR_RGB2BGR))

        print("   Processed image".format(image_path))

    print(f'-> Done! Find outputs in {output_dir}')

In [15]:
test_simple(model_path="models/mono+stereo_640x192", image_path="test_images/1653387338686_jpg.rf.550e4ec3ee16da6a85143e66728513eb.jpg", output_image_path="output_test_images", beta=3, airlight=200)

-> Loading model from  models/mono+stereo_640x192
   Loading pretrained encoder
   Loading pretrained decoder


  loaded_dict_enc = torch.load(encoder_path, map_location=device)
  loaded_dict = torch.load(depth_decoder_path, map_location=device)


   Processed image
-> Done! Find outputs in output_test_images
