<a href="https://colab.research.google.com/github/AlexanderLontke/ssl-remote-sensing/blob/vae-segmentation/notebooks/Segmentation_Downstream_Task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Segmentation downstream task: SEN12FLOOD

Model: ResNetUnet \\
Data: SEN12FLOOD \\
Pretrained weights: dependent on pretext tasks



### Environment setup

In [None]:
!pip install ssl_remote_sensing@git+https://github.com/AlexanderLontke/ssl-remote-sensing.git@vae-segmentation

In [None]:
!pip install rasterio torchmetrics

In [None]:
# !pip install wandb

In [None]:
# # from ssl_remote_sensing.downstream_tasks.segmentatio
from ssl_remote_sensing.data.dfc2020 import DFC2020
from ssl_remote_sensing.downstream_tasks.segmentation.utils import (
    patch_first_conv,
    get_metrics,
)
from ssl_remote_sensing.downstream_tasks.segmentation.model import ResNetUNet
from ssl_remote_sensing.pretext_tasks.vae.model import VariationalAutoencoder
from ssl_remote_sensing.pretext_tasks.vae.utils import reproducibility
from ssl_remote_sensing.constants import RANDOM_INITIALIZATION
from ssl_remote_sensing.pretext_tasks.utils import (
    load_encoder_checkpoint_from_pretext_model,
)

In [None]:
import rasterio
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import os
import cv2
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn as nn
import torch.nn.functional as F
import random
from tqdm import tqdm
from torchmetrics import JaccardIndex
from sklearn.metrics import confusion_matrix, accuracy_score, jaccard_score
import gdown
import tarfile

In [None]:
import wandb

wandb.login()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu").type

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
dfc2020_path = "/content/drive/MyDrive/dfc2020/dfc_2020.tar.gz"

In [None]:
with tarfile.open(dfc2020_path, mode="r") as tar:
    tar.extractall(path="/content/")

In [None]:
# subset - 'val' 986
# subset - 'test' 5128

# train_set for sentinel-2
train_set = DFC2020(
    "/content/dfc_data",
    subset="test",
    use_s1=False,
    use_s2lr=True,
    use_s2hr=True,
    use_s2mr=True,
    no_savanna=True,
)

In [None]:
train_set.visualize_observation(170)

In [None]:
len(train_set)

In [None]:
val_set = DFC2020(
    "/content/dfc_data",
    subset="val",
    use_s1=False,
    use_s2lr=True,
    use_s2hr=True,
    use_s2mr=True,
    no_savanna=True,
)

In [None]:
val_set.visualize_observation(170)

In [None]:
len(val_set)

In [None]:
train_loader = DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    drop_last=False,
)

In [None]:
train_loader = DataLoader(
    val_set,
    batch_size=16,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    drop_last=False,
)

In [None]:
RANDOM_INITIALIZATION

In [None]:
g_drive_path = "/content/drive/MyDrive/deep_learning_checkpoints"
check_point_paths = os.listdir(g_drive_path)
check_point_paths += [RANDOM_INITIALIZATION]
# check_point_paths.append(RANDOM_INITIALIZATION)
check_point_paths = [g_drive_path + "/" + x for x in check_point_paths]
check_point_paths

In [None]:
# for file in check_point_paths:
#   if file == '/content/drive/MyDrive/deep_learning_checkpoints/random':
#     print(file)

### Utils

In [None]:
# def display_outputs(idx=None, multi=False):
#     # Pick a random index if none is specified
#     if not idx:
#         idx = random.randint(0, len(valset))
#     print('Validation image ID: {}'.format(idx))

#     # Get Sentinel 2 and Sentinel 1 data
#     s2_data = torch.unsqueeze(valset.__getitem__(idx)['s2_img'].float().to(device), 0)
#     s1_data = torch.unsqueeze(valset.__getitem__(idx)['s1_img'].float().to(device), 0)

#     # Get predictions from the model
#     if multi:
#         output = model(s1_data, s2_data)
#     else:
#         output = model(s2_data)

#     # Threshold the output to generate the binary map (FYI: the threshold value "0" can be tuned as any other hyperparameter)
#     output_binary = torch.zeros(output.shape)
#     output_binary[output >= 0] = 1

#     get_metrics(valset.__getitem__(idx)['mask'], output_binary)

#     fig, axes = plt.subplots(1, 3, figsize=(15, 7))
#     axes[0].imshow(np.transpose(valset.__getitem__(idx)['s2_img'][[3,2,1],:,:], (1, 2, 0)) / valset.__getitem__(idx)['s2_img'].max())
#     axes[0].set_title('True Color Sentinel-2')
#     axes[2].imshow(valset.__getitem__(idx)['mask'], cmap='Blues')
#     axes[2].set_title('Groundtruth')
#     axes[1].imshow(output_binary.squeeze(), cmap='Blues')
#     axes[1].set_title('Predicted Mask')

### Hyperparameter setup

