# Training the Extended Mask2Former UAV-SOD Drone Dataset

In [1]:
# Import libraries
import os, json
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from src.data_set_up import SOD_Data
from models.extended_mask2former_model import ExtendedMask2Former
from models.efpn_backbone.anchors import Anchors


# Import data paths
map_path = "src/code_map.json"
data_info_path = "src/data_info/uav_data_preprocessing.json"
base_dir = "data/uav_sod_data/"

  Referenced from: <2BD1B165-EC09-3F68-BCE4-8FE4E70CA7E2> /Users/stamatiosorphanos/Documents/MCs_Thesis/SOD_Thesis/master_thesis/lib/python3.11/site-packages/torchvision/image.so
  warn(


### Set up GPU growth

In [14]:
# Set device we are going to load the model and the data
device = "mps" if torch.backends.mps.is_available() else "cpu" # type: ignore

'mps'

### Set up basic static data

- Get the number of classes
- Get the mean and standard deviation 
- Create the data paths for the [train, test, validation]

In [3]:
# Load the classes of the UAV-SOD Drone dataset
map = open(map_path)
data = json.load(map)
classes = data["UAV_SOD_DRONE"]["CATEGORY_ID_TO_NAME"]
map.close() 

# The number of classes plus the background
number_classes = len(classes) + 1


# Load the mean and standard deviation for the train data
map = open(data_info_path)
data = json.load(map)
mean = data["uav_data"]["mean"]
standard_deviation = data["uav_data"]["std"]
map.close() 


# Define train, test and validation path
train_path = os.path.join(base_dir, "train")
test_path = os.path.join(base_dir, "test")
validation_path = os.path.join(base_dir, "validation")

### Dataset - Dataloader
- Collate function
- Data transformations
- DataLoader and Dataset

In [4]:
# Data transform function
data_transform = {
    "train": transforms.Compose([
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]),

    "test": transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]), 
            
    "validation": transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=standard_deviation)]) 
}


# Dataset and DataLoader
train_dataset      = SOD_Data(train_path +"/images", train_path + "/annotations", data_transform["train"])
test_dataset       = SOD_Data(test_path + "/images", test_path  + "/annotations", data_transform["test"])
validation_dataset = SOD_Data(validation_path + "/images", validation_path + "/annotations", data_transform["validation"])

train_loader      = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
test_loader       = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [5]:
def validate(model, val_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, targets in val_loader:
            images = torch.stack(images).to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            loss = model.compute_loss(outputs, targets)
            val_loss += loss.item()
    return val_loss / len(val_loader)

In [6]:
# Assuming feature_map_shapes, scales, and aspect_ratios are defined
feature_map_shapes = [(38, 38)]
scales = [32]
aspect_ratios = [0.5, 1, 2]

anchors = torch.tensor(Anchors.generate_anchors(feature_map_shapes, scales, aspect_ratios), dtype=torch.float32).to(device)

In [7]:
model = ExtendedMask2Former(num_classes=number_classes).to(device)

# Hyperparameters
num_epochs = 1
learning_rate = 0.001
batch_size = 2


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Loaded pretrained weights for efficientnet-b7


In [8]:
def train(model, train_loader, optimizer, device, anchors):
    model.train()
    running_loss = 0.0
    for images, targets in train_loader:
        images = torch.stack(images).to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        optimizer.zero_grad()
        outputs = model(images)
        loss = model.compute_loss(outputs, targets, anchors)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    return epoch_loss



for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, device, anchors)
    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}')
    
    # validation_loss = validate(model, validation_loader, device, anchors)
    # print(f'Validation Loss: {validation_loss:.4f}')

    scheduler.step()

# # Test the model
# test_loss = test(model, test_loader, device, anchors)
# print(f'Test Loss: {test_loss:.4f}')

RuntimeError: result type Float can't be cast to the desired output type Byte

In [None]:
# def validate(model, val_loader, device, anchors):
#     model.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for images, targets in val_loader:
#             images = torch.stack(images).to(device)
#             targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#             outputs, _ = model(images)
#             loss = model.compute_loss(outputs, targets, anchors)
#             val_loss += loss.item()
#     return val_loss / len(val_loader)

# def test(model, test_loader, device, anchors):
#     model.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for images, targets in test_loader:
#             images = torch.stack(images).to(device)
#             targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#             outputs, _ = model(images)
#             loss = model.compute_loss(outputs, targets, anchors)
#             test_loss += loss.item()
#     return test_loss / len(test_loader)
