In [40]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from model import Model
from tqdm import tqdm
import cv2
import os
from time import time

In [41]:
MODEL_PATH = "../models/gen-4"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_PATH = "../inputs/nyu_data/data/nyu2_test"

In [42]:
depth_model = Model().to(DEVICE)
depth_model.load_state_dict(torch.load("../models/gen-4/3.pth", weights_only=True))

<All keys matched successfully>

In [43]:
# print(list(depth_model.named_parameters()))
# print(list(depth_model.named_buffers()))
print(dir(depth_model))

['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hoo

In [44]:
for name, module in depth_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
        prune.remove(module, "weight")

In [45]:
image_data = {}

"""
STORAGE FORMAT:
{
    key1 : [image1, depth1],
    key2 : [image2, depth2]
}
"""

# Read images from the given path
image_files = os.listdir(DATA_PATH)

# Classify and store images based on their endings
for image_file in tqdm(image_files, desc="Processing images"):
    image_name, ext = os.path.splitext(image_file)
    key = image_name.split("_")[0]  # Extract the part before '_'

    image_path = os.path.join(DATA_PATH, image_file)

    if image_file.endswith("_depth.png"):
        if key in image_data:
            image_data[key] = (image_data[key][0], cv2.imread(image_path))
        else:
            image_data[key] = (None, cv2.imread(image_path))

    elif image_file.endswith("_colors.png"):
        if key in image_data:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_data[key] = (image, image_data[key][1])
        else:
            image_data[key] = (cv2.imread(image_path), None)

Processing images: 100%|██████████| 1308/1308 [00:08<00:00, 149.55it/s]


In [46]:
outputs = []

progress_bar = tqdm(image_data.items(), desc="Evaluating", total=len(image_data))

for key, (color_img, ground_truth) in progress_bar:

    image_tensor = (
        torch.from_numpy(cv2.resize(color_img, (640, 480)))
        .float()
        .permute(2, 0, 1)
        .unsqueeze(0)
        .to(DEVICE)
    )

    with torch.no_grad():
        start = time()
        predicted = depth_model(image_tensor).cpu().squeeze().numpy()
        end = time()

    inference_time = end - start

    progress_bar.set_description(
        f"Evaluating (Inference Time: {inference_time:.4f} s/image)"
    )

    outputs.append((color_img, ground_truth, 1 - predicted))

Evaluating (Inference Time: 0.1088 s/image):   5%|▌         | 33/654 [00:03<01:12,  8.61it/s]


KeyboardInterrupt: 