<center><img src="https://drive.google.com/uc?export=view&id=1ygAs8EMNlIim2ypwmvQn9yN1LbY3hWHV" alt="Drawing"  width="30%"/><center>

# <center><strong>Data Visualization and Baseline replication</strong></center>
<br/>

<br/><center>This notebook allows you to visualize the data used in the **FLAIR #2 challenge**.<br/>The code bellow works with the toy dataset (subset) provided in the starting-kit alongside this notebook as well as with the full FLAIR-two dataset accessible after registration to the competition.</center> <br/> 
<center>**We also strongly advise you to read the data technical description provided in the datapaper.**</center>
<br/> <br/> 
  

<hr style="height:1.5px;border-width:0;color:red;background-color:red">    

# <font color='red'>PART-1: Data vizualisation with the toy dataset</font>

First, let's import relevant functions from the <font color='#D7881C'><em>data_display.py</em></font> file. 
<br/>

In [None]:
from pathlib import Path
import yaml
from os.path import join
import numpy as np
import sys

# Necessary to load from src
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

from data_display import (display_nomenclature,
                            display_samples, 
                            display_time_serie,
                            display_all_with_semantic_class, 
                            display_all)
from src.load_data import load_data

## <font color='#90c149'>Nomenclatures</font>

<br/><hr>

Next, we display the semantic land-cover classes used in the FLAIR #2 datatset. You will see that <font color='#90c149'>two nomenclatures are available </font> : 
<ul>
    <li>the <strong><font color='#90c149'>full nomenclature</font></strong> corresponds to the semantic classes used by experts in photo-interpretation to label the pixels of the ground-truth images.</li>
    <li>the <font color='#90c149'><b>main (baseline) nomenclature</b></font> is a simplified version of the full nomenclature. It regroups (into the class 'other') classes that are either strongly under-represented or irrelevant to this challenge.</li>
</ul>        
See the associated datapaper for additionnal details on these nomenclatures.<br/><br/>

<font color='#90c149'>Note:</font> in the data exploration part, we employ the full nomenclature. For the second part related to the challenge baseline, the main nomenclature is used. <br/><hr><br/> 

In [None]:
display_nomenclature()

## <font color='#90c149'>Data display</font>

<br/><hr>

We start by creating lists containing the paths to the input images (`images`) and supervision masks (`masks`) files of the dataset.<hr><br/>

In [None]:
config_path = "../flair-2-config.yml" # Change to yours
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

In [None]:
# Creation of the train, val and test dictionnaries with the data file paths
d_train, d_val, d_test = load_data(config)

images = d_train["PATH_IMG"]
labels = d_train["PATH_LABELS"]
sentinel_images = d_train["PATH_SP_DATA"]
sentinel_masks = d_train["PATH_SP_MASKS"] # Cloud masks
sentinel_products = d_train["PATH_SP_DATES"] # Needed to get the dates of the sentinel images
centroids = d_train["SP_COORDS"] # Position of the aerial image in the sentinel super area


### Visu

<br/><hr>

Let's display some random samples of IMG-MSK pairs. <font color='#90c149'>Re-run the cell bellow for a different image.</font> Here we also plot the Sentinel super area, super patch and patch. Even though the last one is not used in practice, it is shown to provide an idea of what the Sentinel data looks like. The red rectangle shows the extent of the RVB image inside the Sentinel image. <hr><br/>

In [None]:
display_samples(images, labels, sentinel_images, centroids)

<br/><hr>
We can also plot a few images from sentinel time series along with the acquisition date. Here we filter the dates with too much cloud coverage.

<hr><br/>

In [None]:
display_time_serie(sentinel_images, sentinel_masks, sentinel_products, nb_samples=3)

<br/><hr>

Next let's have a closer look at some specific semantic class.<br/> By setting `semantic_class` to a class number (*e.g.*, `semantic_class`=1 for building or `semantic_class`=5 for water) we can visualize the images containing pixels of this specific class. (the full nomenclature is be used.)<br/>
<font color='#90c149'>Note:</font> for Colab users, this can take some time. <hr><br/>

In [None]:
display_all_with_semantic_class(images, labels, semantic_class=1)

<br/><hr> 

We can directly display all images (be sure to use the toy dataset!).<br/> <hr><br/>

In [None]:
display_all(images, labels)

