In [8]:
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../"))

In [9]:
from pathlib import Path
import torch
import numpy as np
from src.models.unet import UNet
from src.models.slim_unet import SlimUNet
from src.visualization import plot_image_and_mask_and_prediction
from src.utils import load_and_normalize_tiff,load_mask
from skimage.transform import resize
from tqdm import tqdm 
import warnings

# Set matplotlib to display inline
warnings.filterwarnings("ignore")

In [10]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 512

In [11]:
def load_model(model_path):
    model = UNet(n_channels=4, n_classes=1).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))
    model.eval()
    return model

In [12]:
def predict(model, image_tensor, threshold):
    """Make prediction for a single image tensor"""
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(DEVICE)
        output = model(image_tensor)
        prob = torch.sigmoid(output).squeeze().cpu().numpy()
        mask = (prob > threshold).astype(np.uint8)
    return mask

In [13]:
def main(data_folder,mask_folder, model_path, threshold=0.5):
    data_folder = Path(data_folder)
    mask_folder = Path(mask_folder)
    model = load_model(model_path)

    for image_file in tqdm(sorted(data_folder.glob("*.tif")), desc="Processing Images", unit="image"):
        try:
            # Load and preprocess image
            image = load_and_normalize_tiff(image_file)
            true_mask = load_mask(mask_folder / image_file.name)

            # Validate input shape
            if image.shape[0] != 4:
                raise ValueError(f"Image must have 4 channels (RGB + IR), found {image.shape[0]} channels.")

            # Resize if needed
            if image.shape[1] != IMG_SIZE or image.shape[2] != IMG_SIZE:
                image = resize(image.transpose(1, 2, 0), (IMG_SIZE, IMG_SIZE), 
                          order=1, mode='reflect', preserve_range=True)
                image = image.transpose(2, 0, 1)

            # Normalize each band
            for i in range(image.shape[0]):
                band = image[i]
                if band.max() > band.min():
                    image[i] = (band - band.min()) / (band.max() - band.min())
                else:
                    image[i] = 0

            # Convert to tensor and predict
            image_tensor = torch.tensor(image).float()  # (C, H, W)
            mask = predict(model, image_tensor, threshold)
            
            # Post-process prediction
            mask_resized = resize(mask, (512, 512), order=0, mode='reflect', preserve_range=True)
            mask_resized = (mask_resized > 0.5).astype(np.uint8) 

            plot_image_and_mask_and_prediction(image, true_mask, mask_resized, title=image_file.stem,visualize=True)

        except Exception as e:
            print(f"Error processing {image_file.name}: {str(e)}")
            continue

In [14]:
main(
    data_folder="test/data",
    mask_folder="test/masks",
    model_path="../outputs/models/Unet_model.pth",
    threshold=0.35
)

Processing Images: 100%|██████████| 120/120 [02:03<00:00,  1.03s/image]
