# Torchgeo FCN Demo

This notebook is a small demo of using a small amount of NAIP imagery and Chesapeake Land Cover data to train both `minerva` and `torchvision` FCNs within the `minerva` framework.

## Imports

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
import torch.nn as nn
from torchgeo.datasets import stack_samples
from torchgeo.samplers import RandomGeoSampler
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from minerva.models import SimConv
from minerva.loss import SegBarlowTwinsLoss
from minerva.utils.utils import get_cuda_device
from minerva.datasets import SSL4EOS12Sentinel2

device = get_cuda_device(0)

In [None]:
root = Path(input()) / "SSL4EO-S12/ssl4eo-s12/"

train_root = root / "s2a"
test_root = root / "s2a"

## Dataset Definitions

In [None]:
train_dataset = SSL4EOS12Sentinel2(train_root)
test_dataset = SSL4EOS12Sentinel2(test_root)

sampler = RandomGeoSampler(train_dataset, size=256, length=200)
dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_samples, batch_size=32)

testsampler = RandomGeoSampler(test_dataset, size=256, length=8)
testdataloader = DataLoader(test_dataset, sampler=testsampler, collate_fn=stack_samples, batch_size=8, num_workers=4)
testdata = list(testdataloader)[0]

## Minerva FCN Example

In [None]:
crit = SegBarlowTwinsLoss()

# Criterions are normally parsed to models at init in minerva.
model = SimConv(crit, input_size=(4, 256, 256), n_classes=13).to(device)
opt = Adam(model.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
model.set_optimiser(opt)

for epoch in range(101):
  losses = []
  for i, sample in enumerate(dataloader):
    image = sample["image"].to(device).float() / 255.0
    
    # Uses MinervaModel.step.
    loss, pred = model.step(image, target, train=True)
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      image = testdata["image"].to(device).float() / 255.0
      pred = model(image)

      fig, axs = plt.subplots(2, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()

## Torchvision FCN Example

In [None]:
fcn = fcn_resnet50(num_classes=13).to(device)
fcn.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)

crit = CrossEntropyLoss()
opt = Adam(fcn.parameters(), lr=1e-3)

for epoch in range(101):
  losses = []
  for i, sample in enumerate(dataloader):
    image = sample["image"].to(device).float() / 255.0
    target = sample["mask"].to(device).long().squeeze(1)

    opt.zero_grad()
    pred = fcn(image)["out"]
    loss = crit(pred, target)
    loss.backward()
    opt.step()
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      image = testdata["image"].to(device).float() / 255.0
      target = testdata["mask"].to(device).long().squeeze(1)
      pred = fcn(image)["out"]

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()