## Example of the `aitlas` toolbox in the context of image segmentation
---
```
Author: Ana Kostovska
Organisation: Bias Variance Labs
Website: https://www.bvlabs.ai/
Ljubljana, 2024
```
---

### Importing required packages

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from aitlas.datasets import TiiLIDARDatasetSegmentation
from aitlas.models import HRNet
model_config = TiiLIDARDatasetSegmentation.get_fixed_model_config()

### Loading train, validation and test data

Input parameters for train, validation and test data:

- **batch_size**: The number of samples processed before the model is updated. A larger batch size can speed up processing but requires more memory.
- **num_workers**: The number of worker processes that will be used for processing data. Increasing the number of workers can significantly speed up data processing, however, it also increases memory and CPU/GPU usage.
- **object_class**: A parameter that specifies the type of archaeological object you are interested in processing, e.g., 'AO', 'barrow', 'enclosure', 'ringfort'.
- **object_class_band_id**: An integer parameter identifying the band where the annotations for a specific object class are located within the segmentation masks.
- **visualisation_type**: The vizuelization type used for the patches, e.g., 'SLRM'.
- **DFM_quality**: List of annotation qualities to be included in the processed data, e.g., '1,2'.
- **keep_empty_patches**: A boolean parameter that controls if empty patches are kept. Set to False when training since "empty" data can't be used for training. For testing or validation, True can be used to check how the model handles empty patches.
- **shuffle**: Determines whether the data should be shuffled before being processed. 
- **data_dir**: The directory path where the input data is stored. 
- **annotations_dir**: The directory path where the segmentation masks are stored. 
- **transforms**: A list of transformations applied to the input data during processing.
- **target_transforms**: A list of transformations applied to the segmentation masks during processing.
- **joint_transforms**: Transformations applied simultaneously to both the input data and segmentation masks.

In [None]:
batch_size = 16
num_workers = 4
object_class = "ringfort"
object_class_band_id = 2
visualisation_type = "SLRM"

In [None]:
train_dataset_config = {
    "batch_size": batch_size,
    "num_workers": num_workers,
    "object_class": object_class,
    "object_class_band_id": object_class_band_id,
    "visualisation_type": visualisation_type,
    "DFM_quality": '1,2',
    "shuffle": True,
    "keep_empty_patches": False,
    "data_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/samples/train",
    "annotations_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/labels/segmentation_masks/train",
    "joint_transforms": ["aitlas.transforms.FlipHVRandomRotate"],
    "transforms": ["aitlas.transforms.Transpose"],
	"target_transforms": ["aitlas.transforms.Transpose"]
}
train_dataset = TiiLIDARDatasetSegmentation(train_dataset_config)

validation_dataset_config = {
    "batch_size": batch_size,
    "num_workers": num_workers,
    "object_class": object_class,
    "object_class_band_id": object_class_band_id,
    "visualisation_type": visualisation_type,
    "DFM_quality": '1,2',
    "shuffle": False,
    "keep_empty_patches": False,
    "data_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/samples/validation",
    "annotations_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/labels/segmentation_masks/validation",
    "transforms": ["aitlas.transforms.Transpose"],
    "target_transforms": ["aitlas.transforms.Transpose"]
}
validation_dataset = TiiLIDARDatasetSegmentation(validation_dataset_config)

test_dataset_config = {
    "batch_size": batch_size,
    "num_workers": num_workers,
    "object_class": object_class,
    "object_class_band_id": object_class_band_id,
    "visualisation_type": visualisation_type,
    "DFM_quality": '1,2',
    "shuffle": False,
    "keep_empty_patches": False,
    "data_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/samples/test",
    "annotations_dir": "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/demo_data/labels/segmentation_masks/test",
    "transforms": ["aitlas.transforms.Transpose"],
	"target_transforms": ["aitlas.transforms.Transpose"]
}
test_dataset = TiiLIDARDatasetSegmentation(test_dataset_config)

len(train_dataset), len(validation_dataset), len(test_dataset)

### Model creation

In [None]:
model = HRNet(model_config)
model.prepare()

### Loading pretrained ADAF model (optional)

If you don't want to use an existing model, you can skip this step. If you do want to use one, set the model path, uncomment the lines, and run the cell to load the model into memory.

In [None]:
model_path = "/Users/anakostovska/Dropbox/aitlas_v1/inference/data/models/semantic_segmentation/ringfort_HRNet_SLRM_512px_pretrained_train_12_val_124_with_Transformation.tar" 
model.load_model(model_path)

### Training the model

Input parameters: 
- **epochs**: The total number of training cycles the model will undergo. Each epoch represents one complete pass of the training dataset through the model.
- **model_directory**: Path to the directory where the trained model and its checkpoints will be saved. This is used for storing the model during and after training.
- **run_id**: Name of the subdirectory within the model_directory to store results from different runs

In [None]:
epochs = 20
model_directory = "./models/semantic_segmentation/"
run_id = 'ringfort_1_2'

In [None]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    val_dataset=validation_dataset,
    epochs=epochs,
    model_directory=model_directory,
    run_id=run_id
);

### Model evaluation

In [None]:
model = HRNet(model_config)
model.prepare()
model.running_metrics.reset()
model_path = "/Users/anakostovska/Dropbox/aitlas_v1/retrain_model/models/semantic_segmentation/ringfort_1_2/best_checkpoint_1710336422_1.pth.tar" # update the path!
model.evaluate(dataset=test_dataset, model_path=model_path)
model.running_metrics.get_scores(model.metrics)