# SSL4EO-S12 SimConvNet Demo

This notebook is a small demo of using a small amount of SSL4EO-S12 and NAIP imagery with Chesapeake Land Cover data to train a `minerva` `SimConv` model.

## Imports

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader, RandomSampler
from torchgeo.datasets import stack_samples, EuroSAT100, EuroSAT
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt
from rasterio.crs import CRS
from segmentation_models_pytorch import PSPNet
from kornia.color import bgr_to_rgb
from matplotlib.colors import ListedColormap

In [None]:
from minerva.models import SimConv, MinervaWrapper, ResNet18
from minerva.loss import SegBarlowTwinsLoss
from minerva.utils.utils import get_cuda_device, calc_norm_euc_dist
from minerva.datasets import NonGeoSSL4EOS12Sentinel2, PairedNonGeoDataset, DFC2020, stack_sample_pairs
from minerva.transforms import ClassTransform, Normalise, MinervaCompose, make_transformations
from minerva.utils.visutils import get_mlp_cmap
from minerva.utils.utils import find_modes, eliminate_classes, find_empty_classes

In [None]:
device = get_cuda_device(0)

In [None]:
root = Path(input("Path to the root directory containing all the data"))

train_root = root / "SSL4EO-S12/ssl4eo-s12_100patches/s2a"
val_root = root / "DFC2020"

In [None]:
print(train_root)
print(val_root)

In [None]:
patch_size = (4, 120, 120)
batch_size = 8

In [None]:
normalisation_factor = 10000

## Transform Definitions

In [None]:
transform_params = {
    "Normalise": {
        "module": "minerva.transforms",
        "norm_value": normalisation_factor,
    },
    "RandomApply": {
        "p": 0.2,
        "DetachedColorJitter": {
            "module": "minerva.transforms",
            "brightness": 0.2,
            "contrast": 0.1,
            "saturation": 0.1,
            "hue": 0.15,
        },
    },
    "RandomResizedCrop": {
        "module": "kornia.augmentation",
        "p": 0.2,
        "size": patch_size[1:3],
        "cropping_mode": "resample",
        "keepdim": True,
    },
    "RandomHorizontalFlip": {
        "module": "kornia.augmentation",
        "p": 0.2,
        "keepdim": True,
    },
    "RandomGaussianBlur": {
        "module": "kornia.augmentation",
        "kernel_size": 9,
        "p": 0.2,
        "sigma": [0.01, 0.2],
        "keepdim": True,
    },
    "RandomGaussianNoise": {
        "module": "kornia.augmentation",
        "p": 0.2,
        "std": 0.05,
        "keepdim": True,
    },
    "RandomErasing": {
        "module": "kornia.augmentation",
        "p": 0.2,
        "keepdim": True,
    },
}
transformations = make_transformations({"image": transform_params})

## Dataset Definitions

In [None]:
print("Making Train Dataset")
train_dataset = PairedNonGeoDataset(NonGeoSSL4EOS12Sentinel2(str(train_root), bands=["B2", "B3", "B4", "B8"], transforms=transformations), size=patch_size[1:3], max_r=64)

In [None]:
sampler = RandomSampler(train_dataset, num_samples=256)
dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_sample_pairs, batch_size=batch_size, num_workers=0)

In [None]:
batch_pair = next(iter(dataloader))

for j, batch in enumerate(batch_pair):
    images = batch["image"]
    for i, image in enumerate(images):
        x = torch.from_numpy(np.array(image)).float()

        bins = torch.linspace(0, 10000 / normalisation_factor, 128)
        hist = [torch.histogram(c, bins=bins) for c in x]
        
        plt.figure(figsize=(3, 3))

        plt.plot(hist[0].bin_edges[:-1], hist[0].hist, color="b")
        plt.plot(hist[1].bin_edges[:-1], hist[1].hist, color="g")
        plt.plot(hist[2].bin_edges[:-1], hist[2].hist, color="r")
        plt.plot(hist[3].bin_edges[:-1], hist[3].hist, color="orange")
        plt.show()

