# Libraries

In [None]:
import cv2
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
import shutil
import random
import time
from pathlib import Path

sns.set_style("dark")

DATA_ROOT = "../input/landcoverai"

In [None]:
OUTPUT_DIR = os.path.join(os.getcwd(), "output")
IMGS_DIR = "../input/landcoverai/images"
MASKS_DIR = "../input/landcoverai/masks"
DATA_ROOT = "../input/landcoverai"
IMG_PATHS = glob.glob(os.path.join(IMGS_DIR, "*.tif"))
MASK_PATHS = glob.glob(os.path.join(MASKS_DIR, "*.tif"))


# Dataset

In [None]:
images_list = list(glob.glob(os.path.join(DATA_ROOT, "images", "*.tif")))
samples = [0,1,2,3]
fig, ax = plt.subplots(figsize = (9,9), nrows = 2, ncols =2)
for i, sample in enumerate(samples):
    r,c = divmod(i,2)
    ax[r,c].imshow(cv2.imread(images_list[sample])/255)
    ax[r,c].axis("off")
plt.suptitle("Sample (s) of high resolution images", fontsize = 15)
plt.tight_layout(pad=0.8)
# plt.savefig("Samples.png")
plt.show()

# Splitting

In [None]:
IMAGE_SIZE = 512

def split_images(TARGET_SIZE = IMAGE_SIZE):
    """
    A function to split the aerial images into squared images of
    size equal to TARGET_SIZE. Stores the new images into
    a directory named output, located in working directory.
    """
    tic = time.time()
    print(f"Splitting the images...\n")
    img_paths = glob.glob(os.path.join(IMGS_DIR, "*.tif"))
    mask_paths = glob.glob(os.path.join(MASKS_DIR, "*.tif"))

    img_paths.sort()
    mask_paths.sort()
    
    if Path(OUTPUT_DIR).exists() and Path(OUTPUT_DIR).is_dir():
        shutil.rmtree(OUTPUT_DIR)
    os.makedirs(OUTPUT_DIR)
    for i, (img_path, mask_path) in enumerate(zip(img_paths, mask_paths)):
        img_filename = os.path.splitext(os.path.basename(img_path))[0]
        mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path)

        assert img_filename == mask_filename and img.shape[:2] == mask.shape[:2]

        k = 0
        for y in range(0, img.shape[0], TARGET_SIZE):
            for x in range(0, img.shape[1], TARGET_SIZE):
                img_tile = img[y:y + TARGET_SIZE, x:x + TARGET_SIZE]
                mask_tile = mask[y:y + TARGET_SIZE, x:x + TARGET_SIZE]

                if img_tile.shape[0] == TARGET_SIZE and img_tile.shape[1] == TARGET_SIZE:
                    out_img_path = os.path.join(OUTPUT_DIR, "{}_{}.jpg".format(img_filename, k))
                    cv2.imwrite(out_img_path, img_tile)

                    out_mask_path = os.path.join(OUTPUT_DIR, "{}_{}_m.png".format(mask_filename, k))
                    cv2.imwrite(out_mask_path, mask_tile)

                k += 1

        print("Processed {} {}/{}".format(img_filename, i + 1, len(img_paths)))
    mins,sec = divmod(time.time()-tic,60)
    print(f"Execution completed in {mins} minutes and {sec:.2f} seconds.")

In [None]:
IMAGE_SIZE = 512
split_images(TARGET_SIZE = IMAGE_SIZE)

In [None]:
labels_cmap = matplotlib.colors.ListedColormap(["#000000", "#A9A9A9",
        "#8B8680", "#D3D3D3", "#FFFFFF"])

def visualize_dataset(num_samples = 8, seed = 42,
                     w = 10, h = 10, nrows = 4, ncols = 4, save_title = None,
                     pad = 0.8, indices = None):
    """
    A function to visualize the images of the dataset along with their
    corresponding masks.
    """
    data_list = list(glob.glob(os.path.join(OUTPUT_DIR, "*.jpg")))
    if indices == None:
        np.random.seed(seed)
        indices = np.random.randint(low = 0, high = len(data_list),
                                   size = num_samples)
    sns.set_style("white")
    fig, ax = plt.subplots(figsize = (h,w), nrows = num_samples//2,
                           ncols = 4)
    for i, idx in enumerate(indices):
        r,rem = divmod(i,2)
        img = cv2.imread(data_list[idx])/255
        mask_pt = data_list[indices[i]].split(".jpg")[0] + "_m.png"
        mask = cv2.imread(mask_pt)[:,:,1]
        ax[r,2*rem].imshow(img)
        ax[r,2*rem].set_title("Sample"+str(i+1))
        ax[r,2*rem+1].imshow(mask, cmap = labels_cmap, interpolation = None,
                            vmin = -0.5, vmax = 4.5)
        ax[r,2*rem+1].set_title("Mask" + str(i+1))
    plt.suptitle("Samples of 512 x 512 images", fontsize = 20)
    plt.tight_layout(pad = 0.8)
    if save_title is not None:
        plt.savefig(save_title + ".png")
    plt.show()

