### <center>**Reading and Cleaning Annotation Data for Custom PyTorch Object Detection**</center> 

In [1]:
# Import necessary packages
%matplotlib inline
import json
import os
import shutil
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion(); # interactive mode

##### Load annotation data into dataframe

In [None]:
# Function for reading JSON as dictionary
def read_json(filename: str) -> dict:
    try:
        with open(filename, "r") as f:
            data = json.load(f)
    except Exception as e:
        raise Exception(f"Reading {filename} file encountered an error: {e}")
    return data

# Function to create a DataFrame from a list of records
def create_dataframe(data: list) -> pd.DataFrame:
    # Normalize the column levels and create a DataFrame
    return pd.json_normalize(data)

# Main function to iterate over files in directory and add to df
def main():
    # Assign directory and empty list for collecting records
    directory = "C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/Annotations/"  # annotation directory
    records = []
    
    # Iterate over files in directory
    for filename in os.listdir(directory):
        f = os.path.join(directory, filename)
        # checking if it is a file
        if os.path.isfile(f):
            # Read the JSON file as python dictionary 
            data = read_json(filename=f)
        
            # Create the dataframe for the array items in annotations key 
            df = create_dataframe(data=data['annotations'])
            df.insert(loc=0, column='img_name', value=f'{f[-30:-5]}.JPG')
        
            df.rename(columns={
                "img_name": "img_name",
                "name": "label",
                "bounding_box.h": "bbox_height",
                "bounding_box.w": "bbox_width",
                "bounding_box.x": "bbox_x_topLeft",
                "bounding_box.y": "bbox_y_topLeft",
                "polygon.paths": "polygon_path"
            }, inplace=True)
            
            # Append the records to the list
            records.append(df)
        else:
            print(f"Skipping non-file: {filename}")

    # Concatenate all records into a single DataFrame
    annos_df = pd.concat(records, ignore_index=True)

    # Convert x, y, h, w to xmin, ymin, xmax, ymax
    annos_df['xmin'] = annos_df['bbox_x_topLeft']
    annos_df['ymin'] = annos_df['bbox_y_topLeft']
    annos_df['xmax'] = annos_df['bbox_x_topLeft'] + annos_df['bbox_width']
    annos_df['ymax'] = annos_df['bbox_y_topLeft'] + annos_df['bbox_height']
  
    # Drop unnecessary columns 
    annos_df = annos_df.drop(columns=['bbox_height', 'bbox_width', 'bbox_x_topLeft', 
                                      'bbox_y_topLeft', 'id', 'slot_names', 'polygon_path'])
        
    return annos_df

if __name__ == "__main__":
    df = main()
    print(df.head())

##### Pre-process annotation dataframe

In [3]:
# Get the unique image names
unique_img_names = df['img_name'].unique()

invalid_img_names = []
for img_name in unique_img_names:
    img_path = f'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/Images/{img_name}'
    img = Image.open(img_path)
    if img.size == (5184, 3888):
        invalid_img_names.append(img_name)

# remove from df
df = df[~df['img_name'].isin(invalid_img_names)]

In [None]:
# Identify classes with fewer than 200 occurrences as negative classes
class_counts = df['label'].value_counts()
negative_classes = class_counts[class_counts < 200].index.tolist()

print(f'Total classes: {len(class_counts)}')
print(f'Negative classes: {len(negative_classes)}')
print(f'Positive classes: {len(class_counts) - len(negative_classes)}')

# Add 'Hen' to the list of negative classes
if 'Hen' not in negative_classes:
    negative_classes.append('Hen')

# Mark negative classes and 'Hen' as background (0)
df['target'] = df['label'].apply(lambda x: 0 if x in negative_classes else x)

# Convert labels to categorical data and get the numeric codes
df['target'] = pd.Categorical(df['target']).codes

# filter out images with only negative classes
df = df.groupby('img_name').filter(lambda x: x['target'].ne(0).any())

# filter out images with invalid bounding boxes
df = df.groupby('img_name').filter(lambda x: ((x['xmin'] < x['xmax']) & (x['ymin'] < x['ymax'])).all())

# Create a dictionary using df['label'] as the keys and df['target'] as the values
label_dict = dict(zip(df['target'], df['label']))

# change label_dict key '0' value to 'background'
label_dict[0] = 'NEGATIVE'

# Drop the original 'label' column from df
df = df.drop(['label'], axis=1)

# Rename 'target' column to 'label'
df.rename(columns={'target': 'label'}, inplace=True)

# Save df as csv in directory
df.to_csv('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv', index=False)

##### Filter images after pre-processing

In [5]:
# Store unique img_names in filtered df as array
img_names = df['img_name'].unique().tolist()

# Create a new directory called 'filtered_images'
new_dir = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images'
if not os.path.exists(new_dir):
    os.makedirs(new_dir)
else:
    for file in os.listdir(new_dir):
        os.remove(os.path.join(new_dir, file))

# Copy images in img_names to new directory
for img in img_names:
    shutil.copy2(f'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/Images/{img}', new_dir)

### <center>**Transform and Augment Image and Annotation Data for Custom PyTorch Object Detection**</center> 

In [6]:
# import necessary packages
import numpy as np
from collections import defaultdict
import torchvision
torchvision.disable_beta_transforms_warning()
import torch
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
from torchvision import transforms as _transforms, tv_tensors
import torchvision.transforms.v2 as T
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
import utils

In [7]:
class MAVdroneDataset(torch.utils.data.Dataset):
    """Dataset Loader for Waterfowl Drone Imagery"""

    def __init__(self, csv_file, root_dir, transforms):
        """
        Arguments:
            csv_file (string): Path to the CSV file with annotations.
            root_dir (string): Directory containing all images.
            transforms (callable): Transformation to be applied on a sample.
        """
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transforms = transforms
        self.unique_image_names = self.df['img_name'].unique()

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_name = self.unique_image_names[idx]

        # Isolate first row to prevent multiple instances of the same image
        row = self.df[self.df['img_name'] == image_name].iloc[0]

        image_path = os.path.join(self.root_dir, row['img_name'])

        image = Image.open(image_path).convert('RGB')
        image = np.array(image, dtype=np.uint8)
        image = torch.from_numpy(image).permute(2, 0, 1)  # Convert to Tensor

        # Bounding boxes and labels
        boxes = self.df[self.df['img_name'] == image_name][['xmin', 'ymin', 'xmax', 'ymax']].values 
        labels = self.df[self.df['img_name'] == image_name]['label'].values

        labels = torch.as_tensor(labels, dtype=torch.int64)  # (n_objects)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        # Calculate area
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # Assume no crowd annotations
        iscrowd = torch.zeros((len(labels),), dtype=torch.int64)

        # Create target dictionary
        target = {
            'boxes': tv_tensors.BoundingBoxes(boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(image.shape[1], image.shape[2])),
            'labels': labels,
            'image_id': torch.tensor([idx]),
            'area': area,
            'iscrowd': iscrowd
        }

        image = tv_tensors.Image(image)

        if self.transforms:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self):
        return len(self.unique_image_names)

##### Data transformation function

In [8]:
def get_transform(train: bool): 
    """
    Args:
        train (bool): Whether the transform is for training or validation/testing.
    """
    transforms_list = []
    transforms_list.append(T.ToImage())
    transforms_list.append(T.ToDtype(torch.float32, scale=True))
    if train:
        transforms_list.append(T.RandomHorizontalFlip(0.5))
        transforms_list.append(T.RandomApply([T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], p=0.25))
        transforms_list.append(T.RandomApply([T.GaussianBlur(kernel_size=1, sigma=(0.05, 0.25))], p=0.25))
        transforms_list.append(T.RandomIoUCrop(min_scale=0.75, max_scale=1.5, min_aspect_ratio=16/9, max_aspect_ratio=16/9))
        transforms_list.append(T.ClampBoundingBoxes())
        transforms_list.append(T.SanitizeBoundingBoxes())
    transforms_list.append(T.Resize(size=(800,), max_size=1333, interpolation=T.InterpolationMode.BILINEAR))
    return T.Compose(transforms_list)

