In [1]:
import os
from glob import glob
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision.transforms.functional as TF
from PIL import Image


from torchvision.models.detection.rpn import AnchorGenerator
from collections import OrderedDict



gradient reversal

In [13]:
deleted = 0

for fname in os.listdir(img_dir):
    if not fname.lower().endswith((".jpg", ".png", ".jpeg")):
        continue

    img_path = os.path.join(img_dir, fname)
    lbl_path = os.path.join(lbl_dir, os.path.splitext(fname)[0] + ".txt")

    # If label missing â†’ delete image
    if not os.path.exists(lbl_path):
        print("Missing label â†’ DELETING IMAGE:", img_path)
        try:
            os.remove(img_path)
        except Exception as e:
            print("Could not delete (in use):", e)
        deleted += 1
        continue

    # Load image size SAFELY (auto-close)
    try:
        with Image.open(img_path) as img:
            w, h = img.size
    except Exception as e:
        print("Unreadable image â†’ DELETING:", img_path)
        try:
            os.remove(img_path)
            if os.path.exists(lbl_path):
                os.remove(lbl_path)
        except Exception as e2:
            print("Delete failed:", e2)
        deleted += 1
        continue

    # Load labels
    try:
        boxes = load_yolo_label(lbl_path, w, h)
    except Exception as e:
        print("Invalid label â†’ deleting:", lbl_path)
        os.remove(lbl_path)
        try:
            os.remove(img_path)
        except:
            pass
        deleted += 1
        continue

    # Check if invalid
    if is_invalid(boxes, w, h):
        print("Invalid boxes â†’ DELETING:", img_path)
        try:
            os.remove(img_path)
            if os.path.exists(lbl_path):
                os.remove(lbl_path)
        except Exception as e:
            print("Could not delete (in use):", e)
        deleted += 1

print(f"\nðŸ”¥ Deleted {deleted} corrupted images & labels.")


Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\test_W0_T0_F60_S05_000009.jpg
Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\test_W1_T2_F75_S120_000002.jpg
Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\train_W0_T1_F60_S24_000007.jpg
Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\train_W0_T1_F75_S28_000003.jpg
Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\train_W0_T2_F60_S41_000006.jpg
Invalid boxes â†’ DELETING: C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images\train_W0_T3_F60_S5

In [2]:
import torch
import torch.nn as nn
import torch.autograd as autograd

# --- 1. The Autograd Function (No Change Needed Here) ---
class GradientReverseFn(autograd.Function):
    """Handles the forward (identity) and backward (gradient reversal) pass."""
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_ # Save the dynamic lambda for the backward pass
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        # Reverse and scale gradient using the lambda saved in the context
        return -ctx.lambda_ * grad_output, None


