In [None]:
!pip install --q ultralytics supervision torch torchvision transformers

In [None]:
!gdown --q --folder 1n98ur5MdsKuRtXTHkeM4PEwklBkDaJe4
!unzip --q /content/Data/observing.zip
!unzip --q /content/Data/public_test.zip

In [None]:
# ============================================
# IMPORTS
# ============================================

import os
import json
import cv2
import yaml
import random
import numpy as np
import torch
import torch.nn as nn
from scipy.interpolate import interp1d
from concurrent.futures import ThreadPoolExecutor, as_completed
from ultralytics import YOLOWorld, YOLO


In [None]:

# ============================================
# CONFIGURATION (OPTIMIZED AND ENHANCED)
# ============================================

class Config:
    """Configuration for training pipeline"""
    # Dataset paths
    DATASET_ROOT = "train"
    ANNOTATIONS_PATH = os.path.join(DATASET_ROOT, "annotations/annotations.json")
    SAMPLES_DIR = os.path.join(DATASET_ROOT, "samples")
    WORK_DIR = "enhanced_mixed_dataset_v2"  # New directory

    # Training settings
    TRAIN_RATIO = 0.8
    IMG_EXT = "jpg"
    FRAME_STEP = 1
    NUM_WORKERS = 8

    # ========================================
    # IMPORTANT CHANGE: Smaller model
    # ========================================
    MODEL_WEIGHTS = "yolo12s.pt"  # ~11M params base

    # ALTERNATIVE STRATEGY:
    # If 'yolov8s' still misses videos 3 & 4, try 'yolov8m' again
    # (since it has seen them) and let HNM (code below) handle noise 5 & 6.
    # MODEL_WEIGHTS = "yolov8m-worldv2.pt"
    # ========================================

    CLASS_NAMES = ["target"]

    # Parameter limit
    PARAM_LIMIT = 50_000_000  # 50M params limit

    # Masked training settings
    ENABLE_MASKING = True
    NUM_AUGMENTATIONS_PER_PHASE = 2

    # ========================================
    # Hard Negative Mining settings
    # ========================================
    BACKGROUND_FRAME_RATIO = 0.1

    # Curriculum Learning settings
    CURRICULUM = {
        'phase1': {'epochs': (0, 2), 'mask_ratio': 0.10, 'strategy': 'random'},
        'phase2': {'epochs': (2, 5), 'mask_ratio': 0.20, 'strategy': 'random'},
        'phase3': {'epochs': (5, 10), 'mask_ratio': 0.30, 'strategy': 'random'}
    }

config = Config()
print("Configuration loaded (v2: HNM + Augmentations)")


In [None]:

# ============================================
# MODEL OPTIMIZER CLASS
# ============================================

class ModelOptimizer:
    """Optimize YOLO World model"""

    def __init__(self, model):
        self.model = model

    def count_parameters(self):
        """Count number of parameters"""
        model = self.model.model if hasattr(self.model, 'model') else self.model
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return {'total': total, 'trainable': trainable, 'non_trainable': total - trainable}

    def print_model_summary(self, stage=""):
        """Print model statistics"""
        params = self.count_parameters()
        print(f"\n{'='*60}\nMODEL STATISTICS {stage}\n{'='*60}")
        print(f"   Total params: {params['total']:,}")
        print(f"   Trainable: {params['trainable']:,}")
        if params['total'] > config.PARAM_LIMIT:
            print(f"\n   EXCEEDED LIMIT: {params['total'] - config.PARAM_LIMIT:,} params over")
        else:
            print(f"\n   WITHIN LIMIT: {config.PARAM_LIMIT - params['total']:,} params remaining")
        print(f"{'='*60}\n")
        return params

    def apply_magnitude_pruning(self, prune_ratio=0.3):
        """Magnitude-based pruning"""
        print(f"\nApplying {prune_ratio*100:.0f}% Magnitude Pruning...")
        model = self.model.model if hasattr(self.model, 'model') else self.model
        all_weights = []
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                if hasattr(module, 'weight') and module.weight is not None:
                    all_weights.append(module.weight.data.abs().flatten())
        if not all_weights: 
            return
        all_weights_tensor = torch.cat(all_weights)
        threshold = torch.quantile(all_weights_tensor, prune_ratio)
        pruned_count = 0
        total_count = 0
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                if hasattr(module, 'weight') and module.weight is not None:
                    weight = module.weight.data
                    mask = weight.abs() > threshold
                    weight.mul_(mask)
                    pruned_count += (~mask).sum().item()
                    total_count += weight.numel()
        print(f"   Pruned {pruned_count:,} / {total_count:,} weights")
        return self.model

