In [6]:
"""
Main IDEA : Given a sequence of multimodal inputs, predict the final drone position + orientation.
"""

'\nMain IDEA : Given a sequence of multimodal inputs, predict the final drone position + orientation.\n'

In [1]:
import os
import math
from dataclasses import dataclass
from typing import Tuple, List, Dict, Any

import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

In [5]:

@dataclass
class Config:

    csv_path: str
    root_dir: str
    out_dir: str = "./checkpoints"
    
    img_size: Tuple[int, int] = (224, 224)  # (H, W)
    
    seq_len: int = 5
    seq_stride: int = 1
    
    batch_size: int = 8
    epochs: int = 20
    lr: float = 1e-4
    weight_decay: float = 1e-5
    
    num_workers: int = 4
    pos_loss_weight: float = 1.0
    ori_loss_weight: float = 100.0
    
    use_pretrained_backbones: bool = True
    use_gru: bool = True
    freeze_cnn: bool = False
    
    rgb_backbone: str = "resnet18"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def quaternion_normalize(q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Normalize quaternion to unit norm. q shape: (..., 4).
    """
    return q / (q.norm(dim=-1, keepdim=True) + eps)


def load_image(path: str, resize: Tuple[int, int], is_depth: bool) -> Image.Image:
    """
    Load an image from the given path, resize it to the specified dimensions,
    """
    img = Image.open(path)   
    
    if is_depth:
        # 2a. If it's a depth image → ensure it has an appropriate mode
        # "I;16" = 16-bit grayscale, "I" = 32-bit int grayscale,
        # "F" = 32-bit float grayscale, "L" = 8-bit grayscale
        if img.mode not in ("I;16", "I", "F", "L"):
            img = img.convert("I")  # convert to 32-bit grayscale
    else:
        # 2b. If it's an RGB image → ensure mode is "RGB"
        if img.mode != "RGB":
            img = img.convert("RGB")
    
    # 3. Resize the image to the given size (resize[0] = height, resize[1] = width)
    # Use different interpolation:
    # - RGB → bilinear (smooth)
    # - Depth → nearest (preserves exact depth values)
    img = img.resize((resize[1], resize[0]), Image.BILINEAR if not is_depth else Image.NEAREST)
    
    return img


In [7]:
class DronePoseDataset(Dataset):
    """
    Dataset: expects a CSV with modality + pose info.
    """

    def __init__(self, cfg: Config, split: str = "train", train_split: float = 0.9):
        super().__init__()
        self.cfg = cfg
        self.split = split

        df = pd.read_csv(cfg.csv_path)

        if "timestamp" in df.columns:
            df = df.sort_values("timestamp").reset_index(drop=True)
        self.df = df

        n = len(df)
        n_train = int(train_split * n)
        
        # split the dataset into train and validation sets
        if split == "train":
            self.df_split = df.iloc[:n_train].reset_index(drop=True)
        else:
            self.df_split = df.iloc[n_train:].reset_index(drop=True)

        self.indices: List[int] = []

        # The dataset works on sliding windows of length seq_len with stride seq_stride.
        L = len(self.df_split)
        for start in range(0, L - cfg.seq_len + 1, cfg.seq_stride):
            self.indices.append(start)


        """
        Transforms:
        - RGB → To tensor → Normalize with ImageNet stats.
        - Depth → To tensor, but no normalization (since depth values are absolute, not colors).
        """
        self.rgb_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        self.depth_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.ConvertImageDtype(torch.float32),
        ])

    def __len__(self):
        return len(self.indices)

    def _row_to_modalities(self, row: pd.Series) -> Dict[str, Any]:
        """
        Given a single row from the CSV, it loads all sensor modalities. 

        The topics are expected to be:
        - RGB & Depth : rgb_file,depth_file,
        - Sonar : sonar_front,sonar_back,sonar_left,sonar_right,
        - IMU : imu_ang_vel_x,imu_ang_vel_y,imu_ang_vel_z,imu_lin_acc_x,imu_lin_acc_y,imu_lin_acc_z,
        - GT : gt_pos_x,gt_pos_y,gt_pos_z,gt_orient_x,gt_orient_y,gt_orient_z,gt_orient_w
        """

        rgb_path = os.path.join(self.cfg.root_dir, str(row["rgb_file"]))
        depth_path = os.path.join(self.cfg.root_dir, str(row["depth_file"]))

        rgb_img = load_image(rgb_path, self.cfg.img_size, is_depth=False)
        depth_img = load_image(depth_path, self.cfg.img_size, is_depth=True)

        rgb_tensor = self.rgb_transform(rgb_img)
        depth_tensor = self.depth_transform(depth_img)

        sonar = torch.tensor([row["sonar_front"], row["sonar_back"],
                              row["sonar_left"], row["sonar_right"]], dtype=torch.float32)
        
        imu = torch.tensor([row["imu_ang_vel_x"], row["imu_ang_vel_y"], row["imu_ang_vel_z"],
                            row["imu_lin_acc_x"], row["imu_lin_acc_y"], row["imu_lin_acc_z"]],
                           dtype=torch.float32)

        pos = torch.tensor([row["gt_pos_x"], row["gt_pos_y"], row["gt_pos_z"]], dtype=torch.float32)
        quat = torch.tensor([row["gt_orient_x"], row["gt_orient_y"],
                             row["gt_orient_z"], row["gt_orient_w"]], dtype=torch.float32)

        return {"rgb": rgb_tensor, "depth": depth_tensor, "sonar": sonar,
                "imu": imu, "pos": pos, "quat": quat}

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        - Each index (idx) corresponds to a window of length seq_len starting at start.
        - Instead of returning a pose per frame, it only takes the final frame's pose.
        """

        start = self.indices[idx]
        rows = self.df_split.iloc[start:start + self.cfg.seq_len]

        rgbs, depths, sonars, imus = [], [], [], []
        
        for _, row in rows.iterrows():
            
            m = self._row_to_modalities(row)
            rgbs.append(m["rgb"])
            depths.append(m["depth"])
            sonars.append(m["sonar"])
            imus.append(m["imu"])

        last = self._row_to_modalities(rows.iloc[-1]) # get only the pose of the last row 

        return {
            "rgb": torch.stack(rgbs, dim=0),
            "depth": torch.stack(depths, dim=0),
            "sonar": torch.stack(sonars, dim=0),
            "imu": torch.stack(imus, dim=0),
            "pos": last["pos"],
            "quat": last["quat"],
        }


In [None]:
"""
USAGE :
1. Depth Processing:
depth_model = SmallDepthCNN(out_dim=128)
depth_feats = depth_model(depth_tensor)  # depth_tensor: (B,T,1,H,W)

2. RGB Processing:
rgb_model = make_resnet_feature_extractor("resnet18", pretrained=True, out_dim=256)
rgb_feats = rgb_model(rgb_tensor)  # rgb_tensor: (B,T,3,H,W)
"""

class SmallDepthCNN(nn.Module):
    """
    Input: depth image → shape (B, 1, H, W) or (B, T, 1, H, W) if a sequence
    """

    def __init__(self, out_dim: int = 128):

        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=2, padding=2), nn.BatchNorm2d(16), nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        
        # Flatten → (B, 128) → fully connected layer → (B, out_dim).
        self.fc = nn.Linear(128, out_dim)

    def forward(self, x):
        """
        Handles both single frame (B,1,H,W) and sequence (B,T,1,H,W).
        - (B, out_dim) if single frame.
        - (B, T, out_dim) if sequence.
        """

        is_seq = (x.dim() == 5)

        if is_seq:
            B, T = x.shape[:2]
            x = x.view(B*T, *x.shape[2:])
        feat = self.net(x).flatten(1)
        feat = self.fc(feat)

        if is_seq:
            feat = feat.view(B, T, -1)

        return feat


