In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import cv2
import torch
import random
import numpy as np
from tqdm import tqdm
from typing import Tuple
import matplotlib.pyplot as plt

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
sys.path.append("..")

from src.models.unet import UNet
from src.lightning_models.unet_lightning_model import UNetLightningModel
from src.datasets.sky_cover_dataset import SkyCoverModule
from src.utils.file import get_paths_recursive
from src.datasets.sky_finder import (
    get_sky_finder_masks,
    get_sky_finder_bounding_boxes,
    get_sky_finder_paths_dict,
)
from src.config import (
    DEVICE,
    UNET_CHECKPOINT_PATH,
    SKY_COVER_WIDTH,
    SKY_COVER_HEIGHT,
    SKY_FINDER_IMAGES_PATH,
    SEED,
)

In [None]:
# Get model
def get_model():
    model = UNet(pretrained=True)
    lightning_model = UNetLightningModel.load_from_checkpoint(
        UNET_CHECKPOINT_PATH,
        model=model,
        learning_rate=0,
        weight_decay=0,
        name="unet",
        dataset="sky_finder",
    )
    model = lightning_model.model.to(DEVICE)
    model.eval()
    
    return model

In [None]:
def get_image(
    image_file_path: str,
    mask: np.ndarray,
    bounding_box: Tuple[int, int, int, int],
    mean: np.ndarray = np.array([0.485, 0.456, 0.406], dtype=np.float32),
    std: np.ndarray = np.array([0.229, 0.224, 0.225], dtype=np.float32),
) -> np.ndarray:
    """
    Get image from file path.

    Args:
        image_file_path (str): Path to the image file.
        mask (np.ndarray): Mask for the image.
        bounding_box (Tuple[int, int, int, int]): Bounding box for cropping the image.
        mean (np.ndarray, optional): Mean for normalization. Defaults to [0.485, 0.456, 0.406].
        std (np.ndarray, optional): Standard deviation for normalization. Defaults to [0.229, 0.224, 0.225].

    Returns:
        np.ndarray: Image as a numpy array.
    """
    # Read image
    image = cv2.imread(image_file_path, cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError(f"❌ Failed to read image: {image_file_path}")

    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Crop image and mask based on bounding box
    x_min, y_min, x_max, y_max = bounding_box
    image = image[y_min:y_max, x_min:x_max]
    mask = mask[y_min:y_max, x_min:x_max]

    # Apply inpainting to fill the ground
    inpaint_mask = (~mask).astype(np.uint8) * 255
    image = cv2.inpaint(image, inpaint_mask, 3, cv2.INPAINT_TELEA)

    # Resize and normalize image
    image = cv2.resize(image, (SKY_COVER_WIDTH, SKY_COVER_HEIGHT))
    image = image / 255.0
    image = (image - mean) / std

    return image

def unnormalize_image(
    image: np.ndarray,
    mean: np.ndarray = np.array([0.485, 0.456, 0.406], dtype=np.float32),
    std: np.ndarray = np.array([0.229, 0.224, 0.225], dtype=np.float32),
) -> np.ndarray:
    image = image * std + mean
    image = np.clip(image, 0, 1)
    image = (image * 255).astype(np.uint8)

    return image

In [None]:
# Get model
model = get_model()

# Get image file paths
image_file_paths = get_paths_recursive(
    folder_path=SKY_FINDER_IMAGES_PATH,
    match_pattern="*.jpg",
    path_type="f",
    recursive=True,
)
random.shuffle(image_file_paths)
print(f"✅ Found {len(image_file_paths)} images.")

paths_dict = get_sky_finder_paths_dict()
masks = get_sky_finder_masks(paths_dict)
bounding_boxes = get_sky_finder_bounding_boxes(paths_dict)

In [None]:
# Generate embeddings
for image_file_path in tqdm(
    image_file_paths, desc="⌛ Generating embeddings...", unit="file"
):
    sky_type = image_file_path.split("/")[-3]
    camera_id = image_file_path.split("/")[-2]

    if camera_id not in masks:
        print(f"❌ Camera ID {camera_id} not found in masks.")
        continue
    mask = masks[camera_id]

    if camera_id not in bounding_boxes:
        print(f"❌ Camera ID {camera_id} not found in bounding boxes.")
        continue
    bounding_box = bounding_boxes[camera_id]

    image = get_image(
        image_file_path=image_file_path, mask=mask, bounding_box=bounding_box
    )

    input_image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).float().to(DEVICE)
    print(input_image.shape)
    prediction = model(input_image).squeeze().cpu().detach().numpy()
    prediction = np.clip(prediction, 0, 1)
    prediction = (prediction * 255).astype(np.uint8)

    # plot
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(unnormalize_image(image))
    plt.title("Input Image")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(prediction, cmap="gray")
    plt.title("Predicted Mask")
    plt.axis("off")
    plt.tight_layout()
    plt.show()