##### Helper functions for plotting image and annotations

In [9]:
# classes are values in label_dict
classes = list(label_dict.values())

# reverse label dictionary for mapping predictions to classes
rev_label_dict = {v: k for k, v in label_dict.items()}

# distinct colors 
bbox_colors = ['#f032e6', '#ffffff', '#ffe119', '#3cb44b', '#42d4f4',
                    '#f58231', '#e6194B', '#dcbeff', '#469990', '#4363d8']

# label color map for plotting color-coded boxes by class
label_color_map = {k: bbox_colors[i] for i, k in enumerate(label_dict.keys())}

# function for reshaping boxes 
def get_box(boxes):
    boxes = np.array(boxes)
    boxes = boxes.astype('float').reshape(-1, 4)
    if boxes.shape[0] == 1 : return boxes
    return np.squeeze(boxes)


# function for plotting image
def img_show(image, ax = None, figsize = (6, 9)):
    if ax is None:
        fig, ax = plt.subplots(figsize = figsize)
    ax.xaxis.tick_top()
    ax.imshow(image)
    return ax
 

def plot_bbox(ax, boxes, labels):
    # add box to the image and use label_color_map to color-code by bounding box class if exists else 'black'
    ax.add_patch(plt.Rectangle((boxes[:, 0], boxes[:, 1]), boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1],
                    fill = False,
                    color = label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 
                    linewidth = 1.5))
    # add label text to bounding box using label_dict if label exists else labels
    ax.text(boxes[:, 2], boxes[:, 3], 
            (label_dict[labels.item()] if labels.item() in label_dict else labels.item()),
            fontsize = 8,
            bbox = dict(facecolor = 'white', alpha = 0.8, pad = 0, edgecolor = 'none'),
            color = 'black')


# function for plotting all boxes and labels on the image using get_polygon, img_show, and plot_mask functions
def plot_detections(image, boxes, labels, ax = None):
    ax = img_show(image.permute(1, 2, 0), ax = ax)
    for i in range(len(boxes)):
        box = get_box(boxes[i])
        plot_bbox(ax, box, labels[i])

##### Plot sample batch to confirm data loads and transforms correctly

In [10]:
# Load sample batch of data to custom PyTorch Dataset and Transform
sample_dataset = MAVdroneDataset(csv_file = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv', 
                                root_dir = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images', 
                                transforms = get_transform(train = True))

sample_data_loader = torch.utils.data.DataLoader(sample_dataset, batch_size = 8, shuffle = True, 
                                             collate_fn = utils.collate_fn, num_workers = 0)

In [None]:
# store images and annotation targets from sample batch
batch = next(iter(sample_data_loader))
images, targets = batch
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]

images = [np.clip(image, 0, 1) for image in images]

# Plot the all samples from batch in a grid of subplots.
plt.figure(figsize =(12, int(sample_data_loader.batch_size)*4))
for i in range(int(sample_data_loader.batch_size)):
    ax = plt.subplot(int(sample_data_loader.batch_size), 2, 1 + i)
    plot_detections(images[i], targets[i]['boxes'], targets[i]['labels'], ax = ax)
    plt.axis('off')
    plt.title(f"Sample {i + 1}")

#### Load RetinaNet with ResNet FPN backbone and add alpha, gamma, and dropout params
##### Adapted from: https://arxiv.org/abs/1708.02002

In [12]:
import torch
import torch.nn as nn
from torchvision.models.detection import RetinaNet
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.retinanet import RetinaNetClassificationHead, RetinaNetRegressionHead
from typing import Optional, Callable, List
from torchvision.ops import sigmoid_focal_loss

def _sum(x: List[torch.Tensor]) -> torch.Tensor:
    res = x[0]
    for i in x[1:]:
        res = res + i
    return res

class CustomRetinaNetClassificationHead(RetinaNetClassificationHead):
    def __init__(self, in_channels, num_anchors, num_classes, alpha=0.25, gamma_loss=2.0, prior_probability=0.01, norm_layer: Optional[Callable[..., nn.Module]] = None, dropout_prob=0.05):
        super().__init__(in_channels, num_anchors, num_classes, prior_probability, norm_layer)
        self.alpha = alpha
        self.gamma_loss = gamma_loss
        self.dropout = nn.Dropout(p=dropout_prob)

    def compute_loss(self, targets, head_outputs, matched_idxs):
        losses = []

        cls_logits = head_outputs["cls_logits"]

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image)
            gt_classes_target[
                foreground_idxs_per_image,
                targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
            ] = 1.0

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

            # compute the classification loss with custom alpha and gamma_loss
            losses.append(
                sigmoid_focal_loss(
                    cls_logits_per_image[valid_idxs_per_image],
                    gt_classes_target[valid_idxs_per_image],
                    alpha=self.alpha,
                    gamma=self.gamma_loss,
                    reduction="sum",
                )
                / max(1, num_foreground)
            )

        return _sum(losses) / len(targets)

    def forward(self, x):
        all_cls_logits = []
        for features in x:
            cls_logits = self.conv(features)
            cls_logits = self.dropout(cls_logits)  # Apply dropout
            cls_logits = self.cls_logits(cls_logits)

            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
            N, _, H, W = cls_logits.shape
            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, K)

            all_cls_logits.append(cls_logits)

        return torch.cat(all_cls_logits, dim=1)

class CustomRetinaNetRegressionHead(RetinaNetRegressionHead):
    def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None, dropout_prob=0.05):
        super().__init__(in_channels, num_anchors, norm_layer)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        all_bbox_regression = []
        for features in x:
            bbox_regression = self.conv(features)
            bbox_regression = self.dropout(bbox_regression)  # Apply dropout
            bbox_regression = self.bbox_reg(bbox_regression)

            # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
            N, _, H, W = bbox_regression.shape
            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)

            all_bbox_regression.append(bbox_regression)

        return torch.cat(all_bbox_regression, dim=1)

def get_retinanet_model(depth, num_classes, alpha=0.25, gamma_loss=2.0, trainable_backbone_layers=4, dropout_prob=0.05):
    # Create the backbone with FPN
    if depth == 18:
        backbone = resnet_fpn_backbone(backbone_name='resnet18', 
                                       weights=torchvision.models.ResNet18_Weights.DEFAULT, 
                                       trainable_layers=trainable_backbone_layers
                                       )
    elif depth == 34:
        backbone = resnet_fpn_backbone(backbone_name='resnet34', 
                                       weights=torchvision.models.ResNet34_Weights.DEFAULT,
                                       trainable_layers=trainable_backbone_layers
                                       )
    elif depth == 50:
        backbone = resnet_fpn_backbone(backbone_name='resnet50', 
                                       weights=torchvision.models.ResNet50_Weights.DEFAULT,
                                       trainable_layers=trainable_backbone_layers
                                       )
    elif depth == 101:
        backbone = resnet_fpn_backbone(backbone_name='resnet101', 
                                       weights=torchvision.models.ResNet101_Weights.DEFAULT, 
                                       trainable_layers=trainable_backbone_layers
                                       )
    elif depth == 152:
        backbone = resnet_fpn_backbone(backbone_name='resnet152', 
                                       weights=torchvision.models.ResNet152_Weights.DEFAULT, 
                                       trainable_layers=trainable_backbone_layers
                                       )
    else:
        raise ValueError("Unsupported model depth")

    # Create the RetinaNet model with the custom backbone
    model = RetinaNet(backbone, num_classes=num_classes)

    # Replace the classification head with the custom one
    in_channels = model.head.classification_head.cls_logits.in_channels
    num_anchors = model.head.classification_head.num_anchors
    model.head.classification_head = CustomRetinaNetClassificationHead(in_channels, num_anchors, num_classes, alpha=alpha, gamma_loss=gamma_loss, dropout_prob=dropout_prob)

    # Replace the regression head with the custom one
    model.head.regression_head = CustomRetinaNetRegressionHead(in_channels, num_anchors, dropout_prob=dropout_prob)

    return model

