## Image Testing the Model

### Helpers

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

def label_to_rgb(mask_tensor, color_map):
    # Create an empty RGB image

    rgb_image = torch.zeros(3, mask_tensor.shape[-2], mask_tensor.shape[-1], dtype=torch.uint8)

    # Map each label to its RGB color
    for label, color in color_map.items():
        for channel, intensity in enumerate(color):
            rgb_image[channel][mask_tensor == label] = intensity

    return rgb_image

def display_results(input_image_path, ground_truth_image_path, model, color_map):
    # Load and process the input image

    crop_size = 1024
    input_image = Image.open(input_image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    # Get model prediction
    model.eval()
    with torch.no_grad():
        prediction = model(input_tensor)
    _, predicted_mask = torch.max(prediction, 1)

    # Convert tensor to color-coded image
    predicted_rgb = label_to_rgb(predicted_mask.squeeze(0).cpu(), color_map)
    predicted_image = transforms.ToPILImage()(predicted_rgb)

    # Display results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Display input image

    transform = transforms.Compose([transforms.CenterCrop(crop_size)])

    axes[0].imshow(transform(input_image))
    axes[0].set_title("Input Image")
    axes[0].axis("off")

    # Display predicted mask
    axes[1].imshow((predicted_image))
    axes[1].set_title("Model Prediction")
    axes[1].axis("off")

    # Display ground truth, if provided
    if ground_truth_image_path:
        ground_truth = Image.open(ground_truth_image_path)
        axes[2].imshow(transform(ground_truth), cmap="gray")
        axes[2].set_title("Ground Truth")
        axes[2].axis("off")
    else:
        axes[2].remove()

    plt.tight_layout()
    plt.show()

### Loading the Model

In [None]:
from mapillary_vistas_dataset import MapillaryVistasDataset
from unet import UNet
import torch

SAVED_MODEL_NAME = "unet-256-crop-50-2000-001"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if SAVED_MODEL_NAME:
    saved_model_path = f"/virtual/csc490_mapillary/models/{SAVED_MODEL_NAME}/{SAVED_MODEL_NAME}.pth"
    print(f"Evaluating {saved_model_path}...")
    model = UNet(in_channels=3, out_channels=MapillaryVistasDataset.NUM_CLASSES).to(device)
    model_info = torch.load(saved_model_path)
    model.load_state_dict(model_info['state_dict'])
    print(f"Accuracies: {model_info['accuracies']}")
    print(f"Losses: {model_info['losses']}")
else:
    print(f"Evaluating {SAVED_MODEL_NAME}..")

### View Image

In [None]:
# Assuming your model is already loaded and on the correct device.
ROOT_DIR = "/virtual/csc490_mapillary/data_v12"
image = "03G8WHFnNfiJR-457i0MWQ"
path_input_image = ROOT_DIR + f'/validation/images/{image}.jpg'
path_labeled_image = ROOT_DIR + f'/validation/labels/{image}.png'
display_results(path_input_image, path_labeled_image, model, MapillaryVistasDataset.i_to_color)