In [None]:
visualize_dataset(num_samples = 8, w = 12, h = 12, pad = 1.4,
                 save_title = "Visualize_dataset", indices = [0,1,17,20,29,5,6,7])

In [None]:
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):
    """
    The main class that handles the dataset. Reads the images from
    OUTPUT_DIR, handles the data augmentation transformations, and converts
    the numpy images to tensors. Filters out images where 90% or more of
    the mask pixels are unlabeled (labeled as 0).
    """
    def __init__(self, mode="train", ratio=None, transforms=None, seed=42):
        self.mode = mode
        self.transforms = transforms
        self.output_dir = OUTPUT_DIR
        self.data_root = DATA_ROOT
        self.unlabeled_threshold = 0.9  # 90% threshold for unlabeled pixels
        
        if mode in ["train", "test", "val"]:
            with open(os.path.join(self.data_root, self.mode + ".txt")) as f:
                self.img_names = f.read().splitlines()
                if ratio is not None:
                    print(f"Using {100 * ratio:.2f}% of the initial {mode} set --> {int(ratio * len(self.img_names))}|{len(self.img_names)}")
                    np.random.seed(seed)
                    self.indices = np.random.randint(low=0, high=len(self.img_names),
                                                     size=int(ratio * len(self.img_names)))
                else:
                    print(f"Using the whole {mode} set --> {len(self.img_names)}")
                    self.indices = list(range(len(self.img_names)))
        else:
            raise ValueError(f"mode should be either train, val or test ... not {self.mode}.")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, item):
        while True:
            # Load image and mask
            img_path = os.path.join(self.output_dir, self.img_names[self.indices[item]] + ".jpg")
            mask_path = os.path.join(self.output_dir, self.img_names[self.indices[item]] + "_m.png")
            
            img = cv2.imread(img_path)
            mask = cv2.imread(mask_path)
            
            # Extract the label mask (assuming the labels are stored in the second channel)
            label = mask[:, :, 1]
            
            # Calculate the percentage of unlabeled (0) pixels
            unlabeled_pixels = np.sum(label == 0)
            total_pixels = label.size
            unlabeled_ratio = unlabeled_pixels / total_pixels

            # If the image has more than 90% unlabeled pixels, skip it and pick another
            if unlabeled_ratio < self.unlabeled_threshold:
                break
            else:
                # If we need to skip this image, pick another random one
                item = np.random.randint(0, len(self.indices))

        if self.transforms is None:
            img = np.transpose(img, (2, 0, 1))  # Convert to channel-first format (C, H, W)
        else:
            transformed = self.transforms(image=img, mask=label)
            img = np.transpose(transformed["image"], (2, 0, 1))
            label = transformed["mask"]

        return torch.tensor(img, dtype=torch.float32) / 255, torch.tensor(label, dtype=torch.int64)

In [None]:
import torch
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

# Dataloader

In [None]:
train_set = SegmentationDataset(mode = "train")
train_dloader = DataLoader(train_set,batch_size = 8,num_workers =2)

class_dist = {"background":0, "building":0,
                     "woodland":0, "water":0, "road":0}
label_mapping = {0: "background", 1: "building",
                2: "woodland", 3: "water", 4: "road"}

for img,mask in train_dloader:
    for class_label in label_mapping.keys():
        class_dist[label_mapping[class_label]] += mask[mask == class_label].numpy().size

In [None]:

# temp_list = sorted([(l,s) for (l,s) in class_dist.items()], key= lambda x: x[1])
# labels = [x[0] for x in temp_list]
# support = [x[1] for x in temp_list]

# sns.set_style("dark")
# fig, ax = plt.subplots(figsize = (10,8))
# ax.bar(labels, support, color = "#36454F")
# ax.set_yscale("log")
# ax.set_title("The distribution of the training set with 512x512 images",
#             fontsize = 17)
# ax.set_ylabel("Number of pixels")
# plt.savefig("Barplt.png")
# plt.show()

# Augmentation

In [None]:
import albumentations as A

