# Large Raster Object Recognition Model Training

### Summary

This notebook gathers up the necessary code to train a RetinaNet-based object recognition model using pytorch lightning.

What makes this model and training process stand out, is that it's mainly oriented to train on really large raster images, namely satellite imagery of the earth's surface, astronomical imagery, or any simmilar task where raster data is often in the form of large, high-resolution files.  

In this notebook you will find many utilities made to deal with large image files and datasets, from a pipeling for cleaning and augmenting the data, to a pytorch-lightning based machine learning workflow that performs a grid search for optimal hyperparameters while training an object recognition model on tiles that were generated from the original rasters.

### Main Features

- Use of multiprocessing Pools to deal with the large amount of computing needed to preprocess the datasets and images by code paralellization.
- Use of tensor operations to split images and generate bounding boxes from object's position and diameter.
- Flexible functions that can be fed a variable amount of filters or operations created on the fly.
- Custom implemented common data augmentation techniques for iamges and datasets.
- Many plotting and logging utilities to facilitate visualization and debugging.
- Grid search through different hyperparameters.
- Pytorch lightning module wrapping a RetinaNet based model.
- A ModelConfig class for storing model hyperparameters.
- Custom model logging for training, validation and testing results, supporting checkpointing.
- A simple heatmap plot to review the results of the best performing model after training.

### Disclaimer

The images and data used for model training are not in the public domain. As a result, some of the debugging plots might be off by default.  
This notebook and the adjacent modules serve exclusively as a code sample, as such, I do not plan on regularly mantaining it.

### How to use

The required steps for running this code are the following:
1. Fill the data/datasets local directory with csv files each containing one dataset per raster tiff file.

2. Fill the data/images local directory with one raster tiff file per dataset/csv file.

3. Fill the DF_FILENAMES and IMAGE_FILENAMES variables with the names of your dataset and raster files, make sure the i'th dataset corresponds to the i'th raster file for all entries.

4. Run the notebook

In [None]:
DF_FILENAMES = ['example_df1.csv', 'example_df2.csv']
IMAGE_FILENAMES = ['example_im1.tiff', 'example_im2.tiff']

## Module imports

Here, the necessary modules for data preprocessing and module training are imported.  
They are split into builtin (aka standard library) modules first, not builtin modules second and custom local modules third. 

In [None]:
import logging
import os
import random
import shutil
import time
from collections.abc import Iterable
from multiprocessing import Pool
from typing import Optional, Any, Callable
from functools import partial

import numpy as np
import pandas as pd
import torch
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from PIL import Image

from ImageData import ImageData
from fileio import identifier_from_tilename, recreate_directory, recreate_temp_directory
from logger import LoggerProtocol, DefaultModelLogger, get_logging_key
from imutils import rotate_bbox_180, rotate_bbox_270, rotate_bbox_90, rotate_center
from implot import plot_image_with_annotations
from LargeRasterOR import LargeRasterORModel

## Logger Setup

This function defines a standard configuration for the logger that will keep track of the application code for preprocessing the data. This configuration includes a handler for dumping the logs into a file and is set to INFO level by default.  
Right before the training a second logger will take care of the model's logging.

In [None]:
def configure_logger():
    logger = logging.getLogger(__name__)

    logger.setLevel(logging.INFO)

    c_handler = logging.StreamHandler()
    f_handler = logging.FileHandler('application.log')

    c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
    f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    c_handler.setFormatter(c_format)
    f_handler.setFormatter(f_format)

    logger.addHandler(c_handler)
    logger.addHandler(f_handler)

    return logger

setup_logger = configure_logger()

## Data directories
Directories where the dataframes (csv files) and image files (tiff files) should be found by default.

In [None]:
DF_DIR = os.path.join('data', 'datasets')
IMG_DIR = os.path.join('data', 'images')

## Dataset and image reading
Multiprocessing is used for leveraging on multiple cores when reading from large datasets.  
`MAX_DIAMETER` defines the limit before the diameter of an object in a dataframe is considered a product of data noise.

For the data used in this notebook's training, a transform will be applied to turn the diameter units into pixel values, the `MAX_DIAMETER` variable will then correspond to 75 pixels.

In [None]:
df_filepaths = [os.path.join(DF_DIR, name) for name in DF_FILENAMES]

In [None]:
from fileio import read_raster

tiff_file_paths = [
os.path.join(IMG_DIR, identifier) for identifier in IMAGE_FILENAMES
]

with Pool() as p:
    images = p.map(read_raster, tiff_file_paths)

In [None]:
MAX_DIAMETER = 25

## Preprocessing Datasets
After defining the functions that perform each step in the preprocessing pipeline, they are applied to each of the dataframes in a parallelized chain of apply methods, this makes the code more readable, maintainable and open for new additions to the preprocessing pipeline.

Augmentation process:  

1. Remove rows with NaN or non-numeric values in the object's diameter column.
2. Set most uncommon labels as 'Unknown' to focus training on relevant data.
3. Generate bounding boxes from each datapoint's `diameter`.
4. Correct swapped minimum and maximum coordinates in noisy datapoints.
5. Filter out bounding boxes that expand outside the original image or are otherwise unreasonably large.
6. Create rotated versions of each image.

