# Setup

In [1]:
import os
import tempfile
import time
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from torchgeo.datasets import NAIP, ChesapeakeDE
from torchgeo.datasets.utils import download_url, stack_samples
from torchgeo.models import resnet50 as resnet50_torchgeo
from torchgeo.samplers import RandomGeoSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeo.datasets as dg_datasets
import torchgeo.models as dg_models

# Data

Let's create the splits for the training. For more information about this cell, refer to `1_data_exploration.ipynb` tutorial. 

In [3]:
data_root = tempfile.gettempdir()
naip_url = (
    "https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/"
)
TILES = [
    "m_3807511_ne_18_060_20181104.tif",
    "m_3807511_se_18_060_20181104.tif",
    "m_3807512_nw_18_060_20180815.tif"]
cache = True

In [4]:
# Training set
naip_root = os.path.join(data_root, "naip_train")
download_url(naip_url + TILES[0], naip_root)

chesapeake_root = os.path.join(data_root, "chesapeake_train")
chesapeake = ChesapeakeDE(chesapeake_root, download=True)

train_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
train_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

train_dataset = train_chesapeake & train_naip
# train_sampler = GridGeoSampler(train_dataset, size=1000, stride=500)
train_sampler = RandomGeoSampler(train_dataset, size=256, length=10000)

train_dataloader = DataLoader(
    train_dataset, batch_size=8, sampler=train_sampler, 
    collate_fn=stack_samples, 
)

Using downloaded and verified file: /tmp/naip_train/m_3807511_ne_18_060_20181104.tif


In [5]:
# Validation set
naip_root = os.path.join(data_root, "naip_val")
download_url(naip_url + TILES[1], naip_root)

# chesapeake_root = os.path.join(data_root, "chesapeake_val")
# chesapeake = ChesapeakeDE(chesapeake_root, download=True)

val_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
val_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

val_dataset = val_chesapeake & val_naip
# val_sampler = GridGeoSampler(val_dataset, size=1000, stride=500)
val_sampler = RandomGeoSampler(val_dataset, size=256, length=10000)

val_dataloader = DataLoader(
    val_dataset, batch_size=8, sampler=val_sampler, 
    collate_fn=stack_samples, shuffle=False,
)

Using downloaded and verified file: /tmp/naip_val/m_3807511_se_18_060_20181104.tif


In [6]:
next(iter(val_dataloader))["image"]

