## Imports

In [4]:
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from PIL import Image
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import JaccardIndex
from transformers import Dinov2Model, AutoImageProcessor

## 2. Dataset Class

In [5]:
DINO_IMAGE_SIZE = 224

In [6]:
class PeatlandDinoDataset(Dataset):
    """Custom PyTorch Dataset adapted for DinoV2's image processor."""
    def __init__(self, images_dir, masks_dir, image_processor, transform=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.image_processor = image_processor
        self.transform = transform
        self.image_filenames = sorted(os.listdir(self.images_dir))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = self.images_dir / img_name
        mask_path = self.masks_dir / img_name
        
        image = Image.open(img_path).convert("RGB")
        mask = np.array(Image.open(mask_path))

        if self.transform:
            augmented = self.transform(image=np.array(image), mask=mask)
            image = Image.fromarray(augmented['image'])
            mask = augmented['mask']
        
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.squeeze(0)
        mask = torch.from_numpy(mask).long()
        return pixel_values, mask

class DinoV2ForSemanticSegmentation(nn.Module):
    def __init__(self, num_classes=5):
        super(DinoV2ForSemanticSegmentation, self).__init__()
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        for param in self.dinov2.parameters():
            param.requires_grad = False
        self.head = nn.Sequential(
            nn.Conv2d(768, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, pixel_values):
        outputs = self.dinov2(pixel_values, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        patch_tokens = last_hidden_state[:, 1:, :]
        batch_size, seq_len, num_channels = patch_tokens.shape
        height = width = int(seq_len**0.5)
        feature_map = patch_tokens.permute(0, 2, 1).contiguous().reshape(batch_size, num_channels, height, width)
        logits = self.head(feature_map)
        final_logits = nn.functional.interpolate(logits, size=(DINO_IMAGE_SIZE, DINO_IMAGE_SIZE), mode='bilinear', align_corners=False)
        return final_logits

## 3. Transforms

In [7]:
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
eval_transform = A.Compose([A.Resize(height=DINO_IMAGE_SIZE, width=DINO_IMAGE_SIZE)])

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


## 4. Configuration

In [9]:
RUN_NAME = "dinov2_2025-08-03_17-48-54" # <--- CHANGE THIS TO THE RUN NAME TO EVALUATE

METRICS_DIR = Path("metrics") / RUN_NAME
# The best model was saved with this name due to early stopping
MODEL_PATH = METRICS_DIR / "best_model.pth" 
BATCH_SIZE = 4

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"Evaluating model from run: {RUN_NAME}")
print(f"Using device: {DEVICE}")

Evaluating model from run: dinov2_2025-08-03_17-48-54
Using device: mps


## 5. Load Test Dataset

In [10]:
BASE_PROCESSED_DIR = Path("../data/processed/segmentation")
TEST_IMG_DIR = BASE_PROCESSED_DIR / "test" / "images"
TEST_MASK_DIR = BASE_PROCESSED_DIR / "test" / "masks"

test_dataset = PeatlandDinoDataset(
    images_dir=TEST_IMG_DIR,
    masks_dir=TEST_MASK_DIR,
    image_processor=processor,
    transform=eval_transform
)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Loaded {len(test_dataset)} samples for testing.")

Loaded 204 samples for testing.


## 6. Load Trained Model

In [11]:
model = DinoV2ForSemanticSegmentation(num_classes=5).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
model.eval()
print("Model loaded successfully.")


Model loaded successfully.


## 7. Run Evaluation and calculate Metrics

In [12]:
NUM_CLASSES = 5
CLASS_NAMES = ["PATH", "NATURAL_GROUND", "TREE", "VEGETATION", "IGNORE"]
jaccard = JaccardIndex(task="multiclass", num_classes=NUM_CLASSES, average=None).to(DEVICE)
total_correct_pixels, total_pixels = 0, 0

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Evaluating on Test Set"):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        jaccard.update(preds, masks)
        total_correct_pixels += (preds == masks).sum().item()
        total_pixels += torch.numel(masks)

pixel_accuracy = (total_correct_pixels / total_pixels) * 100
iou_per_class = jaccard.compute()
mean_iou = iou_per_class.mean()

print("\n--- Evaluation Complete ---")
print(f"Overall Pixel Accuracy: {pixel_accuracy:.2f}%")
print(f"Mean IoU (mIoU): {mean_iou:.4f}")
print("\nIoU per Class:")
for i, class_name in enumerate(CLASS_NAMES):
    print(f"  - {class_name}: {iou_per_class[i]:.4f}")

Evaluating on Test Set: 100%|██████████| 51/51 [00:20<00:00,  2.52it/s]



--- Evaluation Complete ---
Overall Pixel Accuracy: 86.29%
Mean IoU (mIoU): 0.6565

IoU per Class:
  - PATH: 0.7115
  - NATURAL_GROUND: 0.7868
  - TREE: 0.2320
  - VEGETATION: 0.7450
  - IGNORE: 0.8070


## 8. Save Metrics

In [13]:
metrics_data = {
    'Metric': ['Pixel Accuracy', 'Mean IoU'] + [f'IoU_{name}' for name in CLASS_NAMES],
    'Value': [pixel_accuracy, mean_iou.item()] + iou_per_class.cpu().numpy().tolist()
}
metrics_df = pd.DataFrame(metrics_data)
output_csv_path = METRICS_DIR / "test_set_evaluation.csv"
metrics_df.to_csv(output_csv_path, index=False)
print(f"\nEvaluation metrics saved to: {output_csv_path}")


Evaluation metrics saved to: metrics/dinov2_2025-08-03_17-48-54/test_set_evaluation.csv


## 9. Visualize Predictions on Sample Images

In [14]:
def visualize_predictions(dataset, model, device, num_samples=5):
    vis_dir = METRICS_DIR / "visualizations"
    vis_dir.mkdir(exist_ok=True)
    color_map = np.array([
        [60, 16, 152], [132, 41, 246], [110, 193, 228],
        [254, 221, 58], [0, 0, 0]
    ], dtype=np.uint8)

    model.eval()
    with torch.no_grad():
        for i in range(num_samples):
            # Get processed data for the model
            image_tensor, gt_mask = dataset[i]
            
            # Load the original, unprocessed image for clean visualization
            original_image_path = dataset.images_dir / dataset.image_filenames[i]
            display_image = Image.open(original_image_path).convert("RGB")
            
            # Get prediction
            input_tensor = image_tensor.unsqueeze(0).to(device)
            output = model(input_tensor)
            pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
            
            # Apply color map to masks
            gt_mask_color = color_map[gt_mask.cpu().numpy()]
            pred_mask_color = color_map[pred_mask]

            fig, ax = plt.subplots(1, 3, figsize=(20, 6))
            fig.suptitle(f"Sample {i}", fontsize=16)
            
            ax[0].imshow(display_image)
            ax[0].set_title("Original Image")
            ax[0].axis('off')
            
            ax[1].imshow(gt_mask_color)
            ax[1].set_title("Ground Truth Mask")
            ax[1].axis('off')
            
            ax[2].imshow(pred_mask_color)
            ax[2].set_title("Model Prediction")
            ax[2].axis('off')
            
            plt.savefig(vis_dir / f"sample_{i}_comparison.png")
            plt.close()

    print(f"Saved {num_samples} visualization samples to: {vis_dir}")

visualize_predictions(test_dataset, model, DEVICE)


Saved 5 visualization samples to: metrics/dinov2_2025-08-03_17-48-54/visualizations