In [None]:
print(get_retinanet_model(depth = 101, num_classes=len(classes), alpha=0.5, gamma_loss=3.0, trainable_backbone_layers=4, dropout_prob=0.05))

#### Use stratified sampling to split multi-label dataset into train, val, test sets

In [None]:
# Set random number generator for reproducible data splits
rng = np.random.default_rng(np.random.MT19937(np.random.SeedSequence(51)))

# Group annotations by image
image_groups = df.groupby('img_name')

# Create a dictionary to store the class distribution for each image
image_class_distribution = {}

# Populate the dictionary with class distributions
for image_name, group in image_groups:
    labels = group['label'].tolist()
    image_class_distribution[image_name] = labels

# Create a list of all image names and their corresponding labels
all_images = list(image_class_distribution.keys())
all_labels = [image_class_distribution[image] for image in all_images]

# Convert labels to a binary matrix for stratification
unique_labels = sorted(df['label'].unique())
label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
binary_labels = np.zeros((len(all_images), len(unique_labels)), dtype=int)

for i, labels in enumerate(all_labels):
    for label in labels:
        binary_labels[i, label_to_index[label]] = 1

# Define the split ratios
train_ratio = 0.8
val_ratio = 0.15
test_ratio = 0.05

# Function to perform stratified sampling
def stratified_split(all_images, binary_labels, train_ratio, val_ratio, rng):
    n_samples = len(all_images)
    indices = np.arange(n_samples)
    rng.shuffle(indices)

    train_indices = []
    val_indices = []
    test_indices = []

    class_counts = np.sum(binary_labels, axis=0)
    train_class_counts = np.zeros_like(class_counts)
    val_class_counts = np.zeros_like(class_counts)
    test_class_counts = np.zeros_like(class_counts)

    for idx in indices:
        label_vector = binary_labels[idx]
        if np.all(train_class_counts + label_vector <= train_ratio * class_counts):
            train_indices.append(idx)
            train_class_counts += label_vector
        elif np.all(val_class_counts + label_vector <= val_ratio * class_counts):
            val_indices.append(idx)
            val_class_counts += label_vector
        else:
            test_indices.append(idx)
            test_class_counts += label_vector

    return train_indices, val_indices, test_indices

# Perform stratified split
train_indices, val_indices, test_indices = stratified_split(all_images, binary_labels, train_ratio, val_ratio, rng)

# Map image names to unique indices
image_to_unique_index = {image: idx for idx, image in enumerate(df['img_name'].unique())}

# Create lists of unique indices for each split
train_indices = [image_to_unique_index[all_images[idx]] for idx in train_indices]
val_indices = [image_to_unique_index[all_images[idx]] for idx in val_indices]
test_indices = [image_to_unique_index[all_images[idx]] for idx in test_indices]

# Function to get class distribution
def get_class_distribution(images, image_class_distribution):
    class_counts = defaultdict(int)
    for image in images:
        for label in image_class_distribution[image]:
            class_counts[label] += 1
    return class_counts

# Get train, val, and test images
train_images = [all_images[idx] for idx in train_indices]
val_images = [all_images[idx] for idx in val_indices]
test_images = [all_images[idx] for idx in test_indices]

train_class_distribution = get_class_distribution(train_images, image_class_distribution)
val_class_distribution = get_class_distribution(val_images, image_class_distribution)
test_class_distribution = get_class_distribution(test_images, image_class_distribution)

class_indices = {label: [] for label in df['label'].unique()}

for idx, row in df.iterrows():
    class_indices[row['label']].append(idx)

train_class_distribution = {k: v / len(class_indices[k]) for k, v in train_class_distribution.items()}
val_class_distribution = {k: v / len(class_indices[k]) for k, v in val_class_distribution.items()}
test_class_distribution = {k: v / len(class_indices[k]) for k, v in test_class_distribution.items()}

print("Train class distribution:", dict(sorted(train_class_distribution.items())))
print("Validation class distribution:", dict(sorted(val_class_distribution.items())))
print("Test class distribution:", dict(sorted(test_class_distribution.items())))

#### Create weighted random sampler to handle class imbalances during training

In [15]:
# Assuming train_labels is a list of lists, where each list contains the labels for an image
train_labels = [image_class_distribution[image] for image in train_images]

# Flatten the list of lists into a single list of labels
flattened_train_labels = [label for sublist in train_labels for label in sublist]

# Calculate class counts (ignoring background/negative class)
train_class_counts = pd.Series([label for label in flattened_train_labels if label != 0]).value_counts().sort_index().tolist()

# Calculate the total count of labels
train_total_count = sum(train_class_counts)

# Calculate class weights
train_class_weights = [train_total_count / count for count in train_class_counts]
train_class_weights = torch.tensor(train_class_weights, dtype=torch.float32)

# store label weights for each train image 
train_label_weights = [torch.tensor([train_class_weights[label - 1] for label in labels], dtype=torch.float32) for labels in train_labels]

# Calculate the average weight for each image
train_image_weights = [torch.mean(weights).item() for weights in train_label_weights]

In [16]:
train_sampler = torch.utils.data.WeightedRandomSampler(train_image_weights, len(train_image_weights), replacement=True)

### <center>**Tune Model Hyperparameters using Ray Tune**</center> 

##### Ray Tune trainable

In [17]:
from datetime import datetime
import gc
from engine_gradientAccumulation import train_one_epoch, evaluate
from concurrent.futures import ThreadPoolExecutor
from coco_utils import get_coco_api_from_dataset
import ray
from ray import train, tune
from ray.tune import JupyterNotebookReporter
from ray.tune.search import ConcurrencyLimiter
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.search.bohb import TuneBOHB
from pathlib import Path
import ray.cloudpickle as pickle
import random

In [18]:
# Set random seed for reproducible training
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def calculate_f1_score(precision, recall):
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)


