In [4]:
import os
import torch
import numpy as np
from PIL import Image

CLASS_INDEX_COLORS = [
    (128, 64, 128),         # road - Purple
    (244, 35, 232),         # flat - Pink
    (220, 20, 60),          # human - Red
    (0, 0, 142),            # vehicle - Blue
    (70, 70, 70),           # construction - Gray
    (107, 142, 35),         # background - Green
]

OUTPUT_DIR_COLOR = "converted_predictions/color"
OUTPUT_DIR_OBJECT = "converted_predictions/object"

os.makedirs(OUTPUT_DIR_COLOR, exist_ok=True)
os.makedirs(OUTPUT_DIR_OBJECT, exist_ok=True)


def convert_macroclass_to_color(pred_class):
    """Convert [H, W] prediction (0-5) to RGB image using macro class colors"""
    pred_np = pred_class.cpu().numpy()
    h, w = pred_np.shape
    color_img = np.zeros((h, w, 3), dtype=np.uint8)

    for class_idx, color in enumerate(CLASS_INDEX_COLORS):
        mask = pred_np == class_idx
        color_img[mask] = color
        
    return Image.fromarray(color_img)


def save_object_mask_channel(pred_object):
    """Compute binary object mask image"""
    object_mask = (pred_object > 0.5).cpu().numpy().astype(np.uint8) * 255
    return Image.fromarray(object_mask)


def process_prediction_file(pred_tensor: torch.Tensor, index: int):
    """
    Saves the macro-class color image and binary object mask image of a single prediction tensor.
    
    Parameters:
    - pred_tensor: [7, H, W] tensor output from the model
    - index: integer to use in the filename (e.g., 1 for image_1)
    """
    name = f"image_{index}"

    # Color
    pred_macro = pred_tensor[:6]  # [6, H, W]
    class_indices = pred_macro.argmax(dim=0)
    color_img = convert_macroclass_to_color(class_indices)
    color_img.save(os.path.join(OUTPUT_DIR_COLOR, f"{name}_color.png"))

    # Object mask
    object_img = save_object_mask_channel(pred_tensor[6])
    object_img.save(os.path.join(OUTPUT_DIR_OBJECT, f"{name}_object.png"))


INPUT_DIR = "saved_predictions"
prediction_files = sorted([f for f in os.listdir(INPUT_DIR) if f.endswith(".pt")])

for pred_file in sorted(prediction_files):
    idx = int(os.path.splitext(pred_file)[0].split('_')[1])
    
    path = os.path.join(INPUT_DIR, pred_file)
    pred_tensor = torch.load(path)  # [7, H, W]
    process_prediction_file(pred_tensor, idx)

print("All predictions processed and saved.")

  pred_tensor = torch.load(path)  # [7, H, W]


All predictions processed and saved.
