# Setting up the notebook

In [None]:
import warnings

# Settings the warnings to be ignored
warnings.filterwarnings('ignore')

In [None]:
# Uncomment when in Google Colab
!pip install transformers
!pip install torch
!pip install py7zr
!pip install scikit-learn
!pip install Pillow
!pip install pandas
!pip install opencv-python
!pip install albumentations
!pip install evaluate

In [None]:
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, recall_score, precision_score, confusion_matrix
from tqdm.notebook import tqdm
import os
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
import pandas as pd
import cv2
import numpy as np
import py7zr
import albumentations as A
import matplotlib.pyplot as plt
import math
from torchvision import transforms

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/TrainingDataset_Pan.zip"
# extract files from landfill_dataset.7z
# in_path = './TrainingDataset_Pan.7z'
# out_path = './'
# with py7zr.SevenZipFile(in_path, mode='r') as z:
#     z.extractall(out_path)

# Create Image Segmentation dataset for training and validation

In [None]:
class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_dir, feature_extractor, transforms=None, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train
        self.transforms = transforms

        sub_path = "train" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, "images", sub_path)
        self.ann_dir = os.path.join(self.root_dir, "mask", sub_path)

        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)

        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)
        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        image = cv2.imread(os.path.join(self.img_dir, self.images[idx]))

        segmentation_map = cv2.imread(os.path.join(self.ann_dir, self.annotations[idx]))
        segmentation_map = cv2.cvtColor(segmentation_map, cv2.COLOR_BGR2GRAY)
        # Convert all 255 to 1 for metrics
        segmentation_map[segmentation_map == 255] = 1

        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            encoded_inputs = self.feature_extractor(augmented['image'], augmented['mask'], return_tensors="pt")
        else:
            encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [None]:
WIDTH = 512
HEIGHT = 512

transform = A.Compose([
        A.augmentations.Rotate(limit=90, p=0.5),
        A.augmentations.HorizontalFlip(p=0.5),
        A.augmentations.VerticalFlip(p=0.5),
        A.augmentations.transforms.ColorJitter(p=0.5),
])

In [None]:
# Retrieve training set and validation set
root_dir = './TrainingDataset_Pan' # remove /content/ when not in colab
feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False)

train_dataset = ImageSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, transforms=transform)
valid_dataset = ImageSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, transforms=None, train=False)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

In [None]:
encoded_inputs = train_dataset[0]
print("encoded_inputs['pixel_values'] shape: ", encoded_inputs["pixel_values"].shape)
print("encoded_inputs['labels'] shape: ", encoded_inputs["labels"].shape, "unique:", encoded_inputs["labels"].squeeze().unique())
print("encoded_inputs['labels']: ", encoded_inputs['labels'])
encoded_inputs["labels"].squeeze().unique()
mask = encoded_inputs["labels"].numpy()

In [None]:
mask = encoded_inputs["labels"].numpy()
plt.imshow(mask)

In [None]:
from torch.utils.data import DataLoader
batch_size = 1
epochs = 31
lr = 0.001
train_dataloader = DataLoader(train_dataset, pin_memory=True, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, pin_memory=True, batch_size=batch_size,)

# SegFormer Model Training

In [None]:
classes = pd.read_csv(f'{root_dir}/class_dict_seg.csv')['name'] # remove /content/ when not in colab
id2label = classes.to_dict()
label2id = {v: k for k, v in id2label.items()}

In [None]:
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", ignore_mismatched_sizes=True,
                                                         num_labels=len(id2label), id2label=id2label, label2id=label2id,
                                                         reshape_last_stage=True)

In [None]:
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.0001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(torch.cuda.is_available())
print(device)
print("Model Initialized!")

In [None]:
import evaluate
# import gc
# gc.collect()
# torch.cuda.empty_cache()

mean_iou = evaluate.load("mean_iou")