def train_MAVdroneDataset(config):
    import ray
    import torch
    import utils
    import tempfile
    from pathlib import Path
    import ray.cloudpickle as pickle
    from torch_lr_finder import LRFinder, TrainDataLoaderIter

    # Custom data loader iterator
    class CustomTrainDataLoaderIter(TrainDataLoaderIter):
        def inputs_labels_from_batch(self, batch_data):
            inputs = [image.to('cuda') for image in batch_data[0]]
            labels = [{k: v.to('cuda') for k, v in t.items()} for t in batch_data[1]]
            return inputs, labels

    # function for finding optimal learning rate given hyperparameters
    def train_lr_finder(config, dataset_train, accumulation_steps):
        # Create data loader
        data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=config["batch_size"],
                                                    sampler=config["train_sampler"],
                                                    collate_fn=utils.collate_fn,
                                                    num_workers=0, pin_memory=True)

        # Construct custom RetinaNet model (do this inside function to avoid conflicts with main function)
        model = get_retinanet_model(
            depth=config["resnet"], num_classes=len(classes), alpha=config["alpha"], 
            gamma_loss=config["gamma_loss"], trainable_backbone_layers=int(config["backbone_lyrs"]), 
            dropout_prob=0.05
        ).to('cuda')

        # Define the optimizer
        optimizer = torch.optim.SGD(
            model.parameters(), lr=1e-7, momentum=config["momentum"], weight_decay=config["weight_decay"]
        )

        # Create custom iterator
        train_iter = CustomTrainDataLoaderIter(data_loader_train)

        grad_scaler = torch.GradScaler()

        amp_config = {
            'device_type': 'cuda',
        }

        # Use LRFinder to find optimal learning rate
        class CustomLRFinder(LRFinder):
            def __init__(self, model, optimizer, criterion, device=None, amp_backend="native", amp_config=None, grad_scaler=None):
                super().__init__(model, optimizer, criterion, device)
                self.amp_backend = amp_backend
                self.amp_config = amp_config
                self.grad_scaler = grad_scaler or torch.GradScaler()

            def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=True):
                self.model.train()
                total_loss = 0
                
                self.optimizer.zero_grad()
                for _ in range(accumulation_steps):
                    inputs, labels = next(train_iter)
                    inputs, labels = self._move_to_device(inputs, labels, non_blocking=non_blocking_transfer)

                    # Forward pass with mixed precision
                    with torch.autocast(device_type="cuda"):
                        outputs = self.model(inputs, labels)  # Ensure targets are passed here
                        loss = sum(loss for loss in outputs.values())  # Sum the losses

                    # loss should be averaged in each step
                    loss /= accumulation_steps

                    # Backward pass with mixed precision
                    self.grad_scaler.scale(loss).backward()

                    total_loss += loss

                self.grad_scaler.step(self.optimizer)
                self.grad_scaler.update()

                return total_loss.item()

        # Use the custom LRFinder
        lr_finder = CustomLRFinder(model, optimizer, None, device='cuda', amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler)

        lr_finder.range_test(train_iter, 
                            end_lr=1, 
                            num_iter=100, 
                            step_mode='exp', 
                            accumulation_steps=accumulation_steps 
                            )

        # Plot the learning rate finder results
        ax, suggested_lr = lr_finder.plot()
        
        lr_finder.reset()

        return suggested_lr

    # Set random seed for reproducible training
    set_seed(51) 
    
    # Get dataset references directly from config
    dataset_train = ray.get(config["dataset_train_ref"])
    data_loader_val = ray.get(config["data_loader_val_ref"])
    train_coco_ds = ray.get(config["train_coco_ds_ref"])
    val_coco_ds = ray.get(config["val_coco_ds_ref"])

    # Use gradient accumulation due to memory constraints (batch_size=16 maxes out GPU)
    training_steps = [
        {"step": 0, "batch_size": config["batch_size"], "epochs": 20, "print_freq": 25, "accumulation_steps": 8}, # [ 4*8 ] --> 32
        {"step": 1, "batch_size": config["batch_size"], "epochs": 15, "print_freq": 25, "accumulation_steps": 16}, # [ 4*16 ] --> 64
        {"step": 2, "batch_size": config["batch_size"], "epochs": 10, "print_freq": 25, "accumulation_steps": 32}, # [ 4*32 ] --> 128
        {"step": 3, "batch_size": config["batch_size"], "epochs": 5, "print_freq": 25, "accumulation_steps": 64} # [ 4*64 ] --> 256
    ]

    # Determine the optimal learning rate given hyperparameter configuration
    suggested_lr = train_lr_finder(config, dataset_train, training_steps[0]["accumulation_steps"])

    # add optimal lr to config
    config["lr"] = suggested_lr

    # Construct custom RetinaNet model
    model = get_retinanet_model(depth=config["resnet"],
                                num_classes=len(classes), 
                                alpha=config["alpha"], 
                                gamma_loss=config["gamma_loss"],
                                trainable_backbone_layers=int(config["backbone_lyrs"]),
                                dropout_prob=0.05)

    # Explicitly set CUDA_VISIBLE_DEVICES
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

    # Force reinitialization of CUDA context
    torch.cuda.device_count()

    device = "cpu" 
    if torch.cuda.is_available():
        device = "cuda:0"
    model.to(device)

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=suggested_lr,
                                momentum=config["momentum"], 
                                weight_decay=config["weight_decay"])

    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=int(config["step_size"]),
                                                   gamma=config["gamma_lr"])

    # Load checkpoint if available
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "rb") as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state["epoch"] + 1
            model.load_state_dict(checkpoint_state["model_state_dict"])
            optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
            current_step = checkpoint_state["current_step"]
            batch_size = checkpoint_state["batch_size"]
            accumulation_steps = checkpoint_state["accumulation_steps"]
    else:
        start_epoch = 0
        current_step = 0
        batch_size = config["batch_size"]
        accumulation_steps = 1

    # Initialize step index
    step_index = current_step

    # loop through training_steps during training to increase batch size
    while step_index < len(training_steps):
        step = training_steps[step_index]

        batch_size = step['batch_size']
        total_epochs = step['epochs']
        print_freq = step['print_freq']
        accumulation_steps = step['accumulation_steps']

        # Calculate the remaining epochs for the current step
        remaining_epochs = total_epochs - (start_epoch % total_epochs)

        # use the sampler for weighted sampling to address class imbalances
        data_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size,
                                                  sampler=config["train_sampler"],
                                                  collate_fn=utils.collate_fn,
                                                  num_workers=0, pin_memory=True)

        print(f'Training step {step["step"]}... batch size: {batch_size*accumulation_steps}')

        for epoch in range(start_epoch, start_epoch + remaining_epochs):
            train_metric_logger, val_metric_logger = train_one_epoch(model, optimizer, data_loader, device,
                                                                     epoch, print_freq, accumulation_steps,
                                                                     data_loader_val)

            # evaluate on the val dataset
            train_coco_evaluator, val_coco_evaluator = evaluate(model, data_loader_val, val_coco_ds, device, data_loader, train_coco_ds)

            # update the learning rate
            lr_scheduler.step()

            checkpoint_data = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "current_step": step["step"],
                "batch_size": batch_size,
                "accumulation_steps": accumulation_steps,
            }

            with tempfile.TemporaryDirectory() as checkpoint_dir:
                data_path = Path(checkpoint_dir) / "data.pkl"
                with open(data_path, "wb") as fp:
                    pickle.dump(checkpoint_data, fp)
                train.report(
                    {"epoch": epoch,
                     "lr": suggested_lr,
                     "train_loss": train_metric_logger.loss.avg,  # metric_logger object
                     "val_loss": val_metric_logger.loss.avg,
                     "train_mAP": train_coco_evaluator.coco_eval['bbox'].stats[0],  # mAP (IoU=0.50:0.95)
                     "val_mAP": val_coco_evaluator.coco_eval['bbox'].stats[0],
                     "train_mAR": train_coco_evaluator.coco_eval['bbox'].stats[8],  # mAR (IoU=0.50:0.95)
                     "val_mAR": val_coco_evaluator.coco_eval['bbox'].stats[8],
                     "train_f1": calculate_f1_score(train_coco_evaluator.coco_eval['bbox'].stats[0],  # AP (IoU=0.50:0.95)
                                                    train_coco_evaluator.coco_eval['bbox'].stats[8]  # AR (IoU=0.50:0.95)
                                                    ),
                     "val_f1": calculate_f1_score(val_coco_evaluator.coco_eval['bbox'].stats[0],
                                                  val_coco_evaluator.coco_eval['bbox'].stats[8]
                                                  )},
                     checkpoint=train.Checkpoint.from_directory(checkpoint_dir),
                )

        # set start_epoch to the next epoch for the next training step
        start_epoch += remaining_epochs
        step_index += 1

    print('Tuning Trial Complete!')