<br><br>
<hr style="height:3px;border-width:0;color:red;background-color:red">   

# <center><font color='red'>PART-2: Baseline </font></center>

<br/><hr>

In this second part, we use the toy dataset to train a model similar to the FLAIR #2 baseline provided with the challenge.<br/> 
<font color='#90c149'>Note:</font> the presented pipeline can also be applied to the full dataset.

First, let's check if GPU ressources are available in our execution environment. If not, make sure to set `accelerator = 'cpu'` in the parameters.
<hr><br/>

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0: print('No GPU found.')
else: print(gpu_info)

<br/><hr>

The cell bellow imports the required libraries, classes and functions, including those provided in the <font color='#D7881C'><em>src</em></font> folder provided with this starting-kit. If you are running this notebook on a local environment, make sure all necessary libraries are installed (refer to the <font color='red'>README.md</font> file).

This baseline relies on <font color='#90c149'><em>pytorch-lightning</em></font>, a high-level python framework built on top of Pytorch. It allows multi-GPU training, significantly speeding-up computation of the baseline on the full FLAIR #2 dataset. It is however also possible to train on a single GPU as we demonstrate in this notebook.

In this notebook, we also take advantage of the <font color='#90c149'><em>segmentation-models-pytorch</em></font> library, which provides a variery of different pre-trained segmentation models (*e.g.*, U-Net, PSPNet,...).
<hr><br/>

In [None]:
import os
from pathlib import Path 
import sys

import torch
import torch.nn as nn

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.distributed import rank_zero_only 

# Necessary to load from src
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

from src.backbones.txt_model import TimeTexture_flair
from src.datamodule import DataModule
from src.task_module import SegmentationTask
from src.utils_prints import print_config, print_metrics
from src.utils_dataset import read_config
from src.load_data import load_data
from src.prediction_writer import PredictionWriter
#from src.metrics import generate_miou

## <font color='#90c149'>Task and parameters</font>

<br/><hr>

The toy dataset is composed of $26$ aerial patches of $512x512$ with corresponding semantic masks, and $26$ Sentinel-2 patches. It has $12$ test patches. The full FLAIR #2 dataset contains $61,712$ aerial patches and 41,029 Sentinel-2 acquisitions as training set, and $16,050$ aerial and $10,215$ satellite testing patches.<br/><br/>

The next cell loads the configuration file, which defines <font color='#90c149'>the paths and hyper-parameters</font>. 


We recommand starting with the given default values and test if everything is working (check the datapaper for the baseline hyper-parameters).
<hr><br/>

In [None]:
config_path = "../flair-2-config-notebooks.yml" # Change to yours
config =  read_config(config_path)

print_config(config)

## <font color='#90c149'>Dataloaders</font>

<br/><hr>

The following cell loads the data into a pytorch-lighning DataModule. It takes the dictionaries containing the data paths and the configuration file as input. 

We fix the global seed (python random, torch, numpy) with `seed_eveything`.

<hr><br/>

In [None]:
out_dir = Path(config["out_folder"], config["out_model_name"])
out_dir.mkdir(parents=True, exist_ok=True)
seed_everything(2022, workers=True)

dict_train, dict_val, dict_test = load_data(config)

# Augmentation
if config["use_augmentation"] == True:
    transform_set = A.Compose([A.VerticalFlip(p=0.5),
                               A.HorizontalFlip(p=0.5),
                               A.RandomRotate90(p=0.5)])
else:
    transform_set = None   

dm = DataModule(
    dict_train = dict_train,
    dict_val = dict_val,
    dict_test = dict_test,
    config=config,
    drop_last = True)

## <font color='#90c149'>Learning setup</font>

<br/><hr>

Next, we define our <font color='#90c149'>model, criterion, optimizer and callbacks</font>.

The model `U-T&T` has `two` branches to extract spatial and temporal information from the very high resolution aerial images and high resolution satellite images. The two architecture which constitute the two branches embedded in the U-T&T model are: 
- `U-Net` (spatial/texture branch): for the aerial imagery patches, a U-Net architecture is adopted. The encoder is a ResNet34 backbone model  which weights are
pre-trained on ImageNet for a total of ≈ 24.4 M parameters. This is similar to the architecture used for the FLAIR #1 baselines.
- `U-TAE` (spatio-temporal branch): the spatial and temporal information supplied by the Sentinel-2 time series is explored with a U-TAE architecture. This U-Net based architecture includes a Temporal self-Attention Encoder (TAE) taking as input the lowest resolution features of the convolutional encoder and yielding a set of temporal attention masks further applied to all resolutions upon decoding.

