In [None]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
#from scipy.signal import convolve2d  # TODO: use torch.nn.functional
import torch.nn.functional as F
import os

In [None]:
# region DEBUG
DEBUG = False


def print_debug(*args, **kwargs):
    if DEBUG:
        print(args, kwargs)


if __name__ == "__main__":
    DEBUG = os.environ.get("PYTHON_DEBUG_MODE")
    if DEBUG is not None and DEBUG.lower() == "true":
        DEBUG = True
        print("DEBUG mode is enabled")
# endregion

MAX_DISPARITY = 30
SMOOTH_KERNEL_SIZE = 5

# Helper for plotting, convert tensor to numpy array
def _to_numpy(x):
    if torch.is_tensor(x):
        return x.detach().cpu().numpy()
    return np.asarray(x)

def load_image_in_grayscale(filepath) -> torch.tensor:
    img_np = cv.imread(filepath, cv.IMREAD_GRAYSCALE)
    if img_np is None:
        raise FileNotFoundError(filepath)
    img_tensor = torch.from_numpy(img_np).float()
    return img_tensor

def sum_of_abs_diff(nparray1: np.array, nparray2: np.array) -> int:
    return (np.abs(nparray1 - nparray2)).sum().item()


def scanlines(tb_left: np.array, tb_right: np.array):
    row_idx = 152
    col_idx1 = 102
    col_len = 100
    tb_left_cropped = tb_left[row_idx][col_idx1 : col_idx1 + col_len]

    g_best = None
    d_best = None
    #for d in range(col_len + 1):  # TODO: check max disparity
    for d in range(MAX_DISPARITY + 1):
        start = col_idx1 - d
        end = start + col_len
        if start < 0 or end > tb_right.shape[1]:
            continue
        tb_right_cropped = tb_right[row_idx][start:end]
        g = sum_of_abs_diff(tb_left_cropped, tb_right_cropped)
        if g_best is None or g < g_best:
            g_best = g
            d_best = d 
    print_debug(f"g_best: {g_best}, d_best: {d_best}")

    return d_best


def plot_1d_array(array, title, xlabel=None, ylabel=None, save_image=True):
    domain = range(len(array))
    plt.plot(domain, array, marker="o")
    plt.xlabel(title)
    plt.ylabel(xlabel)
    plt.title(ylabel)
    plt.grid(True)
    if save_image:
        plt.savefig(f"figure/{title}.png")
    plt.show()


def plot_2d_array_as_image(array2d: np.array, title, save_image=True):
    plt.imshow(array2d, cmap="gray")
    plt.title(title)
    plt.colorbar()
    if save_image:
        plt.savefig(f"figure/{title}.png")
    plt.show()


def shift_array(nparray: np.array, d: int) -> np.array:
    shifted = np.zeros_like(nparray)
    if d == 0:
        shifted[:, :] = nparray[:, :]
    elif d > 0:
        shifted[:, d:] = nparray[:, :-d]
    elif d < 0:
        shifted[:, : nparray.shape[1] + d] = nparray[:, -d:]
    return shifted


if DEBUG:
    a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    assert (shift_array(a, 1) == [[0, 1, 2], [0, 4, 5], [0, 7, 8]]).all()
    assert (shift_array(a, 2) == [[0, 0, 1], [0, 0, 4], [0, 0, 7]]).all()


def auto_correlation(tb_right):
    '''
    q3: for each dispage 3, in [0, MAX_DISPARITY], compute the absolute difference between
    the original right image and the shifted right image. Then, record the value at (152, 152)
    to form a 1D array of auto-correlation values.
    Return the 1D array.
    '''
    auto_correlations = []
    for d in range(MAX_DISPARITY + 1):
        abs_diff = np.abs(tb_right - shift_array(tb_right, d))
        auto_correlation_value = abs_diff[152, 152].item()
        auto_correlations.append(auto_correlation_value)
    if DEBUG:
        plot_1d_array(
            auto_correlations,
            title="auto_correlation",
            xlabel="disparity",
            ylabel="auto-correlation value at (152, 152)",
        )
    return auto_correlations



def convolve2d_torch(array: np.array, kernel_size: int):
    as_tensor = torch.tensor(array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    kernel = torch.tensor(np.ones((kernel_size, kernel_size))).unsqueeze(0).unsqueeze(0)
    convolved = nn.functional.conv2d(as_tensor, kernel, padding=kernel_size // 2)
    if DEBUG:
        assert convolved.shape == as_tensor.shape

    return np.array(convolved.squeeze().squeeze())
    