In [None]:
import os
import sys
import shutil
import time
import yaml
import glob
import random
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import torch
import torchvision

# Install dependencies if missing
def install_dependencies():
    os.system("pip install -q albumentations==1.4.0 torchmetrics")

try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    from torchmetrics.detection.mean_ap import MeanAveragePrecision
except ImportError:
    print("Installing dependencies...")
    install_dependencies()
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    from torchmetrics.detection.mean_ap import MeanAveragePrecision

from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import mobilenet_backbone
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
from google.colab import drive

# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Reproducibility Setup
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# Configuration
BASE_SAVE_DIR = "/content/drive/MyDrive/Model/FasterRCNN_Small"
DRIVE_YAML_PATH = "/content/drive/MyDrive/Dataset/FINAL_YOLO_SPLIT/dataset.yaml"
LOCAL_DATA_DIR = "/content/local_dataset"

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Computation Device: {DEVICE}")

# Class Definitions
# YOLO Format: 0=Brain (Ignored), 1=CSP, 2=LV
# Model Format: 0=Background, 1=CSP, 2=LV
CLASS_NAMES = ['CSP', 'LV']
NUM_CLASSES = len(CLASS_NAMES) + 1
ID_MAPPING = {1: 1, 2: 2}

os.makedirs(BASE_SAVE_DIR, exist_ok=True)

In [None]:
if not os.path.exists(LOCAL_DATA_DIR):
    print(f"Copying dataset to local runtime: {LOCAL_DATA_DIR}...")
    try:
        drive_data_dir = os.path.dirname(DRIVE_YAML_PATH)
        shutil.copytree(drive_data_dir, LOCAL_DATA_DIR)
        print("Dataset setup complete.")
    except Exception as e:
        print(f"Error copying dataset: {e}")
else:
    print(f"Local dataset found at {LOCAL_DATA_DIR}.")

In [None]:
class YOLODataset(Dataset):
    def __init__(self, root_dir, split='train', transforms=None):
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms
        self.img_dir = os.path.join(root_dir, split, 'images')
        self.label_dir = os.path.join(root_dir, split, 'labels')

        # Load images
        self.img_files = sorted(glob.glob(os.path.join(self.img_dir, "*.jpg")) +
                                glob.glob(os.path.join(self.img_dir, "*.png")))

        # Mapping YOLO classes to Model classes
        self.target_mapping = ID_MAPPING

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        file_name = os.path.basename(img_path)
        label_file = os.path.splitext(file_name)[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_file)

        # Read Image
        image = cv2.imread(img_path)
        if image is None:
            return self.__getitem__((idx + 1) % len(self.img_files))

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape

        boxes = []
        labels = []

        # Parse Labels
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                lines = f.readlines()

            for line in lines:
                parts = list(map(float, line.strip().split()))
                cls_id_raw = int(parts[0])

                if cls_id_raw in self.target_mapping:
                    final_cls_id = self.target_mapping[cls_id_raw]
                    x_c, y_c, bw, bh = parts[1], parts[2], parts[3], parts[4]

                    # Convert normalized xywh to absolute xyxy
                    x_min = (x_c - bw / 2) * w
                    y_min = (y_c - bh / 2) * h
                    x_max = (x_c + bw / 2) * w
                    y_max = (y_c + bh / 2) * h

                    # Clip to image boundaries
                    x_min = max(0, x_min)
                    y_min = max(0, y_min)
                    x_max = min(w, x_max)
                    y_max = min(h, y_max)

                    if x_max > x_min and y_max > y_min:
                        boxes.append([x_min, y_min, x_max, y_max])
                        labels.append(final_cls_id)

        # Create Tensors
        if len(boxes) > 0:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)

        # Apply Augmentations
        if self.transforms:
            try:
                transformed = self.transforms(image=image, bboxes=boxes, labels=labels)
                image = transformed['image']
                boxes = torch.as_tensor(transformed['bboxes'], dtype=torch.float32)
                labels = torch.as_tensor(transformed['labels'], dtype=torch.int64)
            except Exception:
                # Fallback if transform fails
                image = ToTensorV2()(image=image)["image"]
        else:
            image = ToTensorV2()(image=image)["image"]

        # Final Empty Check
        if len(boxes) == 0:
             boxes = torch.zeros((0, 4), dtype=torch.float32)
             labels = torch.zeros((0,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])

        # Calculate Area
        if len(boxes) > 0:
            target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)
        else:
            target["area"] = torch.as_tensor([], dtype=torch.float32)
            target["iscrowd"] = torch.as_tensor([], dtype=torch.int64)

        # Normalize Image
        if isinstance(image, torch.Tensor):
            if image.dtype == torch.uint8:
                image = image.float() / 255.0

        return image, target

