In [None]:
import torch
import matplotlib.pyplot as plt
import random
from Segnet import SegNet

In [None]:
def make_predictions(
    model: torch.nn.Module,
    data: list,
    device: torch.device = "cpu"
):
    model.to(device)
    model.eval()
    predictions = []
    with torch.inference_mode():
        for sample in data:
            sample = torch.unsqueeze(sample, dim=0).to(device)
            pred = model(sample)
            pred = pred.squeeze()
            predictions.append(pred)

    return predictions

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load('best_depth_model.pth', weights_only=False)
model = SegNet(
    in_channels=3,
    out_channels=1,
    features=checkpoint['config']['features']
)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"features: {checkpoint['config']['features']}\nlearning rate: {checkpoint['config']['lr']}\nbatch size: {checkpoint['config']['batch_size']}\nloss: {checkpoint['test_loss']}")

_, _, testset = torch.utils.data.random_split(dataset, [0.7, 0.15, 0.15])
samples = random.sample(range(len(testset)), 5)
predictions = make_predictions(model, [testset[i][0] for i in samples], device=device)

# Show original image, ground truth depth, and predicted depth
fig, axes = plt.subplots(3, 5, figsize=(15, 9))

for i, idx in enumerate(samples):
    image, depth_map = testset[idx]
    pred = predictions[i].cpu().numpy()

    # Convert tensors back to numpy for plotting
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()
    if isinstance(depth_map, torch.Tensor):
        depth_map = depth_map.numpy().squeeze()

    # Row 1: Original images
    axes[0, i].imshow(image)
    axes[0, i].set_title(f'Image {idx}')
    axes[0, i].axis('off')

    # Row 2: Ground truth depth maps
    axes[1, i].imshow(depth_map, cmap='gray')
    axes[1, i].set_title(f'Ground Truth {idx}')
    axes[1, i].axis('off')

    # Row 3: Predicted depth maps
    axes[2, i].imshow(pred, cmap='gray')
    axes[2, i].set_title(f'Predicted {idx}')
    axes[2, i].axis('off')

plt.tight_layout()
plt.show()