In [None]:
import os
import tempfile

from torch.utils.data import DataLoader
from torchvision.models.segmentation import fcn_resnet50
import torch.nn as nn
from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from minerva.models import FCN8ResNet18
from minerva.utils.utils import get_cuda_device

device = get_cuda_device()

In [None]:
data_root = tempfile.gettempdir()
naip_root = os.path.join(data_root, "naip")
naip_url = "https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/"
tiles = [
    "m_3807511_ne_18_060_20181104.tif",
    "m_3807511_se_18_060_20181104.tif",
    "m_3807512_nw_18_060_20180815.tif",
    "m_3807512_sw_18_060_20180815.tif",
]
for tile in tiles:
    download_url(naip_url + tile, naip_root)

naip = NAIP(naip_root)

chesapeake_root = os.path.join(data_root, "chesapeake")

chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True)

dataset = naip & chesapeake

sampler = RandomGeoSampler(naip, size=300, length=200)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples, batch_size=32)

testsampler = RandomGeoSampler(naip, size=300, length=8)
testdataloader = DataLoader(dataset, sampler=testsampler, collate_fn=stack_samples, batch_size=8, num_workers=4)
testdata = list(testdataloader)[0]

In [None]:
crit = CrossEntropyLoss()

# Criterions are normally parsed to models at init in minerva.
fcn = FCN8ResNet18(crit, input_size=(4, 300, 300), n_classes=13).to(device)
opt = Adam(fcn.parameters(), lr=1e-3)

# Optimisers need to be set to a model in minerva before training.
fcn.set_optimiser(opt)

for epoch in range(101):
  losses = []
  for i, sample in enumerate(dataloader):
    image = sample["image"].to(device).float() / 255.0
    target = sample["mask"].to(device).long().squeeze(1)
    
    # Uses MinervaModel.step.
    loss, pred = fcn.step(image, target, train=True)
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      image = testdata["image"].to(device).float() / 255.0
      target = testdata["mask"].to(device).long().squeeze(1)
      pred = fcn(image)

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()

In [None]:
fcn = fcn_resnet50(num_classes=13).to(device)
fcn.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)

crit = CrossEntropyLoss()
opt = Adam(fcn.parameters(), lr=1e-3)

for epoch in range(101):
  losses = []
  for i, sample in enumerate(dataloader):
    image = sample["image"].to(device).float() / 255.0
    target = sample["mask"].to(device).long().squeeze(1)

    opt.zero_grad()
    pred = fcn(image)["out"]
    loss = crit(pred, target)
    loss.backward()
    opt.step()
    losses.append(loss.item())

  print(epoch, np.mean(losses))
  if epoch % 10 == 0:
    with torch.no_grad():
      image = testdata["image"].to(device).float() / 255.0
      target = testdata["mask"].to(device).long().squeeze(1)
      pred = fcn(image)["out"]

      fig, axs = plt.subplots(3, pred.shape[0], figsize=(10,4))
      for i in range(pred.shape[0]):
        axs[0,i].imshow(image[i].cpu().numpy()[:3].transpose(1,2,0))
        axs[1,i].imshow(target[i].cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
        axs[2,i].imshow(pred[i].detach().argmax(dim=0).cpu().numpy(), cmap="Set3", vmin=0, vmax=12)
      plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
      plt.show()