In [None]:
train_loss_values = []
val_loss_values = []
mean_iou_values = []
train_acc_values = []
val_acc_values = []
train_rec_values = []
val_rec_values = []
for epoch in range(1, epochs):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    pbar = tqdm(train_dataloader)
    accuracies = []
    recalls = []
    losses = []
    val_accuracies = []
    val_recalls = []
    val_losses = []
    model.train()
    for idx, batch in enumerate(pbar):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(pixel_values=pixel_values, labels=labels)
        # evaluate
        upsampled_logits = nn.functional.interpolate(outputs.logits.to(device), size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)
        mask = (labels != 255) # we don't include the background class in the accuracy calculation
        pred_labels = predicted[mask].detach().cpu().numpy()
        true_labels = labels[mask].detach().cpu().numpy()
        accuracy = accuracy_score(true_labels, pred_labels)
        recall = recall_score(true_labels, pred_labels)
        recalls.append(recall)
        loss = outputs.loss
        accuracies.append(accuracy)
        losses.append(loss.item())
        pbar.set_postfix({'Batch': idx, 'Pixel-wise accuracy': sum(accuracies)/len(accuracies), 'Recall':sum(recalls)/len(recalls), 'Loss': sum(losses)/len(losses)})

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.eval()
        with torch.no_grad():
            for idx, batch in enumerate(valid_dataloader):
                pixel_values = batch["pixel_values"].to(device)
                labels = (batch["labels"].to(device))

                outputs = model(pixel_values=pixel_values, labels=labels)
                upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
                predicted = upsampled_logits.argmax(dim=1)

                mask = (labels != 255)
                pred_labels = predicted[mask].detach().cpu().numpy()
                true_labels = labels[mask].detach().cpu().numpy()
                accuracy = accuracy_score(true_labels, pred_labels)
                recall = recall_score(true_labels, pred_labels)
                val_recalls.append(recall)
                val_loss = outputs.loss
                val_accuracies.append(accuracy)
                val_losses.append(val_loss.item())
                predictions = predicted.detach().cpu().numpy()
                actual = labels.detach().cpu().numpy()
                results_mean_iou = mean_iou.compute(predictions=predictions, references=actual, ignore_index=0, num_labels=2)


    mean_iou_values.append(results_mean_iou['mean_iou'])
    # Calculate and append the training loss
    train_loss_values.append(sum(losses)/len(losses))
    # Calculate and append the validation loss
    val_loss_values.append(sum(val_losses)/len(val_losses))
    # Calculate and append the training accuracy
    train_acc_values.append(sum(accuracies)/len(accuracies))
    # Calculate and append the validation accuracy
    val_acc_values.append(sum(val_accuracies)/len(val_accuracies))
    # Calculate and append the training recall
    train_rec_values.append(sum(recalls)/len(recalls))
    # Calculate and append the validation recall
    val_rec_values.append(sum(val_recalls)/len(val_recalls))
    print(f"Train Pixel-wise accuracy: {sum(accuracies)/len(accuracies)}\
         Recall: {sum(recalls)/len(recalls)}\
         Train Loss: {sum(losses)/len(losses)}\
         Val Pixel-wise accuracy: {sum(val_accuracies)/len(val_accuracies)}\
         Val Recall: {sum(val_recalls)/len(val_recalls)}\
         Val Loss: {sum(val_losses)/len(val_losses)}\
         IoU: {results_mean_iou}")

    # if epoch % 5 == 0:
    #     torch.save({
    #             'epoch': epoch,
    #             'model_state_dict': model.state_dict(),
    #             'optimizer_state_dict': optimizer.state_dict(),
    #             'loss': sum(val_losses)/len(val_losses),
    #             }, f'./checkpoint_{epoch}_lr:{lr}_bs:{batch_size}.pth')

print(train_loss_values)
plt.figure(figsize=(12, 6))
plt_epochs = np.arange(1, epochs)
plt.plot(plt_epochs, train_loss_values, label='Training Loss')
plt.plot(plt_epochs, val_loss_values, label='Validation Loss')
plt.xticks(plt_epochs)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'Training and Validation Loss LR={lr} BS={batch_size}')
plt.legend()
plt.show()

#Plot the mIoU of the model

In [None]:
plt.plot(plt_epochs, mean_iou_values, label='Mean IoU')
plt.title(f'Mean IoU LR={lr} BS={batch_size}')
plt.show()

#Plot the training and validation accuracies of the model

In [None]:
plt.figure(figsize=(12, 6))
plt_epochs = np.arange(1, epochs)
plt.plot(plt_epochs, train_acc_values, label='Training Accuracy')
plt.plot(plt_epochs, val_acc_values, label='Validation Accuracy')
plt.xticks(plt_epochs)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title(f'Training and Validation Accuracy LR={lr} BS={batch_size}')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
plt_epochs = np.arange(1, epochs)
plt.plot(plt_epochs, train_rec_values, label='Training Recall')
plt.plot(plt_epochs, val_rec_values, label='Validation Recall')
plt.xticks(plt_epochs)
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.title(f'Training and Validation Recall LR={lr} BS={batch_size}')
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), "<GIVE SAVE PATH>")

#Model Inference

In [None]:
directory = '<GIVE TEST SET PATH>'
dir_size = 0
for file in os.listdir(directory):
    f = os.path.join(directory, file)
    # checking if it is a file
    if os.path.isfile(f):
        dir_size += 1