def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
def get_transforms(condition='Raw'):
    bbox_params = A.BboxParams(format='pascal_voc', label_fields=['labels'], min_visibility=0.1)

    # Base transforms: Resize and conversion to Tensor
    base_ops = [
        A.Resize(height=640, width=640),
        ToTensorV2()
    ]

    if condition == 'Tuned':
        aug_ops = [
            A.Affine(
                scale=(0.6, 1.4),
                translate_percent=(0, 0.2),
                rotate=(-45, 45),
                shear=(-5, 5),
                p=1.0
            ),
            A.HorizontalFlip(p=0.5),
            A.Perspective(scale=(0.05, 0.1), p=0.5)
        ]
        return A.Compose(aug_ops + base_ops, bbox_params=bbox_params)
    else:
        return A.Compose(base_ops, bbox_params=bbox_params)

In [None]:
def get_model_mobilenetv3_small(num_classes):
    # Load Backbone
    backbone = mobilenet_backbone(
        backbone_name="mobilenet_v3_small",
        weights="DEFAULT",
        fpn=True
    )
    backbone.out_channels = 256

    # Dynamic Anchor Configuration
    # Runs a dummy input to determine the actual number of feature maps
    model_device = next(backbone.parameters()).device
    dummy_input = torch.randn(1, 3, 640, 640).to(model_device)

    with torch.no_grad():
        features = backbone(dummy_input)
        feature_map_names = list(features.keys())
        num_feature_maps = len(feature_map_names)

    print(f"Backbone feature maps detected: {num_feature_maps}")

    # Anchor Generator Setup
    base_sizes = [32, 64, 128, 256, 512, 640]
    selected_sizes = tuple((s,) for s in base_sizes[:num_feature_maps])
    selected_ratios = ((0.5, 1.0, 2.0),) * num_feature_maps

    anchor_generator = AnchorGenerator(
        sizes=selected_sizes,
        aspect_ratios=selected_ratios
    )

    # RoI Pooler Setup
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(
        featmap_names=feature_map_names,
        output_size=7,
        sampling_ratio=2
    )

    # Assemble Model
    model = FasterRCNN(
        backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler
    )

    # Head Optimization (Reduce parameters)
    in_channels = backbone.out_channels * 7 * 7
    representation_size = 512
    model.roi_heads.box_head = TwoMLPHead(in_channels, representation_size)
    model.roi_heads.box_predictor = FastRCNNPredictor(representation_size, num_classes)

    return model

In [None]:
def evaluate_map_complete(model, dataloader, device):
    """Calculates mAP 50-95 (Global) and mAP 50 (Per Class)."""
    model.eval()
    metric_global = MeanAveragePrecision(class_metrics=True).to(device)
    metric_50 = MeanAveragePrecision(class_metrics=True, iou_thresholds=[0.5]).to(device)

    with torch.no_grad():
        for images, targets in dataloader:
            images = list(img.to(device) for img in images)
            t_clean = [{k: v.to(device) for k, v in t.items() if k in ['boxes', 'labels']} for t in targets]

            outputs = model(images)
            metric_global.update(outputs, t_clean)
            metric_50.update(outputs, t_clean)

    res_global = metric_global.compute()
    res_50 = metric_50.compute()

    return {
        'map': res_global['map'].item(),
        'map_50': res_global['map_50'].item(),
        'map_per_class': res_global['map_per_class'],
        'map_50_per_class': res_50['map_per_class']
    }