print("Model optimizer class loaded")


In [None]:

# ============================================
# 1. MASKING STRATEGIES
# ============================================
class FrameMaskingStrategy:
    """Strategies for masking frames to create self-supervised learning"""
    
    @staticmethod
    def random_mask(total_frames, mask_ratio=0.3):
        """Randomly mask frames based on ratio"""
        if total_frames < 3: 
            return []
        num_mask = max(1, int(total_frames * mask_ratio))
        num_mask = min(num_mask, total_frames - 2)
        if num_mask <= 0: 
            return []
        try:
            return sorted(random.sample(range(1, total_frames-1), num_mask))
        except ValueError: 
            return []

    @staticmethod
    def span_mask(total_frames, mask_ratio=0.3):
        """Mask continuous spans of frames"""
        if total_frames < 5: 
            return FrameMaskingStrategy.random_mask(total_frames, mask_ratio)
        span_length = max(2, int(total_frames * 0.1))
        num_spans = max(1, int(total_frames * mask_ratio / span_length))
        masked_indices = []
        for _ in range(num_spans):
            start = random.randint(1, max(2, total_frames - span_length - 1))
            masked_indices.extend(range(start, min(start + span_length, total_frames-1)))
        return sorted(list(set(masked_indices)))

    @staticmethod
    def keyframe_mask(frame_boxes, mask_ratio=0.3):
        """Mask frames with high motion (keyframes)"""
        if len(frame_boxes) < 3: 
            return []
        motion_scores = []
        frame_indices = sorted(frame_boxes.keys())
        for i in range(1, len(frame_indices) - 1):
            prev_frame, curr_frame = frame_indices[i-1], frame_indices[i]
            if not frame_boxes.get(prev_frame) or not frame_boxes.get(curr_frame): 
                continue
            try:
                prev_box, curr_box = frame_boxes[prev_frame][0], frame_boxes[curr_frame][0]
                prev_cx, prev_cy = (prev_box['x1'] + prev_box['x2']) / 2, (prev_box['y1'] + prev_box['y2']) / 2
                curr_cx, curr_cy = (curr_box['x1'] + curr_box['x2']) / 2, (curr_box['y1'] + curr_box['y2']) / 2
                motion = np.sqrt((curr_cx - prev_cx)**2 + (curr_cy - prev_cy)**2)
                motion_scores.append((curr_frame, motion))
            except (IndexError, KeyError): 
                continue
        if not motion_scores: 
            return FrameMaskingStrategy.random_mask(len(frame_indices), mask_ratio)
        motion_scores.sort(key=lambda x: x[1], reverse=True)
        num_mask = max(1, int(len(motion_scores) * mask_ratio))
        return sorted([frame for frame, _ in motion_scores[:num_mask]])

print("Masking strategies loaded")


In [None]:

# ============================================
# 2. TEMPORAL INTERPOLATION
# ============================================
def interpolate_boxes(frame_boxes, masked_frames, method='cubic'):
    """Interpolate bounding boxes for masked frames"""
    all_frame_indices = sorted(frame_boxes.keys())
    visible_frames = [f for f in all_frame_indices if f not in masked_frames]
    if len(visible_frames) < 2: 
        return {}, all_frame_indices
    ground_truth = {}
    visible_data = [(f, frame_boxes[f][0]) for f in visible_frames if frame_boxes[f] and len(frame_boxes[f]) > 0]
    if len(visible_data) < 2: 
        return {}, visible_frames
    try:
        frames = [d[0] for d in visible_data]
        x1_vals, y1_vals = [d[1]['x1'] for d in visible_data], [d[1]['y1'] for d in visible_data]
        x2_vals, y2_vals = [d[1]['x2'] for d in visible_data], [d[1]['y2'] for d in visible_data]
        kind = 'linear'
        f_x1 = interp1d(frames, x1_vals, kind=kind, bounds_error=False, fill_value=(x1_vals[0], x1_vals[-1]))
        f_y1 = interp1d(frames, y1_vals, kind=kind, bounds_error=False, fill_value=(y1_vals[0], y1_vals[-1]))
        f_x2 = interp1d(frames, x2_vals, kind=kind, bounds_error=False, fill_value=(x2_vals[0], x2_vals[-1]))
        f_y2 = interp1d(frames, y2_vals, kind=kind, bounds_error=False, fill_value=(y2_vals[0], y2_vals[-1]))
        for frame_idx in masked_frames:
            if frame_idx in frame_boxes:
                ground_truth[frame_idx] = [{
                    'frame': frame_idx,
                    'x1': int(np.clip(f_x1(frame_idx), 0, 10000)), 
                    'y1': int(np.clip(f_y1(frame_idx), 0, 10000)),
                    'x2': int(np.clip(f_x2(frame_idx), 0, 10000)), 
                    'y2': int(np.clip(f_y2(frame_idx), 0, 10000))
                }]
    except Exception as e:
        print(f"Interpolation failed: {e}")
        return {}, visible_frames
    return ground_truth, visible_frames

