# Training the Extended Mask2Former UAV-SOD Drone Dataset

In [1]:
# Import libraries
import os, json
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from src.data_set_up import SOD_Data
from models.extended_mask2former_model import ExtendedMask2Former


# Import data paths
map_path = "src/code_map.json"
data_info_path = "src/data_info/uav_data_preprocessing.json"
base_dir = "data/uav_sod_data/"

### Set up GPU growth

In [3]:
# Set device we are going to load the model and the data
device = "mps" if torch.backends.mps.is_available() else "cpu"

### Set up basic static data

- Get the number of classes
- Get the mean and standard deviation 
- Create the data paths for the [train, test, validation]

In [4]:
# Load the classes of the UAV-SOD Drone dataset
map = open(map_path)
data = json.load(map)
classes = data["UAV_SOD_DRONE"]["CATEGORY_ID_TO_NAME"]
map.close() 

number_classes = len(classes)


# Load the mean and standard deviation for the train data
map = open(data_info_path)
data = json.load(map)
mean = data["uav_data"]["mean"]
standard_deviation = data["uav_data"]["std"]
map.close() 


# Define train, test and validation path
train_path = os.path.join(base_dir, "train")
test_path = os.path.join(base_dir, "test")
validation_path = os.path.join(base_dir, "validation")

### Dataset - Dataloader
- Collate function
- Data transformations
- DataLoader and Dataset

In [5]:
# Transformations for each one of the type of data
data_transform = {
    "train": transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]),

    "test": transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]), 
    
    "validation": transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=standard_deviation)]) 
}

# Dataset and DataLoader
train_dataset      = SOD_Data(train_path +"/images", train_path + "/annotations", transform=data_transform["train"])
test_dataset       = SOD_Data(test_path + "/images", test_path  + "/annotations",  transform=data_transform["test"])
validation_dataset = SOD_Data(validation_path + "/images", validation_path + "/annotations", transform=data_transform["validation"])


train_loader      = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
test_loader       = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
validation_loader = DataLoader(validation_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [6]:
def validate(model, val_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, targets in val_loader:
            images = torch.stack(images).to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            loss = model.compute_loss(outputs, targets)
            val_loss += loss.item()
    return val_loss / len(val_loader)

In [7]:
# Model, Optimizer and Training Loop setup
model = ExtendedMask2Former(num_classes=number_classes).to(device)

# Hyperparameters
num_epochs = 1
learning_rate = 0.001
batch_size = 8


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, targets in train_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        outputs = model(images)
        loss = model.compute_loss(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

    scheduler.step()

    # Validation
    validation_loss = validate(model, validation_loader, device)    
    print(f'Validation Loss: {validation_loss:.4f}')

Loaded pretrained weights for efficientnet-b7


AttributeError: 'tuple' object has no attribute 'split'