# test set accuracy of best model
def test_best_model(best_trial, best_checkpoint):
    best_model =  get_retinanet_model(depth=best_trial.config["resnet"],
                                      num_classes=len(classes), 
                                      alpha = best_trial.config["alpha"], 
                                      gamma_loss=best_trial.config["gamma_loss"],
                                      trainable_backbone_layers = int(best_trial.config["backbone_lyrs"]),
                                      dropout_prob = 0.05)
                                      
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        
    best_model.to(device)

    with best_checkpoint.as_directory() as checkpoint_dir:
        data_path = Path(checkpoint_dir) / "data.pkl"
        with open(data_path, "rb") as fp:
            best_checkpoint_data = pickle.load(fp)

        best_model.load_state_dict(best_checkpoint_data["model_state_dict"])

    data_loader_test = ray.get(best_trial.config["data_loader_test_ref"])
    test_coco_ds = ray.get(best_trial.config["test_coco_ds_ref"]) 
    
    test_results = evaluate(best_model, data_loader_test, test_coco_ds, device, train_data_loader=None, train_coco_ds=None)

    print(f'Best trial test set mAP: {test_results.coco_eval["bbox"].stats[0]}') # IoU=0.50:0.95
    print(f'Best trial test set mAR: {test_results.coco_eval["bbox"].stats[8]}') # IoU=0.50:0.95
    print(f'Best trial test set f1-score: {calculate_f1_score(test_results.coco_eval["bbox"].stats[0], test_results.coco_eval["bbox"].stats[8])}') # IoU=0.50:0.95


def trial_dirname_creator(trial):
    return f"{trial.trial_id}"


def create_coco_datasets(train_dataset, val_dataset, test_dataset):
    """
    Create COCO dataset objects from torch.utils.data.Dataset using get_coco_api_from_dataset.
    This function creates the COCO dataset objects in parallel.
    
    :param train_dataset: torch.utils.data.Dataset
    :param val_dataset: torch.utils.data.Dataset
    :param test_dataset: torch.utils.data.Dataset
    :return: train_coco_ds, val_coco_ds, test_coco_ds
    """
    with ThreadPoolExecutor() as executor:
        train_future = executor.submit(get_coco_api_from_dataset, train_dataset)
        val_future = executor.submit(get_coco_api_from_dataset, val_dataset)
        test_future = executor.submit(get_coco_api_from_dataset, test_dataset)

        train_coco_ds = train_future.result()
        val_coco_ds = val_future.result()
        test_coco_ds = test_future.result()

    return train_coco_ds, val_coco_ds, test_coco_ds

##### Main Tuning Program

In [None]:
def main(num_samples, max_num_epochs, restore_path=""):
    ray.shutdown()
    ray.init()
    os.environ["RAY_record_ref_creation_sites"] = "1"
    print(ray._private.utils.get_ray_temp_dir())

    # Prepare datasets and other configurations
    dataset = MAVdroneDataset(
        csv_file='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
        root_dir='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/',
        transforms=get_transform(train=True)
    )

    dataset_val = MAVdroneDataset(
        csv_file='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
        root_dir='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/',
        transforms=get_transform(train=False)
    )

    dataset_test = MAVdroneDataset(
        csv_file='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
        root_dir='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/',
        transforms=get_transform(train=False)
    )

    # Subset using a 80/15/5 split for train, validation, and test datasets
    dataset_train = torch.utils.data.Subset(dataset, train_indices)
    dataset_val = torch.utils.data.Subset(dataset_val, val_indices)
    dataset_test = torch.utils.data.Subset(dataset_test, test_indices)

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=1, shuffle=False,
        collate_fn=utils.collate_fn, num_workers=0, pin_memory=True
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False,
        collate_fn=utils.collate_fn, num_workers=0, pin_memory=True
    )

    # Create COCO dataset objects for train, val, and test datasets
    train_coco_ds, val_coco_ds, test_coco_ds = create_coco_datasets(dataset_train, dataset_val, dataset_test)

    # Re-create ObjectRefs
    dataset_train_ref = ray.put(dataset_train)
    data_loader_val_ref = ray.put(data_loader_val)
    data_loader_test_ref = ray.put(data_loader_test)
    train_coco_ds_ref = ray.put(train_coco_ds)
    val_coco_ds_ref = ray.put(val_coco_ds)
    test_coco_ds_ref = ray.put(test_coco_ds)

    config = {
        "resnet": tune.choice([18, 34, 50]),
        "momentum": tune.uniform(0.4, 0.9),
        "weight_decay": tune.loguniform(0.00001, 0.01),
        "step_size": tune.choice([5, 10, 15]),
        "gamma_lr": tune.uniform(0.1, 0.5),
        "alpha": tune.uniform(0.2, 0.8),
        "gamma_loss": tune.uniform(1.5, 3.5),
        "backbone_lyrs": tune.choice([2, 3, 4]),
        "batch_size": 4, # constant handle OOM errors when tuning trials concurrently
        "dataset_train_ref": dataset_train_ref,
        "data_loader_val_ref": data_loader_val_ref,
        "data_loader_test_ref": data_loader_test_ref,
        "train_coco_ds_ref": train_coco_ds_ref,
        "val_coco_ds_ref": val_coco_ds_ref,
        "test_coco_ds_ref": test_coco_ds_ref,
        "train_sampler": train_sampler
    }

    if tune.Tuner.can_restore(os.path.abspath(restore_path)):
        tuner = tune.Tuner.restore(
            os.path.abspath(restore_path),
            trainable=train_MAVdroneDataset,
            param_space=config,  # pass same config with new ObjectRefs
            resume_unfinished=True,
            resume_errored=False
        )
        print(f"Tuner Restored from {restore_path}")
    else:
        algo = TuneBOHB(
            points_to_evaluate=[
                {"resnet": 34,
                 "momentum": 0.9,
                 "weight_decay": 0.0005,
                 "step_size": 5,
                 "gamma_lr": 0.1,
                 "alpha": 0.7,
                 "gamma_loss": 3.0,
                 "backbone_lyrs": 4}
            ],  # starting point for search
            seed=51  # set for identical initial configurations
        )

        algo = ConcurrencyLimiter(algo, max_concurrent=2)

        scheduler = HyperBandForBOHB(
            time_attr="training_iteration",
            max_t=int(max_num_epochs),
            reduction_factor=4,
            stop_last_trials=False,
        )

        reporter = JupyterNotebookReporter(overwrite=True,
            metric_columns=["epoch", "lr", "train_loss", "val_loss", "train_mAP", "val_mAP", "train_mAR", "val_mAR", "train_f1", "val_f1"],
            parameter_columns=["resnet", "momentum", "weight_decay", "step_size", "gamma_lr", "batch_size", "alpha", "gamma_loss", "backbone_lyrs"],
            sort_by_metric=True
        )

        # Dictionary to store train_f1 scores for each trial
        val_f1_history = defaultdict(list)

        def custom_stop(trial_id, result):
            # Ensure the required keys are in the result dictionary
            required_keys = ["training_iteration", "val_f1"]
            if all(key in result for key in required_keys):
                # Append the current val_f1 score to the trial's history
                val_f1_history[trial_id].append(result["val_f1"])
                
                # Check if there are at least 5 epochs recorded
                if len(val_f1_history[trial_id]) >= 5:
                    # Calculate the improvement over the last 5 epochs
                    initial_f1 = val_f1_history[trial_id][-5]
                    current_f1 = val_f1_history[trial_id][-1]
                    
                    # Check if initial_f1 is zero to avoid division by zero
                    if initial_f1 == 0:
                        if current_f1 == 0:
                            return True  # No improvement if both initial and current f1 are zero (Stop)
                        improvement = float('inf')  # Set improvement to infinity if initial_f1 is zero but current_f1 is not to avoid zero division
                    else:
                        improvement = (current_f1 - initial_f1) / initial_f1
                    
                    # Check if the improvement is less than 0.5%
                    if improvement < 0.005:
                        return True # (Stop)
            return False

        tuner = tune.Tuner(
            tune.with_resources(
                train_MAVdroneDataset,
                resources={"cpu": 18.0, "gpu": float(1/2)} # each trial uses 12 CPUs and 0.5 GPU
            ),
            run_config=train.RunConfig(
                name=f"BOHB_RetinaNet_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
                failure_config=train.FailureConfig(max_failures=2),
                stop=custom_stop,
                progress_reporter=reporter,
            ),
            tune_config=tune.TuneConfig(
                metric="val_f1",
                mode="max",
                search_alg=algo,
                scheduler=scheduler,
                num_samples=int(num_samples),
                trial_dirname_creator=trial_dirname_creator
            ),
            param_space=config
        )
    results = tuner.fit()

    best_trial = results.get_best_result("val_f1", "max")

    print("Best trial config: {}".format(best_trial.config))
    print()
    print("Best trial final training loss: {}".format(best_trial.metrics["train_loss"]))
    print("Best trial final validation loss: {}".format(best_trial.metrics["val_loss"]))
    print("Best trial final training mAP: {}".format(best_trial.metrics["train_mAP"]))
    print("Best trial final validation mAP: {}".format(best_trial.metrics["val_mAP"]))
    print("Best trial final training mAR: {}".format(best_trial.metrics["train_mAR"]))
    print("Best trial final validation mAR: {}".format(best_trial.metrics["val_mAR"]))
    print("Best trial final training f1-score: {}".format(best_trial.metrics["train_f1"]))
    print("Best trial final validation f1-score: {}".format(best_trial.metrics["val_f1"]))
    print()

    best_checkpoint = best_trial.get_best_checkpoint(metric="val_f1", mode="max")

    test_best_model(best_trial, best_checkpoint)

    return train_coco_ds, val_coco_ds, test_coco_ds, results, best_trial