print("Temporal interpolation function loaded")


In [None]:

# ============================================
# 3. REMOVE DUPLICATE BOXES
# ============================================
def remove_duplicate_boxes(boxes, iou_threshold=0.95):
    """Remove duplicate boxes based on IoU threshold"""
    if len(boxes) <= 1: 
        return boxes
    
    def calculate_iou(box1, box2):
        """Calculate Intersection over Union (IoU) between two boxes"""
        x1_min, y1_min, x1_max, y1_max = box1
        x2_min, y2_min, x2_max, y2_max = box2
        inter_x_min, inter_y_min = max(x1_min, x2_min), max(y1_min, y2_min)
        inter_x_max, inter_y_max = min(x1_max, x2_max), min(y1_max, y2_max)
        if inter_x_max < inter_x_min or inter_y_max < inter_y_min: 
            return 0.0
        inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min)
        box1_area = (x1_max - x1_min) * (y1_max - y1_min)
        box2_area = (x2_max - x2_min) * (y2_max - y2_min)
        union_area = box1_area + box2_area - inter_area
        return inter_area / union_area if union_area > 0 else 0.0
    
    boxes_xyxy = [(bb['x1'], bb['y1'], bb['x2'], bb['y2']) for bb in boxes]
    keep = [True] * len(boxes)
    for i in range(len(boxes)):
        if not keep[i]: 
            continue
        for j in range(i + 1, len(boxes)):
            if not keep[j]: 
                continue
            iou = calculate_iou(boxes_xyxy[i], boxes_xyxy[j])
            if iou > iou_threshold:
                area_i = (boxes[i]['x2'] - boxes[i]['x1']) * (boxes[i]['y2'] - boxes[i]['y1'])
                area_j = (boxes[j]['x2'] - boxes[j]['x1']) * (boxes[j]['y2'] - boxes[j]['y1'])
                if area_i >= area_j: 
                    keep[j] = False
                else: 
                    keep[i] = False
                    break
    return [boxes[i] for i in range(len(boxes)) if keep[i]]

print("Duplicate removal function loaded")


In [None]:

# ============================================
# 4. ENHANCED FRAME EXTRACTION (UPDATED WITH HNM)
# ============================================