In [None]:
def remove_non_numeric_rows(df: pd.DataFrame, column_name: str, default_diameter: float=0.5) -> pd.DataFrame:
    """
    Removes all non-numeric or NaN values from a given column in a DataFrame
    
    Parameters:
    - df (pd.DataFrame): The DataFrame object which column will be filtered.
    - column_name (str): The name of the olumn that will be filtered.
    - default_diameter (float): Zero values will be replaced by `default_diameter` after filtering.

    Returns:
    - pd.DataFrame: The provided DataFrame object after its columns have been filtered.
    """
    
    df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
    
    df.dropna(subset=[column_name], inplace=True)

    df.diameter = df.diameter.astype('float64')
    df.loc[df.diameter <= 0, column_name] = default_diameter

    return df

In [None]:
def gather_common_labels(dataframes: Iterable[pd.DataFrame], percentile: float) -> "pd.Series[bool]":
    """
    Finds the smallest set of labels that are needed to cover up to a given percentile of a dataset.

    This function calculates the total count of rows of each label in a Dataframe object,
    sorts the labels from most common to least common and creates a pd.Series[bool] object
    indexed by those labels.
     
    The value for a given label will be True for the most common, then the second, third, etc until
    there are enough of that label's examples in the original DataFrame to cover up to a given `percentile`.

    Parameters:
    - dataframes (Iterable[pd.DataFrame]): A collection of DataFrame objects.
    - percentage (float): The percentile of samples that the top labels need to cover.

    Returns:
    - numpy.Series: A pandas Series containing boolean values. True for labels needed to cover up to the percentile and False for labels that are not needed.
    """
    total_count = pd.Series(dtype='int64')

    for df in dataframes:

        if "label" not in df.columns:
            setup_logger.warning(f"Warning: 'label' column not found in df. Skipping this file.")
            continue

        count = df['label'].value_counts()

        total_count = total_count.add(count, fill_value=0)

    total_count = total_count.sort_values(ascending=False).astype('int64') 

    cumulative_sum = total_count.cumsum()
    cumulative_percentage = 100 * cumulative_sum / cumulative_sum.iloc[-1]

    top_labels = cumulative_percentage <= percentile

    return top_labels

In [None]:
def filter_label_column(df: pd.DataFrame, label_mask: pd.Series, label_column_name: str) -> pd.DataFrame:
    """
    Sets the least common labels to "Unknown" based on the boolean value provided by `label_mask`.

    Parameters:
    - df (pd.DataFrame): A pandas DataFrame object containing the column `label_column_name`.
    - label_mask (pd.Series): A pandas Series object containing either True or False, if `False` for a given label, all rows containing that label value will be deleted from the DataFrame.

    Returns:
    - pd.DataFrame: DataFrame object after filtering.
    """
    if "label" not in df.columns:
        raise ValueError(f"Column {label_column_name} was not found in the provided DataFrame.")

    mask = ~df[label_column_name].isin(label_mask.index[label_mask])
    df.loc[mask, label_column_name] = "Unknown"
    return df

In [None]:
def generate_bboxes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Calculates a bounding box for a given center and diameter in each row of a DataFrame, then returns it.

    Parameters:
    - df (pd.DataFrame): A pandas DataFrame object containing objects with pixel coordinates that correspond to a raster image.

    Returns:
    - pd.DataFrame: DataFrame augmented with the bounding boxes' pixel coordinates for each example.
    """
    df['diameter'] = df['diameter'].astype('float64')
    df['xmin'] = df['pixel_x'] - df['diameter']/2
    df['xmax'] = df['pixel_x'] + df['diameter']/2
    df['ymin'] = df['pixel_y'] - df['diameter']/2
    df['ymax'] = df['pixel_y'] + df['diameter']/2
    return df


In [None]:
def correct_bounding_boxes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensures all pixel-coordinate bounding boxes in a DataFrame follow the format: (xmin, ymin, xmax, ymax).

    Parameters:
    - df (pd.DataFrame): The DataFrame.

    - common_directory (str): Path where the resulting tiles will be dumped.

    - annotation (pd.Dataframe): DataFrame containing labels for all the objects in an image.

    Returns:
    - pd.DataFrame: The merged annotations containing bounding box coordinates for each of the tiles instead of the full image.
    """
    swapped_count = 0

    for index, row in df.iterrows():
        swap_made = False

        if row['xmin'] > row['xmax']:
            df.at[index, 'xmin'], df.at[index, 'xmax'] = row['xmax'], row['xmin']
            swap_made = True

        if row['ymin'] > row['ymax']:
            df.at[index, 'ymin'], df.at[index, 'ymax'] = row['ymax'], row['ymin']
            swap_made = True

        if swap_made:
            swapped_count += 1

    setup_logger.info(f"Total number of coordinates swapped: {swapped_count}")
    return df

In [None]:
def min_filter(x_min, y_min, x_max, y_max, diam, sizex, sizey):
    return 0 <= x_min and 0 <= y_min

def max_filter(x_min, y_min, x_max, y_max, diam, sizex, sizey):
    return x_max <= sizex and y_max <= sizey

def max_diameter_filter(x_min, y_min, x_max, y_max, diam, sizex, sizey):
    return diam < MAX_DIAMETER

In [None]:
def filter_invalid_bboxes(df: pd.DataFrame, sizex: int, sizey: int, *filters: Callable) -> pd.DataFrame:
    """
    Removes rows from a DataFrame when they aren't validated by one of the provided filters.

    Parameters:
    - df (pd.DataFrame): A pandas DataFrame object containing objects with pixel coordinates for a given raster.
    - sizex (int): The x coordinate limit for bounding boxes in the DataFrame.
    - sizey (int): The y coordinate limit for bounding boxes in the DataFrame.
    - filters: Callable filters that receive x_min, y_min, x_max, y_max, diam and return True or False.

    Returns:
    - pd.DataFrame: Resulting DataFrame after filtering its bounding boxes.
    """

    def filter_invalid_row(row):
        x_min, y_min, x_max, y_max, diam = row['xmin'], row['ymin'], row['xmax'], row['ymax'], row['diameter']
        
        return all(filter_fn(x_min, y_min, x_max, y_max, diam, sizex, sizey) for filter_fn in filters)

    mask = df.apply(filter_invalid_row, axis=1)
    mask[mask.isna()] = False
    print(mask.value_counts())
    return df[mask]