In [None]:
from albumentations import (
    GaussianBlur, ElasticTransform, GridDistortion, OpticalDistortion,
    ShiftScaleRotate, ChannelShuffle, CLAHE, ISONoise, CoarseDropout,
    MotionBlur, RandomFog, RandomRain, RandomSnow, Solarize, Equalize, 
    InvertImg, Posterize, RandomSunFlare, RandomShadow, RandomBrightnessContrast
)
import cv2
import glob
import os
import matplotlib.pyplot as plt

# Define the additional Albumentations transformations
additional_transforms = [
    GaussianBlur(blur_limit=(3, 7), p=1),
    ElasticTransform(alpha=1, sigma=50, alpha_affine=None, p=1),
    GridDistortion(p=1),
    OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1),
    ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=1),
    ChannelShuffle(p=1),
    CLAHE(p=1),
    ISONoise(p=1),
    CoarseDropout(max_holes=8, max_height=16, max_width=16, p=1),
    MotionBlur(blur_limit=7, p=1),
    RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, alpha_coef=0.08, p=1),
    RandomRain(slant_lower=-10, slant_upper=10, drop_length=20, drop_color=(200, 200, 200), p=1),
    RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3, p=1),
    Solarize(p=1),
    Equalize(p=1),
    InvertImg(p=1),
    Posterize(num_bits=4, p=1),
    RandomSunFlare(flare_roi=(0.0, 0.0, 1.0, 0.5), angle_lower=0.0, p=1),
    RandomShadow(shadow_roi=(0.0, 0.5, 1.0, 1.0), p=1),
    RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1)
]

additional_transforms_names = [
    "GaussianBlur", "ElasticTransform", "GridDistortion", 
    "OpticalDistortion", "ShiftScaleRotate", "ChannelShuffle", 
    "CLAHE", "ISONoise", "CoarseDropout", "MotionBlur", 
    "RandomFog", "RandomRain", "RandomSnow", "Solarize", 
    "Equalize", "InvertImg", "Posterize", "RandomSunFlare", 
    "RandomShadow", "RandomBrightnessContrast"
]

# Read the NUM_SAMPLE sample in the training set
NUM_SAMPLE = 4
trainpath_list = list(glob.glob(os.path.join(os.getcwd(), "output", "*.jpg")))
img = cv2.imread(trainpath_list[NUM_SAMPLE])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert image to RGB

# Create a figure to display original and transformed images
fig, ax = plt.subplots(figsize=(16, 12), nrows=4, ncols=5)

# Display original image
ax[0, 0].imshow(img)
ax[0, 0].axis("off")
ax[0, 0].set_title("True image")

# Apply and display transformations
count = 0
for i in range(4):
    for j in range(5):
        if i == 0 and j == 0:
            continue  # Skip the original image slot
        transformed_img = additional_transforms[count](image=img)["image"]
        ax[i, j].imshow(transformed_img)
        ax[i, j].axis("off")
        ax[i, j].set_title(additional_transforms_names[count])
        count += 1

plt.suptitle("Data Augmentation", fontsize=17)
plt.tight_layout(pad=1)
plt.savefig("augmentations.png")
plt.show()


In [None]:
import albumentations as A

# Configuring the set of transformations
transforms = A.Compose([

    # Color augmentations
    A.OneOf([
        A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=30, p=1),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.5, p=1),
        A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=1),
        A.ToSepia(p=1),  # Convert image to sepia tone
        A.Solarize(threshold=128, p=1)  # Invert pixel intensities above a threshold
    ], p=0.5),  # Apply one of the color-based augmentations

    # Spatial augmentations
    A.OneOf([
        A.RandomRotate90(p=1),  # Rotate the image 90 degrees randomly
        A.HorizontalFlip(p=1),  # Flip the image horizontally
        A.VerticalFlip(p=1),  # Flip the image vertically
        A.Transpose(p=1),  # Transpose image by swapping axes
        A.RandomSizedCrop(min_max_height=(248, 512), height=512, width=512, p=1),  # Randomly crop and resize
        A.Perspective(scale=(0.05, 0.1), p=1)  # Perspective transformation
    ], p=0.5),  # Apply one of the spatial augmentations

    # Warping/Distortion augmentations
    A.OneOf([
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=None, p=1),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=1),  # Grid-based warping of the image
        A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1),  # Optical lens effect distortion
        A.PiecewiseAffine(scale=(0.03, 0.05), p=1)  # Piecewise affine transformation for localized distortions
    ], p=0.5),

    # Combination of shifting, scaling, and rotating
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5),

    # Blur and noise
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 7), p=1),  # Apply Gaussian blur
        A.MotionBlur(blur_limit=7, p=1),  # Simulate motion blur
        A.MedianBlur(blur_limit=7, p=1),  # Apply median blur
    ], p=0.5),
    
    # Noise augmentation
    A.OneOf([
        A.ISONoise(p=1),  # Add noise similar to high ISO photographs
        A.GaussNoise(var_limit=(10.0, 50.0), p=1)  # Add Gaussian noise
    ], p=0.5),

    # CLAHE for contrast enhancement
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),

    # Dropout for occlusion simulation
    A.OneOf([
        A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=1),  # Randomly drop out portions of the image
    ], p=0.5)
])