In [None]:
batch_pair = next(iter(dataloader))

# Setup the figure.
fig, ax = plt.subplots(nrows=batch_size, ncols=2, figsize=(2, batch_size))

plt.axis('off')

for j, batch in enumerate(batch_pair):
    images = batch["image"]
    for i, image in enumerate(images):
            image = bgr_to_rgb(image[0:3, :, :]).permute(1, 2, 0)
            
            ax[i, j].imshow(image)
            ax[i, j].axes.get_xaxis().set_visible(False)
            ax[i, j].axes.get_yaxis().set_visible(False)
            #ax[i, j].set_title(f"Sample {i}")

fig.tight_layout()
fig.show()

In [None]:
print("Making Validation Dataset")
val_dataset = DFC2020(val_root, split="val", use_s2hr=True, labels=True)

In [None]:
val_labels = [sample["mask"] for sample in val_dataset]
    

In [None]:
class_dist = find_modes(val_labels, plot=False)

# Finds the empty classes and returns modified classes, a dict to convert between the old and new systems
# and new colours.
new_classes, forwards, new_colours = eliminate_classes(find_empty_classes(class_dist), val_dataset.classes, val_dataset.colours)


In [None]:
class_transform = ClassTransform(forwards)
val_dataset.transforms = MinervaCompose({"image": Normalise(4095), "mask": class_transform})

In [None]:
valsampler = RandomSampler(val_dataset, num_samples=120)
valdataloader = DataLoader(val_dataset, sampler=valsampler, collate_fn=stack_samples, batch_size=batch_size, num_workers=2)
valdata = list(valdataloader)[0]

In [None]:
for i, image in enumerate(valdata["image"]):
    
    x = torch.from_numpy(np.array(image)).float()

    bins = torch.linspace(0, 1, 128)
    hist = [torch.histogram(c, bins=bins) for c in x]

    plt.figure(figsize=(3, 3))

    plt.plot(hist[0].bin_edges[:-1], hist[0].hist, color="b")
    plt.plot(hist[1].bin_edges[:-1], hist[1].hist, color="g")
    plt.plot(hist[2].bin_edges[:-1], hist[2].hist, color="r")
    plt.plot(hist[3].bin_edges[:-1], hist[3].hist, color="orange")
    plt.show()

In [None]:
test_root = root / "EuroSat100"

In [None]:
test_dataset = EuroSAT100(str(test_root), split="test", bands=["B04", "B03", "B02", "B08"], download=True)

In [None]:
# Loss functions for the SimConvNet and the downstream PSPNet.
crit = SegBarlowTwinsLoss()
xentropy = CrossEntropyLoss(ignore_index=255)

