In [None]:
# ----------- Step 0: Required Libraries -----------
import os
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from torchvision.ops import box_iou
from tqdm import tqdm

# ----------- Step 1: Custom Dataset Loader -----------

class FLIRDataset(Dataset):
    def __init__(self, img_dir, annot_dir, transform=None):
        self.img_dir = img_dir
        self.annot_dir = annot_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith('.jpeg') or f.endswith('.jpg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        annot_path = os.path.join(self.annot_dir, self.images[idx].replace('.jpeg', '.xml'))

        image = Image.open(img_path).convert("RGB")

        boxes, labels = self.parse_annotation(annot_path)

        if self.transform:
            image = self.transform(image)

        target = {"boxes": torch.tensor(boxes, dtype=torch.float32),
                  "labels": torch.tensor(labels, dtype=torch.int64)}

        return image, target

    def parse_annotation(self, xml_file):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        boxes = []
        labels = []

        label_map = {'person': 1, 'vehicle': 2}  # Label encoding

        for obj in root.findall('object'):
            label = obj.find('name').text
            if label not in label_map:
                continue
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(label_map[label])
        return boxes, labels

# ----------- Step 2: Model (Simple CNN + Detection Head) -----------

class SimpleDetector(nn.Module):
    def __init__(self, num_classes=3):
        super(SimpleDetector, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )

        self.cls_head = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, num_classes, 1)
        )

        self.reg_head = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.Conv2d(128, 4, 1)
        )

    def forward(self, x):
        feat = self.feature_extractor(x)
        cls_logits = self.cls_head(feat)
        bbox_preds = self.reg_head(feat)
        return cls_logits, bbox_preds

# ----------- Step 3: Loss Functions -----------

def detection_loss(cls_logits, bbox_preds, targets):
    cls_loss_fn = nn.CrossEntropyLoss()
    reg_loss_fn = nn.SmoothL1Loss()

    # Dummy loss (for demonstration)
    cls_loss = cls_loss_fn(cls_logits.mean([2,3]), targets["labels"])  # Simplified
    reg_loss = reg_loss_fn(bbox_preds.mean([2,3]), targets["boxes"].float())  # Simplified

    return cls_loss + reg_loss

# ----------- Step 4: Training -----------

def train(model, dataloader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, targets in tqdm(dataloader):
            images = images.cuda()
            targets = {k: v.cuda() for k, v in targets.items()}

            optimizer.zero_grad()
            cls_logits, bbox_preds = model(images)
            loss = detection_loss(cls_logits, bbox_preds, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}] Loss: {total_loss/len(dataloader):.4f}")

# ----------- Step 5: Evaluation -----------

def evaluate(model, dataloader):
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(dataloader):
            images = images.cuda()
            cls_logits, bbox_preds = model(images)
            # You can expand this into full mAP calculation
            print("Sample Predictions:", cls_logits.mean().item(), bbox_preds.mean().item())

# ----------- Step 6: Pipeline -----------

if __name__ == "__main__":
    # Dataset paths
    train_img_dir = "/path_to/FLIR/train/images"
    train_annot_dir = "/path_to/FLIR/train/annotations"

    val_img_dir = "/path_to/FLIR/val/images"
    val_annot_dir = "/path_to/FLIR/val/annotations"

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    train_dataset = FLIRDataset(train_img_dir, train_annot_dir, transform=transform)
    val_dataset = FLIRDataset(val_img_dir, val_annot_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

    # Model
    model = SimpleDetector(num_classes=3).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train
    train(model, train_loader, optimizer, epochs=10)

    # Evaluate
    evaluate(model, val_loader)