fig_width = 10  # Adjust this value for the desired width of each image
fig_height = (dir_size * 512) / (2 * 512) * fig_width
fig, axs = plt.subplots(dir_size, 2, figsize=(fig_width, fig_height))

count = 0

for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    # checking if it is a file
    if os.path.isfile(f):
        image_name = filename
        image = cv2.imread(f'{root_dir}/images/test/{image_name}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(f'{root_dir}/mask/test/{image_name}', 0)
        axs[count][0].set_aspect('auto')  # Adjust the aspect ratio as needed
        axs[count][0].axis('off')
        axs[count][1].set_aspect('auto')  # Adjust the aspect ratio as needed
        axs[count][1].axis('off')
        axs[count][0].imshow(image)
        axs[count][1].imshow(mask)
        count += 1

pan_sharp = "./TrainingDataset_Pan"
plt.show()


In [None]:
df = pd.read_csv(f'{root_dir}/class_dict_seg.csv')
classes = df['name']
palette = df[[' r', ' g', ' b']].values

In [None]:
feature_extractor_inference = SegformerFeatureExtractor(do_random_crop=False, do_pad=False)
accuracy_values = []
recall_values = []
precision_values = []
specificity_values = []
mean_iou_values_inf = []
count_inference = 0
fig2, axs2 = plt.subplots(1, 3, figsize=(fig_width, 0))
axs2[0].set_title("Test samples")
axs2[1].set_title("Ground truth")
axs2[2].set_title("SegFormer MiT-b5")
axs2[0].axis('off')
axs2[1].axis('off')
axs2[2].axis('off')

for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    # checking if it is a file
    if os.path.isfile(f):
        image_name = filename
        image = cv2.imread(f'{root_dir}/images/test/{image_name}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_inf = cv2.imread(f'{root_dir}/mask/test/{image_name}', 0)
        mask_inf[mask_inf == 255] = 1
        pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device)
        model.eval()
        outputs = model(pixel_values=pixel_values)# logits are of shape (batch_size, num_labels, height/4, width/4)
        logits = outputs.logits.cpu()
        upsampled_logits = nn.functional.interpolate(logits,
                size=image.shape[:-1], # (height, width)
                mode='bilinear',
                align_corners=False)
        # Second, apply argmax on the class dimension
        seg = upsampled_logits.argmax(dim=1)[0]
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\

        mask_bool = (mask_inf != 255)
        predicted = seg[mask_bool].detach().cpu().numpy()
        real = mask_inf[mask_bool]
        for label, color in enumerate(palette):
            color_seg[seg == label, :] = color
        # Convert to BGR
        color_seg = color_seg[..., ::-1]

        tn, fp, fn, tp = confusion_matrix(real, predicted).ravel()
        accuracy = (tp + tn)/(tn + fp + fn + tp)
        recall = tp/(tp + fn)
        precision = tp/(tp + fp)
        specificity = tn/(tn + fp)
        predicted_inf = [seg.detach().cpu().numpy().astype(np.int64)]
        actual_inf = [mask_inf.astype(np.int64)]
        results_mean_iou_inf = mean_iou.compute(predictions=predicted_inf, references=actual_inf, ignore_index=0, num_labels=2)

        mean_iou_values_inf.append(results_mean_iou_inf['mean_iou'])
        accuracy_values.append(accuracy)
        recall_values.append(recall)
        precision_values.append(precision)
        specificity_values.append(specificity)
        fig, axs = plt.subplots(1, 3, figsize=(fig_width, fig_height))
        # Show images witho ut the axes, with image name.
        axs[0].set_aspect('auto')
        axs[1].set_aspect('auto')
        axs[2].set_aspect('auto')
        axs[0].imshow(image)
        axs[1].imshow(mask_inf)
        axs[2].imshow(color_seg)
        axs[0].axis('off')
        axs[1].axis('off')
        axs[0].set_title(filename[6:-4])
        plt.setp(axs[2].get_xticklabels(), visible=False)
        plt.setp(axs[2].get_yticklabels(), visible=False)
        axs[2].tick_params(axis='both', which='both', length=0)
        count_inference += 1
        plt.show()
print("Average pixel-wise accuracy: ", sum(accuracy_values)/len(accuracy_values))
print("Average recall: ", sum(recall_values)/len(recall_values))
print("Average precision: ", sum(precision_values)/len(precision_values))
print("Average specificity: ", sum(specificity_values)/len(specificity_values))
print("Mean IoU: ", sum(mean_iou_values_inf)/len(mean_iou_values_inf))

# Number of parameters of the model


In [None]:
pytorch_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_trainable_params)
print(pytorch_total_params)