# Tree Species Classification in Orthoimages of Brandenburg

In [None]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm_notebook as tqdm
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import math
import matplotlib.pyplot as plt
import rasterio
import sys
from pathlib import Path
from typing import List, Dict
import re

import torch
from torch import permute, nan_to_num
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import torchvision
from torchvision.transforms import ToTensor, Resize
import torchvision.transforms as transforms
import torchvision.models as models

from trees_brandenburg import preprocess, plotting
from trees_brandenburg.modelling import train, inference
from trees_brandenburg.external import transfer

## Data Overview

The goal of this notebook is develop a CNN from which can classify different tree species in orthoimages. The data used are orthoimages from Brandenburg from different years and seasons. For more detail about the images, see [here](https://geobroker.geobasis-bb.de/gbss.php?MODE=GetProductInformation&PRODUCTID=253b7d3d-6b42-47dc-b127-682de078b7ae).
All images with a resolution of 20cm were downloaded and further processed into final image tiles of 100px by 100px.

The table below shows the number of training samples per tree species. The most dominant species by is the *Pinus sylvestris*, followed by *Alnus rubra* and *Quercus robur*. The dataset contains 33 different species. In a first effort however, we're only trying to clssify the five most prominent ones.

To make our lives easier, the raw data (`data/raw`) is renamed and moved into a different folder (`data/processed`).

| **Species** | **Number of training samples** |
|:-----------:|--------------------------------|
| GKI         | 201633                         |
| RER         | 5876                           |
| SEI         | 2834                           |
| GBI         | 2482                           |
| TEI         | 2169                           |
| GDG         | 1893                           |
| RBU         | 1741                           |
| ELA         | 1524                           |
| REI         | 1305                           |
| GFI         | 1186                           |
| PAS         | 599                            |
| RO          | 435                            |
| EI          | 315                            |
| BPA         | 217                            |
| BAH         | 191                            |
| JLA         | 169                            |
| WKI         | 145                            |
| WEB         | 114                            |
| SKI         | 101                            |
| WLI         | 86                             |
| KTA         | 78                             |
| SAH         | 71                             |
| AS          | 59                             |
| HBU         | 58                             |
| WLS         | 34                             |
| HLS         | 31                             |
| WER         | 26                             |
| GES         | 25                             |
| STK         | 22                             |
| BFI         | 17                             |
| SFI         | 13                             |
| EIS         | 13                             |
| HPA         | 12                             |

In [None]:
class_subset: List[str] = ["GKI", "RER", "SEI", "GBI", "TEI"]
src: Path = Path("../data/raw")
subset: Path = Path("../data/processed/imgs")
preprocess.generate_subset(src, subset, class_subset)

metadata = pd.read_csv("../data/processed/sensing-dates.csv")
metadata.query("month >= 4 and month <= 10")

data = preprocess.generate_data_overview(subset)
data.head()

Next, we exclude all images that were not taken between april and october. We can see that quite a substantial amount of images is removed this way.

In [None]:
print(f"Rows before filtering: {data.shape[0]}")
data = preprocess.filter_image_df(data, metadata.tile)
print(f"Rows after filtering: {data.shape[0]}")
data.head()

While this further reduces training data for now, we set apart 5% of the available data for testing at the end. When more training data is available, this should be increased

In [None]:
hold_oud: pd.DataFrame = data.sample(frac=0.05)
data = data.drop(hold_oud.index)
hold_out = hold_oud.reset_index()
data = data.reset_index()
print("Final data:")
print(f"\t{data.shape[0]} images can be used for training/fine tuning")
print(f"\t{hold_out.shape[0]} will be used for final validation")

## Setup Trainig

Define key-parameters for the training of the deep learning model. This is required early in the code as for example batch size is needed to define the DataLoader in the correct way.


In [None]:
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size: int = 1024
validation_split: float = .3
shuffle_dataset: bool = True
random_seed: int = 42
n_epochs: int = 20

Prepare the data split and create data loaders that will allow you to load the data into the model training process.

In [None]:
dataset_size: int = len(data)
indices: List[int] = list(range(dataset_size))
split: int = int(math.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
training_sampler, validation_sampler = SubsetRandomSampler(indices[split:]), SubsetRandomSampler(indices[:split])

Define data transformations and create a Pytorch dataset class instance. Do the transforms make sense when used for the fine-tuning procedure below? Actually not sure...

In [None]:
transform = transforms.Compose([transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = train.TreeSpeciesClassificationDataset(data, transform)<

Create a Dataloader for the training dataset.

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=training_sampler, pin_memory=True if "cuda" in device.type else False, num_workers=10)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=validation_sampler, pin_memory=True if "cuda" in device.type else False, num_workers=10)

We do some plotting

In [None]:
reverse_labels: Dict[int, str] = data[["encoded_labels", "labels"]].drop_duplicates(ignore_index=True).to_dict()["labels"]
fig, ax = plotting.plot_images(train_loader, reverse_labels, figsize=(15, 10))

Instantiate the model and set up our optimizer. Here, we also adjust the weights (not fully understood by me) to work with imbalanced classes.

In [None]:
cnn_model = train.CNN()

num_samples = np.empty((len(class_subset),))
num_samples[data.encoded_labels.value_counts().index.values] = data.encoded_labels.value_counts().values
weights = 1.0 / num_samples
normalized_weights = weights / np.sum(weights) # norm to 1

criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(normalized_weights).float())
optimizer = optim.Adam(cnn_model.parameters(), lr=0.0001)

In [None]:
print_every = 10
valid_loss_min = np.inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_loader)

