# Evaluate UNet and UNetPP CNN models

## Notebook Purpose

The purpose of this notebook is to test **U-Net** and **U-Net++** models based on various metrics.  

To keep the notebook clean and focused, a separate `load_unets` file is included, which handles the loading of the models.


## Requirements for Using the Notebook

To successfully use this notebook, the following paths and configurations are required:

### 1. Model Checkpoint Paths
Ensure the U-Net and U-Net++ model checkpoints are available at the specified paths:
- **U-Net Checkpoint:** `models/200.sav`
- **U-Net++ Checkpoint:** `models/dice_unetpp.sav`

### 2. Data Directories
The dataset should be organized as follows:
- **Image Directory:** `test_data/images`
- **Label Directory:** `test_data/labels`

### 3. Label File Naming Convention
The label files must follow a specific naming pattern to be correctly matched with the image files. In the dataset loader, the label path is constructed as:
```python
label_path = os.path.join(self.label_dir, file.replace('clip_reproject', 'clip_reproject_classified'))


### Imports

In [None]:
import numpy as np
import torch
from osgeo import gdal
import skimage as sk
import os
from torch.utils.data import Dataset, DataLoader
import rasterio
from torchvision import transforms
import json
from load_unets import load_models

### Dataset for UNet and UNetPP

### CustomSatelliteDataset

This class is based on an implementation originally created by **Zoltán Gugolya**,  
simplified and adapted for testing purposes.


In [None]:
class CustomSatelliteDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_filenames = [f for f in os.listdir(image_dir) if f.lower().endswith('.tif')]
        self.images, self.labels = self.load_tif_images_and_labels()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
            
        image = image.permute(1, 0, 2)
        
        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        
        return image, label

    def load_tif_images_and_labels(self):
        tif_images = []
        tif_labels = []

        for file in self.image_filenames:
            image_path = os.path.join(self.image_dir, file)
            label_path = os.path.join(self.label_dir, file.replace('clip_reproject', 'clip_reproject_classified'))

            try:
                ds = gdal.Open(image_path, gdal.GA_ReadOnly)
                rows = ds.RasterYSize
                cols = ds.RasterXSize
                bands = ds.RasterCount
                array = ds.ReadAsArray().astype(dtype="float32")

                if array.shape != (bands, rows, cols):
                    array = np.reshape(array, [bands, rows, cols])

                max_value = array.max(axis=(1, 2), keepdims=True)
                array = array / max_value

                with rasterio.open(label_path) as label_img:
                    label_array = label_img.read(1).astype('f4')

                labeled_img = sk.measure.label(label_array)
                regions = sk.measure.regionprops(labeled_img)

                for props in regions:
                    minr, minc, maxr, maxc = props.bbox
                    center = np.array([(minr + maxr) // 2, (minc + maxc) // 2])
                    start_x = max(center[0] - 64, 0)
                    start_y = max(center[1] - 64, 0)
                    end_x = min(start_x + 128, array.shape[1])
                    end_y = min(start_y + 128, array.shape[2])

                    cut = array[:, start_x:end_x, start_y:end_y]
                    cut_label = label_array[start_x:end_x, start_y:end_y]

                    if cut.shape[1:] != (128, 128):
                        cut = self.pad_to_shape(cut, (128, 128))
                        cut_label = self.pad_to_shape(cut_label, (128, 128))

                    tif_images.append(cut)
                    tif_labels.append(cut_label)

            except Exception as e:
                print(f"Error occurred while loading {image_path} or {label_path}: {e}")
        
        return np.stack(tif_images, axis=0), np.stack(tif_labels, axis=0)

    def pad_to_shape(self, array, target_shape):
        if len(array.shape) == 2:
            padding = [(max(target_shape[i] - array.shape[i], 0) // 2,
                        max(target_shape[i] - array.shape[i], 0) - max(target_shape[i] - array.shape[i], 0) // 2)
                       for i in range(2)]
        elif len(array.shape) == 3:
            padding = [(0, 0)] + [(max(target_shape[i] - array.shape[i + 1], 0) // 2,
                                   max(target_shape[i] - array.shape[i + 1], 0) - max(target_shape[i] - array.shape[i + 1], 0) // 2)
                                  for i in range(2)]
        else:
            raise ValueError("Invalid input shape for padding")

        return np.pad(array, padding, mode='constant', constant_values=0)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

image_dir_merged = 'test_data/images'
label_dir_merged = 'test_data/labels'

test_dataset_merged = CustomSatelliteDataset(image_dir_merged, label_dir_merged, transform=transform)

test_merged_loader = DataLoader(test_dataset_merged, batch_size=8, shuffle=False)

In [None]:
print(len(test_dataset_merged))

for images, labels in test_merged_loader:
    print(images.shape, labels.shape)

### Load models (UNet, UNetPP)

In [11]:
unet_checkpoint_path = 'models/200.sav'
unetpp_checkpoint_path = 'models/dice_unetpp.sav'

model_unet, model_unetpp = load_models(unet_checkpoint_path, unetpp_checkpoint_path)

### Create predictions

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ground_truths_merged = []

#### UNet

In [None]:
predictions_unet_merged = []

model_unet.to(device)
model_unet.eval()

for inputs, labels in test_merged_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        outputs_unet = model_unet.predict(inputs)
        preds_unet = (outputs_unet > 0.5).float()
    predictions_unet_merged.append(preds_unet.cpu())
    ground_truths_merged.append(labels.cpu())

#### UNetPP

In [14]:
predictions_unetpp_merged = []

model_unetpp.to(device)
model_unetpp.eval()

for inputs, labels in test_merged_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        outputs_unetpp = model_unetpp.predict(inputs)
        preds_unetpp = (outputs_unetpp > 0.5).float()
    predictions_unetpp_merged.append(preds_unetpp.cpu())

  image = torch.tensor(image, dtype=torch.float32)


In [None]:
# Merge predictions and ground truths into a single tensor along dimension 0.
# This is necessary when predictions and ground truths are collected in batches during evaluation
# and need to be combined into one tensor for overall metric computation.
predictions_unet_merged = torch.cat(predictions_unet_merged, dim=0)
predictions_unetpp_merged = torch.cat(predictions_unetpp_merged, dim=0)
ground_truths_merged = torch.cat(ground_truths_merged, dim=0)

# Flatten the merged tensors to convert them into a 1D format.
# This step is required for computing metrics like accuracy, precision, recall, or F1-score,
# which typically operate on flat arrays of predictions and labels.
preds_unet_flat_merged = predictions_unet_merged.view(-1)
preds_unetpp_flat_merged = predictions_unetpp_merged.view(-1)
labels_flat_merged = ground_truths_merged.view(-1)

In [16]:
binary_labels_merged = torch.where(labels_flat_merged == 100, 1, 0)
print("Unique values in preds_unet_flat:", preds_unet_flat_merged.unique())
print("Unique values in labels_flat:", binary_labels_merged.unique())

Unique values in preds_unet_flat: tensor([0., 1.])
Unique values in labels_flat: tensor([0, 1])


### METRICES

#### Unet

In [None]:
TP_unet = ((preds_unet_flat_merged == 1) & (binary_labels_merged == 1)).sum().item()
TN_unet = ((preds_unet_flat_merged == 0) & (binary_labels_merged == 0)).sum().item()
FP_unet = ((preds_unet_flat_merged == 1) & (binary_labels_merged == 0)).sum().item()
FN_unet = ((preds_unet_flat_merged == 0) & (binary_labels_merged == 1)).sum().item()

# Accuary
accuracy_unet = (TP_unet + TN_unet) / (TP_unet + TN_unet + FP_unet + FN_unet)
# Precision
precision_unet = TP_unet / (TP_unet + FP_unet) if (TP_unet + FP_unet) > 0 else 0
# Recall
recall_unet = TP_unet / (TP_unet + FN_unet) if (TP_unet + FN_unet) > 0 else 0
# F1-score
f1_score_unet = 2 * (precision_unet * recall_unet) / (precision_unet + recall_unet) if (precision_unet + recall_unet) > 0 else 0

torch.Size([475136]) torch.Size([475136])


#### UnetPP

In [None]:
TP_unetpp = ((preds_unetpp_flat_merged == 1) & (binary_labels_merged == 1)).sum().item()
TN_unetpp = ((preds_unetpp_flat_merged == 0) & (binary_labels_merged == 0)).sum().item()
FP_unetpp = ((preds_unetpp_flat_merged == 1) & (binary_labels_merged == 0)).sum().item()
FN_unetpp = ((preds_unetpp_flat_merged == 0) & (binary_labels_merged == 1)).sum().item()

# Accuracy
accuracy_unetpp = (TP_unetpp + TN_unetpp) / (TP_unetpp + TN_unetpp + FP_unetpp + FN_unetpp)
# Precision
precision_unetpp = TP_unetpp / (TP_unetpp + FP_unetpp) if (TP_unetpp + FP_unetpp) > 0 else 0
# VRecall
recall_unetpp = TP_unetpp / (TP_unetpp + FN_unetpp) if (TP_unetpp + FN_unetpp) > 0 else 0
# F1-score
f1_score_unetpp = 2 * (precision_unetpp * recall_unetpp) / (precision_unetpp + recall_unetpp) if (precision_unetpp + recall_unetpp) > 0 else 0

### Summary

In [None]:
metrics_unet = {
    'True Positives (TP)': TP_unet,
    'False Positives (FP)': FP_unet,
    'True Negatives (TN)': TN_unet,
    'False Negatives (FN)': FN_unet,
    'Accuracy': accuracy_unet,
    'Precision': precision_unet,
    'Recall': recall_unet,
    'F1-score': f1_score_unet
}


metrics_unetpp = {
    'True Positives (TP)': TP_unetpp,
    'False Positives (FP)': FP_unetpp,
    'True Negatives (TN)': TN_unetpp,
    'False Negatives (FN)': FN_unetpp,
    'Accuracy': accuracy_unetpp,
    'Precision': precision_unetpp,
    'Recall': recall_unetpp,
    'F1-score': f1_score_unetpp
}

In [None]:
def convert_numpy_types(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


results = {
    'UNet': metrics_unet,
    'UNet++': metrics_unetpp
}

with open('model_metrics_unets.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4, default=convert_numpy_types)