In [None]:
# Model Setup
class Hparams:
    def __init__(self):
        # self.pretext_task = "VAE"
        # self.pretext_task = "SimCLR"
        # self.pretext_saved_name = 'autoencoder.ckpt'
        self.checkpoint_name = None
        # self.pretext_saved_name = 'SimCLR_ResNet18_adam-v3.ckpt'
        self.epochs = 10  # number of training epochs for pretext tasks
        self.seed = 1234  # randomness seed
        self.save = "./saved_model"
        self.gradient_accumulation_steps = 1  # gradient accumulation steps
        self.batch_size = 16
        self.lr = 1e-3
        self.weight_decay = 1e-6
        self.latent_dim = 256
        self.optim = "Adam"
        self.embedding_size = 128  # papers value is 128
        self.temperature = 0.5  # 0.1 or 0.5
        self.weight_decay = 1e-6
        self.cuda = True  # use coda
        self.transform = False
        self.split = False

In [None]:
train_config = Hparams()

In [None]:
reproducibility(train_config)

### Directory & Wandb setup

## Data

SEN12FLOOD \\


*   13 Bands


In [None]:
trainset = SEN12FLOODS(root="/content/chips/", transforms=True, split="train")

valset = SEN12FLOODS(root="/content/chips/", split="val")


train_loader = DataLoader(trainset, batch_size=8, pin_memory=True)

val_loader = DataLoader(valset, batch_size=8, pin_memory=True)

In [None]:
print("[LOG] Shape of mask image is:", next(iter(train_loader))["mask"].shape)
print("[LOG] Shape of sentinel-2 image is:", next(iter(train_loader))["s2_img"].shape)

In [None]:
trainset.visualize_observation(196)

In [None]:
valset.visualize_observation(127)

In [None]:
valset.visualize_observation(42)

In [None]:
valset.visualize_observation(37)

## Model

In [None]:
# First of all, let's verify if a GPU is available on our compute machine. If not, the cpu will be used instead.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device used: {}".format(device))
# Define a learning rate
learning_rate = train_config.lr
# Initialise the loss function and move it to the GPU if available
criterion = torch.nn.BCEWithLogitsLoss().to(device)

In [None]:
# Initialise the loss function and move it to the GPU if available
criterion = torch.nn.BCEWithLogitsLoss().to(device)

In [None]:
next(iter(train_loader))["s2_img"].shape

In [None]:
# load_encoder_checkpoint_from_pretext_model(
#         path_to_checkpoint='/content/drive/MyDrive/deep_learning_checkpoints/random'
#     )

In [None]:
# for file in check_point_paths:
#     print(file)
#     print(load_encoder_checkpoint_from_pretext_model(
#         path_to_checkpoint=file
#     ))