FINAL_MODEL_PATH: Path = Path("../data/processed") / 'model_scripted.pt'

for epoch in range(1, n_epochs+1):
    running_loss = 0.0
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')
    for batch_idx, (data_, target_) in enumerate(train_loader):
        optimizer.zero_grad() # zero the parameter gradients
        # forward + backward + optimize
        outputs = cnn_model(data_)
        loss = criterion(outputs, target_.long())
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        correct += train.accuracy(outputs, target_)
        total += target_.size(0)
        if (batch_idx) % 20 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'\ntrain loss: {np.mean(train_loss):.4f}, train acc: {(100 * correct / total):.4f}')
    batch_loss = 0
    total_t=0
    correct_t=0
    cnn_model.eval()
    with torch.inference_mode():
        for data_t, target_t in validation_loader:
            outputs_t = cnn_model(data_t)
            loss_t = criterion(outputs_t, target_t.long())
            batch_loss += loss_t.item()
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
        val_acc.append(100 * correct_t / total_t)
        val_loss.append(batch_loss/len(validation_loader))
        network_learned = batch_loss < valid_loss_min  # FIXME is this correct?
        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t / total_t):.4f}\n')
        # Saving the best weight 
        if network_learned:
            valid_loss_min = batch_loss
            torch.save(cnn_model.state_dict(), Path("../data/processed") / 'model_classification_tutorial.pt')
            print('Detected network improvement, saving current model')
    cnn_model.train()

model_scripted = torch.jit.script(cnn_model)
model_scripted.save(FINAL_MODEL_PATH)

## Post-Training

After we're done with the model trainin, let us plot both the training/validation loss and accuracy.