# --- 2. The Module Wrapper (Refined to accept lambda in forward) ---
class GradientReverse(nn.Module):
    """
    Module wrapper for GRL, now designed to accept the dynamic lambda
    value in its forward call, which is passed down to the function.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x, lambda_):
        # We pass the dynamically calculated lambda_ directly to the function
        return GradientReverseFn.apply(x, lambda_)

DA faster R-CNN


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator

# Import the refined GRL from the file above
# Assuming GradientReverse is available in the scope

class DAFasterRCNN_Global_Stable(nn.Module):
    # Pass current_step directly, or use a separate set_current_step method
    def __init__(self, num_classes, lambda_da_max=0.01, da_warmup_steps=1000): 
        super().__init__()
        self.lambda_da_max = lambda_da_max
        # Using 'steps' (iterations) is often better than 'epochs' for scheduling
        self.da_warmup_steps = da_warmup_steps 
        self.current_step = 0 # Track current iteration internally

        # ----------- BASE DETECTOR -------------
        # Use small anchors for tiny drones
        anchor_generator = AnchorGenerator(
            sizes=((8, 16, 32), (16, 32, 64), (32, 64, 128), (64, 128, 256), (128, 256, 512)),
            aspect_ratios=((0.5, 1.0, 2.0),) * 5
        )

        self.detector = fasterrcnn_resnet50_fpn(
            weights=None,
            num_classes=num_classes,
            rpn_anchor_generator=anchor_generator
        )

        # GRL initialization (no lambda_ required now)
        self.grl = GradientReverse()

        # Domain classifier (lazy init)
        self.im_domain_head = None

    def set_current_step(self, step):
        """Allows external update of the current iteration/step count."""
        self.current_step = step

    def warmup_lambda(self):
        """
        Gradually increase lambda_da over first steps.
        """
        if self.current_step >= self.da_warmup_steps:
            return self.lambda_da_max
        
        # Use a smooth function like p = 2 / (1 + exp(-gamma * iter)) - 1
        # Here, a simple linear ramp is used based on your initial logic
        progress = self.current_step / self.da_warmup_steps
        return progress * self.lambda_da_max

    def _build_domain_head(self, feature_dict):
        num_scales = len(feature_dict)
        # Assuming FPN features are P2-P6 (5 scales), and dim is 256
        feat_dim = next(iter(feature_dict.values())).shape[1]
        input_dim = num_scales * feat_dim

        self.im_domain_head = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        ).to(next(self.parameters()).device)

    def forward(self, images, targets=None, domain=None):
        if self.training and domain is None:
            raise ValueError("Need domain='source' or 'target' during training.")

        device = images[0].device
        original_sizes = [img.shape[-2:] for img in images]

        # --- Transform ---
        images_t, targets_t = self.detector.transform(images, targets)

        # --- Backbone ---
        features = self.detector.backbone(images_t.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([("0", features)])

        # Build domain head lazily
        if self.im_domain_head is None:
            self._build_domain_head(features)

        # --- DETECTION (SOURCE ONLY) ---
        det_losses = {}
        detections = None
        
        # Only compute RPN/ROI losses if it's the source domain
        if domain == "source":
            # The standard torchvision implementation handles the loss computation internally
            # when targets are provided.
            if self.training:
                # In training mode, call the detector with targets to get losses
                det_outputs = self.detector.forward(images, targets)
                # The detector's forward returns losses when targets is not None in training
                det_losses.update(det_outputs)
            else:
                # Inference mode (done in the final else block for simplicity)
                pass


        # --- DOMAIN ADAPTATION LOSS ---
        da_losses = {}
        if self.training:
            # Domain label (0=source, 1=target)
            dom_label = 0 if domain == "source" else 1
            dom_label_tensor = torch.full(
                (len(images),), dom_label, dtype=torch.long, device=device
            )

            # Global pooled features (same as your original logic - correct)
            pooled = [
                F.adaptive_avg_pool2d(f, 1).flatten(1)
                for f in features.values()
            ]
            im_feat = torch.cat(pooled, dim=1)

            # GRL with warm-up
            lambda_da = self.warmup_lambda()
            # Pass the dynamic lambda_da to the GRL's forward method
            rev_feat = self.grl(im_feat, lambda_da)

            logits = self.im_domain_head(rev_feat)
            da_losses["loss_da_im"] = F.cross_entropy(logits, dom_label_tensor)

        if self.training:
            losses = {}
            losses.update(det_losses)
            losses.update(da_losses)
            return losses
        else:
            # Inference: Detector forward returns detections when targets is None
            # We call the standard detector forward here for clean inference.
            # NOTE: We can't reuse features/proposals easily without replicating
            # the internal logic of the torchvision model. The cleanest way 
            # is to call the full detector forward for inference.
            self.detector.eval()
            with torch.no_grad():
                # Pass original images (not transformed) to the detector's forward method
                # which handles transformation and postprocessing internally.
                detections = self.detector(images)
            self.detector.train() # Reset to train mode if we are inside a training loop batch
            return detections

dataset for faster R-CNN

In [4]:
class YoloTxtDetectionDataset(Dataset):
    def __init__(self, img_dir: str, label_dir: str, transforms=None):
        """
        img_dir: folder with images
        label_dir: folder with YOLO txt labels
        transforms: callable(img, target) -> (img, target)
        """
        self.img_paths = sorted(
            [p for p in glob(os.path.join(img_dir, "*")) if p.lower().endswith((".jpg", ".png", ".jpeg"))]
        )
        self.label_dir = label_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths)

    def _load_yolo_labels(self, img_path: str, w: int, h: int):
        # Get corresponding label path
        base = os.path.splitext(os.path.basename(img_path))[0]
        label_path = os.path.join(self.label_dir, base + ".txt")

        boxes = []
        labels = []

        if not os.path.exists(label_path):
            # No labels: return empty
            return torch.empty((0, 4), dtype=torch.float32), torch.empty((0,), dtype=torch.int64)

        with open(label_path, "r") as f:
            for line in f.readlines():
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                cls = int(parts[0])
                x_c, y_c, bw, bh = map(float, parts[1:5])

                # Convert from normalized cx, cy, w, h â†’ absolute xyxy
                x_c *= w
                y_c *= h
                bw *= w
                bh *= h
                x1 = x_c - bw / 2
                y1 = y_c - bh / 2
                x2 = x_c + bw / 2
                y2 = y_c + bh / 2

                boxes.append([x1, y1, x2, y2])
                labels.append(cls)  # map as needed (e.g. single class "drone" = 1)

        if boxes:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)
        else:
            boxes = torch.empty((0, 4), dtype=torch.float32)
            labels = torch.empty((0,), dtype=torch.int64)

        return boxes, labels

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB")
        w, h = img.size

        boxes, labels = self._load_yolo_labels(img_path, w, h)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) if boxes.numel() > 0 else torch.tensor([]),
            "iscrowd": torch.zeros((boxes.shape[0],), dtype=torch.int64),
        }

        if self.transforms:
            img, target = self.transforms(img, target)

        # Convert PIL â†’ Tensor here if you want:
        img = TF.to_tensor(img)  # [C,H,W], float32 in [0,1]
        return img, target


def det_collate(batch):
    imgs, targets = zip(*batch)
    return list(imgs), list(targets)


datasets

In [5]:
real_img_dir = r"C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\real\train\images"
real_lbl_dir = r"C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\real\train\labels"

virt_img_dir = r"C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\images"
virt_lbl_dir = r"C:\Users\giorg\Downloads\drone-detection-distance-template\drone-detection-distance\data\virtual\train\labels"

source_dataset = YoloTxtDetectionDataset(virt_img_dir, virt_lbl_dir)
target_dataset = YoloTxtDetectionDataset(real_img_dir, real_lbl_dir)  # labels ignored in training


In [6]:
from torch.utils.data import DataLoader
import torch
import itertools
import torch.nn.utils as utils


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = 2  # e.g. background + 1 drone class -> check your IDs
da_weight = 0.001  # changed from 0.01 it was exploding
lambda_da_max = 1   # how strong to weight DA loss vs detection
warmup_steps = 1000 

model = DAFasterRCNN_Global_Stable(
    num_classes=num_classes,
    lambda_da_max=lambda_da_max,
    da_warmup_steps=warmup_steps
).to(device)

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.001,                  # ðŸ”¥ lower LR for stability (very important)
    momentum=0.9,
    weight_decay=0.0005
)

source_loader = DataLoader(
    source_dataset, batch_size=2, shuffle=True, collate_fn=det_collate
)
target_loader = DataLoader(
    target_dataset, batch_size=2, shuffle=True, collate_fn=det_collate
)


import torch
import torch.nn.utils as utils
import itertools

def train_da_stable(
    model,
    source_loader,
    target_loader,
    optimizer,
    num_epochs=10,
    da_weight=0.001,
    device="cuda"
):
    model.train()
    target_iter = itertools.cycle(target_loader)
    global_step = 0 
    
    # Initialize total_loss outside the epoch loop, 
    # to be safely used in the checkpoint logic on the first epoch.
    total_loss = torch.tensor(0.0, device=device) 

    for epoch in range(num_epochs):
        
        # -----------------------------------------------
        # 1. BATCH LOOP (INNER LOOP)
        # -----------------------------------------------
        for i, (src_imgs, src_tgts) in enumerate(source_loader):
            
            # Update the global step and the model's internal step counter
            global_step += 1
            model.set_current_step(global_step)

            src_imgs = [img.to(device) for img in src_imgs]
            src_tgts = [{k: v.to(device) for k, v in t.items()} for t in src_tgts]

            # --- SOURCE FORWARD PASS (Task Loss + DA Loss) ---
            src_loss_dict = model(src_imgs, src_tgts, domain="source")

            det_loss = (
                src_loss_dict.get("loss_classifier", torch.tensor(0.0, device=device)) +
                src_loss_dict.get("loss_box_reg", torch.tensor(0.0, device=device)) +
                src_loss_dict.get("loss_objectness", torch.tensor(0.0, device=device)) +
                src_loss_dict.get("loss_rpn_box_reg", torch.tensor(0.0, device=device))
            )
            da_src = src_loss_dict.get("loss_da_im", torch.tensor(0.0, device=device))

            # --- TARGET FORWARD PASS (DA Loss ONLY) ---
            tgt_imgs, _ = next(target_iter)
            tgt_imgs = [img.to(device) for img in tgt_imgs]

            tgt_loss_dict = model(tgt_imgs, targets=None, domain="target")
            da_tgt = tgt_loss_dict.get("loss_da_im", torch.tensor(0.0, device=device))

            # --- TOTAL LOSS & OPTIMIZATION ---
            # Total Loss = Task Loss + DA Weight * (DA Source Loss + DA Target Loss)
            total_loss = det_loss + da_weight * (da_src + da_tgt)

            optimizer.zero_grad()
            total_loss.backward()

            # Gradient clipping is essential for adversarial stability
            utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            optimizer.step()
            
            # --- LOGGING ---
            if global_step % 10 == 0:
                current_lambda = model.warmup_lambda()
                print(
                    f"[Epoch {epoch} Step {global_step}] "
                    f"det={det_loss.item():.4f} "
                    f"DA_src={da_src.item():.4f} "
                    f"DA_tgt={da_tgt.item():.4f} "
                    f"lambda={current_lambda:.6f} "
                    f"total={total_loss.item():.4f}"
                )
        
        # -----------------------------------------------
        # 2. CHECKPOINT SAVING (CORRECT PLACEMENT)
        #    This runs once, after all batches in the current epoch are complete.
        # -----------------------------------------------
        if epoch % 1 == 0: # Save every epoch
            save_path = f"da_frcnn_epoch_{epoch}.pth"
            torch.save({
                'epoch': epoch,
                'global_step': global_step, # Crucial for resuming lambda schedule
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': total_loss.item(), # Last calculated loss of this epoch
            }, save_path)
            print(f"\nModel checkpoint saved to {save_path}\n")

    print("\nTraining complete.")

# Call training:
#train_da_stable(model, source_loader, target_loader, optimizer, num_epochs=10)


In [7]:
train_da_stable(model, source_loader, target_loader, optimizer, num_epochs=10, da_weight=da_weight, device=device)

[Epoch 0 Step 10] det=0.7699 DA_src=0.7602 DA_tgt=0.6704 lambda=0.010000 total=0.7713
[Epoch 0 Step 20] det=0.6884 DA_src=0.7575 DA_tgt=0.6658 lambda=0.020000 total=0.6898
[Epoch 0 Step 30] det=0.2815 DA_src=0.9188 DA_tgt=0.5210 lambda=0.030000 total=0.2829
[Epoch 0 Step 40] det=0.0570 DA_src=1.8915 DA_tgt=0.2334 lambda=0.040000 total=0.0591
[Epoch 0 Step 50] det=0.7231 DA_src=2.0333 DA_tgt=0.1388 lambda=0.050000 total=0.7253
[Epoch 0 Step 60] det=0.1527 DA_src=2.0868 DA_tgt=0.1036 lambda=0.060000 total=0.1549


KeyboardInterrupt: 

## Check corrupted images/labels


In [11]:
for idx in range(len(source_dataset)):
    try:
        img, tgt = source_dataset[idx]
        boxes = tgt["boxes"]

        if torch.isnan(boxes).any():
            print("NaN boxes at index:", idx)

        if (boxes[:,2] <= boxes[:,0]).any() or (boxes[:,3] <= boxes[:,1]).any():
            print("Invalid box size at index:", idx)

        if (boxes < 0).any():
            print("Negative box coordinate at index:", idx)

    except Exception as e:
        print("Error at index:", idx, "->", str(e))
        break


Negative box coordinate at index: 26
Negative box coordinate at index: 38
Negative box coordinate at index: 148
Negative box coordinate at index: 159
Negative box coordinate at index: 337
Negative box coordinate at index: 595
Negative box coordinate at index: 676
Negative box coordinate at index: 1099
Negative box coordinate at index: 1308
Negative box coordinate at index: 1604
Negative box coordinate at index: 1682
Invalid box size at index: 1835
Invalid box size at index: 2046
Invalid box size at index: 2066
Invalid box size at index: 2136
Invalid box size at index: 2702
Invalid box size at index: 2793