def evaluate_best_f1(model, dataloader, device, num_classes):
    """Calculates Best F1-Score, Precision, and Recall."""
    model.eval()
    class_preds = {i: [] for i in range(1, num_classes)}
    class_gt_counts = {i: 0 for i in range(1, num_classes)}

    with torch.no_grad():
        for images, targets in dataloader:
            images = list(img.to(device) for img in images)
            outputs = model(images)

            for i, output in enumerate(outputs):
                pred_boxes = output['boxes']
                pred_scores = output['scores']
                pred_labels = output['labels']
                gt_boxes = targets[i]['boxes'].to(device)
                gt_labels = targets[i]['labels'].to(device)

                for cls_id in range(1, num_classes):
                    class_gt_counts[cls_id] += (gt_labels == cls_id).sum().item()

                if len(pred_scores) == 0: continue

                # Sort predictions
                sorted_indices = torch.argsort(pred_scores, descending=True)
                pred_boxes = pred_boxes[sorted_indices]
                pred_scores = pred_scores[sorted_indices]
                pred_labels = pred_labels[sorted_indices]

                used_gt_indices = set()
                iou_matrix = None
                if len(gt_boxes) > 0 and len(pred_boxes) > 0:
                    iou_matrix = torchvision.ops.box_iou(pred_boxes, gt_boxes)

                for p_idx in range(len(pred_boxes)):
                    p_label = pred_labels[p_idx].item()
                    p_score = pred_scores[p_idx].item()
                    if p_label == 0: continue

                    is_tp = False
                    if iou_matrix is not None:
                        ious = iou_matrix[p_idx]
                        if len(ious) > 0:
                            max_iou, max_gt_idx = torch.max(ious, dim=0)
                            max_gt_idx = max_gt_idx.item()
                            if (max_iou > 0.5) and \
                               (gt_labels[max_gt_idx].item() == p_label) and \
                               (max_gt_idx not in used_gt_indices):
                                is_tp = True
                                used_gt_indices.add(max_gt_idx)

                    class_preds[p_label].append((p_score, is_tp))

    results = {}
    for cls_id in range(1, num_classes):
        preds = class_preds[cls_id]
        total_gt = class_gt_counts[cls_id]

        if len(preds) == 0:
            results[cls_id] = {'p': 0.0, 'r': 0.0, 'f1': 0.0, 'thres': 0.0}
            continue

        preds.sort(key=lambda x: x[0], reverse=True)
        preds_np = np.array(preds)
        scores = preds_np[:, 0]
        tp_status = preds_np[:, 1].astype(int)

        tp_cumsum = np.cumsum(tp_status)
        fp_cumsum = np.cumsum(1 - tp_status)

        precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-16)
        recalls = tp_cumsum / total_gt if total_gt > 0 else np.zeros_like(tp_cumsum)
        f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-16)

        best_idx = np.argmax(f1_scores)
        results[cls_id] = {
            'p': precisions[best_idx],
            'r': recalls[best_idx],
            'f1': f1_scores[best_idx],
            'thres': scores[best_idx]
        }
    return results

