In [1]:
#general 
import os
import numpy as np
import json
import random
from pathlib import Path 

#deep learning
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning import Trainer, seed_everything
try:
  from pytorch_lightning.utilities.distributed import rank_zero_only
except ImportError:
  from pytorch_lightning.utilities.rank_zero import rank_zero_only  

import albumentations as A

#flair-one baseline modules 
from py_module.utils import load_data, subset_debug
from py_module.datamodule import OCS_DataModule
from py_module.model import SMP_Unet_meta
from py_module.task_module import SegmentationTask
from py_module.writer import PredictionWriter

##############################################################################################
# paths and naming
path_data = "./toy_dataset_flair-one/" # toy (or full) dataset folder
path_metadata_file = "./metadata/flair-one_TOY_metadata.json" # json file containing the metadata

out_folder = "/content/gdrive/MyDrive/models_output/" # output directory for logs and predictions.
out_model_name = "flair-one-baseline_argu" # to keep track
##############################################################################################

##############################################################################################
# tasking
use_weights = True 
class_weights = [1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0]

use_metadata = False
use_augmentation = True
##############################################################################################

##############################################################################################
# training hyper-parameters
batch_size = 8
learning_rate = 0.02
num_epochs = 20
##############################################################################################

##############################################################################################
# computational ressources
accelerator = 'gpu' # set to 'cpu' if GPU not available
gpus_per_node = 1 # set to 1 if mono-GPU
num_nodes = 1 # set to 1 if mono-GPU
strategy = None # Put this parameter to None if train on only one GPU or on CPUs. If multiple GPU, set to 'ddp'
num_workers = 0
##############################################################################################

##############################################################################################
# display
enable_progress_bar = True
progress_rate = 10 #tqdm update rate during training 
##############################################################################################

out_dir = Path(out_folder, out_model_name)
out_dir.mkdir(parents=True, exist_ok=True)

seed_everything(2022, workers=True)

@rank_zero_only
def step_loading(path_data, path_metadata_file: str, use_metadata: bool) -> dict:
    print('+'+'-'*29+'+', '   LOADING DATA   ', '+'+'-'*29+'+')
    train, val, test = load_data(path_data, path_metadata_file, use_metadata=use_metadata)
    return train, val, test


@rank_zero_only
def print_recap():
    print('\n+'+'='*80+'+',f"{'Model name: '+out_model_name : ^80}", '+'+'='*80+'+', f"{'[---TASKING---]'}", sep='\n')
    for info, val in zip(["use weights", "use metadata", "use augmentation"], [use_weights, use_metadata, use_augmentation]): print(f"- {info:25s}: {'':3s}{val}")
    print('\n+'+'-'*80+'+', f"{'[---DATA SPLIT---]'}", sep='\n')
    for split_name, d in zip(["train", "val", "test"], [dict_train, dict_val, dict_test]): print(f"- {split_name:25s}: {'':3s}{len(d['IMG'])} samples")
    print('\n+'+'-'*80+'+', f"{'[---HYPER-PARAMETERS---]'}", sep='\n')
    for info, val in zip(["batch size", "learning rate", "epochs", "nodes", "GPU per nodes", "accelerator", "workers"], [batch_size, learning_rate, num_epochs, num_nodes, gpus_per_node, accelerator, num_workers]): print(f"- {info:25s}: {'':3s}{val}")        
    print('\n+'+'-'*80+'+', '\n')

dict_train, dict_val, dict_test = step_loading(path_data, path_metadata_file, use_metadata=use_metadata)  
print_recap()



if use_augmentation == True:
    transform_set = A.Compose([ 
                                A.VerticalFlip(p=0.5),
                                A.HorizontalFlip(p=0.5),
                                A.RandomRotate90(p=0.5)])
else:
    transform_set = None