def make_resnet_feature_extractor(backbone="resnet18", pretrained=True, out_dim=256):
    """
    Creates a feature extractor for RGB images using a ResNet backbone (18/34/50).
    """
    if backbone == "resnet18":
        m = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT if pretrained else None)
        feat_dim = 512

    elif backbone == "resnet34":
        m = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT if pretrained else None)
        feat_dim = 512

    elif backbone == "resnet50":
        m = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None)
        feat_dim = 2048
    else:
        raise ValueError(f"Unsupported backbone: {backbone}")

    body = nn.Sequential(*list(m.children())[:-1]) # Reduces ResNet features to desired out_dim (e.g., 256). 
    proj = nn.Linear(feat_dim, out_dim)

    class ResnetFeat(nn.Module):
        """
        (B,T,C,H,W) → flattened into (B*T, C,H,W) → pass through ResNet → projected features → reshape back to (B,T,out_dim)
        """
        def __init__(self, body, proj):
            super().__init__()
            self.body = body
            self.proj = proj

        def forward(self, x):
            is_seq = (x.dim() == 5)
            if is_seq:
                B, T = x.shape[:2]
                x = x.view(B*T, *x.shape[2:])
            f = self.body(x).flatten(1)
            f = self.proj(f)
            if is_seq:
                f = f.view(B, T, -1)
            return f

    return ResnetFeat(body, proj)