def extract_frames_with_masking(
    video_id,
    ann_dict,
    mode="train",
    augmentation_id=0,
    mask_strategy='random',
    mask_ratio=0.3
):
    """
    Extract frames + labels with masked frame augmentation
    and Hard Negative Mining (add background frames)
    """
    video_dir = os.path.join(config.SAMPLES_DIR, video_id)
    video_path = os.path.join(video_dir, "drone_video.mp4")

    if not os.path.exists(video_path):
        return {'status': 'missing_video', 'video_id': video_id}

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        cap.release()
        return {'status': 'cannot_open', 'video_id': video_id}

    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

    # Build bbox_dict with deduplication
    bbox_dict = {}
    dup_removed = 0
    for interval in ann_dict.get(video_id, {}).get("annotations", []):
        for bb in interval.get("bboxes", []):
            bbox_dict.setdefault(bb["frame"], []).append(bb)
    for frame_idx in bbox_dict:
        original_count = len(bbox_dict[frame_idx])
        bbox_dict[frame_idx] = remove_duplicate_boxes(bbox_dict[frame_idx], iou_threshold=0.95)
        dup_removed += original_count - len(bbox_dict[frame_idx])

    # Apply masking strategy
    masked_frames = []
    ground_truth = {}
    if config.ENABLE_MASKING and mask_ratio > 0 and len(bbox_dict) > 4:
        masker = FrameMaskingStrategy()
        frame_indices = sorted(bbox_dict.keys())
        num_frames_with_labels = len(frame_indices)
        if mask_strategy == 'random':
            masked_idx_list = masker.random_mask(num_frames_with_labels, mask_ratio)
        elif mask_strategy == 'span':
            masked_idx_list = masker.span_mask(num_frames_with_labels, mask_ratio)
        elif mask_strategy == 'keyframe':
            masked_frames = masker.keyframe_mask(bbox_dict, mask_ratio)
            masked_idx_list = []
        else:
            masked_idx_list = []
        if masked_idx_list:
            masked_frames = [frame_indices[i] for i in masked_idx_list if i < len(frame_indices)]
        if masked_frames:
            ground_truth, _ = interpolate_boxes(bbox_dict, masked_frames, method='cubic')

    # Output directories
    if mode == "train":
        img_out = os.path.join(config.WORK_DIR, 'train', 'images')
        lbl_out = os.path.join(config.WORK_DIR, 'train', 'labels')
    else:
        img_out = os.path.join(config.WORK_DIR, 'val', 'images')
        lbl_out = os.path.join(config.WORK_DIR, 'val', 'labels')

    saved = 0
    masked_count = 0
    idx = 0

    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Only process frames according to FRAME_STEP
            if idx % config.FRAME_STEP == 0:

                # CASE 1: POSITIVE FRAME (exists in annotation)
                if idx in bbox_dict:
                    img_name = f"{video_id}_aug{augmentation_id:04d}_frame_{idx:06d}.{config.IMG_EXT}"
                    img_path = os.path.join(img_out, img_name)
                    txt_name = f"{video_id}_aug{augmentation_id:04d}_frame_{idx:06d}.txt"
                    txt_path = os.path.join(lbl_out, txt_name)

                    if idx in masked_frames and idx in ground_truth:
                        boxes_to_save = ground_truth[idx]
                        masked_count += 1
                    else:
                        boxes_to_save = bbox_dict[idx]

                    if boxes_to_save:
                        if not cv2.imwrite(img_path, frame):
                            continue
                        lines = []
                        for bb in boxes_to_save:
                            x1, y1, x2, y2 = bb["x1"], bb["y1"], bb["x2"], bb["y2"]
                            if x2 <= x1 or y2 <= y1: 
                                continue
                            x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
                            cx, cy = (x1 + x2) / 2 / w, (y1 + y2) / 2 / h
                            bw, bh = (x2 - x1) / w, (y2 - y1) / h
                            if not (0 <= cx <= 1 and 0 <= cy <= 1 and 0 < bw <= 1 and 0 < bh <= 1): 
                                continue
                            lines.append(f"0 {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}")

                        with open(txt_path, "w") as f:
                            f.write("\n".join(lines))
                        saved += 1

                # CASE 2: HARD NEGATIVE MINING (add background frames)
                # Only add background frames for TRAIN set
                elif mode == "train" and random.random() < config.BACKGROUND_FRAME_RATIO:
                    img_name = f"{video_id}_aug{augmentation_id:04d}_frame_{idx:06d}.{config.IMG_EXT}"
                    img_path = os.path.join(img_out, img_name)
                    txt_name = f"{video_id}_aug{augmentation_id:04d}_frame_{idx:06d}.txt"
                    txt_path = os.path.join(lbl_out, txt_name)

                    if cv2.imwrite(img_path, frame):
                        # Write empty label file
                        with open(txt_path, "w") as f:
                            f.write("")
                        saved += 1  # Still count as saved

            idx += 1
    finally:
        cap.release()

    return {
        'status': 'success', 
        'video_id': video_id, 
        'aug_id': augmentation_id,
        'frames_saved': saved, 
        'masked_frames': masked_count, 
        'duplicates_removed': dup_removed
    }

print("Frame extraction function loaded (V2: HNM enabled)")


In [None]:

# ============================================
# 5. CURRICULUM LEARNING CONTROLLER
# ============================================