The architecture also encompass a fusion module, which takes as input the U-TAE embedding (last feature maps of the U-TAE decoder) and is applied to each stage of the U-
Net branch. See the datapaper for more details on the fusion method.


If `use_metadata = True`, it adds a custom Multi-layer Perceptron to the U-Net, encoding the metadata.

As criterion, we use two `Cross Entropy` losses, one for each branch. They are summed to get the final loss. Each criterion can be initialized with differents weights for the classes to give more or less importance to particular classes.

The pytorch-lighning module `SegmentationTask` organizes and manages the different loops and steps (e.g., training, validation), otherwise manually implemented using torch.

Finally we define `callbacks` (save model checkpoints, stop if learning is stuck with a patience threshold and display progress) as well as a `logger` (tensorboard logs).

<hr><br/>

### Model
<font color='#90c149'>Note:</font> the next cell will trigger the download of ResNet34 (default for U-Net architecture in pytorch-lightning) with pre-trained weights.

In [None]:
model = TimeTexture_flair(config)

### Loss

In [None]:
with torch.no_grad():
    weights_aer = torch.FloatTensor(np.array(list(config['weights_aerial_satellite'].values()))[:,0])
    weights_sat = torch.FloatTensor(np.array(list(config['weights_aerial_satellite'].values()))[:,1])
criterion_vhr = nn.CrossEntropyLoss(weight=weights_aer)
criterion_hr = nn.CrossEntropyLoss(weight=weights_sat)

### Optimizer

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

### Pytorch lightning module

In [None]:
seg_module = SegmentationTask(
    model=model,
    num_classes=config["num_classes"],
    criterion=nn.ModuleList([criterion_vhr, criterion_hr]),
    optimizer=optimizer,
    config=config
)

### Callbacks

In [None]:
# Callbacks

ckpt_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(out_dir,"checkpoints"),
    filename="ckpt-{epoch:02d}-{val_loss:.2f}"+'_'+config["out_model_name"],
    save_top_k=1,
    mode="min",
    save_weights_only=True, # can be changed accordingly
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=30, # if no improvement after 30 epoch, stop learning. 
    mode="min",
)

prog_rate = TQDMProgressBar(refresh_rate=config["progress_rate"])

callbacks = [
    ckpt_callback, 
    early_stop_callback,
    prog_rate,
]


### Loggers

In [None]:
logger = TensorBoardLogger(
    save_dir=out_dir,
    name=Path("tensorboard_logs"+'_'+config['out_model_name']).as_posix()
)

loggers = [
    logger
]

## <font color='#90c149'>Launch the training</font>

<br/><hr>

Defining a `Trainer` allows to automate tasks, such as enabling/disabling grads, running the dataloaders or invoking the callbacks when needed.
<hr><br/>

In [None]:
#### instanciation of  Trainer
trainer = Trainer(
        accelerator=config["accelerator"],
        devices=config["gpus_per_node"],
        strategy=config["strategy"],
        num_nodes=config["num_nodes"],
        max_epochs=config["num_epochs"],
        num_sanity_val_steps=0,
        callbacks = callbacks,
        logger=loggers,
        enable_progress_bar = config["enable_progress_bar"],
    )

<br/><hr>

<font color='#90c149'>Let's launch the training.</font>
<br/><hr>

In [None]:
trainer.fit(seg_module, datamodule=dm)

## <font color='#90c149'>Check metrics on the validation dataset</font>

<br/><hr> 

To give an idea on the training results, we call validate on the trainer to print some metrics. <hr><br/>

In [None]:
trainer.validate(seg_module, datamodule=dm)

## <font color='#90c149'>Inference and predictions export</font>

<br/><hr>

For inference, we define a new callback, `PredictionWriter`, which is used to export the predictions on the test dataset.<br/><br/>
<font color='#90c149'>Note:</font> the callback exports the files with the mandotary formatting of outputs (files named <font color='red'><b> PRED_{ID].tif</b></font>, with datatype <font color='red'><b>uint8</b></font> and <font color='red'><b>LZW</b></font> compression), using Pillow.
Check the <font color='#D7881C'><em>writer.py</em></font> file for details.<br/><br/>

