# Training the Extended Mask2Former UAV-SOD Drone Dataset

In [4]:
# Import libraries
import os, json
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from src.data_set_up import SOD_Data
from models.extended_mask2former_model import ExtendedMask2Former


# Import data paths
map_path = "src/code_map.json"
data_info_path = "src/data_info/uav_data_preprocessing.json"
base_dir = "data/uav_sod_data/"

### Set up GPU growth

In [5]:
# Set device we are going to load the model and the data
device = "mps" if torch.backends.mps.is_available() else "cpu"

### Set up basic static data

In [7]:
# Load the classes of the UAV-SOD Drone dataset
map = open(map_path)
data = json.load(map)
classes = data["UAV_SOD_DRONE"]["CATEGORY_ID_TO_NAME"]
map.close() 

number_classes = len(classes)


# Load the mean and standard deviation for the train data
map = open(data_info_path)
data = json.load(map)
mean = data["uav_data"]["mean"]
standard_deviation = data["uav_data"]["std"]
map.close() 


# Define train, test and validation path
train_path = os.path.join(base_dir, "train")
test_path = os.path.join(base_dir, "test")
validation_path = os.path.join(base_dir, "validation")

[0.47168887769178647, 0.5133378785061212, 0.5267694282792925]
[0.12757125130655708, 0.09660376509464505, 0.07690925977558358]


In [None]:

# Function that creates batches of [(image_1, annotations_1), ..., (image_n, annotations_n)]
def collate_fn(batch):
    return tuple(zip(*batch))

# Transformations for each one of the type of data
data_transform = {
    "train": transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]),

    "test": transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]), 
    
    "validation": transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]) 
}


# Dataset and DataLoader
train_dataset      = SOD_Data(os.path.join(train_path, "/images"), os.path.join(train_path, "/annotations"), transform=data_transform["train"])
test_dataset       = SOD_Data(os.path.join(test_path, "/images"),  os.path.join(test_path, "/annotations"),  transform=data_transform["test"])
validation_dataset = SOD_Data(os.path.join(validation_path, "/images"),  os.path.join(validation_path, "/annotations"), transform=data_transform["validation"])

train_loader      = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_loader       = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [None]:

# Model
model = ExtendedMask2Former(number_classes)

# Optimizer and Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Training Loop
num_epochs = 10
model.train()

for epoch in range(num_epochs):
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        outputs = model(images)
        loss = model.compute_loss(outputs, targets)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")
    scheduler.step()