# Criterions are normally parsed to models at init in minerva.
model = SimConv(crit, input_size=patch_size, feature_dim=512, projection_dim=128).to(
    device
)
opt = Adam(model.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
model.set_optimiser(opt)
model.determine_output_dim(sample_pairs=True)

## Training & Validation Loop

In [None]:
n_epochs = 50  # Number of epoches to conduct.
f_val = 10  # Frequency of downstream validation in number of training epoches.
n_classes = len(new_classes)
cmap_style = ListedColormap(new_colours.values(), N=len(new_colours))
cmap = get_mlp_cmap(cmap_style, len(new_classes))

for epoch in range(n_epochs):
    losses = []
    euc_dists = []
    collapse_levels = []
    avg_loss = 0.0
    avg_std = 0.0
    for i, batch in enumerate(dataloader):
        x_i_batch, x_j_batch = (
            batch[0]["image"].to(device).float(),
            batch[1]["image"].to(device).float(),
        )

        x_batch = torch.stack([x_i_batch, x_j_batch])

        # Uses MinervaModel.step.
        loss, pred = model.step(x_batch, train=True)
        losses.append(loss.item())

        z = pred.flatten(1, -1)

        z_a, z_b = torch.split(z, int(0.5 * len(z)), 0)

        euc_dist = calc_norm_euc_dist(z_a.detach().cpu(), z_b.detach().cpu())
        euc_dists.append(euc_dist / len(z_a))

        output = torch.nn.functional.normalize(z_a, dim=1)

        std = torch.std(output, 0).mean()

        # use moving averages to track the loss and standard deviation
        w = 0.9
        avg_loss = w * avg_loss + (1 - w) * loss.item()
        avg_std = w * avg_std + (1 - w) * std.item()

        # the level of collapse is large if the standard deviation of the l2
        # normalized output is much smaller than 1 / sqrt(dim)
        collapse_level = 1 - avg_std / np.sqrt(len(output))

        collapse_levels.append(collapse_level)

    print(
        f"Train {epoch}| Loss: {np.mean(losses)}| Euc_dist: {np.mean(euc_dists)} | Collapse Level: {np.mean(collapse_levels) * 100.0}%"
    )

    if epoch % f_val == 0:
        # Extract encoder from the model and freeze its weights.
        encoder = model.backbone
        encoder = encoder.requires_grad_(False)

        # Construct a new PSPNet.
        psp = MinervaWrapper(
            PSPNet,
            input_size=patch_size,
            criterion=xentropy,
            n_classes=n_classes,
            encoder_name="resnet18",
            classes=n_classes,
            in_channels=4,
        ).to(device)

        # Replace its encoder and decoder with our pre-trained encoder (which is a PSP encoder-decoder).
        psp.decoder = encoder.decoder
        psp.encoder = encoder.encoder

        # Set up the optimiser for the PSP.
        psp_opt = Adam(psp.parameters(), lr=0.001)
        psp.set_optimiser(psp_opt)

        opt_losses = []

        # Train downstream PSP.
        for sample in valdataloader:
            x = sample["image"].to(device).float()
            y = sample["mask"].to(device).long()  # .squeeze(1)

            opt_loss, z = psp.step(x, y, train=True)
            opt_losses.append(opt_loss.item())

        # Use the pre-selected batch of data for visualisation of the PSP's results.
        image = valdata["image"].to(device).float()
        target = valdata["mask"].to(device).long()  # .squeeze(1)
        final_loss, pred = psp.step(image, target, train=False)
        opt_losses.append(final_loss.item())

        print(f"Val {epoch}| Loss: {np.mean(opt_losses)}")

        fig, axs = plt.subplots(3, pred.shape[0], figsize=(10, 4))
        for i in range(pred.shape[0]):
            axs[0, i].imshow(image[i].cpu().numpy()[:3].transpose(1, 2, 0))
            axs[1, i].imshow(
                target[i].cpu().numpy(), cmap=cmap, vmin=0, vmax=len(new_classes)
            )
            axs[2, i].imshow(
                pred[i].detach().argmax(dim=0).cpu().numpy(),
                cmap=cmap,
                vmin=0,
                vmax=len(new_classes),
            )
        plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
        plt.show()

## Downstream Test Model

In [None]:
from minerva.transforms import SelectChannels
from torchgeo.datasets import BigEarthNet

big_root = root / "BigEarthNet"

bigearthnet_transforms = MinervaCompose({"image": [SelectChannels([1, 2, 3, 7]), Normalise(2048)]})

bigearthnet_dataset = BigEarthNet(root=big_root, split="val", bands="s2", download=False, transforms=bigearthnet_transforms)

bigearthnet_sampler = RandomSampler(bigearthnet_dataset, num_samples=512)
bigearthnet_dataloader = DataLoader(bigearthnet_dataset, batch_size=8, sampler=bigearthnet_sampler, collate_fn=stack_samples)

In [None]:
batch = next(iter(bigearthnet_dataloader))

# Setup the figure.
fig, ax = plt.subplots(nrows=1, ncols=batch_size, figsize=(batch_size, 1))

plt.axis('off')

for i, image in enumerate(batch["image"]):
        image = bgr_to_rgb(image[0:3, :, :]).permute(1, 2, 0)
        
        ax[i].imshow(image)
        ax[i].axes.get_xaxis().set_visible(False)
        ax[i].axes.get_yaxis().set_visible(False)
        #ax[i, j].set_title(f"Sample {i}")

fig.tight_layout()
fig.show()

In [None]:
for i, image in enumerate(next(iter(bigearthnet_dataloader))["image"]):
    x = torch.from_numpy(np.array(image)).float()

    bins = torch.linspace(0, 1, 128)
    hist = [torch.histogram(c, bins=bins) for c in x]
    
    plt.figure(figsize=(3, 3))

    plt.plot(hist[0].bin_edges[:-1], hist[0].hist, color="b")
    plt.plot(hist[1].bin_edges[:-1], hist[1].hist, color="g")
    plt.plot(hist[2].bin_edges[:-1], hist[2].hist, color="r")
    plt.plot(hist[3].bin_edges[:-1], hist[3].hist, color="orange")
    plt.show()

In [None]:
from minerva.models import FlexiSceneClassifier
from torch.nn import BCELoss

n_epochs = 20

bigearthnet_model = FlexiSceneClassifier(
    criterion=BCELoss(),
    input_size=patch_size,
    n_classes=19,
    fc_dim=512,
    encoder_on=True,
    filter_dim=-1,
    freeze_backbone=False,
    clamp_outputs=True,
    backbone_args={
        "module": "minerva.models",
        "name": "MinervaPSP",
        "input_size": patch_size,
        "n_classes": 19,
        "encoder_name": "resnet18",
        "encoder_weights": "imagenet",
        "psp_out_channels": 512,
        "segmentation_on": False,
        "classification_on": False,
        "encoder": False,
    }
).to(device)

bigearthnet_model.train()

optimiser = Adam(bigearthnet_model.parameters(), lr=1.0e-2)
bigearthnet_model.set_optimiser(optimiser)

for epoch in range(n_epochs):
    bigearthnet_losses = []
    for batch in bigearthnet_dataloader:
        images = batch["image"].to(device)
        labels = batch["label"].to(device, dtype=torch.float)

        bigearthnet_model.optimiser.zero_grad()
        
        z = bigearthnet_model(images)
        
        z = z.clamp(0, 1)
        loss = bigearthnet_model.criterion(z, labels)
        
        loss.backward()
        bigearthnet_model.optimiser.step()
        
        bigearthnet_losses.append(loss.item())

    print(f"{epoch}: {np.mean(bigearthnet_losses)}")


In [None]:
fine_tune_root = root / "EuroSAT_MS"
fine_tune_dataset = EuroSAT(fine_tune_root, split="train", bands=["B04", "B03", "B02", "B08"])
fine_tune_sampler = RandomSampler(fine_tune_dataset, num_samples=512)
fine_tune_dataloader = DataLoader(fine_tune_dataset, batch_size=8, sampler=fine_tune_sampler, collate_fn=stack_samples)

In [None]:
test_loss = CrossEntropyLoss()
test_model = ResNet18(test_loss, input_size=patch_size, n_classes=10)

test_model.network.conv1 = model.backbone.encoder.conv1
test_model.network.layer1 = model.backbone.encoder.layer1
test_model.network.layer2 = model.backbone.encoder.layer2
test_model.network.layer3 = model.backbone.encoder.layer3
test_model.network.layer4 = model.backbone.encoder.layer4

test_model.to(device)

In [None]:
test_sampler = RandomSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, batch_size=8, sampler=test_sampler, collate_fn=stack_samples)

In [None]:
test_optimiser = Adam(test_model.parameters(), lr=1.0e-3)
test_model.set_optimiser(test_optimiser)

fine_tune_losses = []
for batch in fine_tune_dataloader:
    images = batch["image"].to(device)
    print(images.size())
    labels = batch["label"].to(device)

    loss, z = test_model.step(images, labels, train=True)
    
    fine_tune_losses.append(loss.item())

print(np.mean(fine_tune_losses))