# Depth Estimation with Masked Image Modeling

This notebook will show monocular depth estimation for robotic applications.  
The used model is taken from [this](https://github.com/SwinTransformer/MIM-Depth-Estimation/tree/main) repository for the paper ["Revealing the Dark Secrets of Masked Image Modeling (Depth Estimation)"](https://arxiv.org/abs/2205.13543).

## Setup

### Imports

In [None]:
#std libs
from collections import OrderedDict
import math

#non std libs
import torch
from torchvision.transforms import ToTensor
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
from pyzed import sl

In [None]:
from models.model import GLPDepth

### Parameters

Name           | Type   | Value
---------------|--------|----------------------------------------------------------------------------------------------------------------------------------------------
`SVO_FILEPATH` | `str`  | Path to the svo recording used for evaluation.
`USE_NYU`      | `bool` | `True`: Use the weights pretrained on "nyudepthv2"<br>`False`: Use the weights pretrained on the kitti eigen split
`CROP_TO_SIZE` | `bool` | `True`: Crop the image to size (1216, 352). This loses parts of the image completely.<br>`False`: Resize the image. This distortes the image.
`OUTPUT_FILE`  | `str`  | Path to the output csv file for evaluation results.

In [None]:
SVO_FILEPATH    = "recording_rear.svo";
USE_NYU         = True;
CROP_TO_SIZE    = False;
OUTPUT_FILE     = "output.csv";

### Constants

In [None]:
MAX_DEPTH = 20.0 #[m]

### Definition utility functions

In [None]:
def PSNR(ground_truth: np.ndarray, estimation: np.ndarray) -> float:
    mse = np.nanmean((ground_truth - estimation) ** 2);
    if (mse == 0): #no noise
        return -1.0;

    psnr = 20 * math.log10(255.0) - 10 * math.log10(mse);
    return psnr;

In [None]:
def crop_image(img: np.ndarray) -> np.ndarray:
    h_im, w_im = img.shape[:2]

    margin_top = int(h_im - 352)
    margin_left = int((w_im - 1216) / 2)

    sized_image = img[margin_top:  margin_top  + 352,
                      margin_left: margin_left + 1216]

    return sized_image;

if CROP_TO_SIZE:
    resize_image = crop_image;
else:
    resize_image = lambda img: cv.resize(img, (1216, 352));

### Initialization of camera object for svo playback

In [None]:
camera_init_parameters = sl.InitParameters();

camera_init_parameters.svo_real_time_mode = False;
camera_init_parameters.open_timeout_sec = 30;
camera_init_parameters.coordinate_units = sl.UNIT.METER;

camera_init_parameters.set_from_svo_file(SVO_FILEPATH);

In [None]:
camera = sl.Camera();
error_code = camera.open(camera_init_parameters);
if (error_code != sl.ERROR_CODE.SUCCESS):
    print("Failed to open Camera object:", error_code);

In [None]:
nr_frames = camera.get_svo_number_of_frames();
resolution = camera.get_camera_information().camera_configuration.camera_resolution;

In [None]:
color_image = sl.Mat(resolution.width, resolution.height, sl.MAT_TYPE.U8_C3, sl.MEM.CPU);
depth_image = sl.Mat(resolution.width, resolution.height, sl.MAT_TYPE.U8_C1, sl.MEM.CPU);
# depth_image = sl.Mat(resolution.width, resolution.height, sl.MAT_TYPE.F32_C1, sl.MEM.CPU);

### Setup of pytorch

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda");
    device_prop = torch.cuda.get_device_properties(device);
    print(f"Using GPU: {device_prop.name} {round(device_prop.total_memory / 1024**3, 2)}GiB (CC: {device_prop.major}.{device_prop.minor})");
else:
    device = torch.device("cpu");
    print("Using CPU.");
torch.set_default_device(device);

### Setup of monocular depth estimation model

In [None]:
class Storage:
    #a minimal storage class to somewhat mimic the behaviour of argparse.ArgumentParser
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs);

#### KITTI parameters

```bash
python3 test.py \
--dataset kitti \
--kitti_crop garg_crop \
--data_path ../data/ \
--max_depth 80.0 \
--max_depth_eval 80.0 \
--backbone swin_large_v2 \
--depths 2 2 18 2 \
--num_filters 32 32 32 \
--deconv_kernels 2 2 2 \
--window_size 22 22 22 11 \
--pretrain_window_size 12 12 12 6 \
--use_shift True True False False \
--flip_test \
--shift_window_test \
--shift_size 16 \
--do_evaluate \
--ckpt_dir ckpt/kitti_swin_large.ckpt
```