# This will apply a combination of the specified augmentations, with probabilities of 0.5 for each group.


# U-net Models

In [None]:
!pip install segmentation-models-pytorch

import segmentation_models_pytorch as smp

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
target_names = np.array(["background", "building", "woodland", "water", "road"])

# Loss function - Mean IoU loss
loss_fn = smp.losses.JaccardLoss(mode = "multiclass",
                                classes = 5).to(device)

# Hyperparameters
batch_size = 8
epochs = 30
lr = 5e-5

# Preparing datasets and DataLoaders
train_set = SegmentationDataset(mode = "train", transforms = transforms,
                               ratio = 0.6)
test_set = SegmentationDataset(mode = "test")
val_set = SegmentationDataset(mode = "val", ratio = 0.7)

train_dloader = DataLoader(train_set, batch_size = batch_size,
                           shuffle = True, num_workers = 2)
test_dloader = DataLoader(test_set, batch_size = batch_size, num_workers = 2)
val_dloader = DataLoader(val_set, batch_size=batch_size, num_workers = 2)

In [None]:
from landcoverutil import training_loop

## Resnet50

In [None]:
model_resnet50 = smp.Unet(encoder_name = "resnet50",
                encoder_weights = "imagenet",
                classes = 5).to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

trainable_params = count_parameters(model_resnet50)
print(f"Total trainable parameters: {trainable_params}")

In [None]:
from landcoverutil import training_loop

epochs = 50

# Training starts!
training_loop(model_resnet50, train_dloader, val_dloader, epochs, lr, loss_fn, mod_epochs =1,
             regularization = "L2", reg_lambda = 1e-6, early_stopping = False,
             patience = 5, verbose = True, model_title = "UNet with Resnet encoder", save = True,
             stopping_criterion = "loss")

In [None]:
# labels_cmap = matplotlib.colors.ListedColormap(["#FFFFFF", "#C9E4CA", "#F7DC6F", "#F2C464", "#FFC080"])

labels_cmap = matplotlib.colors.ListedColormap(["#000000", "#A9A9A9",
        "#8B8680", "#D3D3D3", "#FFFFFF"])

def visualize_preds(model, train_set, title, num_samples = 4, seed = 42,
                    w = 10, h = 10, save_title = None, indices = None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    np.random.seed(seed)
    if indices == None:
        indices = np.random.randint(low = 0, high = len(train_set),
                                    size = num_samples)
    sns.set_style("white")
    fig, ax = plt.subplots(figsize = (w,h),
                           nrows = num_samples, ncols = 3)
    model.eval()
    for i,idx in enumerate(indices):
        X,y = train_set[idx]
        X_dash = X[None,:,:,:].to(device)
        preds = torch.argmax(model(X_dash), dim = 1)
        preds = torch.squeeze(preds).detach().cpu().numpy()

        mirrored_img = np.fliplr(np.transpose(X.cpu(), (2,1,0)))  # Apply horizontal flip
        rotated_mirrored_img = np.rot90(mirrored_img, k=1)  # Rotate 180 degrees (two 90-degree rotations)

        # Display the mirrored and rotated image
        ax[i,0].imshow(rotated_mirrored_img)
        ax[i,0].set_title("True Image")
        ax[i,0].axis("off")
        ax[i,1].imshow(y, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,1].set_title("Labels")
        ax[i,1].axis("off")
        ax[i,2].imshow(preds, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,2].set_title("Predictions")
        ax[i,2].axis("off")
    fig.suptitle(title, fontsize = 20)
    plt.tight_layout()
    if save_title is not None:
        plt.savefig(save_title + ".png")
    plt.show()
    
    
visualize_preds(model_resnet50, test_set, title = "Predictions - UNet+Resnet50",
               save_title = "UNet+Resnet50", h = 12, w = 12, indices = [957,961,1476,1578])

In [None]:
import torchmetrics
import torchvision.transforms.functional as TF
import torch.nn.functional as F

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.

    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


# Call the function and get the metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet50, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics
for i in range(len(precision)):
    print(f"Class {i}: Precision: {precision[i].item():.4f}, Recall: {recall[i].item():.4f}, F1-Score: {f1_score[i].item():.4f}")
print(f"Confusion Matrix:\n {confusion_matrix.cpu().numpy()}")
print(f"Accuracy: {acc.item():.4f}")
print(f"Jaccard Index (Mean IoU): {jaccard.item():.4f}")


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import torchmetrics
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

# Class labels for the segmentation task
classes = ['Building', 'Woodland', 'Water', 'Road', 'Unlabeled']

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.
    Additionally, calculates the per-class metrics and returns them for evaluation.
    
    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
        class_probs: The average probability for each class.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


def plot_confusion_matrix(confusion_matrix, class_names):
    """
    Plot the confusion matrix using Seaborn heatmap.
    Args:
        confusion_matrix: The confusion matrix to plot.
        class_names: List of class names corresponding to the matrix indices.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_matrix.cpu().numpy(), annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()


def plot_multiclass_roc(model, test_loader, num_classes=5, device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        num_classes: Number of classes.
        device: Device to run the evaluation.
    """
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)
            y_true.append(y.cpu())
            y_scores.append(logits.cpu())

    # Concatenate all batches
    y_true = torch.cat(y_true).numpy()
    y_scores = torch.cat(y_scores).numpy()

    # Binarize labels for ROC
    y_true_binarized = label_binarize(y_true, classes=[i for i in range(num_classes)])
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()


