## 

In [1]:
# Load the autoreload extension
%load_ext autoreload

# Set autoreload mode
%autoreload 2

import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf


with hydra.initialize(version_base=None, config_path="./config"):
    cfg = hydra.compose(config_name="train")
    print(OmegaConf.to_yaml(cfg))

epochs: 20
dataloader:
  _target_: torch.utils.data.DataLoader
  _partial_: true
  batch_size: 32
  shuffle: true
  num_workers: 0
model:
  _target_: model.GateCornerCNN
optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.001
loss:
  _target_: torch.nn.MSELoss
dataset:
  dataset:
    _target_: dataset.DroneDataset
    dir_dataset: /workspaces/AE4353-Y24/data/AutonomousFlightData/test



In [2]:
from torch.utils.data import Subset
from dataset import DroneDataset


model = instantiate(cfg.model)

train_dataset = DroneDataset("/workspaces/AE4353-Y24/data/AutonomousFlightData/test")
split_index = int(0.85 * len(train_dataset))
train_set = Subset(train_dataset, range(0, split_index))

val_dataset = DroneDataset("/workspaces/AE4353-Y24/data/AutonomousFlightData/test")
val_set = Subset(val_dataset, range(split_index, len(val_dataset)))

train_loader = instantiate(cfg.dataloader)(train_set)
val_loader = instantiate(cfg.dataloader)(val_set)

optimizer = instantiate(cfg.optimizer)(model.parameters())
criterion = instantiate(cfg.loss)

num_epochs = cfg.epochs

Loaded 802 images from /workspaces/AE4353-Y24/data/AutonomousFlightData/test/autonomous_flight-01a-ellipse.h5
Total images loaded: 802
Total targets loaded: 802
Loaded 802 images from /workspaces/AE4353-Y24/data/AutonomousFlightData/test/autonomous_flight-01a-ellipse.h5
Total images loaded: 802
Total targets loaded: 802


In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm


def train_model(
    model, train_loader, val_loader, num_epochs, criterion, optimizer, device
):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode

        running_loss = 0.0
        for images, targets in tqdm(train_loader):
            images = images.to(device)
            targets = [
                t.to(device) for t in targets
            ]  # Assuming targets are a list of tensors

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Compute loss (assuming targets are in the correct shape)
            # If targets are a list, you may need to stack them before passing to the criterion
            loss = criterion(outputs, torch.stack(targets))

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

        # Validation after each epoch
        validate_model(model, val_loader, criterion, device)

    print("Training complete.")


def validate_model(model, val_loader, criterion, device):

    model.eval()

    val_loss = 0

    with torch.no_grad():

        for images, targets in val_loader:
            images = images.to(device)
            targets = [t.to(device) for t in targets]

            # Forward pass
            outputs = model(images)

            # Compute validation loss
            loss = criterion(outputs, torch.stack(targets))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

In [6]:
train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, device)

100%|██████████| 22/22 [00:02<00:00,  7.98it/s]


Epoch [1/20], Loss: 0.1606
Validation Loss: 0.1227


100%|██████████| 22/22 [00:02<00:00, 10.39it/s]


Epoch [2/20], Loss: 0.0840
Validation Loss: 0.1629


100%|██████████| 22/22 [00:02<00:00, 10.52it/s]


Epoch [3/20], Loss: 0.0646
Validation Loss: 0.1326


100%|██████████| 22/22 [00:02<00:00, 10.57it/s]


Epoch [4/20], Loss: 0.0587
Validation Loss: 0.1864


100%|██████████| 22/22 [00:02<00:00, 10.63it/s]


Epoch [5/20], Loss: 0.0503
Validation Loss: 0.1772


100%|██████████| 22/22 [00:02<00:00, 10.55it/s]


Epoch [6/20], Loss: 0.0432
Validation Loss: 0.1701


100%|██████████| 22/22 [00:02<00:00, 10.61it/s]


Epoch [7/20], Loss: 0.0422
Validation Loss: 0.1402


100%|██████████| 22/22 [00:02<00:00, 10.11it/s]


Epoch [8/20], Loss: 0.0385
Validation Loss: 0.1522


100%|██████████| 22/22 [00:02<00:00, 10.05it/s]


Epoch [9/20], Loss: 0.0349
Validation Loss: 0.1273


100%|██████████| 22/22 [00:02<00:00, 10.66it/s]


Epoch [10/20], Loss: 0.0377
Validation Loss: 0.0861


100%|██████████| 22/22 [00:02<00:00, 10.45it/s]


Epoch [11/20], Loss: 0.0346
Validation Loss: 0.1265


100%|██████████| 22/22 [00:02<00:00, 10.33it/s]


Epoch [12/20], Loss: 0.0337
Validation Loss: 0.1099


100%|██████████| 22/22 [00:02<00:00, 10.31it/s]


Epoch [13/20], Loss: 0.0377
Validation Loss: 0.1599


100%|██████████| 22/22 [00:02<00:00, 10.05it/s]


Epoch [14/20], Loss: 0.0311
Validation Loss: 0.1391


100%|██████████| 22/22 [00:02<00:00, 10.22it/s]


Epoch [15/20], Loss: 0.0289
Validation Loss: 0.1085


100%|██████████| 22/22 [00:02<00:00, 10.53it/s]


Epoch [16/20], Loss: 0.0275
Validation Loss: 0.0789


100%|██████████| 22/22 [00:02<00:00, 10.12it/s]


Epoch [17/20], Loss: 0.0282
Validation Loss: 0.1116


100%|██████████| 22/22 [00:02<00:00, 10.55it/s]


Epoch [18/20], Loss: 0.0285
Validation Loss: 0.1864


100%|██████████| 22/22 [00:02<00:00, 10.30it/s]


Epoch [19/20], Loss: 0.0265
Validation Loss: 0.1039


100%|██████████| 22/22 [00:02<00:00, 10.15it/s]


Epoch [20/20], Loss: 0.0256
Validation Loss: 0.1014
Training complete.