class CurriculumController:
    """Control curriculum learning across epochs"""
    
    def __init__(self, curriculum_config):
        self.curriculum = curriculum_config
        self.phases = sorted(curriculum_config.items(), key=lambda x: x[1]['epochs'][0])
    
    def get_phase(self, epoch):
        """Get current phase for given epoch"""
        for phase_name, phase_config in self.phases:
            start_epoch, end_epoch = phase_config['epochs']
            if start_epoch <= epoch < end_epoch:
                return phase_name, phase_config
        if self.phases: 
            return self.phases[-1][0], self.phases[-1][1]
        return None, None
    
    def get_config(self, epoch):
        """Get configuration for given epoch"""
        phase_name, phase_config = self.get_phase(epoch)
        if phase_config is None: 
            return {'phase': 'unknown', 'mask_ratio': 0.0, 'strategy': 'random'}
        return {
            'phase': phase_name, 
            'mask_ratio': phase_config['mask_ratio'], 
            'strategy': phase_config['strategy']
        }
    
    def print_schedule(self):
        """Print curriculum learning schedule"""
        print("\n" + "="*60 + "\nCURRICULUM LEARNING SCHEDULE (FOR DATA GENERATION)\n" + "="*60)
        for phase_name, phase_config in self.phases:
            start, end = phase_config['epochs']
            print(f"\n{phase_name.upper()}: (Reference for Epochs {start}-{end})")
            print(f"  Mask Ratio: {phase_config['mask_ratio']*100:.0f}%")
            print(f"  Strategy: {phase_config['strategy']}")
            print(f"  Augmentations: {config.NUM_AUGMENTATIONS_PER_PHASE} versions")
        print(f"\n  Background HNM Ratio: {config.BACKGROUND_FRAME_RATIO*100:.0f}% (for all phases)")
        print("="*60 + "\n")

print("Curriculum controller loaded")

# ============================================
# MODEL ANALYSIS UTILITY
# ============================================
def analyze_model(model_path):
    """Detailed model analysis"""
    print("\n" + "="*60 + "\nMODEL ANALYSIS\n" + "="*60)
    model = YOLO(model_path)
    optimizer = ModelOptimizer(model)
    params = optimizer.count_parameters()
    print("\nLayer-wise Parameter Count:")
    print("-" * 60)
    layer_params = {}
    for name, module in model.model.named_modules():
        if len(list(module.children())) == 0:
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                layer_type = type(module).__name__
                layer_params.setdefault(layer_type, {'count': 0, 'params': 0})
                layer_params[layer_type]['count'] += 1
                layer_params[layer_type]['params'] += num_params
    sorted_layers = sorted(layer_params.items(), key=lambda x: x[1]['params'], reverse=True)
    for layer_type, info in sorted_layers[:10]:
        print(f"   {layer_type:20s}: {info['count']:3d} layers, {info['params']:12,d} params "
              f"({info['params']/params['total']*100:5.2f}%)")
    print("-" * 60 + f"\n   {'TOTAL':20s}: {params['total']:,} params\n")
    print(f"Parameter Limit Check:")
    if params['total'] <= config.PARAM_LIMIT:
        print(f"   Within limit: {params['total']:,} / {config.PARAM_LIMIT:,}")
    else:
        print(f"   Exceeds limit: {params['total']:,} / {config.PARAM_LIMIT:,}")
    print("="*60)

print("Model analysis function loaded")


In [None]:

# ============================================
# 6. MAIN PIPELINE (UPDATED WITH AUGMENTATIONS)
# ============================================

