# UNet model to segment landfills from satellite images

In [1]:
import cv2
import torch
import torch.nn as nn
import monai

import os
import matplotlib.pyplot as plt
import numpy as np
import rasterio

from torch.utils.data import Dataset

In [15]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, set = 'train'):
        """
        Args:
            root_dir (string): Directory with all the images and masks.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_names = [f for f in os.listdir(self.images_dir)]
        # order
        self.image_names.sort()
        np.random.seed(0)
        np.random.shuffle(self.image_names)
        self.set = set
        # 60% train 20% val and 20% test
        if set == 'train':
            self.image_names = self.image_names[:int(len(self.image_names)*0.6)]
        elif set == 'val':
            self.image_names = self.image_names[int(len(self.image_names)*0.6):int(len(self.image_names)*0.8)]
        elif set == 'test':
            self.image_names = self.image_names[int(len(self.image_names)*0.8):]
        else:
            raise ValueError('set must be "train", "val" or "test"')

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_names[idx])
        mask_name = os.path.join(self.masks_dir, f'mask_{self.image_names[idx].split("_")[1]}')  # Adjust based on your naming convention
        image = plt.imread(img_name)
        mask = plt.imread(mask_name)

        # # resize both image and mask to 256x256
        image = cv2.resize(image, (256, 256))[:, :, :3]
        mask = cv2.resize(mask, (256, 256))
        # mask of size 256x256
        mask = mask[:, :, 0]
        # Convert mask to binary

        sample = {'image': image, 'mask': mask}

        return sample

In [11]:
# split into train val and test in a repdoucible way with seed. 
def split_dataset(dataset, split_ratio=0.8):
    # Set seed for reproducibility
    np.random.seed(0)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(split_ratio * dataset_size))
    np.random.shuffle(indices)
    train_indices, test_indices = indices[:split], indices[split:]
    train_indices, val_indices = train_indices[:int(split*0.8)], train_indices[int(split*0.8):]

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    return train_dataset, val_dataset, test_dataset
# create dataloaders


In [13]:
# create datalaoder. SPlit images from the dataset into train, validation and test with a split I can reproduce
path_to_data = 'data/img_msk'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bs = 4

train_dataset = ImageDataset(path_to_data, set = 'train')
val_dataset = ImageDataset(path_to_data, set = 'val')
test_dataset = ImageDataset(path_to_data, set = 'test')

In [14]:
# get length of datasets
print(f'Length of train dataset: {len(train_dataset)}')
print(f'Length of val dataset: {len(val_dataset)}')
print(f'Length of test dataset: {len(test_dataset)}')


Length of train dataset: 591
Length of val dataset: 197
Length of test dataset: 198


In [None]:
# load data
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs, shuffle=True)


In [None]:
# create UNet model as in original article. Note that my input image is 256*256*3

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, 2, stride=2),
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 2, stride=2),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(
                256, 128, 2, stride=2),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1)
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x = self.decoder(x1)
        return x

# create a model
model = UNet()
# print the model
print(model)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()))

# create a random input tensor


In [None]:
nb_epochs = 30
learning_rate = 10e-5
#use DiceLoss as loss function

criterion = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
#use adam as optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# compute IoU
def train(model, x, y, optimizer, criterion, epochs):
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')
    return model, loss

def test(model, x, y):
    with torch.no_grad():
        output = model(x)
        loss = criterion(output, y)
        print(f'Loss: {loss.item()}')
        output = output > 0.5
        output = output.cpu().numpy()
        y = y.cpu().numpy()
        iou = IoU(output, y)
        print(f'IoU: {iou}')
    return output, loss, iou

def IoU(pred, target):
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return intersection / union

def iou_rim(mask1, mask2):
    """
    Calculate the Intersection over Union (IoU) of two binary masks.

    Parameters:
        mask1 (np.array): First binary mask.
        mask2 (np.array): Second binary mask.

    Returns:
        float: IoU score.
    """
    # Ensure that the masks are boolean arrays
    mask1 = mask1.astype(np.bool)
    mask2 = mask2.astype(np.bool)

    # Intersection and Union calculations
    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou_score = np.sum(intersection) / np.sum(union)

    return iou_score

for epoch in range(nb_epochs):
    val_IoUs = []
    model, train_loss = train(model, x, y, optimizer, criterion, epochs=1)
    output, val_loss, val_iou = test(model, x, y)
    val_IoUs.append(val_iou)
    print('')