# Setup

In [2]:
import os
import tempfile
import time
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from torchgeo.datasets import NAIP, ChesapeakeDE
from torchgeo.datasets.utils import download_url, stack_samples
from torchgeo.models import resnet50 as resnet50_torchgeo
from torchgeo.samplers import GridGeoSampler

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeo.datasets as dg_datasets
import torchgeo.models as dg_models

# Data

Let's create the splits for the training. For more information about this cell, refer to `1_data_exploration.ipynb` tutorial. 

In [3]:
data_root = tempfile.gettempdir()
naip_url = (
    "https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/"
)
tiles = [
    "m_3807511_ne_18_060_20181104.tif",
    "m_3807511_se_18_060_20181104.tif",
]
cache = True

# Training set
naip_root = os.path.join(data_root, "naip_train")
download_url(naip_url + tiles[0], naip_root)

chesapeake_root = os.path.join(data_root, "chesapeake_train")
chesapeake = ChesapeakeDE(chesapeake_root, download=True)

train_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
train_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

train_dataset = train_chesapeake & train_naip
train_sampler = GridGeoSampler(train_dataset, size=1000, stride=500)
train_dataloader = DataLoader(
    train_dataset, batch_size=12, sampler=train_sampler, 
    collate_fn=stack_samples, shuffle=True,
)

# Validation set
naip_root = os.path.join(data_root, "naip_val")
download_url(naip_url + tiles[1], naip_root)

chesapeake_root = os.path.join(data_root, "chesapeake_val")
chesapeake = ChesapeakeDE(chesapeake_root, download=True)

val_chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)
val_naip = NAIP(naip_root, crs=chesapeake.crs, 
                  res=chesapeake.res, cache=cache)

val_dataset = val_chesapeake & val_naip
val_sampler = GridGeoSampler(val_dataset, size=1000, stride=500)
val_dataloader = DataLoader(
    val_dataset, batch_size=12, sampler=val_sampler, 
    collate_fn=stack_samples, shuffle=False,
)

Downloading https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/m_3807511_se_18_060_20181104.tif to /tmp/naip_val/m_3807511_se_18_060_20181104.tif


100%|██████████| 521985441/521985441 [02:52<00:00, 3034274.72it/s]


Downloading https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover/DE/_DE_STATEWIDE.zip to /tmp/chesapeake_val/_DE_STATEWIDE.zip


 75%|███████▌  | 215875584/287350495 [05:37<07:30, 158682.46it/s] 

# Model

In [1]:
# Load the pre-trained ResNet50 model
model = resnet50_torchgeo(pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /home/anne/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


In [None]:
# Modify the last layer
num_classes = len(train_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

In [None]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Re-train resnet 50 weights

In [None]:
# Train the model
for epoch in range(10):
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print("Epoch {} - Train Loss: {:.4f}".format(epoch, loss.item()))

    # Evaluate on validation set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Validation Accuracy: {:.2f}%".format(100 * correct / total))

# Evaluate on test set

In [None]:
# Evaluate on test set
model.eval()
correct = 0
total = 0