# Call the function to get metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet50, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics with class names
print(f"{'Class':<15}{'Precision':<10}{'Recall':<10}{'F1-Score':<10}")
for i, class_name in enumerate(classes):
    print(f"{class_name:<15}{precision[i].item():<10.4f}{recall[i].item():<10.4f}{f1_score[i].item():<10.4f}")
print(f"\nAccuracy: {acc.item():.4f}")
print(f"Mean IoU (Jaccard Index): {jaccard.item():.4f}")




In [None]:
# Plot the confusion matrix
plot_confusion_matrix(confusion_matrix, classes)

In [None]:
def plot_multiclass_roc(model, test_loader, num_classes=5, device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        num_classes: Number of classes.
        device: Device to run the evaluation.
    """
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)  # Get probability scores
            y_true.append(y.cpu())  # True labels
            y_scores.append(logits.cpu())  # Predicted probabilities

    # Concatenate all batches to create full arrays of true labels and scores
    y_true = torch.cat(y_true).numpy()  # Shape: (N,) for N samples
    y_scores = torch.cat(y_scores).numpy()  # Shape: (N, num_classes)

    # Binarize labels for ROC (needed for multi-class)
    y_true_binarized = label_binarize(y_true, classes=[i for i in range(num_classes)])  # Shape: (N, num_classes)
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Plot ROC curve for multi-class classifier
plot_multiclass_roc(model=model_resnet50, test_loader=test_dloader, num_classes=5, device=device)

In [None]:
torch.save(model_resnet50.state_dict(), "resnet50.pth")

## Resnet 101

In [None]:
model_resnet101 = smp.Unet(encoder_name = "resnet101",
                encoder_weights = "imagenet",
                classes = 5).to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

trainable_params = count_parameters(model_resnet101)
print(f"Total trainable parameters: {trainable_params}")

In [None]:
epochs = 50

# Training starts!
training_loop(model_resnet101, train_dloader, val_dloader, epochs, lr, loss_fn, mod_epochs =1,
             regularization = "L2", reg_lambda = 1e-6, early_stopping = False,
             patience = 5, verbose = True, model_title = "UNet with Resnet encoder 101", save = True,
             stopping_criterion = "loss")

In [None]:
# torch.save(model_resnet101.state_dict(), "resnet101.pth")

In [None]:
def visualize_preds(model, train_set, title, num_samples = 4, seed = 42,
                    w = 10, h = 10, save_title = None, indices = None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    np.random.seed(seed)
    if indices == None:
        indices = np.random.randint(low = 0, high = len(train_set),
                                    size = num_samples)
    sns.set_style("white")
    fig, ax = plt.subplots(figsize = (w,h),
                           nrows = num_samples, ncols = 3)
    model.eval()
    for i,idx in enumerate(indices):
        X,y = train_set[idx]
        X_dash = X[None,:,:,:].to(device)
        preds = torch.argmax(model(X_dash), dim = 1)
        preds = torch.squeeze(preds).detach().cpu().numpy()

        mirrored_img = np.fliplr(np.transpose(X.cpu(), (2,1,0)))  # Apply horizontal flip
        rotated_mirrored_img = np.rot90(mirrored_img, k=1)  # Rotate 180 degrees (two 90-degree rotations)

        # Display the mirrored and rotated image
        ax[i,0].imshow(rotated_mirrored_img)
        ax[i,0].set_title("True Image")
        ax[i,0].axis("off")
        ax[i,1].imshow(y, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,1].set_title("Labels")
        ax[i,1].axis("off")
        ax[i,2].imshow(preds, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,2].set_title("Predictions")
        ax[i,2].axis("off")
    fig.suptitle(title, fontsize = 20)
    plt.tight_layout()
    if save_title is not None:
        plt.savefig(save_title + ".png")
    plt.show()

visualize_preds(model_resnet101, test_set, title = "Predictions - UNet+Resnet101",
               save_title = "UNet+Resnet101", h = 12, w = 12, indices = [957,961,1476,1578])

In [None]:
import torchmetrics
import torchvision.transforms.functional as TF
import torch.nn.functional as F

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.

    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


# Call the function and get the metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet101, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics
for i in range(len(precision)):
    print(f"Class {i}: Precision: {precision[i].item():.4f}, Recall: {recall[i].item():.4f}, F1-Score: {f1_score[i].item():.4f}")
print(f"Confusion Matrix:\n {confusion_matrix.cpu().numpy()}")
print(f"Accuracy: {acc.item():.4f}")
print(f"Jaccard Index (Mean IoU): {jaccard.item():.4f}")


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import torchmetrics
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

# Class labels for the segmentation task
classes = ['Building', 'Woodland', 'Water', 'Road', 'Unlabeled']

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.
    Additionally, calculates the per-class metrics and returns them for evaluation.
    
    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
        class_probs: The average probability for each class.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


def plot_confusion_matrix(confusion_matrix, class_names):
    """
    Plot the confusion matrix using Seaborn heatmap.
    Args:
        confusion_matrix: The confusion matrix to plot.
        class_names: List of class names corresponding to the matrix indices.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_matrix.cpu().numpy(), annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()


def plot_multiclass_roc(model, test_loader, num_classes=5, device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        num_classes: Number of classes.
        device: Device to run the evaluation.
    """
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)
            y_true.append(y.cpu())
            y_scores.append(logits.cpu())

    # Concatenate all batches
    y_true = torch.cat(y_true).numpy()
    y_scores = torch.cat(y_scores).numpy()

    # Binarize labels for ROC
    y_true_binarized = label_binarize(y_true, classes=[i for i in range(num_classes)])
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()


# Call the function to get metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet101, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics with class names
print(f"{'Class':<15}{'Precision':<10}{'Recall':<10}{'F1-Score':<10}")
for i, class_name in enumerate(classes):
    print(f"{class_name:<15}{precision[i].item():<10.4f}{recall[i].item():<10.4f}{f1_score[i].item():<10.4f}")
print(f"\nAccuracy: {acc.item():.4f}")
print(f"Mean IoU (Jaccard Index): {jaccard.item():.4f}")




In [None]:
# Plot the confusion matrix
plot_confusion_matrix(confusion_matrix, classes)

In [None]:
def plot_multiclass_roc1(model, test_loader, num_classes=5, device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        num_classes: Number of classes.
        device: Device to run the evaluation.
    """
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)  # Get probability scores
            y_true.append(y.cpu())  # True labels
            y_scores.append(logits.cpu())  # Predicted probabilities

    # Concatenate all batches to create full arrays of true labels and scores
    y_true = torch.cat(y_true).numpy()  # Shape: (N,) for N samples
    y_scores = torch.cat(y_scores).numpy()  # Shape: (N, num_classes)

    # Binarize labels for ROC (needed for multi-class)
    y_true_binarized = label_binarize(y_true, classes=[i for i in range(num_classes)])  # Shape: (N, num_classes)
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Plot ROC curve for multi-class classifier
plot_multiclass_roc1(model=model_resnet101, test_loader=test_dloader, num_classes=5, device=device)

## resnet 152

In [None]:
model_resnet152 = smp.Unet(encoder_name = "resnet152",
                encoder_weights = "imagenet",
                classes = 5).to(device="cpu")

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary
summary(model_resnet152, input_size=(3, 512, 512), device="cpu")

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

trainable_params = count_parameters(model_resnet152)
print(f"Total trainable parameters: {trainable_params}")

In [None]:
from landcoverutil import training_loop

epochs = 50

# Training starts!
training_loop(model_resnet152, train_dloader, val_dloader, epochs, lr, loss_fn, mod_epochs =1,
             regularization = "L2", reg_lambda = 1e-6, early_stopping = False,
             patience = 5, verbose = True, model_title = "UNet with Resnet encoder 152", save = True,
             stopping_criterion = "loss")

In [None]:
def visualize_preds(model, train_set, title, num_samples = 4, seed = 42,
                    w = 10, h = 10, save_title = None, indices = None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    np.random.seed(seed)
    if indices == None:
        indices = np.random.randint(low = 0, high = len(train_set),
                                    size = num_samples)
    sns.set_style("white")
    fig, ax = plt.subplots(figsize = (w,h),
                           nrows = num_samples, ncols = 3)
    model.eval()
    for i,idx in enumerate(indices):
        X,y = train_set[idx]
        X_dash = X[None,:,:,:].to(device)
        preds = torch.argmax(model(X_dash), dim = 1)
        preds = torch.squeeze(preds).detach().cpu().numpy()

        mirrored_img = np.fliplr(np.transpose(X.cpu(), (2,1,0)))  # Apply horizontal flip
        rotated_mirrored_img = np.rot90(mirrored_img, k=1)  # Rotate 180 degrees (two 90-degree rotations)

        # Display the mirrored and rotated image
        ax[i,0].imshow(rotated_mirrored_img)
        ax[i,0].set_title("True Image")
        ax[i,0].axis("off")
        ax[i,1].imshow(y, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,1].set_title("Labels")
        ax[i,1].axis("off")
        ax[i,2].imshow(preds, cmap = labels_cmap, interpolation = None,
                      vmin = -0.5, vmax = 4.5)
        ax[i,2].set_title("Predictions")
        ax[i,2].axis("off")
    fig.suptitle(title, fontsize = 20)
    plt.tight_layout()
    if save_title is not None:
        plt.savefig(save_title + ".png")
    plt.show()

visualize_preds(model_resnet152, test_set, title = "Predictions - UNet+Resnet152",
               save_title = "UNet+Resnet152", h = 12, w = 12, indices = [957,961,1476,1578])

In [None]:
import torchmetrics
import torchvision.transforms.functional as TF
import torch.nn.functional as F

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.

    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


# Call the function and get the metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet152, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics
for i in range(len(precision)):
    print(f"Class {i}: Precision: {precision[i].item():.4f}, Recall: {recall[i].item():.4f}, F1-Score: {f1_score[i].item():.4f}")
print(f"Confusion Matrix:\n {confusion_matrix.cpu().numpy()}")
print(f"Accuracy: {acc.item():.4f}")
print(f"Jaccard Index (Mean IoU): {jaccard.item():.4f}")


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import torchmetrics
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

# Class labels for the segmentation task
classes = ['Building', 'Woodland', 'Water', 'Road', 'Unlabeled']

def segmentation_test_loop_with_metrics(model, test_loader, num_classes=5, device="cpu"):
    """
    Runs a test loop for the model on a test dataset, calculates precision, recall, F1-score, accuracy, IoU, and confusion matrix.
    Additionally, calculates the per-class metrics and returns them for evaluation.
    
    Args:
        model: The trained model to evaluate.
        test_loader: Dataloader for the test dataset.
        num_classes: The number of classes in the dataset.
        device: Device to run the evaluation on ("cpu" or "cuda").

    Returns:
        precision: Precision score for each class.
        recall: Recall score for each class.
        f1_score: F1 score for each class.
        confusion_matrix: Computed confusion matrix.
        acc: Overall accuracy.
        jaccard: Jaccard index (IoU) for the entire test set.
        class_probs: The average probability for each class.
    """

    # Initialize metrics
    precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average=None).to(device)
    recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average=None).to(device)
    f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None).to(device)
    
    # Additional metrics
    acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average="micro", multidim_average="global").to(device)
    jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
    confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

    model.eval()

    class_probs = {i: 0 for i in range(num_classes)}
    num_samples = {i: 0 for i in range(num_classes)}

    for X, y in test_loader:
        X = X.to(device)
        y = y.to(device)

        with torch.no_grad():
            logits = F.softmax(model(X), dim=1)  # Apply softmax to get probabilities
            aggr = torch.max(logits, dim=1)  # Get the predicted class (index of the max probability)
            preds = aggr[1]  # Predictions
            probs = aggr[0]  # Probabilities of the predicted classes

            # Update per-class probabilities and number of samples
            for label in class_probs.keys():
                class_probs[label] += probs[preds == label].sum().item()
                num_samples[label] += (preds == label).sum().item()

            # Update metrics
            precision.update(preds, y)
            recall.update(preds, y)
            f1_score.update(preds, y)
            acc.update(preds, y)
            jaccard.update(preds, y)
            confusion_matrix.update(preds, y)  # Update the confusion matrix with predictions and ground truth

    # Normalize class probabilities
    for label in class_probs.keys():
        if num_samples[label] > 0:
            class_probs[label] /= num_samples[label]

    # Compute final metrics
    precision_result = precision.compute()
    recall_result = recall.compute()
    f1_score_result = f1_score.compute()
    acc_result = acc.compute()
    jaccard_result = jaccard.compute()
    confusion_matrix_result = confusion_matrix.compute()

    return precision_result, recall_result, f1_score_result, confusion_matrix_result, acc_result, jaccard_result, class_probs


