# Torchgeo FCN Demo

This notebook is a small demo of using a small amount of NAIP imagery and Chesapeake Land Cover data to train both `minerva` and `torchvision` FCNs within the `minerva` framework.

## Imports

In [1]:
from pathlib import Path

from torch.utils.data import DataLoader
import torch.nn as nn
from torchgeo.datasets import stack_samples
from torchgeo.samplers import RandomGeoSampler
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt
from rasterio.crs import CRS

In [2]:
from minerva.models import SimConv
from minerva.loss import SegBarlowTwinsLoss
from minerva.utils.utils import get_cuda_device
from minerva.datasets import SSL4EOS12Sentinel2, PairedDataset, stack_sample_pairs
from minerva.samplers import RandomPairGeoSampler

device = get_cuda_device(0)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
EPSG3857 = CRS.from_epsg("3857")

In [4]:
root = Path(input()) / "SSL4EO-S12/100patches/"

train_root = root / "s2a"
test_root = root / "s2a"

In [5]:
patch_size = (4, 64, 64)

## Dataset Definitions

In [6]:
print("Making Train Dataset")
train_dataset = PairedDataset(SSL4EOS12Sentinel2, train_root, res=10.0, crs=EPSG3857, bands=["B2", "B3", "B4", "B8"])

print("Making Test Dataset")
test_dataset = SSL4EOS12Sentinel2(test_root, res=10.0, crs=EPSG3857)


Making Train Dataset
Making Test Dataset


In [7]:
sampler = RandomPairGeoSampler(train_dataset, size=patch_size[1], length=20)
dataloader = DataLoader(train_dataset, sampler=sampler, collate_fn=stack_sample_pairs, batch_size=6)

testsampler = RandomPairGeoSampler(test_dataset, size=patch_size[1], length=8)
testdataloader = DataLoader(test_dataset, sampler=testsampler, collate_fn=stack_sample_pairs, batch_size=6, num_workers=1)
#testdata = list(testdataloader)[0]

## Minerva FCN Example

In [8]:
crit = SegBarlowTwinsLoss()

# Criterions are normally parsed to models at init in minerva.
model = SimConv(crit, input_size=patch_size).to(device)
opt = Adam(model.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
model.set_optimiser(opt)

for epoch in range(101):
  losses = []
  for i, batch in enumerate(dataloader):
    x_i_batch, x_j_batch = batch[0]["image"].to(device).float() / 255.0, batch[1]["image"].to(device).float() / 255.0
    
    x_batch = torch.stack([x_i_batch, x_j_batch])
    
    # Uses MinervaModel.step.
    loss, pred = model.step(x_batch, train=True)
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  #if epoch % 10 == 0:
  #  with torch.no_grad():
  #    image = testdata["image"].to(device).float() / 255.0
  #    pred = model(image)

  #    fig, axs = plt.subplots(2, pred.shape[0], figsize=(10,4))
  #    for i in range(pred.shape[0]):
  #      axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
  #      axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
  #    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
  #    plt.show()

0 204.65859603881836
1 117.51788902282715
2 101.58391761779785
3 94.92309761047363
4 70.7992000579834
5 68.69688034057617
6 57.2332878112793
7 56.10660171508789
8 57.67892837524414
9 54.56737995147705
10 47.94337558746338
11 45.80367374420166
12 42.35103225708008
13 40.37031173706055
14 38.69046401977539
15 36.342968463897705
16 43.49415636062622
17 32.11010503768921
18 32.62216091156006
19 33.227821350097656
20 57.51835346221924
21 39.57969045639038
22 42.99089336395264
23 38.5181770324707
24 36.35181522369385
25 33.424870014190674
26 32.28605318069458
27 30.376003742218018
28 33.23812532424927
29 27.98432493209839
30 30.210110187530518
31 28.575733184814453
32 27.50907564163208
33 24.115703105926514
34 23.671586513519287
35 33.2522406578064
36 25.604840755462646
37 25.95659351348877
38 24.948445796966553
39 26.852287769317627
40 22.53956365585327
41 23.06619930267334
42 22.46173620223999
43 36.43561315536499
44 24.50643014907837
45 23.82144069671631
46 22.48682403564453
47 27.1503777

## Torchvision FCN Example

In [9]:
fcn = fcn_resnet50(num_classes=13).to(device)
fcn.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)

crit = CrossEntropyLoss()
opt = Adam(fcn.parameters(), lr=1e-3)

for epoch in range(101):
  losses = []
  for i, sample in enumerate(dataloader):
    image = sample["image"].to(device).float() / 255.0
    target = sample["mask"].to(device).long().squeeze(1)

    opt.zero_grad()
    pred = fcn(image)["out"]
    loss = crit(pred, target)
    loss.backward()
    opt.step()
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      image = testdata["image"].to(device).float() / 255.0
      target = testdata["mask"].to(device).long().squeeze(1)
      pred = fcn(image)["out"]

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()

NameError: name 'fcn_resnet50' is not defined