In [None]:
def apply_column_transform(df: pd.DataFrame, label_column_name: str, transform: Callable) -> pd.DataFrame:
    '''
    Applies a given transform (Callable function that returns a value) to a specified DataFrame column, returns the DataFrame.

    Parameters:
    - df (pd.DataFrame): A pandas DataFrame object containing the column named as `label_column_name`.
    - label_column_name (str): Name of the column in which a transform will be applied.
    - transform(Callable): A Callable that takes a value and returns a value. Will be applied to the specified column.

    Returns:
    - pd.DataFrame: Resulting DataFrame after applying the transform to the specified column.
    '''
    return df.assign(label_column_name=df[label_column_name].apply(transform))

In [None]:
DEFAULT_OBJECT_DIAMETER = 0.5
DIAMETER_COLUMN_NAME = 'diameter'
LABEL_COLUMN_NAME = 'label'
transform = lambda x: 3*x

In [None]:
def preprocess_df(arg_tuple: tuple[str, tuple[int, int]]) -> pd.DataFrame:
    """
    Applies a series of functions (preprocessing steps) to a DataFrame and returns the result.

    Parameters:
    - arg_tuple (tuple[str, tuple[int, int]]): A tuple containing the DataFrame's filepath and the shape of the raster it annotates.
    """
    df_name = arg_tuple[0]
    raster_shape = arg_tuple[1]
    
    df = pd.read_csv(df_name)

    return (
        df
        .pipe(remove_non_numeric_rows, DIAMETER_COLUMN_NAME, default_diameter=DEFAULT_OBJECT_DIAMETER)
        .pipe(apply_column_transform, DIAMETER_COLUMN_NAME, transform)
        .pipe(generate_bboxes)
        .pipe(correct_bounding_boxes)
        .pipe(filter_invalid_bboxes, raster_shape[1], raster_shape[0], min_filter, max_filter, max_diameter_filter)
    )

### Preprocessing DataFrames

Before augmenting image data, each DataFrame is preprocessed in paralell by applying a series of preprocessing steps, previously detailed in "Preprocessing Datasets".

Details on the `preprocess_df` function:
- The preprocess_df function accepts a pickalable tuple of arguments so as to pass this information to the separate processes. 
- It uses method chaining with pipe to modify a dataframe in a readable and expandable manner, avoiding DataFrame copying as the Datasets are presumed to be large, this favours preventing memory issues over accessibilty when debugging.
- Accepts filepaths and integer values to minimize inter-process context passing, each process creates and modifies its own DataFrame until all DataFrames are merged after preprocessing.

In [None]:
df_paths_and_raster_sizes = zip(df_filepaths, (image.shape for image in images))

with Pool() as p:
    dfs = p.map(preprocess_df, df_paths_and_raster_sizes)

top_labels = gather_common_labels(dfs, 80)
dfs = [df.pipe(filter_label_column, top_labels, LABEL_COLUMN_NAME) for df in dfs]

### Augmenting Image Data
Image rotations are applied to the original rasters to augment the data for training the LargeRasterOR model.
A multiprocessing Pool is used to process each image sepparately, the datapoints (DataFrames and Images) are passed to the rotate_single_annotation function as an `ImageData` object from the `ImageData` dataclass.  

The function is curried before being passed to each process to favour flexibility when passing the ouput directory to the multiprocessing pool.
Annotations are aggregated into a single list of DataFrames, each DataFrame is expanded by adding the `image_path` column, which will be used later to link each DataFrame to its corresponding raster.


In [None]:
Image.MAX_IMAGE_PIXELS = 200000000

