This notebook is using Torchgeo's tutorials. For more information, you can find the tutorials on their [website](https://torchgeo.readthedocs.io/en/stable/tutorials/getting_started.html). 

To evaluate a model, you can directly refer to the `Evaluation` section.

# Setup

In [1]:
import os
import tempfile

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchgeo.datasets import NAIP, ChesapeakeDE
from torchgeo.datasets.utils import download_url, stack_samples
from torchgeo.models import resnet50 as resnet50_torchgeo
from torchgeo.samplers import RandomGeoSampler

  from .autonotebook import tqdm as notebook_tqdm


# Data

Let's create the splits for the training. For more information about this cell, refer to `1_data_exploration.ipynb` tutorial. 

In [2]:
data_root = tempfile.gettempdir()
naip_url = (
    "https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/"
)
TILES = [
    "m_3807511_ne_18_060_20181104.tif",
    "m_3807511_se_18_060_20181104.tif",
    "m_3807512_nw_18_060_20180815.tif"]
cache = True

In [4]:
# Training set
naip_root = os.path.join(data_root, "naip_train")
download_url(naip_url + TILES[0], naip_root)

chesapeake_root = os.path.join(data_root, "chesapeake")
chesapeake = ChesapeakeDE(chesapeake_root, download=True)

train_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
train_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

train_dataset = train_chesapeake & train_naip
# train_sampler = GridGeoSampler(train_dataset, size=1000, stride=500)
train_sampler = RandomGeoSampler(train_dataset, size=256, length=10000)

train_dataloader = DataLoader(
    train_dataset, batch_size=8, sampler=train_sampler, 
    collate_fn=stack_samples, 
)

Using downloaded and verified file: /tmp/naip_train/m_3807511_ne_18_060_20181104.tif
Downloading https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover/DE/_DE_STATEWIDE.zip to /tmp/chesapeake/_DE_STATEWIDE.zip


100%|██████████| 287350495/287350495 [06:22<00:00, 750520.28it/s] 


In [5]:
# Validation set
naip_root = os.path.join(data_root, "naip_val")
download_url(naip_url + TILES[1], naip_root)

val_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
val_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

val_dataset = val_chesapeake & val_naip
val_sampler = RandomGeoSampler(val_dataset, size=256, length=10000)

val_dataloader = DataLoader(
    val_dataset, batch_size=8, sampler=val_sampler, 
    collate_fn=stack_samples, shuffle=False,
)

Using downloaded and verified file: /tmp/naip_val/m_3807511_se_18_060_20181104.tif


# Model

In [6]:
# parameters
num_epochs = 10
batch_size = 8
learning_rate = 0.001

In [7]:
from torchgeo.models import FarSeg

# Load the pre-trained ResNet50 model
# Modify the last layer; there are 12 classes + 0 standing for no_data
model = FarSeg(backbone='resnet50', classes=13, backbone_pretrained=True)

In the code below, we first freeze all the layers except the first and last by setting `requires_grad` to False. Then, we modify the first layer to accept a 4-band image. Finally, we define the loss function and optimizer to only update the parameters that have `requires_grad set to True`. This is achieved using `filter` to select only the trainable parameters.

In [8]:
# Freeze all layers except the first and last
for name, param in model.named_parameters():
    if not ('conv1' in name or 'fc' in name):
        param.requires_grad = False

In [9]:
# Modify the first layer to accept a 4-band image
model.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=(3, 3), bias=False)

In [10]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

# Re-train resnet 50 weights

In [12]:
from tqdm import tqdm

# Train the model
for epoch in tqdm(range(10)):
    model.train()
    for i, sample in enumerate(train_dataloader):
        inputs = sample["image"].float()
        labels = sample["mask"]
        # labels = torch.nn.functional.one_hot(sample["mask"], num_classes=13).float()
        # labels = labels.squeeze(dim=1).permute(0, 3, 1, 2)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        loss = criterion(outputs.view(outputs.shape[0],
                                      outputs.shape[1],-1), 
                         labels.view(labels.shape[0],-1), 
                         )
        loss.backward()
        optimizer.step()
    print("Epoch {} - Train Loss: {:.4f}".format(epoch, loss.item()))
    path = os.path.join("res",'tl-weights_{}_epochs.pt'.format(epoch))
    torch.save(model.state_dict(), path)

    # Evaluate on validation set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for sample in val_dataloader:
            inputs = sample["image"].float()
            labels = sample["mask"]
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Validation Accuracy: {:.2f}%".format(100 * correct / total))

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0 - Train Loss: 1.8597


 10%|█         | 1/10 [1:26:34<12:59:09, 5194.41s/it]

