In [1]:
import torch
import cv2
import matplotlib.pyplot as plt
from model import TransferLearning, TorchUNET, CustomUNET1, CustomUNET2
import argparse
import os
from time import time
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from tqdm import tqdm

In [10]:
MODELS_FOLDER = "../models"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "../inputs/nyu_data/data/nyu2_test"
MODEL_CLASS = "custom-UNET"
MODEL_GEN = "gen-2"

MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_CLASS, MODEL_GEN)
SAVE_PATH = os.path.join("../outputs", MODEL_CLASS, MODEL_GEN)

POOL_SIZE = 3
STRIDE = (POOL_SIZE - 1) // 2

In [11]:
for model in os.listdir(MODEL_PATH):
    tokens = model.split(".")
    if tokens[0] == "3" or tokens[0] == "7":
        model_path = model 

if MODEL_CLASS == "torch-UNET":
    torchUNET = TorchUNET()
    depth_model = torchUNET.get_model().to(DEVICE)
    depth_model.load_state_dict(
        torch.load(os.path.join(MODEL_PATH, model_path), map_location=DEVICE, weights_only=True)
    )

elif MODEL_CLASS == "Transfer-Learning":
    depth_model = TransferLearning().to(DEVICE)
    depth_model.load_state_dict(
        torch.load(os.path.join(MODEL_PATH, model_path), map_location=DEVICE, weights_only=True)
    )

elif MODEL_CLASS == "custom-UNET" and MODEL_GEN == "gen-1":
    depth_model = CustomUNET1().to(DEVICE)
    depth_model.load_state_dict(
        torch.load(os.path.join(MODEL_PATH, model_path), map_location=DEVICE, weights_only=True)
    )

elif MODEL_CLASS == "custom-UNET" and MODEL_GEN == "gen-2":
    depth_model = CustomUNET2().to(DEVICE)
    depth_model.load_state_dict(
        torch.load(os.path.join(MODEL_PATH, model_path), map_location=DEVICE, weights_only=True)
    )

In [4]:
image_data = {}

"""
STORAGE FORMAT:
{
    key1 : [image1, depth1],
    key2 : [image2, depth2]
}
"""

# Read images from the given path
image_files = os.listdir(DATA_PATH)

# Classify and store images based on their endings
for image_file in tqdm(image_files, desc="Processing images"):
    image_name, ext = os.path.splitext(image_file)
    key = image_name.split("_")[0]  # Extract the part before '_'

    image_path = os.path.join(DATA_PATH, image_file)

    if image_file.endswith("_depth.png"):
        if key in image_data:
            image_data[key] = (image_data[key][0], cv2.imread(image_path))
        else:
            image_data[key] = (None, cv2.imread(image_path))

    elif image_file.endswith("_colors.png"):
        if key in image_data:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_data[key] = (image, image_data[key][1])
        else:
            image_data[key] = (cv2.imread(image_path), None)

Processing images: 100%|██████████| 1308/1308 [00:08<00:00, 146.64it/s]


In [12]:
outputs = []

progress_bar = tqdm(image_data.items(), desc="Evaluating", total=len(image_data))

for key, (color_img, ground_truth) in progress_bar:

    image_tensor = (
        torch.from_numpy(cv2.resize(color_img, (320, 240)))
        .float()
        .permute(2, 0, 1)
        .unsqueeze(0)
        .to(DEVICE)
    )

    with torch.no_grad():
        start = time()
        predicted = depth_model(image_tensor).cpu().squeeze().numpy()
        end = time()
    
    inference_time = end - start

    predicted_min = predicted.min()
    predicted_max = predicted.max()
    predicted = (predicted - predicted_min) / (predicted_max - predicted_min)

    progress_bar.set_description(f"Evaluating (Inference Time: {inference_time:.4f} s/image)")

    outputs.append((color_img, ground_truth, predicted))

Evaluating (Inference Time: 0.0412 s/image): 100%|██████████| 654/654 [00:28<00:00, 23.16it/s]


In [13]:
os.makedirs(SAVE_PATH, exist_ok=True)

for i, (color_img, ground_truth, predicted) in tqdm(
    enumerate(outputs), desc="Saving images", total=len(outputs)
):
    filename = f"{i}.png"
    plt.imsave(os.path.join(SAVE_PATH, filename), predicted, cmap="viridis")

Saving images: 100%|██████████| 654/654 [00:18<00:00, 35.94it/s]
