In [None]:
import os
import torch
import easydict
import numpy as np
import torch.nn as nn
from PIL import Image
from helper import load_model
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


In [None]:
args = easydict.EasyDict({
    "simple": False,
    "bins": 36, 'act':'linear',
    'backbone' : 'DepthHistB', 
    'path_pretrained' : "./Models/pretrained/swin_base_patch4_window7_224_22k.pth",
    'path_pth_model' : './checkpoints/NYUv2/model_kitti_cauchy.pt',
    'kernel':'cauchy',})


In [None]:
model = load_model(args)

In [None]:
def denormalize(x):
    mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
    std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
    return x * std + mean


def preprocess_image(image_path):
    """
    Load an image, apply necessary transformations including normalization,
    and return a tensor ready for model inference.
    
    Args:
        image_path (str): Path to the input image.
    
    Returns:
        torch.Tensor: Preprocessed image tensor ready for model input.
    """
    # Load the image and ensure it's in RGB format
    image = Image.open(image_path).convert("RGB")
    
    # Convert the image to a NumPy array and scale pixel values to [0, 1]
    image_np = np.array(image).astype(np.float32) / 255.0

    # Convert to torch tensor
    image_tensor = torch.from_numpy(image_np)
    image_tensor = image_tensor.permute(2, 0, 1)

    # Apply normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    image_tensor = normalize(image_tensor)


    # Add batch dimension
    return image_tensor.unsqueeze(0)


In [None]:

def plot_images(input_image, output_image_1):
    """
    Plots the input and output images.
    
    Parameters:
    input_image (torch tensor): Input image of shape (1, 3, H, W)
    output_image (torch tensor): Output image of shape (1, 1, H, W)
    """
    # Convert torch tensors to numpy arrays
    input_image_np  = input_image.squeeze(0).permute(1, 2, 0).detach().numpy()
    output_image_np_1 = output_image_1.squeeze(0).squeeze(0).detach().cpu().numpy()
    # Plot the input image
    plt.figure(figsize=(16, 12))
    
    plt.subplot(1, 2, 1)
    plt.imshow(input_image_np.astype(np.uint8))
    plt.title("RGB image")
    plt.axis('off')
    
    # Plot the output image
    plt.subplot(1, 2, 2)
    plt.imshow(output_image_np_1, cmap='magma_r')
    plt.title("Depth map")
    plt.axis('off')
    
    plt.show()

In [None]:
image = preprocess_image("/home/rcam/Pictures/Results/0000000005.png")
model = model.eval()

In [None]:
pred = model(image)
pred = F.interpolate(pred[0], image.shape[2:])
plot_images(denormalize(image)*255, pred)
