# Example prediction on a single test image

This notebook gives example code to make a single disparity prediction for one test image.

The file `test_simple.py` shows a more complete version of this code, which additionally:
- Can run on GPU or CPU (this notebook only runs on CPU)
- Can predict for a whole folder of images, not just a single image
- Saves predictions to `.npy` files and disparity images.

In [None]:
from __future__ import absolute_import, division, print_function

import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as pil
import torch
from torchvision import transforms

import networks

## Setting up network and loading weights

In [None]:
model_name = "evo_scratch"
model_path = '/media/data/datasets/penitto/networks/monodepth2'

In [None]:
encoder_path = os.path.join(model_path, model_name, "models", "weights_28", "encoder.pth")
depth_decoder_path = os.path.join(
    model_path, model_name, "models", "weights_28", "depth.pth"
)

In [None]:
encoder = networks.ResnetEncoder(50, False)
depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))

In [None]:
# LOADING PRETRAINED MODEL


loaded_dict_enc = torch.load(encoder_path, map_location='cpu')
filtered_dict_enc = {
    k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()
}
encoder.load_state_dict(filtered_dict_enc)

loaded_dict = torch.load(depth_decoder_path, map_location='cpu')
depth_decoder.load_state_dict(loaded_dict)

encoder.eval()
depth_decoder.eval();

In [None]:
25557032 + 9014100

In [None]:
sum(p.numel() for p in encoder.parameters())

In [None]:
sum(p.numel() for p in depth_decoder.parameters())

In [None]:
def readlines(filename):
    """Read all the lines in a text file and return as a list"""
    with open(filename, 'r') as f:
        lines = f.read().splitlines()
    return lines

In [None]:
lst = readlines('/home/penitto/mono_depth/eval_imgs/ev_img.txt')

In [None]:
lst

In [None]:
def disp_to_depth(disp, min_depth, max_depth):
    """Convert network's sigmoid output into depth prediction
    The formula for this conversion is given in the 'additional considerations'
    section of the paper.
    """
    min_disp = 1 / max_depth
    max_disp = 1 / min_depth
    scaled_disp = min_disp + (max_disp - min_disp) * disp
    depth = 1 / scaled_disp
    return scaled_disp, depth

In [None]:
def normalize_image(x):
    """Rescale image pixels to span range [0, 1]"""
    ma = float(x.max())
    mi = float(x.min())
    d = ma - mi if ma != mi else 1e5
    return (x - mi) / d

In [None]:
for i in lst[:1]:
    split = i.split()
    image_path = (
        "/home/penitto/mono_depth/eval_imgs/" + split[0] + '_' + split[1] + '.jpg'
    )

    input_image = pil.open(image_path).convert('RGB')
    original_width, original_height = input_image.size

    feed_height = loaded_dict_enc['height']
    feed_width = loaded_dict_enc['width']
    input_image_resized = input_image.resize((feed_width, feed_height), pil.LANCZOS)

    input_image_pytorch = transforms.ToTensor()(input_image_resized).unsqueeze(0)

    with torch.no_grad():
        features = encoder(input_image_pytorch)
        outputs = depth_decoder(features)

    disp = outputs[("disp", 0)]

    new_path = os.path.splitext(image_path)[0]
    new_path += '_depth.png'
    new_path

    disp_resized = (
        torch.nn.functional.interpolate(
            disp, (original_height, original_width), mode="bilinear", align_corners=False
        )
        .squeeze()
        .cpu()
        .numpy()
    )
    #     print(original_height, original_width)
    #     print(disp.shape)
    #     print(type(disp))

    print(normalize_image(disp))
    # Saving colormapped depth image
    #     disp_resized_np = np.clip(disp_to_depth(disp_resized.squeeze().cpu().numpy(), 2, 117)[1], 2, 117)
    # vmax = np.percentile(disp_resized_np, 95)
    s = 120 * 2.12 * 1000000
    disp_resized_np = 1 / normalize_image(disp_resized)
    vmax = np.percentile(disp_resized_np, 99)
    vmin = np.percentile(disp_resized_np, 5)

    plt.figure(figsize=(10, 10))
    plt.subplot(211)
    plt.imshow(input_image)
    plt.title("Input", fontsize=22)
    plt.axis('off')

    plt.subplot(212)
    #     plt.imsave(new_path, disp_resized_np, cmap='magma', vmax=vmax)
    plt.imshow(disp_resized_np, cmap='magma', vmax=vmax)
    plt.title("Disparity prediction", fontsize=22)
    plt.axis('off');

## Prediction using the PyTorch model

In [None]:
new_path

## Plotting

In [None]:
cv2.imread(
    '../../eval_imgs/ckad_01_ckad_2020-10-29-17-01-56_0_1603980116931038171.jpg'
).shape[:2]

In [None]:
for i in lst:
    split = i.split()
    image_path = (
        "/home/penitto/mono_depth/eval_imgs/" + split[0] + '_' + split[1] + '.png'
    )
    save_path = (
        "/home/penitto/mono_depth/eval_imgs/" + split[0] + '_' + split[1] + '_alt.png'
    )
    #     gt_depth = "/home/penitto/mono_depth/eval_imgs/ckad_01_ckad_2020-10-29-17-01-56_0_1603980116931038171.png"
    input_image = pil.open(image_path)
    vmax = np.percentile(input_image, 95)
    plt.imsave(save_path, input_image, cmap='magma', vmax=vmax)

In [None]:
vmax = np.percentile(input_image, 95)
plt.imsave('/home/penitto/s.png', input_image, cmap='magma', vmax=vmax)