## 1. Imports

In [1]:
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from PIL import Image
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm import tqdm
# !pip install torchmetrics
from torchmetrics import JaccardIndex

  from .autonotebook import tqdm as notebook_tqdm


## 2. Dataset and Transformations

In [3]:
class BinaryPeatlandDataset(Dataset):
    """Custom PyTorch Dataset for loading images and BINARY (Path/BG) masks."""
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.transform = transform
        self.image_filenames = sorted(os.listdir(self.images_dir))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = self.images_dir / img_name
        mask_path = self.masks_dir / img_name
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path))
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        mask = mask.long()
        return image, mask

# --- Define Transforms for Evaluation ---
IMG_HEIGHT = 480
IMG_WIDTH = 640
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

eval_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
    ToTensorV2(),
])

## 3. Configurations

In [4]:
RUN_NAME = "binary_unet_2025-08-03_22-53-32" # <--- CHANGE THIS

# --- Paths and settings derived automatically ---
METRICS_DIR = Path("metrics") / RUN_NAME
MODEL_PATH = METRICS_DIR / "best_binary_model.pth"
BATCH_SIZE = 4

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"Evaluating model from run: {RUN_NAME}")
print(f"Using device: {DEVICE}")

Evaluating model from run: binary_unet_2025-08-03_22-53-32
Using device: mps


## 4. Load Dataset

In [5]:
BASE_PROCESSED_DIR = Path("../data/processed/binary_segmentation")
VAL_IMG_DIR = BASE_PROCESSED_DIR / "val" / "images"
VAL_MASK_DIR = BASE_PROCESSED_DIR / "val" / "masks"

# Note: We'll use the validation set as our test set for this evaluation
test_dataset = BinaryPeatlandDataset(
    images_dir=VAL_IMG_DIR,
    masks_dir=VAL_MASK_DIR,
    transform=eval_transform
)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Loaded {len(test_dataset)} samples for testing.")

Loaded 203 samples for testing.


## 5. Load Model

In [6]:
model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=2).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
model.eval()
print("Model loaded successfully.")

Model loaded successfully.


## 6. Evaluation

In [7]:
NUM_CLASSES = 2
CLASS_NAMES = ["Background", "Path"]
jaccard = JaccardIndex(task="multiclass", num_classes=NUM_CLASSES, average=None).to(DEVICE)
total_correct_pixels, total_pixels = 0, 0

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Evaluating on Test Set"):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        jaccard.update(preds, masks)
        total_correct_pixels += (preds == masks).sum().item()
        total_pixels += torch.numel(masks)

pixel_accuracy = (total_correct_pixels / total_pixels) * 100
iou_per_class = jaccard.compute()
mean_iou = iou_per_class.mean()

print("\n--- Evaluation Complete ---")
print(f"Overall Pixel Accuracy: {pixel_accuracy:.2f}%")
print(f"Mean IoU (mIoU): {mean_iou:.4f}")
print("\nIoU per Class:")
for i, class_name in enumerate(CLASS_NAMES):
    print(f"  - {class_name}: {iou_per_class[i]:.4f}")

Evaluating on Test Set: 100%|██████████| 51/51 [00:21<00:00,  2.32it/s]



--- Evaluation Complete ---
Overall Pixel Accuracy: 95.72%
Mean IoU (mIoU): 0.8372

IoU per Class:
  - Background: 0.9519
  - Path: 0.7226


## 7. Save Metrics

In [8]:
metrics_data = {
    'Metric': ['Pixel Accuracy', 'Mean IoU'] + [f'IoU_{name}' for name in CLASS_NAMES],
    'Value': [pixel_accuracy, mean_iou.item()] + iou_per_class.cpu().numpy().tolist()
}
metrics_df = pd.DataFrame(metrics_data)
output_csv_path = METRICS_DIR / "test_set_evaluation.csv"
metrics_df.to_csv(output_csv_path, index=False)
print(f"\nEvaluation metrics saved to: {output_csv_path}")


Evaluation metrics saved to: metrics/binary_unet_2025-08-03_22-53-32/test_set_evaluation.csv


## 8. Visualize

In [9]:
def visualize_predictions(dataset, model, device, num_samples=5):
    vis_dir = METRICS_DIR / "visualizations"
    vis_dir.mkdir(exist_ok=True)
    color_map = np.array([[0, 0, 0], [60, 16, 152]], dtype=np.uint8) # BG: Black, Path: Purple

    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            image_tensor, gt_mask = dataset[i]
            display_image = image_tensor.permute(1, 2, 0).cpu().numpy()
            display_image = (display_image * IMAGENET_STD) + IMAGENET_MEAN
            display_image = np.clip(display_image, 0, 1)
            
            input_tensor = image_tensor.unsqueeze(0).to(device)
            output = model(input_tensor)
            pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
            
            gt_mask_color = color_map[gt_mask.cpu().numpy()]
            pred_mask_color = color_map[pred_mask]

            fig, ax = plt.subplots(1, 3, figsize=(20, 6))
            ax[0].imshow(display_image); ax[0].set_title("Original Image"); ax[0].axis('off')
            ax[1].imshow(gt_mask_color); ax[1].set_title("Ground Truth Mask"); ax[1].axis('off')
            ax[2].imshow(pred_mask_color); ax[2].set_title("Model Prediction"); ax[2].axis('off')
            plt.savefig(vis_dir / f"sample_{i}_comparison.png")
            plt.close()

    print(f"Saved {num_samples} visualization samples to: {vis_dir}")

visualize_predictions(test_dataset, model, DEVICE)

Saved 5 visualization samples to: metrics/binary_unet_2025-08-03_22-53-32/visualizations