def main_pipeline():
    """Main training pipeline"""

    print("ENHANCED YOLO WORLD TRAINING PIPELINE (V2: HNM + AUG)")
    print("="*60)

    # Load annotations & Split train/val
    with open(config.ANNOTATIONS_PATH, "r") as f:
        annotations = json.load(f)
    video_ids = [a["video_id"] for a in annotations]
    ann_dict = {a["video_id"]: a for a in annotations}
    random.seed(42)
    random.shuffle(video_ids)
    split_idx = int(len(video_ids) * config.TRAIN_RATIO)
    train_videos, val_videos = video_ids[:split_idx], video_ids[split_idx:]
    print(f"Loaded {len(video_ids)} videos (Train: {len(train_videos)}, Val: {len(val_videos)})")

    # Setup curriculum
    curriculum = CurriculumController(config.CURRICULUM)
    curriculum.print_schedule()  # Will print HNM ratio as well

    # Create output directories
    os.makedirs(config.WORK_DIR, exist_ok=True)
    for subdir in ['train/images', 'train/labels', 'val/images', 'val/labels']:
        os.makedirs(os.path.join(config.WORK_DIR, subdir), exist_ok=True)

    # Data generation (HNM already integrated into extract function)
    print(f"\nPreparing all data augmentations in one go (HNM enabled)...")
    futures, stats_list = [], []
    global_aug_id_counter = 0
    with ThreadPoolExecutor(max_workers=config.NUM_WORKERS) as ex:
        # 1. Train data
        for phase_name, phase_config in curriculum.phases:
            print(f"  -> Submitting jobs for Phase: {phase_name.upper()}")
            mask_ratio, mask_strategy = phase_config['mask_ratio'], phase_config['strategy']
            for vid in train_videos:
                for _ in range(config.NUM_AUGMENTATIONS_PER_PHASE):
                    futures.append(ex.submit(
                        extract_frames_with_masking,
                        vid, ann_dict, "train", global_aug_id_counter, mask_strategy, mask_ratio
                    ))
                    global_aug_id_counter += 1
        # 2. Val data
        print(f"\n  -> Submitting jobs for VALIDATION data (No masking, No HNM)")
        for vid in val_videos:
            futures.append(ex.submit(
                extract_frames_with_masking,
                vid, ann_dict, "val", 0, 'random', 0.0
            ))
        # 3. Collect results
        print(f"\nWaiting for {len(futures)} total jobs to complete...")
        for i, fut in enumerate(as_completed(futures), 1):
            result = fut.result()
            stats_list.append(result)
            if i % 100 == 0 or i == len(futures):
                if result['status'] == 'success':
                    msg = (f"[{i}/{len(futures)}] {result['video_id']}_aug{result['aug_id']:04d}: "
                           f"{result['frames_saved']} frames")
                    if result['masked_frames'] > 0: 
                        msg += f" ({result['masked_frames']} masked)"
                    print(msg)
                else: 
                    print(f"[{i}/{len(futures)}] {result['video_id']}: {result['status']}")

    # Print statistics
    success_stats = [s for s in stats_list if s['status'] == 'success']
    total_frames = sum(s['frames_saved'] for s in success_stats)
    total_masked = sum(s['masked_frames'] for s in success_stats)
    print(f"\n{'='*60}\nTOTAL DATASET STATISTICS (ALL PHASES MIXED):\n"
          f"   Total frames saved: {total_frames}\n"
          f"   Total masked frames: {total_masked} ({total_masked/max(1,total_frames)*100:.1f}%)\n"
          f"   Total train augmentations: {global_aug_id_counter}\n{'='*60}")

    # Create data.yaml
    data_yaml = {
        "train": os.path.abspath(os.path.join(config.WORK_DIR, 'train', 'images')),
        "val": os.path.abspath(os.path.join(config.WORK_DIR, 'val', 'images')),
        "nc": 1, 
        "names": config.CLASS_NAMES
    }
    data_path = os.path.join(config.WORK_DIR, "data.yaml")
    with open(data_path, "w") as f: 
        yaml.dump(data_yaml, f)
    print(f"\ndata.yaml created at: {data_path}")

    # Training
    print("\n" + "="*60 + "\nSTARTING TRAINING (V2)\n" + "="*60)
    model = YOLO(config.MODEL_WEIGHTS)

    # Callback
    monitor = CurriculumController(config.CURRICULUM)
    def on_epoch_end(trainer):
        epoch = trainer.epoch + 1
        phase_config = monitor.get_config(epoch)
        if epoch == 1 or (epoch) % 5 == 0 or epoch == trainer.epochs:
            print(f"\n-- Epoch {epoch}/{trainer.epochs} -- "
                  f"Curriculum Phase (Reference): {phase_config['phase']} --")
    
    model.add_callback("on_epoch_end", on_epoch_end)
    results = model.train(
        data=data_path,
        epochs=15,              # Maximize available epochs
        imgsz=896,              # Reduce to 640 or 896. Don't use 1024 if GPU is weak, it will train slowly
        batch=36,               # Moderate batch size

        # FREEZE STRATEGY (Important)
        freeze=0,               # Freeze first 10 layers (Backbone)

        # OPTIMIZE CONVERGENCE SPEED
        lr0=0.01,               # Higher initial learning rate
        lrf=0.1,                # Less LR reduction at end of cycle
        optimizer="AdamW",      # AdamW converges faster than SGD
        warmup_epochs=1.0,      # Shorter warmup (only 1 epoch or 0.5)

        # DISABLE DIFFICULT AUGMENTATIONS (Let model learn basics quickly)
        erasing=0.0,            # IMPORTANT: Disable YOLO Erasing since we already mask manually
        dropout=0.0,
        mosaic=0.1,             # Disable mosaic (image stitching makes model harder)
        mixup=0.1,              # Disable mixup
        hsv_h=0.03,             # Light color adjustment only
        hsv_s=0.3,              # Keep saturation unchanged
        hsv_v=0.3,              # Keep brightness unchanged
        degrees=0.1,            # No rotation
        translate=0.0,          # Light translation
        scale=0.5,              # Light zoom in/out

        patience=5,
        save_period=5,
        workers=8,
        close_mosaic=10,
        # Other settings
        project=os.path.join(config.WORK_DIR, "runs"),
        name="fast_finetune_yolo",
        exist_ok=True,
        verbose=True
    )

    print("\n" + "="*60 + "\nTRAINING COMPLETE!\n" + "="*60)
    best_model_path = os.path.join(model.trainer.save_dir, 'weights', 'best.pt')
    print(f"Best model: {best_model_path}")

    return best_model_path