We instantiate a new `Trainer` with this newly defined callback and call predict.
<hr><br/>

In [None]:
# Predict
writer_callback = PredictionWriter(        
    output_dir = os.path.join(out_dir, "predictions"+"_"+config["out_model_name"]),
    write_interval = "batch",
)

#### instanciation of prediction Trainer
trainer = Trainer(
    accelerator = config["accelerator"],
    devices = config["gpus_per_node"],
    strategy = config["strategy"],
    num_nodes = config["num_nodes"],
    callbacks = [writer_callback],
    enable_progress_bar = config["enable_progress_bar"],
)

In [None]:
trainer.predict(seg_module, datamodule=dm)

@rank_zero_only
def print_finish():
    print('--  [FINISHED.]  --', f'output dir : {out_dir}', sep='\n')
print_finish()

## <font color='#90c149'>Visual checking of predictions</font>

<br/><hr>

<font color='#90c149'>For the test set, obviously, you do not have access to the masks.</font> Nevertheless, we can visually display some predictions alongside the RGB images.<br/><br/>

First, we create lists containing the paths to the test RGB images (`images_test`) as well as the predicted semantic segmentation masks (`predictions`).<br/><br/>



We then display some random couples of predictions together with their corresponding aerial RGB images.<br/><br/>

<font color='#90c149'><em>Note 1</em></font>: if you are using the toy dataset, don't expect accurate predictions. A set of $200$ training samples will give limited results.<br/> 
<font color='#90c149'><em>Note 2</em></font>: rasterio will yield a <em>NotGeoreferencedWarning</em> regarding the predictions files. This is normal as the prediction files have been written without any geographical information, which is expected by rasterio. This kind of information is not important for assessing the model outputs, so we can just omit the warning.
<hr><br/>

In [None]:
from data_display import display_predictions, get_data_paths

images_test = dict_test["PATH_IMG"]
predictions = sorted(list(get_data_paths(Path(os.path.join(out_dir, "predictions"+"_"+config["out_model_name"])), 'PRED*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))

In [None]:
display_predictions(images_test, predictions, nb_samples=2)

## <font color='#90c149'>Metric calculation: mIoU</font>

<br/><hr>

As mentioned before, the masks of the test set are not available. However, the following cell describes the code that is used to calculate the metric used over the test set and to consequently rank the best models. Again, the toy dataset contains $50$ test pastches, while the full FLAIR-two dataset contains $16,050$ test patches.<br/><br/>

The calculation of the mean Intersection-over-Union (`mIou`) is based on the confusion matrix $C$, which is determined for each test patch. The confusion matrices are subsequently summed providing the confusion matrix describing the test set. Per-class IoU, defined as the ratio between true positives divided by the sum of false positives, false negatives and true positives is calculated from the summed confusion matrix as follows: <br/><br/>
    $$
    IoU_i = \frac{C_{i,i}}
    {C_{i,i} + \sum_{j \neq i}\left(C_{i,j} + C_{j,i} \right)} = \frac{TP}{TP+FP+FN}
    $$
<br>
The final `mIou` is then the average of the per-class IoUs. 


<font color='#90c149'><em>Note:</em></font> as the <font color='#90c149'><em>'other'</em></font> class is <font color='#90c149'>not well defined (void)</font>, its IoU is <font color='#90c149'>removed</font> and therefore does not contribute to the calculation of the `mIou`. In other words,  the remaining per-class IoUs (all except 'other') are averaged by 12 and not 13 to obtain the final `mIou`.</font>

<hr><br/>

In [None]:
#truth_msk = config['data']['path_labels_test']
#pred_msk  = os.path.join(out_dir, "predictions"+"_"+config["out_model_name"])
#mIou, ious = generate_miou(truth_msk, pred_msk)
#print_metrics(mIou, ious)

<br/><br/><br/><br/>

### <center><strong>For any feedback, request, suggestion or simply to say hi, we are reachable at : ai-challenge@ign.fr !</strong></center>
<br/>
<font size=2.5> <b>@IGN, Mai 2023</b></font>
<img src="https://drive.google.com/uc?export=view&id=14clxUsTGj7i6oXt6q9FQeaxzjIi3biI2" alt="Drawing"  width="100%"/>