In [None]:
def rotate_single_annotation(data: ImageData, output_directory: str, angles: list[float]=None) -> list[pd.DataFrame]:
    """
    Rotates an image and its corresponding annotations by a given angle. Returns both rotated and original versions.

    If provided a collection of angles, it creates DataFrame and image rotation for each angle.
    Returns all rotated DataFrame versions including the original. Stores all rotated iamges as well
    as the original image in `output_directory`.
    

    Parameters:
    - data (ImageData): An ImageData object containing an Image and a DataFrame with its corresponding objects.
    - output_directory (str): A path to the directory where the rotated images will be stored.
    - angles (Iterable[float]): An Iterable containing the angles that will be used to create the rotated annotations.

    Returns:
    - list[pd.DataFrame]: List containing the original DataFrame as well as all of its rotated copies.
    """

    original_image, original_df, identifier =  data.image, data.df, data.image_identifier
    
    new_annotations = []

    df = original_df[['xmin', 'ymin', 'xmax', 'ymax', 'pixel_x', 'pixel_y']].copy()
    output_path = os.path.join(output_directory, f'{identifier}_0.tiff')
    df['image_path'] = output_path
    df['label'] = original_df['label']

    os.makedirs(output_directory, exist_ok=True)
    Image.fromarray(original_image).save(output_path)

    h, w = original_image.shape[:2]

    new_annotations.append(df)

    rotate_func = {90: rotate_bbox_90, 180: rotate_bbox_180, 270: rotate_bbox_270}

    if angles is None:
        angles = [90, 180, 270]

    for angle in angles:
        rotated = np.rot90(original_image, angle // 90)

        rotated_image_path = os.path.join(output_directory, f'{identifier}_{angle}.tiff')
        Image.fromarray(rotated).save(rotated_image_path)

        rotated_df = df[['xmin', 'ymin', 'xmax', 'ymax', 'pixel_x', 'pixel_y']].copy()

        rotated_df[['xmin', 'ymin', 'xmax', 'ymax']] = rotated_df.apply(
            lambda row: rotate_func[angle](row['xmin'], row['ymin'], row['xmax'], row['ymax'], w, h),
            axis=1,
            result_type="expand")

        rotated_df[['pixel_x', 'pixel_y']] = rotated_df.apply(
            lambda row: rotate_center(row['pixel_x'], row['pixel_y'], angle, w, h),
            axis=1,
            result_type="expand")

        rotated_df['image_path'] = rotated_image_path
        rotated_df['label'] = original_df['label']

        new_annotations.append(rotated_df)

    return new_annotations

In [None]:
IMAGE_OUTPUT_DIRECTORY = 'data/output_images'

In [None]:
image_names = [os.path.splitext(filename)[0] for filename in IMAGE_FILENAMES]

datapoints = [ImageData(image, identifier, df) for image, identifier, df in zip(images, image_names, dfs)]

curried_rotate_single_annotation = partial(rotate_single_annotation, output_directory=IMAGE_OUTPUT_DIRECTORY, angles=[90, 180, 270])

with Pool() as p:
    annotations_nested_list = p.map(curried_rotate_single_annotation, datapoints)

annotations_list = [annotation for sublist in annotations_nested_list for annotation in sublist]

## Plotting datasets
In this section, we can sanity-check the creation of rotated images, as well as the preprocessing of the DataFrames containing the annotaions by plotting the raster files and their corresponding bounding boxes using utilities from the `implot` module.

In [None]:
PLOT_IMAGE = False

%matplotlib widget

all_annotations = pd.concat(annotations_list, ignore_index=True)

image_path = os.path.join(IMAGE_OUTPUT_DIRECTORY, os.listdir(IMAGE_OUTPUT_DIRECTORY)[0])

if isinstance(all_annotations, pd.DataFrame) and PLOT_IMAGE:
    filtered_annotations = all_annotations[all_annotations['image_path'] == image_path]
    original_image = np.array(Image.open(image_path))
    plot_image_with_annotations(original_image, filtered_annotations)

else:
    setup_logger.error("all_annotations is not a DataFrame. Please check its type.")

## Tile splitting functions
The following is a set of functions which main purpose is to split the original raster files into tiles.  
The resulting tiles will be dumped in the directory specified by `TILING_DIRECTORY`, which by default is `./tiling_dir`.

The tiling directory is further split into 1 directory per tiling size, as multiple tile sizes may be specified for performing grid search while training. Each directory inside the tiling directory will be named `tiling_{size}` where size is the side length of the tiles contained within that directory.

While generating the tiles, the previously generated dataframes containing image information will be read, and the coordinates of each bounding box within the raster file will be localized to the corresponding coordinates inside the tile where that bounding box ends up after the process of splitting.  
After generating tiles for an image, a series of optional filters may be passed to the `make_tiles` function to remove tiles after splitting based on any criterion. As an example, one might want to remove a tile that ended up with no objects in it, or containing just black pixels, if they are not relevnt for training the target model.

This process is done for multiple image sizes in parallel using multiprocessing. It is initialized by the make_tiles function, and checks whether the resulting target directories already exist before generating them from the start, as it is a costly process. This is made to facilitate debugging but can be prevented by setting the `force` argument to True, initiating the tiling process regardless of whether the target directories already exist.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def split_and_save_tiles(tiling_directory: str, tile_size: int, image_name: str, image_tensor: torch.Tensor, bboxes: torch.Tensor, labels: pd.Series):
    """
    Splits a given image and a DataFrame into multiple tiles. 
    Adds the column "image_path" to the DataFrame and returns it.

    Splits a given image into a grid of tiles, taking into account the original tile
    for each DataFrame's row.
    Dumps the resulting images (tiles) into a specified directory, then adds to the 
    original DataFrame the necessary information to recover each row's bounding
    box in its corresponding tile before returning it.        

    Parameters:
    - patch_size (int): The side lenght of the tiles the image will be broken down into.

    - common_directory (str): Path where the resulting tiles will be dumped.

    - annotation (pd.Dataframe): DataFrame containing labels for all the objects in an image.

    Returns:
    - pd.DataFrame: The merged annotations containing bounding box coordinates for each of the tiles instead of the full image.
    """
    
    
    height, width = image_tensor.shape[:2]
    
    tile_number = 0
    new_rows = []
    tile_paths = []
    out_labels = []

    n_tiles_x = (width + tile_size - 1) // tile_size
    n_tiles_y = (height + tile_size - 1) // tile_size
    
    for i in range(n_tiles_y):
        for j in range(n_tiles_x):
            xmin_tile = j * tile_size
            ymin_tile = i * tile_size
            xmax_tile = min(xmin_tile + tile_size, width)
            ymax_tile = min(ymin_tile + tile_size, height)
            
            tile = image_tensor[ymin_tile:ymax_tile, xmin_tile:xmax_tile]
            tile_to_save = tile.cpu().numpy().astype(np.uint8)
            tile_name = f"tile_{image_name}_{tile_number}.png"
            tile_location = os.path.join(tiling_directory, tile_name)
            Image.fromarray(tile_to_save).save(tile_location)
            
            tile_box = torch.tensor([xmin_tile, ymin_tile, xmax_tile, ymax_tile])
            
            centers = (bboxes[:, :2] + bboxes[:, 2:]) // 2

            valid_centers = (centers >= tile_box[:2]) & (centers < tile_box[2:])
            valid_idxs = valid_centers.all(dim=1)

            valid_boxes = bboxes[valid_idxs]
            xmin_new = torch.clamp(valid_boxes[:, 0] - tile_box[0], min=0, max=tile_size)
            ymin_new = torch.clamp(valid_boxes[:, 1] - tile_box[1], min=0, max=tile_size)
            xmax_new = torch.clamp(valid_boxes[:, 2] - tile_box[0], min=0, max=tile_size)
            ymax_new = torch.clamp(valid_boxes[:, 3] - tile_box[1], min=0, max=tile_size)

            valid_idxs_np = valid_idxs.cpu().numpy()
            out_labels.extend(labels[valid_idxs_np].tolist())

            new_rows.extend(torch.stack([xmin_new, ymin_new, xmax_new, ymax_new], dim=1).tolist())
            tile_paths.extend([tile_location] * len(xmin_new))

            tile_number += 1
    
    tiled_df = pd.DataFrame(new_rows, columns=['xmin', 'ymin', 'xmax', 'ymax'])
    tiled_df['label'] = labels 
    tiled_df['tile_path'] = [str(tile_path).split('/')[-1] for tile_path in tile_paths]
    
    return tiled_df

In [None]:
def generate_split_raster_annotation(patch_size: int, common_directory: str, df: pd.DataFrame) -> pd.DataFrame:
    """
    Splits a given image and a DataFrame into multiple tiles. 
    Adds the column "tile_path" to the DataFrame and returns it.

    Splits a given image into a grid of tiles, taking into account the original tile
    for each DataFrame's row.
    Dumps the resulting images (tiles) into a specified directory, then adds to the 
    original DataFrame the necessary information to recover each row's bounding
    box in its corresponding tile before returning it.        

    Parameters:
    - patch_size (int): The side lenght of the tiles the image will be broken down into.

    - common_directory (str): Path where the resulting tiles will be dumped.

    - df (pd.Dataframe): DataFrame containing labels for all the objects in an image.

    Returns:
    - pd.DataFrame: The merged annotations containing bounding box coordinates for each of the tiles instead of the full image.
    """

    image_identifier = identifier_from_tilename(df['image_path'][0])
    image_path = df['image_path'][0]
    image_array = read_raster(image_path)
    
    bboxes = torch.tensor(df[['xmin', 'ymin', 'xmax', 'ymax']].values)
    df['tile_path'] = np.nan
    df = split_and_save_tiles(common_directory, patch_size, image_identifier, torch.from_numpy(image_array), bboxes, df.labels)

    df.dropna(subset=['tile_path'], inplace=True)
    
    return df

In [None]:
CENSORED_PIXELS_LIMIT = 2500

In [None]:
def filter_tiles(common_dir: str, df: pd.DataFrame, pixel_limit: int, filters: Optional[Iterable[Callable]]=None, spare_tiles: int = 0) -> pd.DataFrame:
    """
    Filters out image tiles based on a set of conditions.

    Parameters:
    - common_dir (str): Path to a directory containing the tiles to be filtered.
    - pixel_limit (int): The upper limit of black pixels for the black_pixel_filter.
    - spare_tiles (int): Number of tiles to spare before starting the filter operation.
    - filters: Callable filters that receive a numpy array of the image and return True or False.
    """
    if filters is None or not filters:
        return df

    tile_dir = common_dir
    tile_files = sorted([f for f in os.listdir(tile_dir) if f.startswith('tile_') and f.endswith('.png')])
    
    for tile_file in tile_files[spare_tiles:]:
        tile_path = os.path.join(tile_dir, tile_file)
        
        with Image.open(tile_path) as img:
            arr = np.array(img)
            
            if not all(filter_fn(arr, pixel_limit) for filter_fn in filters):
                os.remove(tile_path)
                tile_files.remove(tile_file)

    dangling_tile_entries = {tile for tile in set(df['tile_path']) if tile not in set(tile_files)}

    setup_logger.info(f'Deleting {len(dangling_tile_entries)} dangling file entries')

    df = df[~df['tile_path'].isin(dangling_tile_entries)]

    return df


In [None]:
def black_pixel_filter(image_array, pixel_limit):
    black_pixel_count = np.sum(np.all(image_array == 0, axis=-1))
    return black_pixel_count <= pixel_limit

In [None]:
def make_single_size(size: int, tiling_directory: str, tile_filters: Optional[Iterable[Callable]]=None, force: Optional[bool]=False) -> None:
    '''
    Generates tiles of a single size from an image, stores the tiles and annotations in `tiling_directory`.

    Parameters:
    - size (int): The side length (pixels) of the tiles to be generated.
    - tiling_directory (str): The directory where the generated tiles will be stored.
    - tile_filters (Optional[Iterable[Callable]]): A list of filter functions to apply to the tiles. Each function takes a tile as input and returns a boolean indicating whether the tile should be kept.
    - force (Optional[bool]): If True, generate tiles again regardless of whether the directory already exists.
    '''
    common_directory = os.path.join(tiling_directory, f"tiling_{size}")

    if not os.path.exists(common_directory) or force:

        setup_logger.debug(f'In make_single_tile: {common_directory=}')

        common_directory = recreate_directory(common_directory)
        
        all_annotations = []
        for annotation in annotations_list:
            all_annotations.append(generate_split_raster_annotation(size, common_directory, annotation))

        merged_annotations = pd.concat(all_annotations, ignore_index=True)

        setup_logger.debug(f'merged_annotations for size {size} have shape {merged_annotations.shape}')

        filtered_annotations = filter_tiles(common_directory, merged_annotations, CENSORED_PIXELS_LIMIT, tile_filters, spare_tiles=0,)

        filtered_annotations.to_csv(os.path.join(common_directory, "tile.csv"), index=False)

def make_tiles(size_grid: list[int], tiling_directory: str, force: bool=False, tile_filters: Optional[Iterable[Callable]]=None):
    '''
    Generates tiles of multiple sizes in parallel.

    Parameters:
    - size_grid (list[int]): A list of integers representing the dimensions (both height and width) of the tiles to be generated.
    - tiling_directory (str): The directory where the generated tiles will be stored.
    - tile_filters (Optional[Iterable[Callable]]): A list of filter functions to apply to the tiles. Each function takes a tile as input and returns a boolean indicating whether the tile should be kept.
    - force (Optional[bool]): If True, generate tiles again regardless of whether the directory already exists.
    '''
    curried_make_tile = partial(make_single_size, tiling_directory=tiling_directory, tile_filters=tile_filters, force=force)
    
    with Pool() as p:
        p.map(curried_make_tile, size_grid)

## Dataset splitting
The following function servers the purpose of splitting the tile and annotation data inside the tiling directory to proper training, validation and testing sets that will be used for training.  
This process is initiated by taking a portion of the total tiles available for a given tile size, indicated on the `sample_size` parameter as a proportion factor between 0 and 1.  
Training the model with different samples sizes is supported, as when dealing with large amounts of data it is sometimes preferable to perform shorter runs without including the full dataset to check whether the training cycle works as expected.

After splitting the tiles and annotations into training, testing, and validation sets, all labels that were selected for the sampled data will be returned as a set. This is dynamically checked to avoid feeding labels to the model that have no corresponding bounding boxes present in the data.

In [None]:
def split_dataset_train_test(tiling_directory: str, sized_tiling_dir_name: str, tmpdir: str, sample_size: float, validation_proportion: float=0.25, test_proportion: float=0.20):
    """
    Splits a group of tiles and their annotations into training, validation, and testing sets. 
    Returns the paths of the training and testing directories, as well as the number of annotations used for validation.

    First, the function calculates the number of tiles in the directory specified by 'tiling_directory' and 'sized_tiling_dir_name'.
    It then randomly selects a portion of them based on 'sample_size'.
    Further splits the sampled tiles into testing and training tiles following the "test_proportion" parameter.
    Splits the training tiles into training and validations by following the "validation_proportion" parameter.

    Creates temporary training and testing directories. The training directory contains both training and validation tiles,
    as well as a file "tile.csv" with annotations that correspond to training tiles, and "val.csv", with annotations
    that correspond to validation tiles.

    Returns the newly created training and testing directories, as well as the number of annotations used for validation. 
    
    Parameters:
    - tiling_directory (str): The directory where the original tiles are stored.
    - sized_tiling_dir_name (str): Path to a directory containing the image tiles for a specific patch size.
    - tmpdir (str): Path of the temporary directory where training and testing directories will be created.
    - sample_size (float in the range [0, 1]): The proportion of the dataset that will be sampled, then split into training, validation, and testing examples.
    - validation_proportion (Optional[float] in the range [0, 1]): The proportion of the training dataset that will be used for validation. Defaults to 0.25.
    - test_proportion (Optional[float] in the range [0, 1]): The proportion of the sampled dataset that will be used for testing. Defaults to 0.25.

    Returns:
    - Tuple[str, str, int, set[str]]: A 3-tuple containing the path of the training directory, the path of the testing directory, and the number of annotations used for validation.
    """

    setup_logger.info('Splitting dataset...')
    # Calculate number of tiles taht will be generated
    tile_files = [f for f in os.listdir(os.path.join(tiling_directory, sized_tiling_dir_name)) if f.startswith('tile_') and f.endswith('.png')]
    tile_number = len(tile_files)
    setup_logger.debug(f'There are {tile_number} tile files in {os.path.join(tiling_directory, sized_tiling_dir_name)}')

    # Randomly select a sample of tiles that will be used
    used_tiles = random.sample(tile_files, int(tile_number * sample_size))
    test_tiles = set(random.sample(used_tiles, int(len(used_tiles) * test_proportion)))
    train_tiles = set(used_tiles) - test_tiles

    setup_logger.debug(f'From sample size {sample_size} we have {len(used_tiles)} used tiles, {len(test_tiles)} for testing and {len(train_tiles)} for training')

    # Create training and testing directories in temporary directory
    train_dir = os.path.join(tmpdir, 'training')
    test_dir = os.path.join(tmpdir, 'testing')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    setup_logger.info('Copying tiles to temporary directories')
    # Copy test tiles to the test directory
    test_sources = [os.path.join(tiling_directory, sized_tiling_dir_name, tile) for tile in test_tiles]
    list(map(lambda src: shutil.copy(src, test_dir), test_sources))

    # Copy train tiles to the train directory
    train_sources = [os.path.join(tiling_directory, sized_tiling_dir_name, tile) for tile in train_tiles]
    list(map(lambda src: shutil.copy(src, train_dir), train_sources))

    setup_logger.debug(f'Found {len(test_sources)}/{len(test_tiles)} test sources')
    setup_logger.debug(f'Found {len(train_sources)}/{len(train_tiles)} train sources')

    setup_logger.info('Splitting annotations into train and test')

    # Split the annotation file into training and testing based on tile name
    annotations_df = pd.read_csv(os.path.join(tiling_directory, sized_tiling_dir_name, "tile.csv"))
    train_annotations = annotations_df[annotations_df['tile_path'].isin(train_tiles)]
    test_annotations = annotations_df[annotations_df['tile_path'].isin(test_tiles)]

    setup_logger.info(f'Train annotations length: {len(train_annotations)}')
    setup_logger.info(f'Test annotations length: {len(test_annotations)}')

    setup_logger.info('Splitting annotations into train and validation')
    
    # Split the training csv into train and validation data
    image_paths = train_annotations.tile_path.unique()
    valid_paths = np.random.choice(image_paths, int(len(image_paths)*validation_proportion) )

    setup_logger.debug(f'For validation split: Selected {len(valid_paths)} paths from {len(image_paths)} paths using {validation_proportion=}')

    valid_annotations = train_annotations.loc[train_annotations.tile_path.isin(valid_paths)]
    train_annotations = train_annotations.loc[~train_annotations.tile_path.isin(valid_paths)]

    setup_logger.info('Saving annotations')

    # Save these annotations to their respective directories
    train_annotations.to_csv(os.path.join(train_dir, "tile.csv"), index=False)
    valid_annotations.to_csv(os.path.join(train_dir, "val.csv"), index=False)
    test_annotations.to_csv(os.path.join(test_dir, "tile.csv"), index=False)

    # Dynamically check which labels are left after taking a data sample
    label_set = (
        set(train_annotations['label'])
        .union(set(valid_annotations['label']))
        .union(set(test_annotations['label']))
    )

    setup_logger.debug(f'Dataset Lengths:')
    setup_logger.debug(f'Train annotations length: {len(train_annotations)}')
    setup_logger.debug(f'Test annotations length: {len(test_annotations)}')
    setup_logger.debug(f'Validation annotations length: {len(valid_annotations)}')

    return train_dir, test_dir, len(valid_annotations), label_set

## Hyperparameters

The hyperparameter grid is defined before setting up the environment for training the model.  
This is done right before generating the tiles from the raster files, defining all the necessary side lengths for the generated tiles in the process.

AFterwards, the default configuration for training the model is defined

In [None]:
from itertools import product

patch_size_grid = [250, 225, 275]
batch_size_grid = [1, 2, 3]
n_epochs_grid = [2, 3, 4]
sample_size_grid = [0.2, 0.6, 1]
lr_grid = [0.001, 0.002]

hyperparameter_grid = product(patch_size_grid, n_epochs_grid, sample_size_grid, lr_grid, batch_size_grid)

## Environment Setup
In the Environment Setup section, the tiling directory containing raster tiles is generated using the rotated images and preprocessed annotations.  
After setting constants for the tiling directory and the column holding the labels for the model, the default model configuration is defined.

In [None]:
RUN_BASE_NAME = "Example run {}"

In [None]:
FORCE_CREATION = False
TILING_DIRECTORY = 'tiling_dir/'
CLASS_COLUMN_NAME = 'label'

make_tiles(patch_size_grid, TILING_DIRECTORY, force=FORCE_CREATION)

In [None]:
from ModelConfig import ModelConfig

# Default model config
conf = ModelConfig(
    workers=24,
    accelerator = 'gpu'
)

conf.validation.pr_compute_interval = 1

conf.test.iou_threshold = 0.05
conf.test.score_threshold = 0.01

## Factorial analyisis loop
The following function performs a training iteration over a given hyperparameter combination.
Each iteration performs the following steps:
1. Extract hyperparameter combination and configure the model with the current values.
2. Create a pytorch lightning trainer for a model using the provided callbacks and logger.
3. Trains the model logging train loss for labels and bounding boxes.
4. Performs validation and logs results periodically according to the validation interval set in model configuration.
5. Tests the model over the provided testing data and logs the results.

By default, the best model will be saved as a checkpoint in the "checkpoints" local directory if the validation process ends up in the best recorded score for a model variant on each iteration.

In [None]:
def train_iteration(hyperparams: dict[str, Any], train_dir: str, test_dir: str, label_set: set(str), logger: LoggerProtocol, artifact_logger_callback: Callable, pl_callbacks: list[Callable]):
    """
    Creates a new model, trains it according to the selected hyperparameters and evaluates it after training.

    Parameters:
    - hyperparams (dict[str]): Dictionary containing the values for each hyperparameter.
    - train_dir (str): Path to the directory that contains training and validation examples (tiles and annotation files).
    - test_dir (str): Path to the directory that contains testing examples (tiles and annotation files).
    - label_set (set[str]): Set containing all possible training and testing data categories.
    - logger (LoggerProtocol): An object that implements the LoggerProtocol.
    - artifact_logger_callback (Callable): Callback for logging artifacts during training.
    - pl_callbacks (list[Callable]): An optional list of callables that will be passed to the model's Trainer.
    
    Returns:
    - str: The newly created directory's path.
    """

    batch_size = hyperparams['batch_size']
    patch_size = hyperparams['patch_size']
    lr = hyperparams['learning_rate']
    n_epochs = hyperparams['n_epochs']
    sample_size = hyperparams['sample_size']

    # Model creation
    model = LargeRasterORModel(label_set)
    
    # Model configuration
    conf.batch_size = batch_size
    conf.train.epochs = n_epochs
    conf.train.lr = lr
    conf.train.csv_file = os.path.join(train_dir, "tile.csv")
    conf.train.root_dir = train_dir

    conf.validation.csv_file = os.path.join(train_dir, "val.csv")
    conf.validation.root_dir = train_dir

    conf.test.csv_file = os.path.join(test_dir, "tile.csv")
    conf.test.root_dir = test_dir

    model.conf = conf

    model.create_trainer(logger=logger, callbacks=pl_callbacks)

    setup_logger.debug(f'{pd.read_csv(os.path.join(train_dir, "tile.csv")).label.dtype=}')
    setup_logger.debug(f'{pd.read_csv(os.path.join(train_dir, "tile.csv")).label.head()=}')
    assert all(isinstance(item, (str, bytes, os.PathLike)) for item in pd.read_csv(os.path.join(train_dir, "tile.csv")).label), "Labels are not the expected type in preprocessing."

    # Training starts
    start_time = time.time()
    model.trainer.fit(model)
    elapsed_time = time.time() - start_time

    # Log time taken for training
    if logger is not None:
        logger.log({"training_time": elapsed_time})

    # Evaluation
    log_filename = f"patch_{patch_size}_epochs_{n_epochs}_sample_{sample_size}.txt"
    log_filepath = os.path.join(os.getcwd(), log_filename)

    try:
        print('Entering evaluation')
        results = model.trainer.test(model)[0]

        box_precision_run = results["box_precision"]
        box_recall_run = results["box_recall"]
        class_recall_run = results["class_recall"]

        with open(log_filepath, 'a') as log_file:
            log_file.write(str(box_precision_run) + "\n")
            log_file.write(str(box_recall_run) + "\n")
            log_file.write(str(class_recall_run) + "\n")

        key = get_logging_key(hyperparams)

        if logger is not None:
            logger.log(
                {
                    f"precision_{key}": box_precision_run,
                    f"box_recall_{key}": box_recall_run,
                }
            )

            artifact_logger_callback(logger, log_filepath, "evaluation_logs", "log_file", "Log file containing evaluation results")
        
        return box_precision_run, box_recall_run, class_recall_run

    except Exception as e:
        error_msg = f"Error during evaluation for patch_size {patch_size} and n_epochs {n_epochs}: {e}"
        if logger is not None:
            logger.log({"error_message": error_msg})

        setup_logger.exception(error_msg, exc_info=True)

        with open(log_filepath, 'a') as log_file:
            log_file.write(error_msg + "\n")

In [None]:
WANDB_ACTIVE = False

precision_results = {}
box_recall = {}
class_recall = {}

dicts = (precision_results, box_recall, class_recall)

run_number = 1

# Main Loop
if __name__ == "__main__":

    for patch_size, n_epochs, sample_size, lr, batch_size in hyperparameter_grid:

        tmpdir = recreate_temp_directory(setup_logger)

        train_dir, test_dir, validation_count, label_set = split_dataset_train_test(TILING_DIRECTORY, f"tiling_{patch_size}", tmpdir, sample_size)

        setup_logger.debug(f'train_dir contents: {os.listdir(train_dir)}')

        test_examples = [f for f in os.listdir(test_dir) if f.endswith('.png')]
        train_examples = [f for f in os.listdir(train_dir) if f.endswith('.png')]

        setup_logger.debug(f'Test directory has {len(test_examples)} examples')
        setup_logger.debug(f'Train directory has {len(train_examples)} examples')

        run_hyperparams = {
            'patch_size': patch_size,
            'learning_rate': lr,
            'n_epochs': n_epochs,
            'sample_size': sample_size,
            'batch_size': batch_size
        }

        checkpoint_callback = ModelCheckpoint(
            dirpath="checkpoints",
            filename="model-{epoch:02d}-{box_precision:.2f}",
            save_top_k=1,
            monitor="val_loss",
            verbose=True,
            save_last=True,
        )

        logger = DefaultModelLogger(log_save_dir='model_logging')
        artifact_logging = DefaultModelLogger.default_artifact_logging

        if WANDB_ACTIVE:
            with wandb.init(project='LargeRasterOR', entity='entity', name=RUN_BASE_NAME.format(run_number),
                config=run_hyperparams, settings=wandb.Settings(start_method="thread")) as run:

                logger = WandbLogger(project="", entity="", log_model="", experiment=run)

                def wandb_artifact_logging(logger: LoggerProtocol, filepath: str, name: str, type: str, description: str):
                    artifact = wandb.Artifact(
                        name=name,
                        type=type,
                        description=description
                    )
                    
                    artifact.add_file(filepath)

                    run.log_artifact(artifact)

                artifact_logging = wandb_artifact_logging
                
                checkpoint_callback = ModelCheckpoint(
                    dirpath=os.path.join(run.dir, "checkpoints"),
                    filename="model-{epoch:02d}-{box_precision:.2f}",
                    save_top_k=1,
                    monitor="val_loss",
                    verbose=True,
                    save_last=True,
                )

                box_precision_run, box_recall_run, class_recall_run = train_iteration(run_hyperparams, train_dir, test_dir, logger, artifact_logging, checkpoint_callback)
        else:
            box_precision_run, box_recall_run, class_recall_run = train_iteration(run_hyperparams, train_dir, test_dir, logger, artifact_logging, checkpoint_callback)

        key = get_logging_key(run_hyperparams)
        precision_results[key] = box_precision_run
        box_recall[key] = box_recall_run
        class_recall[key] = class_recall_run

        run_number += 1


    setup_logger.info("\nPrecision Results:")
    for key, value in precision_results.items():
        setup_logger.info(f"{key}: {value:.4f}")