In [None]:
kitti_args = Storage(
    #max_depth=80.0,
    backbone="swin_large_v2",
    depths=[2, 2, 18, 2],
    window_size=[22, 22, 22, 11],
    pretrain_window_size=[12, 12, 12, 6],
    drop_path_rate=0.3,
    use_checkpoint=False,
    use_shift=[True, True, False, False],
    pretrained='',
    num_deconv=3,
    num_filters=[32, 32, 32],
    deconv_kernels=[2, 2, 2],
)

#### nyudepth parameters

```bash
python3 test.py \
--dataset nyudepthv2 \
--data_path ../data/ \
--max_depth 10.0 \
--max_depth_eval 10.0  \
--backbone swin_large_v2 \
--depths 2 2 18 2 \
--num_filters 32 32 32 \
--deconv_kernels 2 2 2 \
--window_size 30 30 30 15 \
--pretrain_window_size 12 12 12 6 \
--use_shift True True False False \
--flip_test \
--shift_window_test \
--shift_size 2 \
--do_evaluate \
--ckpt_dir ckpt/nyudepthv2_swin_large.ckpt
```

In [None]:
nyudepthv2_args = Storage(
    #max_depth=10.0,
    backbone="swin_large_v2",
    depths=[2, 2, 18, 2],
    window_size=[30, 30, 30, 15],
    pretrain_window_size=[12, 12, 12, 6],
    drop_path_rate=0.3, #
    use_checkpoint=False, #
    use_shift=[True, True, False, False],
    pretrained='', #
    num_deconv=3, #
    num_filters=[32, 32, 32],
    deconv_kernels=[2, 2, 2],
)

In [None]:
# as specified for the svo recording
nyudepthv2_args.max_depth = kitti_args.max_depth = 20.0;

In [None]:
if USE_NYU:
    print("Using weights pretrained on nyudepthv2.")
    model = GLPDepth(args=nyudepthv2_args);

    model_weights: dict = torch.load("./checkpoints/nyudepthv2_swin_large.ckpt", map_location=device);
else:
    print("Using weights pretrained on kitti.")
    model = GLPDepth(args=kitti_args);

    model_weights: dict = torch.load("./checkpoints/kitti_swin_large.ckpt", map_location=device);

model = model.to(device);

if 'module' in next(iter(model_weights.items()))[0]:
    model_weight = OrderedDict((k[7:], v) for k, v in model_weights.items())

model.load_state_dict(model_weights);
model.eval();

In [None]:
output = [];

while (True):
    # get/go to the current frame
    error_code = camera.grab();
    if (error_code == sl.ERROR_CODE.END_OF_SVOFILE_REACHED):
        print("Done" + ' ' * 30)
        break
    elif (error_code != sl.ERROR_CODE.SUCCESS):
        raise SystemExit(f"Failed to grab frame: {error_code}");

    # retrieve current camera frame
    error_code = camera.retrieve_image(color_image, sl.VIEW.LEFT, sl.MEM.CPU);
    if (error_code != sl.ERROR_CODE.SUCCESS):
        raise SystemExit(f"Failed to retrieve color image: {error_code}");

    # retrieve current depth map
    error_code = camera.retrieve_measure(depth_image, sl.MEASURE.DEPTH, sl.MEM.CPU);
    if (error_code != sl.ERROR_CODE.SUCCESS):
        raise SystemExit(f"Failed to retrieve depth image: {error_code}");

    # resize and convert image for model
    sized_image: np.ndarray = resize_image(color_image.get_data());
    sized_image = cv.cvtColor(sized_image, cv.COLOR_BGRA2RGB);

    # resize and normalize depth for comparison
    sized_depth: np.ndarray = resize_image(depth_image.get_data());
    # sized_depth = (255 * sized_depth / MAX_DEPTH).astype(np.uint8);

    # prepare image as torch.Tensor
    img_tensor = ToTensor()(sized_image)
    img_tensor = img_tensor[None, :, :, :]
    img_tensor = img_tensor.to(device);

    # let the model create a depth map from the frame
    with torch.no_grad():
        prediction = model(img_tensor);
        pred_tensor: torch.Tensor = prediction["pred_d"]

    # convert tensor back to numpy array
    pred_ndarray: np.ndarray = pred_tensor.squeeze().cpu().numpy()
    # pred_ndarray = (pred_ndarray / MAX_DEPTH) * 255
    # pred_ndarray = pred_ndarray.astype(np.uint8)

    # calculate PSNR
    psnr = PSNR(sized_depth, pred_ndarray);
    current_frame = camera.get_svo_position() + 1;

    output.append(f"{current_frame},{color_image.timestamp.data_ns},{psnr}\n");

    print(f"[{current_frame}/{nr_frames}] PSNR = {psnr} dB", end='\r');

In [None]:
with open(SVO_FILEPATH.replace(".svo", '') + '-' + OUTPUT_FILE, 'w') as file:
    file.write("Frame,Timestamp [ns],PSNR [dB]\n")
    file.writelines(output);