# Torchgeo FCN Demo

This notebook is a small demo of using a small amount of NAIP imagery and Chesapeake Land Cover data to train both `minerva` and `torchvision` FCNs within the `minerva` framework.

## Imports

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
import torch.nn as nn
from torchgeo.datasets import NAIP, Chesapeake13, stack_samples
from torchgeo.samplers import RandomGeoSampler
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt
from rasterio.crs import CRS

In [None]:
from minerva.models import SimConv
from minerva.loss import SegBarlowTwinsLoss
from minerva.utils.utils import get_cuda_device
from minerva.datasets import SSL4EOS12Sentinel2, PairedDataset, stack_sample_pairs
from minerva.samplers import RandomPairGeoSampler

device = get_cuda_device(0)

In [None]:
EPSG3857 = CRS.from_epsg("3857")

In [None]:
root = Path(input())

train_root = root / "SSL4EO-S12/100patches/s2a"
test_image_root = root / "NAIP/NAIP 2013/Test"
test_mask_root = root / "Chesapeake13"

In [None]:
patch_size = (4, 64, 64)

## Dataset Definitions

In [None]:
print("Making Train Dataset")
train_dataset = PairedDataset(SSL4EOS12Sentinel2, train_root, res=10.0, crs=EPSG3857, bands=["B2", "B3", "B4", "B8"])

print("Making Test Dataset")
test_image_dataset = NAIP(test_image_root, res=1.0)
test_mask_dataset = Chesapeake13(test_mask_root, res=1.0)
test_dataset = test_image_dataset & test_mask_dataset

In [None]:
sampler = RandomPairGeoSampler(train_dataset, size=patch_size[1], length=20)
dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_sample_pairs, batch_size=6)

testsampler = RandomGeoSampler(test_dataset, size=patch_size[1], length=8)
testdataloader = DataLoader(test_dataset, sampler=testsampler, collate_fn=stack_samples, batch_size=6, num_workers=1)
testdata = list(testdataloader)[0]

## Training & Validation Loop

In [None]:
crit = SegBarlowTwinsLoss()

# Criterions are normally parsed to models at init in minerva.
model = SimConv(crit, input_size=patch_size).to(device)
opt = Adam(model.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
model.set_optimiser(opt)

for epoch in range(101):
  losses = []
  for i, batch in enumerate(dataloader):
    x_i_batch, x_j_batch = batch[0]["image"].to(device).float() / 255.0, batch[1]["image"].to(device).float() / 255.0

    x_batch = torch.stack([x_i_batch, x_j_batch])

    # Uses MinervaModel.step.
    loss, pred = model.step(x_batch, train=True)
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      encoder = model.backbone
      image = testdata["image"].to(device).float() / 255.0
      target = testdata["mask"].to(device).long().squeeze(1)
      pred = encoder(image)

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()