In [None]:
fig = plt.figure(figsize=(20,10))
plt.title("Train - Validation Loss")
plt.plot( train_loss, label='train')
plt.plot( val_loss, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('loss', fontsize=12)
plt.legend(loc='best')

In [None]:
fig = plt.figure(figsize=(20,10))
plt.title("Train - Validation Accuracy")
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('accuracy', fontsize=12)
plt.legend(loc='best')

### Plotting Validation

Why exactly are we plotting the validation data here? Whatever...

In [None]:
model = torch.jit.load(FINAL_MODEL_PATH)
model.eval()

fig, ax = plotting.plot_validation(train_loader, model, reverse_labels, figsize=(15, 10))

### Some first thoughts on the results

The model trained above is a lot f things, but not good. The drastically decreasing validation accuracy can likely be contributed to overfitting[^1]. The training data is not that large for a deep learning problem set in addition to the data not being *clean*. For comparison, the MNIST dataset consists of 60000 training data and 10000 validation data for a much more simple problem set.

Multiple pathways may lead to improvements:

1. Simpler model architecture. The model would be able to learn less but could also be less prone to overfitting
1. Stronger regularization would prevent overfitting
1. Cleaner data
1. More data
1. Using apre-trained model



Even though it's listed last above, I want set up a baseline (i.e. *learn* and predict the majority class at all times) and try out an already pre-trained model.

## Base Line Prediction

**TODO:** The accuracy metric is dependent on class distribution and thus problematic! Nonetheless, it's used as an accuracy metric here.

[^1]: https://datascience.stackexchange.com/questions/47720/validation-loss-increases-and-validation-accuracy-decreases

In [None]:
baseline_accuracy = np.sum(data["encoded_labels"] == num_samples.argmax()) / data.shape[0]
print(baseline_accuracy)

## Pre-trained model

Pytorch offers a plathora of pre-trained models. There won't be an exhaustive search for which model performs best on the given dataset here and I simply use a ResNet.

A tutorial is provided [here](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#finetuning-the-convnet). Note, that I don't employ a exponential learning rate decay and use the CrossEntropy and Adam-optimizer in accordance with their original instantiations above.

> Something's not quite working. The GPU-utilization is rather low which indicates a bottleneck *somewhere*
> 
> either in the model itself (rather unlikely)
> or in the approach to load the data (more likely).

In [None]:
res_weights = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
inference_preprocess = models.ResNet18_Weights.DEFAULT.transforms()

original_features = res_weights.fc.in_features
res_weights.fc = nn.Linear(original_features, len(num_samples))
res_weights = res_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(normalized_weights).float().to(device))
optimizer = optim.Adam(res_weights.parameters(), lr=0.0001)

dataloaders = {
    "train": train_loader,
    "val": validation_loader
}

dataset_sizes = {
    "train": len(training_sampler),
    "val": len(validation_sampler)
}

model_ft = transfer.train_model(res_weights, dataloaders, dataset_sizes, criterion, optimizer, device, None, num_epochs=3)

model_scripted = torch.jit.script(model_ft)
model_scripted.save(Path("../models") / "fine-tunned-resnet18.pt")


In [None]:
# TODO access class predictions differently (see blow or `torch.max(model_ft(images), 1).indices == labels`)
model_ft = torch.jit.load(Path("../models") / "fine-tunned-resnet18.pt")
#fig, ax = plotting.plot_ft_validation(validation_loader, model_ft, reverse_labels, figsize=(15, 10))

How does the newly-trained model perform with completely unseen data?
Note, that the original transformations of the ResNet are applied here.

In [None]:
hold_out_dataset = train.TreeSpeciesClassificationDataset(hold_out, inference_preprocess)
hold_out_loader = torch.utils.data.DataLoader(hold_out_dataset, batch_size=64, pin_memory=True if "cuda" in device.type else False, num_workers=3)

In [None]:
number_correct = torch.tensor(0).to(device)
for images, labels in hold_out_loader:
    with torch.inference_mode():
        images, labels = images.cuda(), labels.cuda()
        number_correct += torch.sum(torch.argmax(model_ft(images), 1) == labels)

print(number_correct / len(hold_out_dataset))

Hm, ok so with just two epochs of fine-tuning, we achieven an overall accuracy of 27% - drastically lower compared to training accuracy!

### Tuning the Hyperparameters

...

### Predicting into the area

...

- other interesting idea: include nir channel and ditch green or blue channel!
- What about weighted frequency or better yet, a frequency independent accuracy metric?!