if __name__ == "__main__":
    torch.cuda.empty_cache()
    gc.collect()

    train_coco_ds, val_coco_ds, test_coco_ds, results, best_trial = main(num_samples=100,
                                                                         max_num_epochs=50,
                                                                         restore_path="C:/Users/exx/ray_results/FALSE") # set restore_path to the path of the experiment to restore

### <center>**Train Model Using Tuned Hyperparameters**</center> 

In [None]:
from torch.utils.tensorboard import SummaryWriter
import torch.profiler

def main(train_coco_ds, val_coco_ds, best_trial):
    # Set seed
    set_seed(51)

    batch_size = best_trial.config["batch_size"]

    training_steps = [
        {"step": 0, "batch_size": batch_size, "epochs": 20, "print_freq": 25, "accumulation_steps": 4},
        {"step": 1, "batch_size": batch_size, "epochs": 15, "print_freq": 25, "accumulation_steps": 8}, 
        {"step": 2, "batch_size": batch_size, "epochs": 10, "print_freq": 25, "accumulation_steps": 16}, 
        {"step": 3, "batch_size": batch_size, "epochs": 5, "print_freq": 25, "accumulation_steps": 32} 
    ]

    # load model
    model = get_retinanet_model(depth = best_trial.config["resnet"],
                                num_classes=len(classes), 
                                alpha = best_trial.config["alpha"], 
                                gamma_loss=best_trial.config["gamma_loss"],
                                trainable_backbone_layers = int(best_trial.config["backbone_lyrs"]),
                                dropout_prob = 0.05)

    device = "cpu" 
    if torch.cuda.is_available():
        device = "cuda:0"
    model.to(device)

    # construct an optimizer - SGD w/ momentum and weight decay
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=best_trial.config["lr"],
                                momentum=best_trial.config["momentum"], 
                                weight_decay=best_trial.config["weight_decay"])
    
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=best_trial.config["step_size"],
                                                    gamma=best_trial.config["gamma_lr"])

    # initialize tensorboard writer in folder named f"{current_datetime}" and using name "RetinaNet"
    current_datetime = datetime.now().strftime("%Y%m%d-%H%M%S")
    writer = SummaryWriter(log_dir=f'C:/Users/exx/Documents/GitHub/SSD_VGG_PyTorch/runs/RetinaNet/{current_datetime}')

    # Store one checkpoint dictionary for each epoch in a list of dictionaries. 
    checkpoints = []

    # Initialize the profiler (https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
    profiler = torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(
            wait=1,
            warmup=1,
            active=3,
            repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(writer.log_dir),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    )

     # load training and val datasets
    dataset = MAVdroneDataset(csv_file='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
                                root_dir='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/', 
                                transforms=get_transform(train=True))

    dataset_val = MAVdroneDataset(csv_file='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
                                      root_dir='C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/',
                                      transforms=get_transform(train=False))
    
    # subset using a 80/15/5 split for train, validation, and test datasets
    dataset = torch.utils.data.Subset(dataset, train_indices)
    dataset_val = torch.utils.data.Subset(dataset_val, val_indices)

    data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False,
                                                      collate_fn=utils.collate_fn, num_workers=0,
                                                      pin_memory=True)
    
    start_epoch = 0

    # loop through training_steps during training to increase batch size and decrease learning rate
    for step in training_steps:
        batch_size = step['batch_size']
        num_epochs = step['epochs']
        print_freq = step['print_freq']
        accumulation_steps = step['accumulation_steps']

        # define training and validation data loaders
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                                  sampler=best_trial.config["train_sampler"], 
                                                  collate_fn=utils.collate_fn, num_workers=0,
                                                  pin_memory=True)
        
        print(f'Beginning training step {step["step"]}... batch size: {batch_size*accumulation_steps}')

        #########################################################
        ##               The main training loop                ##
        #########################################################
        with profiler:
            for epoch in range(start_epoch, num_epochs + start_epoch):
                # Monitor memory usage at the start of the epoch
                print(f"Epoch {epoch} - Memory allocated: {torch.cuda.memory_allocated(device)} bytes")

                train_metric_logger, val_metric_logger = train_one_epoch(model, optimizer, data_loader, device, 
                                                                         epoch, print_freq, accumulation_steps,
                                                                         data_loader_val)

                # evaluate on the validation dataset
                train_coco_evaluator, val_coco_evaluator = evaluate(model, data_loader_val, val_coco_ds, device,
                                                                    data_loader, train_coco_ds)
                
                # update the learning rate
                lr_scheduler.step()

                # store training and validation metrics in checkpoint dictionary. 
                checkpoint = {
                    "epoch": epoch,
                    "train_loss": train_metric_logger.loss.avg, # average across entire trianing epoch
                    "train_bbox_loss": train_metric_logger.bbox_regression.avg,
                    "train_class_loss": train_metric_logger.classification.avg,
                    "val_loss": val_metric_logger.loss.avg,
                    "val_bbox_loss": val_metric_logger.bbox_regression.avg,
                    "val_class_loss": val_metric_logger.classification.avg,
                    "train_mAP_50": train_coco_evaluator.coco_eval['bbox'].stats[1],
                    "train_mAR_100": train_coco_evaluator.coco_eval['bbox'].stats[8],
                    "val_mAP_50": val_coco_evaluator.coco_eval['bbox'].stats[1],
                    "val_mAR_100": val_coco_evaluator.coco_eval['bbox'].stats[8],
                    "train_f1": calculate_f1_score(train_coco_evaluator.coco_eval['bbox'].stats[0], # IoU=0.50:0.95
                                                    train_coco_evaluator.coco_eval['bbox'].stats[8] # IoU=0.50:0.95
                                                    ),
                    "val_f1": calculate_f1_score(val_coco_evaluator.coco_eval['bbox'].stats[0],
                                                    val_coco_evaluator.coco_eval['bbox'].stats[8]
                                                    ),
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict()
                }

                # append checkpoint to checkpoints list
                checkpoints.append(checkpoint)

                # report training and validation scalars to tensorboard
                writer.add_scalar('Loss/Train', np.array(float(checkpoint["train_loss"])), epoch) # use tags to group scalars
                writer.add_scalar('Loss/Val', np.array(float(checkpoint["val_loss"])), epoch)
                writer.add_scalar('mAP@50/Train', np.array(float(checkpoint["train_mAP_50"])), epoch)
                writer.add_scalar('mAP@50/Val', np.array(float(checkpoint["val_mAP_50"])), epoch)
                writer.add_scalar('mAR@100/Train', np.array(float(checkpoint["train_mAR_100"])), epoch)
                writer.add_scalar('mAR@100/Val', np.array(float(checkpoint["val_mAR_100"])), epoch)
                writer.add_scalar('F1/Train', np.array(float(checkpoint["train_f1"])), epoch)
                writer.add_scalar('F1/Val', np.array(float(checkpoint["val_f1"])), epoch)

                # Clear CUDA cache and collect garbage to check for memory leaks
                torch.cuda.empty_cache()
                gc.collect()

                # Monitor memory usage at the end of the epoch
                print(f"Epoch {epoch} - Max memory allocated: {torch.cuda.max_memory_allocated(device)} bytes")

            # set start_epoch to current epoch for next training step
            start_epoch += num_epochs

    print('All Training Steps Complete!')

    # close tensorboard writer
    writer.close()

    return checkpoints