In [12]:
class PoseNetMultiModal(nn.Module):

    def __init__(self, cfg: Config, rgb_feat_dim=256, depth_feat_dim=128,
                 sonar_dim=32, imu_dim=64, fused_dim=256, rnn_hidden=256):
        
        super().__init__()

        self.cfg = cfg

        self.rgb_encoder = make_resnet_feature_extractor(cfg.rgb_backbone,
                                                         cfg.use_pretrained_backbones,
                                                         rgb_feat_dim)
        
        self.depth_encoder = SmallDepthCNN(out_dim = depth_feat_dim)

        self.sonar_mlp = nn.Sequential(nn.Linear(4, 32), nn.ReLU(),
                                       nn.Linear(32, sonar_dim), nn.ReLU())
        
        self.imu_mlp = nn.Sequential(nn.Linear(6, 64), nn.ReLU(),
                                     nn.Linear(64, imu_dim), nn.ReLU())

        in_fuse = rgb_feat_dim + depth_feat_dim + sonar_dim + imu_dim # this is the main stacking of all multi-modal features

        self.fuse = nn.Sequential(nn.Linear(in_fuse, fused_dim), nn.ReLU(),
                                  nn.Linear(fused_dim, fused_dim), nn.ReLU())

        self.use_gru = cfg.use_gru

        if self.use_gru: # GRU takes fused features over a sequence → hidden size rnn_hidden.
            self.rnn = nn.GRU(fused_dim, rnn_hidden, batch_first=True)
            head_in = rnn_hidden
        else:
            head_in = fused_dim

        # final head to predict position + orientation
        self.head = nn.Sequential(nn.Linear(head_in, 128), nn.ReLU(),
                                  nn.Linear(128, 7))

    def forward(self, rgb, depth, sonar, imu):

        rgb_f = self.rgb_encoder(rgb)
        depth_f = self.depth_encoder(depth)
        sonar_f = self.sonar_mlp(sonar)
        imu_f = self.imu_mlp(imu)

        fused = torch.cat([rgb_f, depth_f, sonar_f, imu_f], dim=-1)
        fused = self.fuse(fused)

        if self.use_gru:
            out, _ = self.rnn(fused)
            feat_T = out[:, -1, :]

        else:
            feat_T = fused[:, -1, :]

        pose = self.head(feat_T)

        pos, quat = pose[:, :3], quaternion_normalize(pose[:, 3:])
        
        return pos, quat


In [13]:
class PoseLoss(nn.Module):

    def __init__(self, pos_w=1.0, ori_w=100.0):

        super().__init__()
        self.pos_w = pos_w
        self.ori_w = ori_w

    def forward(self, pos_pred, quat_pred, pos_gt, quat_gt):

        pos_loss = F.mse_loss(pos_pred, pos_gt)
        quat_gt = quaternion_normalize(quat_gt)
        ori_loss = F.mse_loss(quat_pred, quat_gt)
        
        return self.pos_w * pos_loss + self.ori_w * ori_loss, pos_loss, ori_loss