def evaluate_and_print(model, dataloader, device, class_names, current_epoch, num_epochs):
    print(f"\nEvaluating Epoch {current_epoch}/{num_epochs}...")
    num_classes = len(class_names) + 1

    map_results = evaluate_map_complete(model, dataloader, device)
    best_f1_results = evaluate_best_f1(model, dataloader, device, num_classes)

    csv_data = []
    total_p, total_r, valid_classes = 0, 0, 0

    map_50_tensor = map_results['map_50_per_class']
    map_50_95_tensor = map_results['map_per_class']

    for i, class_name in enumerate(class_names):
        cls_id = i + 1

        if cls_id in best_f1_results:
            res = best_f1_results[cls_id]
            precision, recall, f1, thres = res['p'], res['r'], res['f1'], res['thres']
        else:
            precision, recall, f1, thres = 0.0, 0.0, 0.0, 0.0

        total_p += precision
        total_r += recall
        valid_classes += 1

        map50 = map_50_tensor[i].item() if i < len(map_50_tensor) else 0.0
        map5095 = map_50_95_tensor[i].item() if i < len(map_50_95_tensor) else 0.0

        csv_data.append({
            "Class": class_name,
            "mAP 50": round(map50, 4),
            "mAP 50-95": round(map5095, 4),
            "Best F1": round(f1, 4),
            "Best Conf": round(thres, 3),
            "Precision": round(precision, 4),
            "Recall": round(recall, 4)
        })

    avg_p = total_p / valid_classes if valid_classes > 0 else 0.0
    avg_r = total_r / valid_classes if valid_classes > 0 else 0.0

    csv_data.append({
        "Class": "GLOBAL (ALL)",
        "mAP 50": round(map_results['map_50'], 4),
        "mAP 50-95": round(map_results['map'], 4),
        "Best F1": "-",
        "Best Conf": "-",
        "Precision": round(avg_p, 4),
        "Recall": round(avg_r, 4)
    })

    df_results = pd.DataFrame(csv_data)
    print(f"\nVALIDATION RESULTS (Epoch {current_epoch}/{num_epochs})")
    print("="*95)
    print(df_results.to_string(index=False))
    print("="*95)

    return {
        'map50': map_results['map_50'],
        'map': map_results['map'],
        'precision': avg_p,
        'recall': avg_r
    }

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_total = 0
    steps = 0

    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()

        # Gradient Clipping to prevent explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()

        loss_total += losses.item()
        steps += 1

    return loss_total / max(steps, 1)

def run_experiment(condition_name, epochs=100, patience=15):
    print(f"\nSTARTING EXPERIMENT: {condition_name} | Patience: {patience}")

    # Dataset Setup
    train_transform = get_transforms(condition=condition_name.split(' ')[0])
    val_transform = get_transforms(condition='Raw')

    train_ds = YOLODataset(LOCAL_DATA_DIR, split='train', transforms=train_transform)
    val_ds = YOLODataset(LOCAL_DATA_DIR, split='val', transforms=val_transform)

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=2)

    # Model Setup
    model = get_model_mobilenetv3_small(NUM_CLASSES)
    model.to(DEVICE)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=0.0001, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Save Setup
    save_path = os.path.join(BASE_SAVE_DIR, condition_name)
    os.makedirs(save_path, exist_ok=True)

    history = []
    best_map = 0.0
    epochs_no_improve = 0
    total_start_time = time.time()

    print(f"Saving models to: {save_path}")

    for epoch in range(epochs):
        epoch_start = time.time()

        # Training
        train_loss = train_one_epoch(model, optimizer, train_loader, DEVICE, epoch)
        lr_scheduler.step()

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{epochs} | Loss: {train_loss:.4f} | LR: {current_lr:.6f}")

        # Evaluation
        eval_metrics = evaluate_and_print(model, val_loader, DEVICE, CLASS_NAMES, epoch+1, epochs)

        current_map50 = eval_metrics['map50']
        duration = time.time() - epoch_start

        # Checkpointing
        status_msg = ""
        if current_map50 > best_map:
            best_map = current_map50
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join(save_path, "best_model.pth"))
            status_msg = "Model Saved!"
        else:
            epochs_no_improve += 1
            status_msg = f"No improvement for {epochs_no_improve}/{patience} epochs"

        print(f"Summary: Loss: {train_loss:.4f} | Time: {duration:.1f}s | {status_msg}")

        # Logging
        epoch_data = {
            'epoch': epoch+1,
            'train_loss': train_loss,
            'mAP50_Global': current_map50,
            'mAP50-95_Global': eval_metrics['map'],
            'Precision_BestF1_Avg': eval_metrics['precision'],
            'Recall_BestF1_Avg': eval_metrics['recall'],
            'Time': duration
        }
        history.append(epoch_data)
        pd.DataFrame(history).to_csv(os.path.join(save_path, "metrics.csv"), index=False)

        if epochs_no_improve >= patience:
            print(f"\nEarly stopping triggered. Best mAP50: {best_map:.4f}")
            break

    total_duration = (time.time() - total_start_time) / 60
    print(f"\nTraining Finished. Best mAP50: {best_map:.4f} | Total Time: {total_duration:.1f}m")
    return best_map, total_duration

