## Libraries

In [29]:
# Standard library imports
import os
import sys
from pathlib import Path

# Third-party imports
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import numpy as np
import buteo as beo
from tqdm import tqdm

# Local imports
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from utils.constants import DATA_FOLDER

## Model Definition

The model will take as input an image containing the 9 spectral bands from Sentinel-2 (normalised and split up into patches of 32x32). The output will be the building density of the pixel ranging between 0-100, which means we are doing a regression task as we are trying to predict specific values.

In this architecture, there is an encoder block where the number of channels is increased to 128. Because of the padding, the resolution of the output channels is unchanged. In the decoder block, there are extra convolutions that bring the number of channels back to a single one that contains the building density.

To help the model, we force the values to lie between the range of 0-100, since any other values will never be correct.

In [23]:
class SimpleConvNet(nn.Module):
  """
  A simple convolutional neural network that encodes input images
  and decodes them back to a single-channel output image.
  """
  def __init__(self, input_channels: int, output_min: float, output_max: float) -> None:
    super(SimpleConvNet, self).__init__()
    self.output_min = output_min
    self.output_max = output_max

    # An encoder without a bottleneck
    self.encoder = nn.Sequential(
      nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.BatchNorm2d(64),
      nn.Conv2d(64, 128, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.BatchNorm2d(128),
    )

    # Simple decoder
    self.decoder = nn.Sequential(
      nn.Conv2d(128, 64, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.BatchNorm2d(64),
      nn.Conv2d(64, 1, kernel_size=3, padding=1),
    )
  
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass through the network.
    """
    x = self.encoder(x)
    x = self.decoder(x)


    # Clamp the output values to be within [output_min, output_max]
    x = torch.clamp(x, self.output_min, self.output_max)
    return x

## Initalise the Model

In [24]:
input_channels = 9 # Sentinel 2 initially.

# Initialise the model
model = SimpleConvNet(input_channels, 0.0, 100.0) # Since we know the labels will always be [0.0, 100.0]

# Constants for the model
EPOCHS = 10
BATCH_SIZE = 16
LEARNING_RATE = 0.001

In [25]:
class NumpyDataset(Dataset):
  """
  A simple dataset class that loads numpy arrays.
  """
  def __init__(self, x_train: np.ndarray, y_train: np.ndarray, data_is_channel_last: bool = False) -> None:
    if data_is_channel_last:
      x_train = beo.channel_last_to_first(x_train)
      y_train = beo.channel_last_to_first(y_train)
    
    self.x_train = torch.from_numpy(x_train).float()
    self.y_train = torch.from_numpy(y_train).float()
  
  def __len__(self) -> int:
    """
    Returns the number of samples in the dataset.
    """
    return len(self.x_train)
  
  def __getitem__(self, idx: int) -> tuple:
    """
    Returns a single sample from the dataset.
    """
    x = self.x_train[idx]
    y = self.y_train[idx]
    
    return x, y



## Load the Data

In [26]:
x_train = np.load(os.path.join(DATA_FOLDER, 'train.npz'))['x_s2'] # Initially we only load the S2 data
y_train = np.load(os.path.join(DATA_FOLDER, 'train.npz'))['y']

# Prepare the data for pytorch
def callback(x: np.ndarray, y: np.ndarray) -> tuple:
    """
    Callback function to prepare the data for PyTorch.
    """
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    return x, y

### Create Dataset and DataLoader

In [None]:
dataset = NumpyDataset(x_train, y_train, data_is_channel_last=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)

In [None]:
import torch
print(torch.backends.mps.is_available())  # Check if MPS is available
print(torch.backends.mps.is_built())      # Check if PyTorch was built with MPS support

## Train the Model

In [28]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimiser = Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
for epoch in range(EPOCHS):
  running_loss = 0.0

  # Initialise the progress bar for training
  train_pbar = tqdm(dataLoader, total=len(dataLoader), ncols=120)

  for i, (inputs, targets) in enumerate(train_pbar):
    # Move inputs and targets to the device
    inputs, targets = inputs.to(device), targets.to(device)

    # Zero the gradients
    optimiser.zero_grad()

    # Forward pass
    outputs = model(inputs)

    # Compute the loss
    loss = criterion(outputs, targets)

    # Backward pass and optimization
    loss.backward()
    optimiser.step()

    # Print the statistics
    current_loss = loss.item()
    running_loss += current_loss
    mean_loss = running_loss / (i + 1)

    # Update the progress bar
    train_pbar.set_description(f"Epoch [{epoch + 1}/{EPOCHS}]")
    print_dict = { 'loss': f'{mean_loss:4f}'}
    train_pbar.set_postfix(print_dict)

Epoch [1/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:56<00:00,  5.52it/s, loss=184.975243]
Epoch [2/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:50<00:00,  5.62it/s, loss=157.885668]
Epoch [3/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:41<00:00,  5.76it/s, loss=146.522431]
Epoch [4/10]: 100%|████████████████████████████████████████████████| 1966/1966 [04:55<00:00,  6.65it/s, loss=137.826457]
Epoch [5/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:00<00:00,  6.55it/s, loss=133.220347]
Epoch [6/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:17<00:00,  6.20it/s, loss=128.514122]
Epoch [7/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:14<00:00,  6.24it/s, loss=126.024043]
Epoch [8/10]: 100%|████████████████████████████████████████████████| 1966/1966 [05:17<00:00,  6.18it/s, loss=123.199071]
Epoch [9/10]: 100%|█████████████

In [None]:
# Save the model 
model_folder = '../models'

# Ensure the output directory exists
Path(model_folder).mkdir(parents=True, exist_ok=True)

torch.save(model.state_dict(), os.path.join(model_folder, 'model_01.pth'))
del dataset, dataloader