In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install segmentation-models-pytorch
!pip install -U albumentations

In [None]:
!rm -r "./pova_train"
!mkdir -p "./pova_train"
!find "/content/drive/My Drive/pova_train" -type f | head -n 1000 | xargs -I {} cp {} "./pova_train" #How many pictures to move from drive to local train folder (too many pictures will take long)

In [None]:
!rm -r "./pova_test"
!mkdir -p "./pova_test"
!find "/content/drive/My Drive/pova_train" -type f | tail -n 200 | xargs -I {} cp {} "./pova_test" #Same as above, just for testing

In [None]:
import albumentations as A
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import segmentation_models_pytorch as smp
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
from torch.optim import Adam
import torch.nn as nn
from tqdm import tqdm
from google.colab.patches import cv2_imshow as show

In [None]:
class Augmentation():
  def __init__(self,
               blur: bool = False, colorJitter: bool = False,
               grayscale: bool = False, img_size: int = 256,
               iterations: int = 1, threshold: float = 0.0):
    self.img_size = img_size
    self.iterations = iterations
    self.threshold = threshold
    self.augmentation_list = [
        A.RandomCrop(width=self.img_size, height=self.img_size),
        A.Rotate(limit=[-360,360]),
        A.RandomToneCurve(scale=0.4),
    ]
    if blur:
      self.augmentation_list.append(A.Blur(blur_limit=(5,7)))
    if colorJitter:
      self.augmentation_list.append(A.ColorJitter())
    if grayscale:
      self.augmentation_list.append(A.ToGray())

  @staticmethod
  def road_coverage(mask: np.ndarray) -> float:
    return np.sum(mask // 255) / mask.size

  def __call__(self, image: np.ndarray, mask: np.ndarray) -> dict[str, np.ndarray]:
    transform = A.Compose(self.augmentation_list)
    for i in range(self.iterations):
      result = transform(image=image, mask=mask)
      if self.road_coverage(result["mask"]) > self.threshold:
        break
    return result

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, dataset_dir, transform=None, clip = 0):
        self.dataset_dir = dataset_dir
        self.transform = transform

        files = os.listdir(dataset_dir)
        images = {f.split('_')[0]: os.path.join(dataset_dir, f) for f in files if f.endswith('_sat.jpg')}
        labels = {f.split('_')[0]: os.path.join(dataset_dir, f) for f in files if f.endswith('_mask.png')}
        self.image_data = [(images[key], labels[key]) for key in images if key in labels]
        if(clip > 0):
            self.image_data = self.image_data[:clip]

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        img_path, label_path = self.image_data[idx]

        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
          return self.transform(image, label)

        # consistent output format regardless of transform
        return {
            "image": image,
            "mask": label,
        }
        """
        TODO - probably remove ? did not know if it is needed somewhere
        image = image / 255.0
        label = label / 255.0

        if self.transform:
            image = self.transform(torch.tensor(image, dtype=torch.float32).permute(2, 0, 1))
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)

        if self.label_transform:
            label = self.label_transform(torch.tensor(label, dtype=torch.float32).unsqueeze(0))
        else:
            label = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        return image, label
        """

In [None]:
# TEST DATA AUGMENTATION
dataset = CustomImageDataset("./pova_train", transform=Augmentation(threshold=0.025, iterations=10, colorJitter=True),clip=1000)
for i in range(5):
  data = dataset[0]
  show(data["image"])
  show(data["mask"])
  print(Augmentation.road_coverage(data["mask"]))

In [None]:
#Dice loss function
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, preds, targets, smooth=1):
        preds = torch.sigmoid(preds)  # Apply sigmoid for probabilities
        preds = preds.view(-1)
        targets = targets.view(-1)

        intersection = (preds * targets).sum()
        dice = 1 - ((2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth))
        return dice

In [None]:
#Experiment params
num_epochs = 20

#Initialize dataset
dataset = CustomImageDataset("./pova_train", clip=1000)
#Split into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


#Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

#Define model
model = smp.DeepLabV3Plus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
)

#Define loss and optimizer
loss_fn = DiceLoss()
optimizer = Adam(model.parameters(), lr=0.001)

In [None]:
#Move model to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_loss_record = []
val_loss_record = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        #Forward pass
        outputs = model(images)

        #Compute loss
        loss = loss_fn(outputs, labels)

        #Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}")

    #Keep track of train loss scores
    train_loss_record.append(train_loss)

    #Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}")

    #Keep track of validation loss scores
    val_loss_record.append(val_loss)


In [None]:
#Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_loss_record, label='Training Loss')
plt.plot(val_loss_record, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#NOTE: This is here just to visually check the model

#Load the test dataset
test_dataset = CustomImageDataset("./pova_test", transform=None)

#Ensure model is in evaluation mode
model.eval()

#Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

#Function to display an image, predicted mask, and actual mask
def display_prediction(image, predicted_mask, actual_mask):
    plt.figure(figsize=(12, 4))

    #Input image
    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())  # Convert CHW to HWC for display
    plt.title("Input Image")
    plt.axis("off")

    #Model prediction
    plt.subplot(1, 3, 2)
    plt.imshow(predicted_mask.cpu().numpy(), cmap='gray')
    plt.title("Model Prediction")
    plt.axis("off")

    #Ground truth
    plt.subplot(1, 3, 3)
    plt.imshow(actual_mask.squeeze(0).cpu().numpy(), cmap='gray')
    plt.title("Actual Mask")
    plt.axis("off")

    plt.show()

#Iterate through the test dataset
for idx in range(len(test_dataset)):
    #Get the image and label
    image, actual_mask = test_dataset[idx]
    image = image.to(device).unsqueeze(0)  #Add batch dimension
    actual_mask = actual_mask.to(device)

    #Run the image through the model
    with torch.no_grad():
        predicted_mask = model(image)
        predicted_mask = torch.sigmoid(predicted_mask)  #Apply sigmoid for binary segmentation
        predicted_mask = (predicted_mask > 0.5).float().squeeze(0)  #Threshold to binary mask

    #Display the input image, model prediction, and actual mask
    display_prediction(image.squeeze(0), predicted_mask.squeeze(0), actual_mask)