In [None]:
def plot_training_results(log_csv_path, title_suffix=""):
    if not os.path.exists(log_csv_path):
        print(f"Log file not found: {log_csv_path}")
        return

    df = pd.read_csv(log_csv_path)
    epochs = df['epoch']

    metrics_config = [
        ('train_loss', 'Train/Total Loss', '#1f77b4'),
        ('Precision_BestF1_Avg', 'Metrics/Precision (Avg)', '#ff7f0e'),
        ('Recall_BestF1_Avg', 'Metrics/Recall (Avg)', '#2ca02c'),
        ('mAP50_Global', 'Metrics/mAP 50', '#d62728'),
        ('mAP50-95_Global', 'Metrics/mAP 50-95', '#9467bd')
    ]

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'Training Results: {title_suffix}', fontsize=16, fontweight='bold')
    axes = axes.flatten()

    def smooth_curve(scalars, weight=0.6):
        last = scalars[0]
        smoothed = []
        for point in scalars:
            smoothed_val = last * weight + (1 - weight) * point
            smoothed.append(smoothed_val)
            last = smoothed_val
        return smoothed

    for i, (col_name, title, color) in enumerate(metrics_config):
        ax = axes[i]
        if col_name in df.columns:
            ax.plot(epochs, df[col_name], color=color, alpha=0.3, linewidth=1, label='Raw')
            smoothed_data = smooth_curve(df[col_name].values)
            ax.plot(epochs, smoothed_data, color=color, linewidth=2.5, label='Smooth')
            ax.set_title(title, fontsize=12, fontweight='bold')
            ax.set_xlabel('Epochs')
            ax.grid(True, linestyle='--', alpha=0.6)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

    fig.delaxes(axes[5])
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    save_path = log_csv_path.replace('.csv', '.png')
    plt.savefig(save_path, dpi=300)
    print(f"Graph saved to: {save_path}")
    plt.show()

if __name__ == "__main__":
    results = []

    # Run Raw Experiment
    map_raw, time_raw = run_experiment("Raw", epochs=100)
    results.append({'Condition': 'Raw', 'mAP': map_raw})

    # Run Tuned Experiment
    map_tuned, time_tuned = run_experiment("Tuned", epochs=100)
    results.append({'Condition': 'Tuned', 'mAP': map_tuned})

    # Plot Comparison
    df = pd.DataFrame(results)
    plt.figure(figsize=(8, 6))
    bars = plt.bar(df['Condition'], df['mAP'], color=['gray', 'firebrick'])

    plt.title("Comparison: Faster R-CNN (MobileNetV3)\nRaw vs Augmented")
    plt.ylabel("Mean Average Precision (mAP)")
    plt.ylim(0, 1.1)
    plt.grid(axis='y', alpha=0.3)

    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    plt.savefig(os.path.join(BASE_SAVE_DIR, 'final_result_map.png'))
    plt.show()

    # Generate Detailed Plots
    print("\nDisplaying Raw Results:")
    plot_training_results(os.path.join(BASE_SAVE_DIR, 'Raw', 'metrics.csv'), title_suffix="Faster R-CNN (RAW)")

    print("\nDisplaying Tuned Results:")
    plot_training_results(os.path.join(BASE_SAVE_DIR, 'Tuned', 'metrics.csv'), title_suffix="Faster R-CNN (TUNED)")