if __name__ == "__main__":
    torch.cuda.empty_cache()
    gc.collect()

    # Set environment variable to avoid memory fragmentation
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
    checkpoints = main(train_coco_ds, val_coco_ds, best_trial)

In [None]:
# Best train epoch is dictionary in checkpoints with highest val_mAP_50 value
best_train_epoch = max(checkpoints, key = lambda x: x['val_mAP_50'])

model = get_retinanet_model(depth = best_trial.config["resnet"],
                            num_classes=len(classes), 
                            alpha = best_trial.config["alpha"], 
                            gamma_loss = best_trial.config["gamma_loss"],
                            trainable_backbone_layers = int(best_trial.config["backbone_lyrs"]),
                            dropout_prob = 0.05)

# load model weights from best_train_epoch
model.load_state_dict(best_train_epoch["model_state_dict"])

# save model weights to .pth file
torch.save(model.state_dict(), 'RetinaNet_ResNet50_FPN_DuckNet_' + str(datetime.now().strftime("%m%d%Y")) + '.pth')

In [None]:
# copy checkpoints and remove model and optimizer state dicts
checkpoints_copy = checkpoints.copy()
for c in checkpoints_copy:
    del c["model_state_dict"]
    del c["optimizer_state_dict"]

# save checkpoints list to text file
with open('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/tuned_model_checkpoints.txt', 'w') as f:
    for item in checkpoints_copy:
        f.write("%s\n" % item)

### <center>**Model Inference on Test Dataset**</center> 

##### Load the test dataset

In [None]:
# create dictionary of test indices and image names
test_dict = dict(zip(test_indices, test_images))

# save test_dict to text file just to be safe
with open('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/test_dict.txt', 'w') as f:
    for key, value in test_dict.items():
        f.write('%s:%s\n' % (key, value))

In [None]:
dataset_test = MAVdroneDataset(csv_file = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
                                root_dir = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/', 
                                transforms = get_transform(train = False))

# subset test dataset using test_indices
dataset_test = torch.utils.data.Subset(dataset_test, test_indices)

data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size = 1, shuffle = False,
                                               collate_fn = utils.collate_fn, num_workers = 0,
                                               pin_memory = True)

In [None]:
test_performance = evaluate(model, data_loader_test, test_coco_ds, device=torch.device('cpu'), train_data_loader=None, train_coco_ds=None)
print(f'Best trial test set mAP_50: {test_performance.coco_eval["bbox"].stats[1]}') 
print(f'Best trial test set mAR_100: {test_performance.coco_eval["bbox"].stats[8]}')
print(f'Best trial test set f1 score: {calculate_f1_score(test_performance.coco_eval["bbox"].stats[0], test_performance.coco_eval["bbox"].stats[8])}')

##### Calculate performance metrics on every image in test dataset

In [None]:
from torchmetrics.detection.mean_ap import MeanAveragePrecision

In [None]:
results = []

metric = MeanAveragePrecision(iou_type="bbox",
                              class_metrics=True,
                              max_detection_thresholds=[1, 10, 100]
                              )

model.to('cpu')
model.eval()

for images, targets in data_loader_test:
    # use image_id to get image_name from image_names list
    image_id = [target['image_id'].item() for target in targets]

    # convert boxes in targets to tensors
    targets = [{k: torch.tensor(v) if k == 'boxes' else v for k, v in t.items()} for t in targets]

    # filter targets to only include boxes and labels keys
    ground_truth = [{k: v for k, v in t.items() if k in ('boxes', 'labels')} for t in targets]

    with torch.no_grad():
        prediction = model(images, targets)

    # calculate mAP and mAR from test dataset
    metric.update(prediction, ground_truth)
    mean_AP = metric.compute()

    # append image name to mean_AP
    mean_AP['image_name'] = test_dict[image_id[0]]

    # Append mean_AP and predictions to results list. 
    results.append(mean_AP)

##### Store per-image test dataset metrics as dataframe

In [None]:
# use pandas to create a dataframe of image names and mAP values
img_results_df = pd.DataFrame()
img_results_df['image_name'] = [result['image_name'] for result in results]
img_results_df['mAP'] = [result['map'].item() for result in results]
img_results_df['mAP_50'] = [result['map_50'].item() for result in results]
img_results_df['mAP_75'] = [result['map_75'].item() for result in results]
img_results_df['mAP_small'] = [result['map_small'].item() for result in results]
img_results_df['mAP_medium'] = [result['map_medium'].item() for result in results]
img_results_df['mAP_large'] = [result['map_large'].item() for result in results]
img_results_df['mAR_1'] = [result['mar_1'].item() for result in results]
img_results_df['mAR_10'] = [result['mar_10'].item() for result in results]
img_results_df['mAR_100'] = [result['mar_100'].item() for result in results]
img_results_df['mAR_small'] = [result['mar_small'].item() for result in results]
img_results_df['mAR_medium'] = [result['mar_medium'].item() for result in results]
img_results_df['mAR_large'] = [result['mar_large'].item() for result in results]

# # if value is == -1.0, replace with NaN
img_results_df = img_results_df.replace(-1.0, np.nan)

In [None]:
# Metric values are running averages in torch metrics, so the last value is the final value.
final_metrics = img_results_df.iloc[-1]
final_metrics = final_metrics.drop('image_name')

##### Print per-image metrics for test dataset as table

In [None]:
from prettytable import PrettyTable
# create a pretty table object
x = PrettyTable()

cols = ['Metric', 'Value']  

# add column headers
x.field_names = cols

# values for column one in table are column names from final_metrics, column two are the column values. 
for i in range(len(final_metrics)):
    x.add_row([final_metrics.index[i], f'{final_metrics[i]*100:.2f}%'])

