# SSL4EO-S12 SimConvNet Demo

This notebook is a small demo of using a small amount of SSL4EO-S12 and NAIP imagery with Chesapeake Land Cover data to train a `minerva` `SimConv` model.

## Imports

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
from torchgeo.datasets import NAIP, Chesapeake13, stack_samples
from torchgeo.samplers import RandomGeoSampler
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt
from rasterio.crs import CRS
from segmentation_models_pytorch import PSPNet

In [None]:
from minerva.models import SimConv, MinervaWrapper
from minerva.loss import SegBarlowTwinsLoss
from minerva.utils.utils import get_cuda_device
from minerva.datasets import SSL4EOS12Sentinel2, PairedDataset, stack_sample_pairs, make_transformations
from minerva.samplers import RandomPairGeoSampler
from minerva.transforms import MinervaCompose, AutoNorm

device = get_cuda_device(0)

In [None]:
# CRS in meters not decimal lat-lon degrees.
EPSG3857 = CRS.from_epsg("3857")

In [None]:
root = Path(input("Path to the root directory containing all the data"))

train_root = root / "SSL4EO-S12/100patches/s2a"
test_image_root = root / "NAIP_2013/Test"
test_mask_root = root / "Chesapeake13"

In [None]:
patch_size = (4, 128, 128)

## Transform Definitions

In [None]:
transform_params = {
    "RandomApply": {
        "p": 0.8,
        "RandomResizedCrop": {
            "module": "torchvision.transforms",
            "size": patch_size[1:],
        },
        "DetachedColorJitter": {
            "module": "minerva.transforms",
            "brightness": 0.8,
            "contrast": 0.8,
            "saturation": 0.8,
            "hue": 0.2,
        },
        "RandomHorizontalFlip": {
            "module": "torchvision.transforms",
        },
        "RandomVerticalFlip": {
            "module": "torchvision.transforms",
        },
        "GaussianBlur": {
            "module": "torchvision.transforms",
            "kernel_size": 25,
        },
    },
}
transforms = make_transformations(transform_params, key="image")
sentinel_norm = AutoNorm(SSL4EOS12Sentinel2(train_root, res=10.0, crs=EPSG3857, bands=["B2", "B3", "B4", "B8"]))
transforms.transforms.append(sentinel_norm)

naip_normalise = MinervaCompose(AutoNorm(NAIP(test_image_root, res=1.0)), "image")

## Dataset Definitions

In [None]:
print("Making Train Dataset")
train_dataset = PairedDataset(SSL4EOS12Sentinel2, train_root, res=10.0, crs=EPSG3857, bands=["B2", "B3", "B4", "B8"], transforms=transforms)

print("Making Test Dataset")
test_image_dataset = NAIP(test_image_root, res=1.0, transforms=naip_normalise)
test_mask_dataset = Chesapeake13(test_mask_root, res=1.0)
test_dataset = test_image_dataset & test_mask_dataset

In [None]:
sampler = RandomPairGeoSampler(train_dataset, size=patch_size[1], length=256, max_r=128)
dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_sample_pairs, batch_size=16, num_workers=2)

testsampler = RandomGeoSampler(test_dataset, size=patch_size[1], length=128)
testdataloader = DataLoader(test_dataset, sampler=testsampler, collate_fn=stack_samples, batch_size=8, num_workers=1)
testdata = list(testdataloader)[0]

## Training & Validation Loop

In [None]:
n_epochs = 1000  # Number of epoches to conduct.
f_val = 25       # Frequency of downstream validation in number of training epoches.

# Loss functions for the SimConvNet and the downstream PSPNet.
crit = SegBarlowTwinsLoss()
xentropy = CrossEntropyLoss()

# Criterions are normally parsed to models at init in minerva.
model = SimConv(crit, input_size=patch_size).to(device)
opt = Adam(model.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
model.set_optimiser(opt)

for epoch in range(n_epochs):
  losses = []
  for i, batch in enumerate(dataloader):
    x_i_batch, x_j_batch = batch[0]["image"].to(device).float(), batch[1]["image"].to(device).float()

    x_batch = torch.stack([x_i_batch, x_j_batch])
    
    # Uses MinervaModel.step.
    loss, pred = model.step(x_batch, train=True)
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  
  if epoch % f_val == 0:
      # Extract encoder from the model and freeze its weights.
      encoder = model.backbone
      encoder = encoder.requires_grad_(False)
      
      # Construct a new PSPNet.
      psp = MinervaWrapper(PSPNet, input_size=patch_size, criterion=xentropy, n_classes=13, encoder_name="resnet18", classes=13, in_channels=4).to(device)

      # Replace its encoder and decoder with our pre-trained encoder (which is a PSP encoder-decoder).
      psp.encoder = encoder.encoder
      psp.decoder = encoder.decoder

      # Set up the optimiser for the PSP.
      psp_opt = Adam(psp.parameters(), lr=0.01)
      psp.set_optimiser(psp_opt)
      
      opt_losses = []
      
      # Train downstream PSP.
      for sample in testdataloader:
        x = sample["image"].to(device).float()
        y = sample["mask"].to(device).long().squeeze(1)
        
        opt_loss, z = psp.step(x, y, train=True)
        opt_losses.append(opt_loss.item())
      
      # Use the pre-selected batch of data for visualisation of the PSP's results.
      image = testdata["image"].to(device).float()
      target = testdata["mask"].to(device).long().squeeze(1)
      final_loss, pred = psp.step(image, target, train=False)
      opt_losses.append(final_loss.item())
      
      print(f"Test {epoch}| Loss: {np.mean(opt_losses)}")

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()