In [15]:
def collate_fn(batch: List[Dict[str, Any]]):
    """
    Purpose: Convert a list of samples from the dataset into a batched tensor dictionary.
    """
    return {k: torch.stack([b[k] for b in batch], dim=0) for k in batch[0]}


def train_one_epoch(model, loss_fn, optimizer, loader, device):
    """
    Sets the model to training mode (activates dropout, batchnorm, etc.)
    and prints position, orientation, and combined loss per batch.
    """
    model.train()
    total, pos_total, ori_total = 0, 0, 0

    for i, batch in enumerate(loader, 1):
        rgb, depth, sonar, imu = (batch["rgb"].to(device), batch["depth"].to(device),
                                  batch["sonar"].to(device), batch["imu"].to(device))
        
        pos_gt, quat_gt = batch["pos"].to(device), batch["quat"].to(device)

        optimizer.zero_grad()
        pos_pred, quat_pred = model(rgb, depth, sonar, imu)
        loss, pos_l, ori_l = loss_fn(pos_pred, quat_pred, pos_gt, quat_gt)

        loss.backward()
        optimizer.step()

        total += loss.item()
        pos_total += pos_l.item()
        ori_total += ori_l.item()

        # Print losses for this batch
        # print(f"Batch {i}: Pos Loss = {pos_l.item():.4f}, Ori Loss = {ori_l.item():.4f}, Total Loss = {loss.item():.4f}")
        
    return total/len(loader), pos_total/len(loader), ori_total/len(loader)



def evaluate(model, loss_fn, loader, device):
    """
    - model.eval() → disables dropout, batchnorm in training mode.
    - torch.no_grad() → prevents gradient computation (saves memory).
    
    Rest is same as train_one_epoch, but without optimizer step.
    """
    model.eval()
    total, pos_total, ori_total = 0, 0, 0

    with torch.no_grad():
        for batch in loader:

            rgb, depth, sonar, imu = (batch["rgb"].to(device), batch["depth"].to(device),
                                      batch["sonar"].to(device), batch["imu"].to(device))
            pos_gt, quat_gt = batch["pos"].to(device), batch["quat"].to(device)

            pos_pred, quat_pred = model(rgb, depth, sonar, imu)
            loss, pos_l, ori_l = loss_fn(pos_pred, quat_pred, pos_gt, quat_gt)

            total += loss.item(); pos_total += pos_l.item(); ori_total += ori_l.item()

    return total/len(loader), pos_total/len(loader), ori_total/len(loader)


In [16]:
def run_training(cfg: Config):

    os.makedirs(cfg.out_dir, exist_ok=True)

    train_set = DronePoseDataset(cfg, "train")
    val_set = DronePoseDataset(cfg, "val")

    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate_fn)
    
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate_fn)

    model = PoseNetMultiModal(cfg).to(cfg.device)

    if cfg.freeze_cnn:
        for p in model.rgb_encoder.parameters():
            p.requires_grad = False

    loss_fn = PoseLoss(cfg.pos_loss_weight, cfg.ori_loss_weight)

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                  lr=cfg.lr, weight_decay=cfg.weight_decay)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                           factor=0.5, patience=3)

    best_val = math.inf
    
    for epoch in range(1, cfg.epochs+1):
        tr_loss, tr_pos, tr_ori = train_one_epoch(model, loss_fn, optimizer, train_loader, cfg.device)
        va_loss, va_pos, va_ori = evaluate(model, loss_fn, val_loader, cfg.device)
        scheduler.step(va_loss)

        print(f"Epoch {epoch:03d} | Train {tr_loss:.4f} (pos {tr_pos:.4f}, ori {tr_ori:.4f}) "
              f"| Val {va_loss:.4f} (pos {va_pos:.4f}, ori {va_ori:.4f})")

        ckpt = {'epoch': epoch, 'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(), 'cfg': cfg.__dict__}
        torch.save(ckpt, os.path.join(cfg.out_dir, 'last.pth'))

        if va_loss < best_val:
            
            best_val = va_loss
            torch.save({**ckpt, 'val_loss': va_loss}, os.path.join(cfg.out_dir, 'best.pth'))
            print(f"  → New best model (val {va_loss:.4f}) saved.")