dm = OCS_DataModule(
    dict_train=dict_train,
    dict_val=dict_val,
    dict_test=dict_test,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=True,
    num_classes=13,
    num_channels=5,
    use_metadata=use_metadata,
    use_augmentations=transform_set)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 2022
  rank_zero_deprecation(


+-----------------------------+    LOADING DATA    +-----------------------------+

                      Model name: flair-one-baseline_argu                       
[---TASKING---]
- use weights              :    True
- use metadata             :    False
- use augmentation         :    True

+--------------------------------------------------------------------------------+
[---DATA SPLIT---]
- train                    :    50697 samples
- val                      :    11015 samples
- test                     :    15700 samples

+--------------------------------------------------------------------------------+
[---HYPER-PARAMETERS---]
- batch size               :    8
- learning rate            :    0.02
- epochs                   :    20
- nodes                    :    1
- GPU per nodes            :    1
- accelerator              :    gpu
- workers                  :    0

+--------------------------------------------------------------------------------+ 



### Model
<font color='#90c149'>Note:</font> the next cell will trigger the download of ResNet34 (default for U-Net architecture in pytorch-lightning) with pre-trained weights.

In [2]:
model = SMP_Unet_meta(n_channels=5, n_classes=13, use_metadata=use_metadata)

### Loss

In [3]:
if use_weights == True:
    with torch.no_grad():
        class_weights = torch.FloatTensor(class_weights)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
else:
    criterion = nn.CrossEntropyLoss()

### Optimizer

In [4]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

### Scheduler

In [5]:
scheduler = ReduceLROnPlateau(
    optimizer=optimizer,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=4,
    min_lr=1e-7,
)

### Pytorch lightning module

In [6]:
seg_module = SegmentationTask(
    model=model,
    num_classes=dm.num_classes,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    use_metadata=use_metadata
)

### Callbacks

In [7]:
ckpt_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(out_dir,"checkpoints"),
    filename="ckpt-{epoch:02d}-{val_loss:.2f}"+'_'+out_model_name,
    save_top_k=1,
    mode="min",
    save_weights_only=True, # can be changed accordingly
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=30, # if no improvement after 30 epoch, stop learning. 
    mode="min",
)

prog_rate = TQDMProgressBar(refresh_rate=progress_rate)

callbacks = [
    ckpt_callback, 
    early_stop_callback,
    prog_rate,
]

### Loggers

In [8]:
logger = TensorBoardLogger(
    save_dir=out_dir,
    name=Path("tensorboard_logs"+'_'+out_model_name).as_posix()
)

loggers = [
    logger
]

## <font color='#90c149'>Launch the training</font>

<br/><hr>

Defining a `Trainer` allows for to automate tasks, such as enabling/disabling grads, running the dataloaders or invoking the callbacks when needed.
<hr><br/>