def plot_confusion_matrix(confusion_matrix, class_names):
    """
    Plot the confusion matrix using Seaborn heatmap.
    Args:
        confusion_matrix: The confusion matrix to plot.
        class_names: List of class names corresponding to the matrix indices.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_matrix.cpu().numpy(), annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()


def plot_multiclass_roc(model, test_loader, num_classes=5, device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        num_classes: Number of classes.
        device: Device to run the evaluation.
    """
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)
            y_true.append(y.cpu())
            y_scores.append(logits.cpu())

    # Concatenate all batches
    y_true = torch.cat(y_true).numpy()
    y_scores = torch.cat(y_scores).numpy()

    # Binarize labels for ROC
    y_true_binarized = label_binarize(y_true, classes=[i for i in range(num_classes)])
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()


# Call the function to get metrics
precision, recall, f1_score, confusion_matrix, acc, jaccard, class_probs = segmentation_test_loop_with_metrics(
    model=model_resnet152, test_loader=test_dloader, num_classes=5, device=device)

# Print the metrics with class names
print(f"{'Class':<15}{'Precision':<10}{'Recall':<10}{'F1-Score':<10}")
for i, class_name in enumerate(classes):
    print(f"{class_name:<15}{precision[i].item():<10.4f}{recall[i].item():<10.4f}{f1_score[i].item():<10.4f}")
