<a href="https://colab.research.google.com/github/Wycology/dl_tea_mapping/blob/main/dl_tea.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <font color='green'><b> SATELLITE DATA FOR AGRICULTURAL ECONOMISTS</b></font>


<font color='blue'><b>THEORY AND PRACTICE</b></font>

**Mapping tea plantations in Central Kenya: _Deep Learning Approach_**


*David Wuepper, Lisa Biber-Freudenberger, Hadi, Wyclife Agumba Oluoch*

[Land Economics Group](https://www.ilr1.uni-bonn.de/en/research/research-groups/land-economics), University of Bonn, Bonn, Germany

---

# **Background**


---

In this tutorial, we introduce basics of using deep learning approach to segment tea fields with a practical example at the foot of Mount Kenya. We obtained high resolution Sentinel-2 satellite image from [Google Earth Engine](https://code.earthengine.google.com/60cf3e783458009bd8378eaded30f5c7). On the other hand, we obtained labels by manually digitizing tea plantations within QGIS using Google Satellite Hybrid basemap. The labels cover a small portion of the downloaded Satellite image so that we can train the model and use it to segment tea fields elsewhere.
We used torchgeo for this modeling task due to the following reasons:
1. It is simple to use, eliminating a lot of issues such as georeferencing, chipping, label creation.
2. It also maximally obtain training samples from the region of interest.

## Loading libraries
---
Since `torchgeo` is not natively installed in colab, we will have to install it. We will also install `torchseg` to help with the segmentation work. Other supporting libraries will just be imported as they are already pre-installed in colab.

In [1]:
# Install libraries not already available in colab
!pip install torchgeo
!pip install torchseg

Collecting torchgeo
  Downloading torchgeo-0.7.0-py3-none-any.whl.metadata (19 kB)
Collecting fiona>=1.8.22 (from torchgeo)
  Downloading fiona-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kornia>=0.7.4 (from torchgeo)
  Downloading kornia-0.8.0-py2.py3-none-any.whl.metadata (17 kB)
Collecting lightly!=1.4.26,>=1.4.5 (from torchgeo)
  Downloading lightly-1.5.20-py3-none-any.whl.metadata (37 kB)
Collecting lightning!=2.3.*,!=2.5.0,>=2 (from lightning[pytorch-extra]!=2.3.*,!=2.5.0,>=2->torchgeo)
  Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)
Collecting rasterio!=1.4.0,!=1.4.1,!=1.4.2,>=1.3.3 (from torchgeo)
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting rtree>=1.0.1 (from torchgeo)
  Downloading rtree-1.4.0-py3-none-manylinu

Collecting torchseg
  Downloading torchseg-0.0.1a4-py3-none-any.whl.metadata (12 kB)
Downloading torchseg-0.0.1a4-py3-none-any.whl (67 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.9/67.9 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchseg
Successfully installed torchseg-0.0.1a4


In [2]:
# Import the necessary libraries
import json
import torch
import rasterio
import torchseg
import torchgeo
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchgeo.transforms import AppendNDVI
from rasterio.transform import from_bounds
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler
from torchgeo.datasets import VectorDataset, RasterDataset, stack_samples

After installing the libraries, we make an important step of confirming the working directory. This is important since both our image and gpkg will be read from this location so we need to be sure of the path. We can use `pwd` function to print it.

In [None]:
pwd # Confirm the working directory

# Defining the Dataset
---
With `torchgeo`, we do not have to pre-chip the satellite image into small chips of say 256 by 256 pixels. This it achieves on the fly. However, we need to tell it the path to where our satellite image is. In fact, we can have several large images here. For now, it is the only _.tif_ image our working directory, wo we define the class as follows:

In [None]:
# Define the GeoTiff dataset class

class GeoTiffDataset(RasterDataset):
  filename_glob = "*.tif"
raster_data = GeoTiffDataset(paths = "/content")

In that case, raster_data is a blueprint of the satellite image we have in the directory. Next, we do the same for the label data. This label data is a _.gpkg_ file which has a column stating the identity of the each polygon as either tea or not tea. In other words, the class column. Here, we call the class column as `tea_no_tea`. This is very important as it is what the library uses to create a label binary layer under the hood to intersect with the satellite image. We achieve this as follows:

In [None]:
# Define the label (here vector but can also be mask raster) dataset class.
# Remember to include label name. Never forget this!

class LabelDataset(VectorDataset):
  filename_glob = "*.gpkg"
label_data = LabelDataset(paths = "/content", label_name = "tea_no_tea")

## Combining raster_data and label_data

Now that we have blueprints of both raster_data and label_data, the next step is to intersect the two. This will behave like _cropping_ or _clipping_ the raster to the extent of the label_data. This is the reason why we are not worried that the raster extent is bigger than the label extent. Chips for training the model will only be ontained from where the two datasets intersect/overlap. Regions outside the label_data will not be sampled. As simple as it can get, we achieve this intersection using an _&_ operator.

In [None]:
# Create the intersection of the raster and vector/label datasets

training_data = raster_data & label_data

You notice the printout that **Converting LabelDataset res from (0.0001, 0.0001) to (10.0, 10.0)**. This tells us that our vector label data with polygons has now been converted to a binary raster under the hood with a pixel size similar to that of the satellite image we have - Sentinel-2. Something imortant to note also is that the pixels in both layers has been **aligned**.

In [None]:
# Append variables derivable from the bands, such as NDVI

append_ndvi = AppendNDVI(index_nir = 7, index_red = 3)

In [None]:
# Define the sampler that will execute the task of extracting labels

sampler = RandomGeoSampler(dataset = training_data, size = 32, length = 1000)

In [None]:
# Define a customized collate function to append NDVI to the sampled images
def custom_collate_fn(samples):
  for sample in samples:
    sample["image"] = append_ndvi(sample["image"])[0]
  return stack_samples(samples)

In [None]:
# Initialize the dataloader. This is the function that will be serving the role of availing batches of extracted samples for model training
dataloader = DataLoader(
    dataset = training_data,
    batch_size = 50,
    sampler = sampler,
    collate_fn = custom_collate_fn # normally would be stack_samples
)

In [None]:
# Iterate through the dataloader to confirm that it is able to load the data

for batch in dataloader:
  image = batch["image"][:, :12, :, :]
  mask = image[:, -1, :, :]

  print(f"Image batch length: {len(image)}")
  print(f"Mask batch length: {len(mask)}")
  break

In [None]:
# We will use Unet model from torchseg. Which is pretrained so we do not need to build it from scratch
model = torchseg.Unet(
    encoder_name = "resnet18",
    encoder_weights = False,
    in_channels = 11,
    classes = 2
)

In [None]:
# Use cuda if available, otherwise cpu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Define the loss function and optimizer

criterion = nn.CrossEntropyLoss(ignore_index = -1)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [None]:
# Training loop
metrics = {"loss": [], "accuracy": []}
num_epochs = 20

In [None]:
for epoch in range(num_epochs):
  model.train()
  epoch_loss = 0.0
  total_correct = 0
  total_pixels = 0

  with tqdm(dataloader, desc = f"Epoch {epoch + 1} / {num_epochs}") as pbar:
    for batch in pbar:
      images = batch["image"][:, :12, :, :].to(device)
      masks = batch["mask"].to(device)

      # Forward pass
      outputs = model(images)
      loss = criterion(outputs, masks.long())
      epoch_loss += loss.item()

      # Backpropagation
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # Calculate accuracy
      preds = torch.argmax(outputs, dim = 1)
      total_correct += (preds == masks).sum().item()
      total_pixels += masks.numel()

      pbar.set_postfix(loss = loss.item())

  epoch_accuracy = total_correct / total_pixels * 100
  metrics["loss"].append(epoch_loss)
  metrics["accuracy"].append(epoch_accuracy)

  print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")

  # Save training metrics
  with open("/content/training_metrics.json", "w") as f:
    json.dump(metrics, f)

  print("Training metrics saved to '/content/training_metrics.json")


In [None]:
# Load training metrics from JSON file
with open("/content/training_metrics.json", "r") as f:
    metrics = json.load(f)

# Extract loss and accuracy values
loss_values = metrics["loss"]
accuracy_values = metrics["accuracy"]
epochs = range(1, len(loss_values) + 1)

# Create the plots
fig, ax1 = plt.subplots()

# Plot loss
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss", color="tab:red")
ax1.plot(epochs, loss_values, label="Loss", color="tab:red")
ax1.tick_params(axis="y", labelcolor="tab:red")

# Create a second y-axis for accuracy
ax2 = ax1.twinx()
ax2.set_ylabel("Accuracy (%)", color="tab:blue")
ax2.plot(epochs, accuracy_values, label="Accuracy", color="tab:blue")
ax2.tick_params(axis="y", labelcolor="tab:blue")

fig.tight_layout()
plt.title("Training Loss & Accuracy Over Epochs")
plt.show()

In [None]:
# Load the trained model
model.eval()

# Path to your input raster
raster_path = "/content/s2_d.tif"
output_path = "/content/predicted_output.tif"

import torch.nn.functional as F

# Read the raster file
with rasterio.open(raster_path) as src:
    image = src.read()  # Shape: (bands, height, width)
    transform = src.transform
    crs = src.crs
    profile = src.profile

# Get original height and width
orig_h, orig_w = image.shape[1], image.shape[2]

# Compute padding needed
pad_h = (32 - (orig_h % 32)) % 32  # Ensure divisibility by 32
pad_w = (32 - (orig_w % 32)) % 32

# Pad image (bottom, right)
image_padded = np.pad(image, ((0, 0), (0, pad_h), (0, pad_w)), mode='reflect')

# Convert to tensor and move to device
image_tensor = torch.tensor(image_padded, dtype=torch.float32).unsqueeze(0).to(device)  # (1, bands, H, W)

# Run the model
with torch.no_grad():
    output = model(image_tensor)

# Convert predictions to class labels
pred_mask_padded = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()  # (H, W)

# Remove padding to match original shape
pred_mask = pred_mask_padded[:orig_h, :orig_w]

# Save the output raster
profile.update(dtype=rasterio.uint8, count=1, height=orig_h, width=orig_w, nodata = 0)  # Update metadata

with rasterio.open(output_path, "w", **profile) as dst:
    dst.write(pred_mask.astype(rasterio.uint8), 1)

print(f"Prediction saved at {output_path}")


In [None]:
# Path to the predicted output
output_path = "/content/predicted_output.tif"

# Read the predicted raster
with rasterio.open(output_path) as src:
    pred_mask = src.read(1)  # Read the first (and only) band

# Plot the predicted mask
plt.figure(figsize=(8, 6))
plt.imshow(pred_mask, cmap="gray")  # Use "gray" or "viridis" for better contrast
plt.colorbar(label="Class Label")
plt.title("Predicted Segmentation Mask")
plt.axis("off")  # Hide axis labels
plt.show()