In [None]:
for filename in check_point_paths:
    # Update checkpoint name
    train_config.checkpoint_name = filename
    # Load Encoder from different pre-text architectures
    encoder = load_encoder_checkpoint_from_pretext_model(
        path_to_checkpoint=filename,
    )
    patch_first_conv(encoder, 13, default_in_channels=3)

    model_name = filename.split("/")[-1].split(".")[0]
    model_dir = "/content/drive/My Drive/deep_learning_segmentation_checkpoints"
    model_path = os.path.join(model_dir, f"segmentation_{model_name}.ckpt")
    # make sure the directory path exists
    assert os.path.exists(model_dir)

    wandb.init(
        project="ssl-remote-sensing-segmentation",
        name=model_name,
        config=train_config.__dict__,
    )

    # Model setup
    model = ResNetUNet(1, encoder=encoder)

    # Initialise the optimizer
    if train_config.optim == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    elif train_config.optim == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    epochs = train_config.epochs

    # Move the model to the GPU
    model.to(device)

    # Create lists for logging losses and evalualtion metrics:
    train_losses = []
    train_accs = []
    train_ious = []

    val_losses = []
    val_accs = []
    val_ious = []

    # IoU
    jaccard = JaccardIndex(num_classes=2).to(device)

    # For every epoch
    for epoch in range(epochs):
        epoch_loss = 0
        progress = tqdm(
            enumerate(train_loader), desc="Train Loss: ", total=len(train_loader)
        )

        # Specify you are in training mode
        model.train()

        epoch_train_loss = 0
        epoch_val_loss = 0

        epoch_train_ious = 0
        epoch_val_ious = 0

        epoch_train_accs = 0
        epoch_val_accs = 0

        for i, batch in progress:
            # Transfer data to GPU if available
            data = batch["s2_img"].float().to(device)
            label = batch["mask"].float().to(device)

            # Make a forward pass
            output = model(data)
            # print(output.shape)

            # Derive binary segmentation map from prediction
            output_binary = torch.zeros(output.shape)
            output_binary[output >= 0] = 1

            # Compute IoU
            epoch_train_ious += jaccard(output_binary.to(device), label.int()) / len(
                train_loader
            )

            # Compute pixel accuracies
            epoch_train_accs += torch.sum(output_binary.to(device) == label.int()) / (
                len(train_loader) * (256 * 256) * 100
            )

            # Compute the loss
            loss = criterion(output, label.unsqueeze(1))

            # Clear the gradients
            optimizer.zero_grad()

            # Calculate gradients
            loss.backward()

            # Update Weights
            optimizer.step()

            # Accumulate the loss over the eopch
            epoch_train_loss += loss / len(train_loader)

            progress.set_description(
                "Epoch = {}, Train Loss: {:.4f}".format(epoch + 1, epoch_train_loss)
            )

        progress = tqdm(
            enumerate(val_loader),
            desc="val Loss: ",
            total=len(val_loader),
            position=0,
            leave=True,
        )

        # Specify you are in evaluation mode
        model.eval()

        # Deactivate autograd engine (no backpropagation allowed)
        with torch.no_grad():
            epoch_val_loss = 0
            for j, batch in progress:
                # Transfer Data to GPU if available
                data = batch["s2_img"].float().to(device)
                label = batch["mask"].float().to(device)

                # Make a forward pass
                output = model(data)

                # Derive binary segmentation map from prediction
                output_binary = torch.zeros(output.shape)
                output_binary[output >= 0] = 1

                # Compute IoU
                epoch_val_ious += jaccard(output_binary.to(device), label.int()) / len(
                    val_loader
                )

                # Compute pixel accuracies
                epoch_val_accs += torch.sum(output_binary.to(device) == label.int()) / (
                    len(val_loader) * (256 * 256) * 100
                )

                # Compute the loss
                val_loss = criterion(output, label.unsqueeze(1))

                # Accumulate the loss over the epoch
                epoch_val_loss += val_loss / len(val_loader)

                progress.set_description(
                    "Validation Loss: {:.4f}".format(epoch_val_loss)
                )

        if epoch == 0:
            best_val_loss = epoch_val_loss
        else:
            if epoch_val_loss <= best_val_loss:
                best_val_loss = epoch_val_loss
                # Save only the best model
                torch.save(model.state_dict(), model_path)
                print("Saving Model...")

        # save result to wandb
        wandb.log(
            {
                "train_loss_segmentation": epoch_train_loss,
                "val_loss_segmentation": epoch_val_loss,
                "train_iou_segmentation": epoch_train_ious,
                "val_iou_segmentation": epoch_val_ious,
                "train_acc_segmentation": epoch_train_accs,
                "val_acc_segmentation": epoch_val_accs,
            }
        )

        # print("Epoch = ", epoch+1)
        # Save losses in list, so that we can visualise them later.
        train_losses.append(epoch_train_loss.cpu().detach().numpy())
        val_losses.append(epoch_val_loss.cpu().detach().numpy())

        # Save IoUs in list, so that we can visualise them later.
        train_ious.append(epoch_train_ious.cpu().detach().numpy())
        val_ious.append(epoch_val_ious.cpu().detach().numpy())
        print(f"train_iou is {epoch_train_ious:.4f}, val_iou is {epoch_val_ious:.4f}")

        # Save accuracies in list, so that we can visualise them later.
        train_accs.append(epoch_train_accs.cpu().detach().numpy())
        val_accs.append(epoch_val_accs.cpu().detach().numpy())
        print(f"train_acc is {epoch_train_accs:.4f}, val_acc is {epoch_val_accs:.4f}")

    print("Finished Training")

    # Initialise a UNet() model
    ResNetUNet(1, encoder=encoder)
    # Load pretrained weights
    model.load_state_dict(torch.load(model_path))
    # Move to device (GPU or CPU)
    model.to(device)

    def display_outputs(idx=None, multi=False):
        # Pick a random index if none is specified
        if not idx:
            idx = random.randint(0, len(valset))
        print("Validation image ID: {}".format(idx))

        # Get Sentinel 2 and Sentinel 1 data
        s2_data = torch.unsqueeze(
            valset.__getitem__(idx)["s2_img"].float().to(device), 0
        )
        s1_data = torch.unsqueeze(
            valset.__getitem__(idx)["s1_img"].float().to(device), 0
        )

        # Get predictions from the model
        if multi:
            output = model(s1_data, s2_data)
        else:
            output = model(s2_data)

        # Threshold the output to generate the binary map (FYI: the threshold value "0" can be tuned as any other hyperparameter)
        output_binary = torch.zeros(output.shape)
        output_binary[output >= 0] = 1

        get_metrics(valset.__getitem__(idx)["mask"], output_binary)

        fig, axes = plt.subplots(1, 3, figsize=(15, 7))
        axes[0].imshow(
            np.transpose(valset.__getitem__(idx)["s2_img"][[3, 2, 1], :, :], (1, 2, 0))
            / valset.__getitem__(idx)["s2_img"].max()
        )
        axes[0].set_title("True Color Sentinel-2")
        axes[2].imshow(valset.__getitem__(idx)["mask"], cmap="Blues")
        axes[2].set_title(f"Groundtruth")
        axes[1].imshow(output_binary.squeeze(), cmap="Blues")
        axes[1].set_title(f"Predicted Mask-{model_name}")

    print("Sample image: ", model_name)
    display_outputs(37)
    display_outputs(127)

Function copied from: https://github.com/qubvel/segmentation_models.pytorch