# print table
print(x)

# save table as txt file
with open('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/testDataset_image_summary_table.txt', 'w') as f:
    print(x, file = f)

# save results_df to csv
img_results_df.to_csv('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/per_image_results_test_dataset.csv', index = False)

##### Store per-class test dataset metrics as dataframe

In [None]:
class_res_df = pd.DataFrame()

# store 'map_per_class' and 'mar_100_per_class' from results in df
class_res_df['image_name'] = [result['image_name'] for result in results]
class_res_df['classes'] = [result['classes'] for result in results]
class_res_df['map_per_class'] = [result['map_per_class'] for result in results]
class_res_df['mar_100_per_class'] = [result['mar_100_per_class'] for result in results]

# convert tensors to numpy arrays
class_res_df['classes'] = class_res_df['classes'].apply(lambda x: x.numpy())
class_res_df['map_per_class'] = class_res_df['map_per_class'].apply(lambda x: x.numpy())
class_res_df['mar_100_per_class'] = class_res_df['mar_100_per_class'].apply(lambda x: x.numpy())

# replace integer labels in classes column with labels using label_dict
class_res_df['classes'] = class_res_df['classes'].apply(lambda x: [label_dict.get(i) for i in x])

# replace -1.0 values in map_per_class and mar_100_per_class with NaN
class_res_df['map_per_class'] = class_res_df['map_per_class'].apply(lambda x: np.where(x == -1.0, np.nan, x))
class_res_df['mar_100_per_class'] = class_res_df['mar_100_per_class'].apply(lambda x: np.where(x == -1.0, np.nan, x))

# if map_per_class or mar_100_per_class is NaN, delete value from list. Also delete corresponding class label.
class_res_df['classes'] = class_res_df.apply(lambda x: [i for i, j in zip(x['classes'], x['map_per_class']) if not np.isnan(j)], axis = 1)
class_res_df['map_per_class'] = class_res_df['map_per_class'].apply(lambda x: [i for i in x if not np.isnan(i)])
class_res_df['mar_100_per_class'] = class_res_df['mar_100_per_class'].apply(lambda x: [i for i in x if not np.isnan(i)])

In [None]:
# metric values are running averages in TorchMetrics. Store map and mar from last image in dataset
classes = class_res_df['classes'].iloc[-1]
class_map = class_res_df['map_per_class'].iloc[-1]
class_mar_100 = class_res_df['mar_100_per_class'].iloc[-1]

##### Print per-class metrics for every image in test dataset as table

In [None]:
# cols = 'value' and all unique classes
cols = ['Class', 'mAP', 'mAR_100']

# create a pretty table object
x = PrettyTable()

# add column headers
x.field_names = cols

# classes go in first column, class_map in second column, and class_mar_100 in third column
for i in range(len(classes)):
    x.add_row([classes[i], f'{class_map[i]*100:.2f}%', f'{class_mar_100[i]*100:.2f}%'])

# print table
print(x)

# save table as txt file
with open('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/testDataset_class_summary_table.txt', 'w') as f:
    print(x, file = f)

# save results_df to csv
class_res_df.to_csv('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/per_class_results_test_dataset.csv', index = False)

##### Load test data into one batch

In [None]:
# load entire test dataset into one batch
data_loader_test_singleBatch = torch.utils.data.DataLoader(dataset_test, batch_size = len(dataset_test), shuffle = False,
                                                collate_fn = utils.collate_fn, num_workers = 0)

# run predictions on all images in the test dataset
images, targets = next(iter(data_loader_test_singleBatch))

images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]

# convert boxes in targets to tensors
targets = [{k: torch.tensor(v) if k == 'boxes' else v for k, v in t.items()} for t in targets]

model.to('cpu')

model.eval()

with torch.no_grad():
    predictions = model(images, targets) 

##### Post-process model predictions for plotting on original images

In [None]:
# for each image in the batch, remove all predicted boxes with scores below 0.5
for i in range(len(predictions)):
    predictions[i]['boxes'] = predictions[i]['boxes'][predictions[i]['scores'] > 0.5]
    predictions[i]['labels'] = predictions[i]['labels'][predictions[i]['scores'] > 0.5]
    predictions[i]['scores'] = predictions[i]['scores'][predictions[i]['scores'] > 0.5]

# resize boxes to original image shape
for i in range(len(images)):
    tran_w, tran_h = images[i].shape[1], images[i].shape[2]
    
    images[i] = Image.open('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/' + test_images[i])

    orig_w, orig_h = images[i].size

    predictions[i]['boxes'] = predictions[i]['boxes'] * torch.tensor([orig_w/tran_w, 
                                                                      orig_h/tran_h, 
                                                                      orig_w/tran_w,
                                                                      orig_h/tran_h]).view(1, 4)

### <center>**Plot Model Predictions for Images in Test Dataset**</center> 

In [None]:
def plot_bbox_predicted(ax, boxes, labels, scores): # modify plot_bbox to add confidence scores
    # add box to the image and use label_color_map to color-code by bounding box class if exists else 'black'
    ax.add_patch(plt.Rectangle((boxes[:, 0], boxes[:, 1]), boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1],
                    fill = False,
                    color = label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 
                    linewidth = 1.5))
    
    # add label and score to the bounding box. concatenate label and score to one string. 
    # use label_dict to replace class numbers with class names
    ax.text(boxes[:, 0], boxes[:, 1] - 100,
        s = f"{label_dict[labels.item()]} {scores.item():.2f}",
        color = 'black',
        fontsize = 6,
        verticalalignment = 'top',
        bbox = {'color': label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 'pad': 0})
    return ax


# function for plotting all predictions on images
def plot_predictions(image, boxes, labels, scores, ax = None):
    ax = img_show(image, ax = ax)
    for i in range(len(boxes)):
        box = get_box(boxes[i])
        plot_bbox_predicted(ax, box, labels[i], scores[i])

In [None]:
# Plot 32 samples from batch in a grid of subplots.
plt.figure(figsize = (24, 36))
for i in range(0, 32):
    ax = plt.subplot(8, 4, 1 + i)
    plot_predictions(images[i], predictions[i]['boxes'], predictions[i]['labels'], predictions[i]['scores'], ax = ax)
    plt.axis('off')
    plt.title(test_images[i])

plt.show()

##### Run inference on full dataset to get model estimates of abundance

In [None]:
dataset_all = MAVdroneDataset(csv_file = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/preprocessed_annotations.csv',
                                root_dir = 'C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/filtered_images/', 
                                transforms = get_transform(train = False))

data_loader_all = torch.utils.data.DataLoader(dataset_all, batch_size = 1, shuffle = False,
                                            collate_fn = utils.collate_fn, num_workers = 0,
                                            pin_memory = True)

# get model predictions for every image in data_loader_all
model_predictions_all = []

for images, targets in data_loader_all:
    # use image_id to get image_name from image_names list
    image_id = [target['image_id'].item() for target in targets]

    # convert boxes in targets to tensors
    targets = [{k: torch.tensor(v) if k == 'boxes' else v for k, v in t.items()} for t in targets]

    with torch.no_grad():
        prediction = model(images, targets)

    # append image name to mean_AP
    prediction['image_name'] = test_dict[image_id[0]]

    # Append mean_AP and predictions to results list. 
    model_predictions_all.append(prediction)

In [None]:
# convert model_predictions_all to a dataframe
model_predictions_df = pd.DataFrame(model_predictions_all)

# save csv for comparison with ground truth
model_predictions_df.to_csv('C:/Users/exx/Deep Learning/UAV_Waterfowl_Detection/RetinaNet/model_predictions_full_dataset.csv', index = False)