In [9]:
#### instanciation of  Trainer
trainer = Trainer(
    accelerator=accelerator,
    devices=gpus_per_node,
    strategy=strategy,
    num_nodes=num_nodes,
    max_epochs=num_epochs,
    num_sanity_val_steps=0,
    callbacks = callbacks,
    logger=loggers,
    enable_progress_bar = enable_progress_bar,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


<br/><hr>

<font color='#90c149'>Let's launch the training.</font>
<br/><hr>

In [10]:
trainer.fit(seg_module, datamodule=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                   | Params
---------------------------------------------------------
0 | model         | SMP_Unet_meta          | 24.4 M
1 | criterion     | CrossEntropyLoss       | 0     
2 | train_metrics | MulticlassJaccardIndex | 0     
3 | val_metrics   | MulticlassJaccardIndex | 0     
4 | train_loss    | MeanMetric             | 0     
5 | val_loss      | MeanMetric             | 0     
---------------------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.778    Total estimated mode

Epoch 0:   8%|▊         | 650/7713 [05:43<1:02:13,  1.89it/s, loss=1.31, v_num=1]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## <font color='#90c149'>Check metrics on the validation dataset</font>

<br/><hr> 

To give an idea on the training results, we call validate on the trainer to print some metrics. <hr><br/>

In [28]:
trainer.validate(seg_module, datamodule=dm)

NameError: name 'trainer' is not defined

## <font color='#90c149'>Inference and predictions export</font>

<br/><hr>

For inference, we define a new callback, `PredictionWriter`, which is used to export the predictions on the test dataset.<br/><br/>
<font color='#90c149'>Note:</font> the callback exports the files with the mandotary formatting of outputs (files named <font color='red'><b> PRED_{ID].tif</b></font>, with datatype <font color='red'><b>uint8</b></font> and <font color='red'><b>LZW</b></font> compression), using Pillow.
Check the <font color='#D7881C'><em>writer.py</em></font> file for details.<br/><br/>

We instantiate a new `Trainer` with this newly defined callback and call predict.
<hr><br/>

In [29]:
writer_callback = PredictionWriter(        
    output_dir=os.path.join(out_dir, "predictions"+"_"+out_model_name),
    write_interval="batch",
)

#### instanciation of prediction Trainer
trainer = Trainer(
    accelerator=accelerator,
    devices=gpus_per_node,
    strategy=strategy,
    num_nodes=num_nodes,
    callbacks = [writer_callback],
    enable_progress_bar = enable_progress_bar,
)

NameError: name 'out_dir' is not defined

In [41]:
trainer.predict(seg_module, datamodule=dm)

@rank_zero_only
def print_finish():
    print('--  [FINISHED.]  --', f'output dir : {out_dir}', sep='\n')
print_finish()

Missing logger folder: C:\Users\marka\Documents\New folder\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

--  [FINISHED.]  --
output dir : \content\gdrive\MyDrive\models_output\flair-one-baseline_argu


## <font color='#90c149'>Visual checking of predictions</font>

<br/><hr>

<font color='#90c149'>For the test set, obviously, you do not have access to the masks.</font> Nevertheless, we can visually display some predictions alongside the RGB images.<br/><br/>

First, we create lists containing the paths to the test RGB images (`images_test`) as well as the predicted semantic segmentation masks (`predictions`).<br/><br/>



We then display some random couples of predictions together with their corresponding aerial RGB images.<br/><br/>

<font color='#90c149'><em>Note 1</em></font>: if you are using the toy dataset, don't expect accurate predictions. A set of $200$ training samples will give limited results.<br/> 
<font color='#90c149'><em>Note 2</em></font>: rasterio will yield a <em>NotGeoreferencedWarning</em> regarding the predictions files. This is normal as the prediction files have been written without any geographical information, which is expected by rasterio. This kind of information is not important for assessing the model outputs, so we can just omit the warning.
<hr><br/>

In [42]:
from py_module.data_display import display_predictions, get_data_paths

images_test = sorted(list(get_data_paths(Path(path_data,'test'), 'IMG*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))
predictions = sorted(list(get_data_paths(Path(out_dir, "predictions"+'_'+out_model_name), 'PRED*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))
display_predictions(images_test, predictions, nb_samples=2)

## <font color='#90c149'>Metric calculation: mIoU</font>

<br/><hr>

As mentioned before, the masks of the test set are not available. However, the following cell describes the code that is used to calculate the metric used over the test set and to consequently rank the best models. Again, the toy dataset contains $50$ test pastches, while the full FLAIR-one dataset contains $15,700$ test patches.<br/><br/>

The calculation of the mean Intersection-over-Union (`mIou`) is based on the confusion matrix $C$, which is determined for each test patch. The confusion matrices are subsequently summed providing the confusion matrix describing the test set. Per-class IoU, defined as the ratio between true positives divided by the sum of false positives, false negatives and true positives is calculated from the summed confusion matrix as follows: <br/><br/>
    $$
    IoU_i = \frac{C_{i,i}}
    {C_{i,i} + \sum_{j \neq i}\left(C_{i,j} + C_{j,i} \right)} = \frac{TP}{TP+FP+FN}
    $$
<br>
The final `mIou` is then the average of the per-class IoUs. 


<font color='#90c149'><em>Note:</em></font> as the <font color='#90c149'><em>'other'</em></font> class is <font color='#90c149'>not well defined (void)</font>, its IoU is <font color='#90c149'>removed</font> and therefore does not contribute to the calculation of the `mIou`. In other words,  the remaining per-class IoUs (all except 'other') are averaged by 12 and not 13 to obtain the final `mIou`.</font>

<hr><br/>

In [44]:
import re
import numpy as np
from PIL import Image
from sklearn.metrics import confusion_matrix


def generate_miou(path_truth: str, path_pred: str) -> list:
  
    #################################################################################################
    def get_data_paths (path, filter):
        for path in Path(path).rglob(filter):
             yield path.resolve().as_posix()  
                
    def calc_miou(cm_array):
        m = np.nan
        with np.errstate(divide='ignore', invalid='ignore'):
            ious = np.diag(cm_array) / (cm_array.sum(0) + cm_array.sum(1) - np.diag(cm_array))
        m = np.nansum(ious[:-1]) / (np.logical_not(np.isnan(ious[:-1]))).sum()
        return m.astype(float), ious[:-1]      

    #################################################################################################
                       
    truth_images = sorted(list(get_data_paths(Path(path_truth), 'MSK*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))
    preds_images  = sorted(list(get_data_paths(Path(path_pred), 'PRED*.tif')), key=lambda x: int(x.split('_')[-1][:-4]))
    if len(truth_images) != len(preds_images): 
        print('[WARNING !] mismatch number of predictions and test files.')
    if truth_images[0][-10:-4] != preds_images[0][-10:-4] or truth_images[-1][-10:-4] != preds_images[-1][-10:-4]: 
        print('[WARNING !] unsorted images and masks found ! Please check filenames.') 
        
    patch_confusion_matrices = []

    for u in range(len(truth_images)):
        target = np.array(Image.open(truth_images[u]))-1 # -1 as model predictions start at 0 and turth at 1.
        target[target>12]=12  ### remapping masks to reduced baseline nomenclature.
        preds = np.array(Image.open(preds_images[u]))         
        patch_confusion_matrices.append(confusion_matrix(target.flatten(), preds.flatten(), labels=list(range(13))))

    sum_confmat = np.sum(patch_confusion_matrices, axis=0)
    mIou, ious = calc_miou(sum_confmat) 

    return mIou, ious


#if name == "__main__":  
#    truth_msk = './reference/
#    pred_msk  = './predictions/'
#    mIou = generate_miou(truth_images, truth_msk)

In [45]:
truth_msk = './reference/'
pred_msk  = './predictions/'
mIou = generate_miou(truth_images, truth_msk)

NameError: name 'truth_images' is not defined