# Training the Extended Mask2Former UAV-SOD Drone Dataset

In [1]:
# Import libraries
import pandas as pd
import os, json, statistics, time
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from src.data_set_up import SOD_Data
from models.extended_mask2former_model import ExtendedMask2Former
from models.efpn_backbone.anchors import Anchors
from src.helpers import train, evaluate_model


# Import data paths
map_path = "src/code_map.json"
data_info_path = "src/data_info/uav_data_preprocessing.json"
base_dir = "data/uav_sod_data/"

### Set up GPU growth

In [2]:
# Set device we are going to load the model and the data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Set up basic static data

- Get the number of classes
- Get the mean and standard deviation 
- Create the data paths for the [train, test, validation]

In [3]:
# Load the classes of the UAV-SOD Drone dataset
map = open(map_path)
data = json.load(map)
classes = data["UAV_SOD_DRONE"]["CATEGORY_ID_TO_NAME"]
map.close() 


# The number of classes plus the background
number_classes = len(classes) + 1


# Load the mean and standard deviation for the train data
map = open(data_info_path)
data = json.load(map)
mean = data["uav_data"]["mean"]
standard_deviation = data["uav_data"]["std"]
map.close() 


# Define train, test and validation path
train_path = os.path.join(base_dir, "train")
test_path = os.path.join(base_dir, "test")
validation_path = os.path.join(base_dir, "validation")


# Define train, test and validation path
train_path = os.path.join(base_dir, 'train')
test_path = os.path.join(base_dir, 'test')
validation_path = os.path.join(base_dir, 'validation')

## Dataset - Dataloader
- Collate function
- Data transformations
- DataLoader and Dataset

In [4]:
# Data transform function
data_transform = {
    "train": transforms.Compose([
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]),

    "test": transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]), 
            
    "validation": transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]) 
}


# Dataset and DataLoader
train_dataset      = SOD_Data(train_path +'/images', train_path + '/annotations', data_transform['train'])
test_dataset       = SOD_Data(test_path + '/images', test_path  + '/annotations', data_transform['test'])
validation_dataset = SOD_Data(validation_path + '/images', validation_path + '/annotations', data_transform['validation'])

train_loader      = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
test_loader       = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
validation_loader = DataLoader(validation_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

## Bounding Box Heuristics

In order to create accurate anchors we get the dataset's bounding box statistics, like mean and standard deviation in order to create representative anchors to help the model find the bounding boxes faster.


In [None]:
# Return a dictionary of the main statistics
bbox_stats = train_dataset.analyze_bounding_boxes()

# Get mean for width and height
mean_width = bbox_stats['mean_width']
mean_height = bbox_stats['mean_height']

# Get standard deviation for width and height
std_width = bbox_stats['std_width']
std_height = bbox_stats['std_height']

# Print statistics
print("Aspect Ratios:", sorted(set(bbox_stats['aspect_ratios'])))
print("Mean Width:", bbox_stats['mean_width'])
print("Mean Height:", bbox_stats['mean_height'])
print("Width Std Dev:", bbox_stats['std_width'])
print("Height Std Dev:", bbox_stats['std_height'])

## Generate Anchors

In [None]:
# Based on the statistics above decide on the values of the statistics 
feature_map_shapes = [(19, 19)]

# Get all the scales
scales = [32]

# Define the aspect ratios
aspect_ratios = [0.5, 1.0]
anchors = torch.tensor(Anchors.generate_anchors(feature_map_shapes, scales, aspect_ratios), dtype=torch.float32)

print("The number of anchors is: {}".format(anchors.size(0)))

## Implement the ExtendedMask2Former model with all the parameters needed

In [None]:
# Initialise the ExtendedMask2Former model and load it to device
model = ExtendedMask2Former(num_classes=number_classes, num_anchors=anchors.size(0), device=device).to(device)
anchors = anchors.to(device)

# Hyperparameters selection
num_epochs = 75
learning_rate = 0.001

# Define the optimizer and the scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
dataset = "uav"

### Train Loop

In [None]:
# Train the model and save all the necessary info
train(model, train_loader, device, anchors, optimizer, num_epochs, dataset)

## Load Trained model

In [None]:
# Provide some time to save the model from the training
time.sleep(10)

# Load the trained model we saved before
trained_model = torch.load("model_uav_75.pt")

### Test Loop

In [None]:
# Evaluate the model and save all the necessary info
evaluate_model(trained_model, test_loader, device, anchors)