Validation Accuracy: 26963242.43%
Epoch 1 - Train Loss: 1.6553


 20%|██        | 2/10 [2:55:03<11:41:34, 5261.79s/it]

Validation Accuracy: 27833313.72%
Epoch 2 - Train Loss: 1.6314


 30%|███       | 3/10 [4:21:49<10:10:53, 5236.26s/it]

Validation Accuracy: 27014443.73%
Epoch 3 - Train Loss: 1.5880


 40%|████      | 4/10 [5:47:45<8:40:27, 5204.55s/it] 

Validation Accuracy: 26168970.60%
Epoch 4 - Train Loss: 1.7663


 50%|█████     | 5/10 [7:12:55<7:10:52, 5170.55s/it]

Validation Accuracy: 27271754.02%
Epoch 5 - Train Loss: 1.7553


 60%|██████    | 6/10 [8:37:11<5:42:05, 5131.47s/it]

Validation Accuracy: 26465652.95%
Epoch 6 - Train Loss: 1.4781


 70%|███████   | 7/10 [10:01:31<4:15:25, 5108.35s/it]

Validation Accuracy: 26647050.64%
Epoch 7 - Train Loss: 1.6379


 80%|████████  | 8/10 [11:25:11<2:49:20, 5080.13s/it]

Validation Accuracy: 27517238.86%
Epoch 8 - Train Loss: 1.6162


 90%|█████████ | 9/10 [12:48:50<1:24:20, 5061.00s/it]

Validation Accuracy: 27153539.46%
Epoch 9 - Train Loss: 1.4646


100%|██████████| 10/10 [14:13:49<00:00, 5122.93s/it] 

Validation Accuracy: 26814570.27%





# Evaluate on test set

In [11]:
import torch

# Load the model architecture
from torchgeo.models import FarSeg

In [12]:
# Validation set
naip_root = os.path.join(data_root, "naip_test")
download_url(naip_url + TILES[2], naip_root)

test_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
test_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

test_dataset = test_chesapeake & test_naip
test_sampler = RandomGeoSampler(test_dataset, size=256, length=1000)

test_dataloader = DataLoader(
    test_dataset, batch_size=8, sampler=test_sampler, 
    collate_fn=stack_samples, shuffle=False,
)

Downloading https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/m_3807512_nw_18_060_20180815.tif to /tmp/naip_test/m_3807512_nw_18_060_20180815.tif


100%|██████████| 489865657/489865657 [01:27<00:00, 5571523.95it/s]


In [13]:
# set the backbone_pretrained argument to False when initializing the model architecture, so that the weights you load
# are not overwritten by the pre-trained weights.
model = FarSeg(backbone='resnet50', classes=13, backbone_pretrained=False)
model.backbone.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2,
                                 padding=(3, 3), bias=False)

In [14]:
# Load the model weights
weights_file = "res/tl-weights_9_epochs.pt"
model.load_state_dict(torch.load(weights_file))

<All keys matched successfully>

In [15]:
# Set the model to evaluation mode
model.eval()

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Evaluate on test set
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for sample in test_dataloader:
        inputs = sample["image"].float()
        labels = sample["mask"]
        outputs = model(inputs)
        loss = criterion(outputs.view(outputs.shape[0], outputs.shape[1], -1), labels.view(labels.shape[0], -1))
        test_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_dataloader.dataset)
print("Test Loss: {:.4f}".format(test_loss))
print("Test Accuracy: {:.2f}%".format(100 * correct / total))


Test Loss: 1595.2490
Test Accuracy: 26273584.70%