print(f"\nAccuracy: {acc.item():.4f}")
print(f"Mean IoU (Jaccard Index): {jaccard.item():.4f}")




In [None]:
# Plot the confusion matrix
plot_confusion_matrix(confusion_matrix, classes)

In [None]:
import torch
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

def plot_multiclass_roc(model, test_loader, classes=['Building', 'Woodland', 'Water', 'Road', 'Unlabeled'], device="cpu"):
    """
    Function to plot ROC curves for multi-class classification.
    
    Args:
        model: The trained model.
        test_loader: Dataloader for the test dataset.
        classes: List of class names.
        device: Device to run the evaluation.
    """
    num_classes = len(classes)
    y_true = []
    y_scores = []

    model.eval()
    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            y = y.to(device)

            logits = F.softmax(model(X), dim=1)  # Get probability scores
            y_true.append(y.cpu())  # True labels
            y_scores.append(logits.cpu())  # Predicted probabilities

    # Concatenate all batches to create full arrays of true labels and scores
    y_true = torch.cat(y_true).numpy()  # Shape: (N, H, W) for N samples in segmentation
    y_scores = torch.cat(y_scores).numpy()  # Shape: (N, num_classes, H, W)
    
    # Flatten both y_true and y_scores so that they are 1D arrays
    y_true_flat = y_true.flatten()  # Shape: (N * H * W,)
    y_scores_flat = y_scores.reshape(-1, num_classes)  # Shape: (N * H * W, num_classes)

    # Binarize labels for ROC (needed for multi-class)
    y_true_binarized = label_binarize(y_true_flat, classes=[i for i in range(num_classes)])  # Shape: (N * H * W, num_classes)
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_scores_flat[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting all ROC curves
    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(fpr[i], tpr[i], lw=2, label=f'ROC curve of class {classes[i]} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic for Multi-Class')
    plt.legend(loc="lower right")
    plt.show()


In [None]:
# Plot ROC curve for multi-class classifier
plot_multiclass_roc(model=model_resnet152, test_loader=test_dloader, device=device)

In [None]:
torch.save(model_resnet152.state_dict(), "resnet152.pth")

Thanks to Christos Nikou (https://github.com/ChrisNick92)