tensor([[[[ 50.,  47.,  35.,  ..., 191., 191., 189.],
          [ 51.,  37.,  46.,  ..., 181., 181., 189.],
          [ 50.,  58.,  71.,  ..., 186., 187., 183.],
          ...,
          [183., 180., 179.,  ..., 186., 187., 190.],
          [182., 186., 180.,  ..., 190., 183., 189.],
          [177., 187., 181.,  ..., 187., 184., 189.]],

         [[ 54.,  43.,  42.,  ..., 182., 187., 178.],
          [ 54.,  43.,  44.,  ..., 174., 175., 178.],
          [ 46.,  57.,  61.,  ..., 181., 176., 176.],
          ...,
          [181., 166., 166.,  ..., 180., 181., 173.],
          [172., 173., 171.,  ..., 181., 180., 178.],
          [172., 173., 171.,  ..., 177., 179., 178.]],

         [[ 51.,  43.,  42.,  ..., 158., 158., 150.],
          [ 50.,  50.,  53.,  ..., 152., 147., 150.],
          [ 45.,  53.,  63.,  ..., 155., 151., 142.],
          ...,
          [155., 142., 140.,  ..., 158., 153., 153.],
          [148., 151., 145.,  ..., 156., 154., 151.],
          [150., 147., 148.,  ...

# Model

In [7]:
# parameters
num_epochs = 10
batch_size = 8
learning_rate = 0.001

In [8]:
from torchgeo.models import FarSeg

# Load the pre-trained ResNet50 model
# Modify the last layer; there are 12 classes + 0 standing for no_data
model = FarSeg(backbone='resnet50', classes=13, backbone_pretrained=True)

In the code below, we first freeze all the layers except the first and last by setting `requires_grad` to False. Then, we modify the first layer to accept a 4-band image. Finally, we define the loss function and optimizer to only update the parameters that have `requires_grad set to True`. This is achieved using `filter` to select only the trainable parameters.

In [9]:
# Freeze all layers except the first and last
for name, param in model.named_parameters():
    if not ('conv1' in name or 'fc' in name):
        param.requires_grad = False

In [10]:
# Modify the first layer to accept a 4-band image
model.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=(3, 3), bias=False)

In [11]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

# Re-train resnet 50 weights

In [12]:
from tqdm import tqdm

# Train the model
for epoch in tqdm(range(10)):
    model.train()
    for i, sample in enumerate(train_dataloader):
        inputs = sample["image"].float()
        labels = sample["mask"]
        # labels = torch.nn.functional.one_hot(sample["mask"], num_classes=13).float()
        # labels = labels.squeeze(dim=1).permute(0, 3, 1, 2)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs.view(outputs.shape[0],
                                      outputs.shape[1],-1), 
                         labels.view(labels.shape[0],-1), 
                         )
        loss.backward()
        optimizer.step()
    print("Epoch {} - Train Loss: {:.4f}".format(epoch, loss.item()))
    path = os.path.join("res",'try2-weights_{}_epochs.pt'.format(epoch))
    torch.save(model.state_dict(), path)

    # Evaluate on validation set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for sample in val_dataloader:
            inputs = sample["image"].float()
            labels = sample["mask"]
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Validation Accuracy: {:.2f}%".format(100 * correct / total))

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0 - Train Loss: 1.8597


 10%|█         | 1/10 [1:26:34<12:59:09, 5194.41s/it]

Validation Accuracy: 26963242.43%
Epoch 1 - Train Loss: 1.6553


 20%|██        | 2/10 [2:55:03<11:41:34, 5261.79s/it]

Validation Accuracy: 27833313.72%
Epoch 2 - Train Loss: 1.6314


 30%|███       | 3/10 [4:21:49<10:10:53, 5236.26s/it]

Validation Accuracy: 27014443.73%
Epoch 3 - Train Loss: 1.5880


 40%|████      | 4/10 [5:47:45<8:40:27, 5204.55s/it] 

Validation Accuracy: 26168970.60%
Epoch 4 - Train Loss: 1.7663


 50%|█████     | 5/10 [7:12:55<7:10:52, 5170.55s/it]

Validation Accuracy: 27271754.02%
Epoch 5 - Train Loss: 1.7553


 60%|██████    | 6/10 [8:37:11<5:42:05, 5131.47s/it]

Validation Accuracy: 26465652.95%
Epoch 6 - Train Loss: 1.4781


 70%|███████   | 7/10 [10:01:31<4:15:25, 5108.35s/it]

Validation Accuracy: 26647050.64%
Epoch 7 - Train Loss: 1.6379


 80%|████████  | 8/10 [11:25:11<2:49:20, 5080.13s/it]

Validation Accuracy: 27517238.86%
Epoch 8 - Train Loss: 1.6162


 90%|█████████ | 9/10 [12:48:50<1:24:20, 5061.00s/it]

Validation Accuracy: 27153539.46%
Epoch 9 - Train Loss: 1.4646


100%|██████████| 10/10 [14:13:49<00:00, 5122.93s/it] 

Validation Accuracy: 26814570.27%





- 1 hour to train 1 epoch (try 1: after 2/3 epochs)
- train 10 epochs in: (try2: losses= 1.8597, 1.6553, 1.6314, 1.5880, 1.7663, 1.7553, 1.4781, 1.6379, 1.6162, )

# Evaluate on test set

In [None]:
# Evaluate on test set
# Set the model to evaluation mode
model.eval()

# Evaluate on test set
correct = 0
total = 0
with torch.no_grad():
    for sample in test_dataloader:
        inputs = sample["image"].float()
        labels = sample["mask"]
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print("Test Accuracy: {:.2f}%".format(100 * correct / total))