print("Main pipeline function loaded (V2: HSV Augs)")


In [None]:

# ============================================
# 7. INFERENCE
# ============================================

def run_inference(model_path, test_root, output_json, conf=0.35, iou=0.55):
    """Run inference on test set"""
    print("\n" + "="*60 + "\nRUNNING INFERENCE\n" + "="*60)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = YOLO(model_path)
    video_dirs = sorted([d for d in os.listdir(test_root) if os.path.isdir(os.path.join(test_root, d))])
    submission = []
    for i, vid in enumerate(video_dirs, 1):
        video_path = os.path.join(test_root, vid, "drone_video.mp4")
        if not os.path.exists(video_path): 
            continue
        bboxes_per_frame = []
        results = model.predict(
            source=video_path, conf=conf, iou=iou, imgsz=896,
            stream=True, verbose=False, device=device
        )
        frame_idx = 0
        for r in results:
            if r.boxes is not None and len(r.boxes) > 0:
                for (x1, y1, x2, y2) in r.boxes.xyxy.cpu().numpy():
                    bboxes_per_frame.append({
                        "frame": frame_idx, 
                        "x1": int(x1), 
                        "y1": int(y1), 
                        "x2": int(x2), 
                        "y2": int(y2)
                    })
            frame_idx += 1
        submission.append({
            "video_id": vid, 
            "detections": [{"bboxes": bboxes_per_frame}] if bboxes_per_frame else []
        })
        print(f"[{i}/{len(video_dirs)}] {vid}: {len(bboxes_per_frame)} detections")
    with open(output_json, "w") as f:
        json.dump(submission, f, indent=2, ensure_ascii=False)
    print(f"\nSaved submission to: {output_json}")

print("Inference function loaded")


In [None]:
# ============================================
# 8. MAIN EXECUTION
# ============================================

if __name__ == "__main__":
    print("""
    YOLO WORLD + MASKED + CURRICULUM (V2: HNM + AUG)
    Enhanced Pipeline for Video Object Detection
    """)
    # Check GPU
    print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"    GPU: {torch.cuda.get_device_name(0)}")

    # Configure paths
    config.DATASET_ROOT = "train"
    TEST_ROOT = "public_test/samples"
    OUTPUT_JSON = "submission_optimized_v2.json"  # New file name

    # Run main pipeline
    try:
        best_model_path = main_pipeline()

        # Analyze best model
        if best_model_path:
            analyze_model(best_model_path)

        # Run inference
        print("\n" + "="*60)
        if os.path.exists(TEST_ROOT) and best_model_path:
            print("Auto-running inference on test set...")
            run_inference(
                model_path=best_model_path,
                test_root=TEST_ROOT,
                output_json=OUTPUT_JSON,
                conf=0.35  # Keep conf=0.35 for comparison
            )
        else:
            print("Skipping inference: Test set path not found or training failed.")

    except Exception as e:
        print(f"\nAn error occurred in the pipeline: {e}")
        import traceback
        traceback.print_exc()