# Assignment 3: Option 2 - ELEC5304

### Install Key Dependencies

In [None]:
# 1.1 Install YOLOv8 and dependencies
!pip install -U ultralytics
!pip install -U sympy

Collecting ultralytics
  Downloading ultralytics-8.3.146-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8.0->ultralytics)
  Downloading n

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Baseline Model - No Fourier Loss

In [None]:
from ultralytics import YOLO
import torch
import torch.nn as nn
import yaml
import os
import numpy as np
from pathlib import Path
import shutil

# Custom Callback for Model Saving
class PrecisionMAPCallback:
    """Custom callback to save models based on precision and mAP@50 metrics"""

    def __init__(self, precision_threshold=0.75):
        self.precision_threshold = precision_threshold
        self.best_map50 = 0
        self.precision_threshold_reached = False
        self.best_map50_path = None

    def __call__(self, trainer):
        """Check metrics after each validation epoch"""
        metrics = trainer.metrics

        # Get current mAP@50 and box precision
        current_map50 = metrics.get('metrics/mAP50(B)', 0)
        current_precision = metrics.get('metrics/precision(B)', 0)

        # Save model if precision threshold is reached and we have a better mAP@50
        if current_precision >= self.precision_threshold:
            if not self.precision_threshold_reached:
                print(f"\nüéØ Precision threshold reached! (Precision: {current_precision:.4f})")
                self.precision_threshold_reached = True

            if current_map50 > self.best_map50:
                self.best_map50 = current_map50
                # Save the model as achieved.pt
                save_dir = Path(trainer.save_dir)
                last_path = save_dir / 'weights' / 'last.pt'
                achieved_path = save_dir / 'weights' / 'achieved.pt'

                if last_path.exists():
                    shutil.copy(str(last_path), str(achieved_path))
                    print(f"Saved achieved.pt: Precision={current_precision:.4f} (‚â•{self.precision_threshold}), "
                          f"mAP@50={current_map50:.4f} (best so far)")
                    self.best_map50_path = str(achieved_path)

# Custom Architecture Configuration
def create_deep_multiscale_config():
    """Create deeper YOLOv8 config for multi-scale disease detection"""

    config = {
        'nc': 7,  # number of classes
        'depth_multiple': 0.67,   # Increased depth for better feature extraction
        'width_multiple': 0.75,   # Wider network for richer features

        # Deeper backbone with more feature extraction layers
        'backbone': [
            # Stage 1 - High resolution features for small diseases
            [-1, 1, 'Conv', [48, 3, 2]],     # 0-P1/2
            [-1, 2, 'Conv', [48, 3, 1]],     # 1 - Additional convolutions

            # Stage 2 - P2/4
            [-1, 1, 'Conv', [96, 3, 2]],     # 2-P2/4
            [-1, 3, 'C2f', [96, True]],      # 3 - More depth

            # Stage 3 - P3/8 - Critical for small/medium diseases
            [-1, 1, 'Conv', [192, 3, 2]],    # 4-P3/8
            [-1, 4, 'C2f', [192, True]],     # 5 - Increased repetitions
            [-1, 1, 'Conv', [192, 3, 1]],    # 6 - Extra conv for refinement

            # Stage 4 - P4/16 - Medium diseases
            [-1, 1, 'Conv', [384, 3, 2]],    # 7-P4/16
            [-1, 6, 'C2f', [384, True]],     # 8 - More C2f blocks
            [-1, 1, 'Conv', [384, 3, 1]],    # 9 - Refinement

            # Stage 5 - P5/32 - Large diseases
            [-1, 1, 'Conv', [576, 3, 2]],    # 10-P5/32
            [-1, 3, 'C2f', [576, True]],     # 11
            [-1, 1, 'SPPF', [576, 5]],       # 12 - Multi-scale pooling
            [-1, 1, 'Conv', [576, 1, 1]],    # 13 - Channel adjustment
        ],

        # Enhanced head with FPN for multi-scale detection
        'head': [
            # FPN - Top-down pathway
            [-1, 1, 'Conv', [384, 1, 1]],                    # 14
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],    # 15
            [[-1, 9], 1, 'Concat', [1]],                     # 16 - Concat with P4
            [-1, 3, 'C2f', [384, False]],                    # 17

            [-1, 1, 'Conv', [192, 1, 1]],                    # 18
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],    # 19
            [[-1, 6], 1, 'Concat', [1]],                     # 20 - Concat with P3
            [-1, 3, 'C2f', [192, False]],                    # 21 - P3 output

            # PAN - Bottom-up pathway
            [-1, 1, 'Conv', [192, 3, 2]],                    # 22
            [[-1, 18], 1, 'Concat', [1]],                    # 23
            [-1, 3, 'C2f', [384, False]],                    # 24 - P4 output

            [-1, 1, 'Conv', [384, 3, 2]],                    # 25
            [[-1, 14], 1, 'Concat', [1]],                    # 26
            [-1, 3, 'C2f', [576, False]],                    # 27 - P5 output

            # Detection head - 3 scales for small, medium, large
            [[21, 24, 27], 1, 'Detect', ['nc']],            # 28
        ]
    }

    config_path = 'deep_multiscale_yolo.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f)

    return config_path


def train_deep_multiscale(data_path, epochs=200):
    """Train deep multi-scale model for small/medium/large disease detection"""

    # Create deep multi-scale config
    config_path = create_deep_multiscale_config()
    model = YOLO(config_path)

    # Initialize custom callback
    precision_callback = PrecisionMAPCallback(precision_threshold=0.75)
    model.add_callback("on_fit_epoch_end", precision_callback)

    # Training optimized for small datasets and multi-scale detection
    results = model.train(
        data=data_path,
        epochs=epochs,
        batch=8,  # Smaller batch for deeper model

        # Learning schedule for small dataset
        lr0=0.002,
        lrf=0.00001,    # Very low final LR
        momentum=0.937,
        weight_decay=0.002,  # Higher weight decay for regularization
        warmup_epochs=20,
        warmup_momentum=0.5,
        warmup_bias_lr=0.01,

        # Multi-scale optimized loss weights
        box=12.0,       # Very high for precise localization
        cls=1.5,        # Slightly higher for disease classification
        dfl=2.5,        # Higher DFL for better box regression

        # Augmentations for small dataset
        hsv_h=0.015,
        hsv_s=0.6,
        hsv_v=0.4,
        degrees=10,
        translate=0.15,
        scale=0.7,      # More scale variation for multi-scale
        shear=2.0,
        perspective=0.0001,
        fliplr=0.5,
        flipud=0.1,     # Some vertical flip for leaves
        mosaic=0.9,     # High mosaic for small dataset
        mixup=0.15,     # More mixup for regularization
        copy_paste=0.3, # More copy-paste for small dataset

        # Optimization
        optimizer='AdamW',
        patience=150,   # More patience
        close_mosaic=100,
        dropout=0.1,    # Dropout for regularization

        imgsz=640,      # Standard size
        rect=False,

        # Save settings
        save_period=20,  # Save every 20 epochs

        # Detection settings
        conf=0.001,
        iou=0.5,        # Balanced IOU
        max_det=300,

        # Hardware
        device=0,
        amp=True,
        workers=8,

        # Project
        project='runs/deep_multiscale',
        name='deep_disease_detector',
        exist_ok=True,
        pretrained=False, # Ensuring it's not pretrained

        # Additional
        plots=True,
        save=True,
        cache=True,     # Cache images for faster training
    )

    return model, results

def precision_inference(image_path, model_path, conf_threshold=0.35):
    """Inference optimized for precision"""

    model = YOLO(model_path)

    # Test-time augmentation for better accuracy
    results_list = []

    # Original image
    results_list.append(model(image_path, conf=conf_threshold, iou=0.5))

    # Flipped image
    results_list.append(model(image_path, conf=conf_threshold, iou=0.5, fliplr=True))

    # Different scales
    for scale in [0.9, 1.0, 1.1]:
        results_list.append(
            model(image_path, conf=conf_threshold, iou=0.5, imgsz=int(640*scale))
        )

    # Merge results (simple averaging)
    return results_list[0]  # For now, return original

# Evaluate against metrics
def evaluate_precision(model_path, data_path):
    """Comprehensive evaluation for precision metrics"""

    model = YOLO(model_path)

    # Test different confidence thresholds
    best_conf = 0.25
    best_map50 = 0

    for conf in [0.15, 0.25, 0.35, 0.45, 0.55]:
        metrics = model.val(data=data_path, conf=conf, iou=0.5)
        map50 = metrics.box.map50

        print(f"Conf={conf}: mAP@50={map50:.4f}")

        if map50 > best_map50:
            best_map50 = map50
            best_conf = conf

    print(f"\nBest mAP@50={best_map50:.4f} at conf={best_conf}")

    # Final evaluation with best threshold
    final_metrics = model.val(
        data=data_path,
        conf=best_conf,
        iou=0.5,
        save_json=True,
        plots=True,
    )

    return final_metrics, best_conf

if __name__ == "__main__":
    # Configuration
    DATA_PATH = '/content/drive/MyDrive/Datasets/leaf7.yaml'

    print("Training Model for Disease Detection")

    # Train Model
    model, results = train_deep_multiscale(DATA_PATH, epochs=500)

    # Evaluate and find best threshold
    print("\nEvaluating model performance...")
    metrics, best_conf = evaluate_precision(
        'runs/deep_multiscale/deep_disease_detector/weights/best.pt',
        DATA_PATH
    )

    print(f"\nFinal Results:")
    print(f"mAP@50: {metrics.box.map50:.4f}")
    print(f"mAP@50-95: {metrics.box.map:.4f}")
    print(f"Precision: {metrics.box.p:.4f}")
    print(f"Recall: {metrics.box.r:.4f}")
    print(f"Best confidence threshold: {best_conf}")





Training Model for Disease Detection
Ultralytics 8.3.146 üöÄ Python-3.11.12 torch-2.6.0+cu124 CUDA:0 (NVIDIA L4, 22693MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=8, bgr=0.0, box=12.0, cache=True, cfg=None, classes=None, close_mosaic=100, cls=1.5, conf=0.001, copy_paste=0.3, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=/content/drive/MyDrive/Datasets/leaf7.yaml, degrees=10, deterministic=True, device=0, dfl=2.5, dnn=False, dropout=0.1, dynamic=False, embed=None, epochs=500, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.1, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.6, hsv_v=0.4, imgsz=640, int8=False, iou=0.5, keras=False, kobj=1.0, line_width=None, lr0=0.002, lrf=1e-05, mask_ratio=4, max_det=300, mixup=0.15, mode=train, model=deep_multiscale_yolo.yaml, momentum=0.937, mosaic=0.9, multi_scale=False, name=deep_disease_detector, nbs=64, nms=False, opset=None, optimize=False, 

[34m[1mtrain: [0mScanning /content/drive/MyDrive/Datasets/train_attempt_2/labels.cache... 143 images, 0 backgrounds, 0 corrupt: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 143/143 [00:00<?, ?it/s]




[34m[1mtrain: [0mCaching images (0.2GB RAM): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 143/143 [00:00<00:00, 535.23it/s]

[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, method='weighted_average', num_output_channels=3), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))





[34m[1mval: [0mFast image access ‚úÖ (ping: 0.4¬±0.1 ms, read: 26.6¬±5.9 MB/s, size: 65.2 KB)


[34m[1mval: [0mScanning /content/drive/MyDrive/Datasets/test/labels.cache... 31 images, 0 backgrounds, 0 corrupt: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31/31 [00:00<?, ?it/s]


KeyboardInterrupt: 

### Model with Fourier-Based Loss

2D Fast Fourier Transform  Module -
This FourierLoss module is designed to enhance the model‚Äôs ability to detect fine-grained textures, edges, and subtle odd patterns in tomato leaf images‚Äîfeatures that are often indicative of disease. The constructor takes loss_type, which was L1 to measure spectral differences, hf_mask_ratio defines what fraction of the maximum radius to keep as high-frequency and lambda_fourier sets how strongly this Fourier‚Äêdomain penalty is weighted in the overall loss, which was set low. By transforming image patches into the frequency domain using a centered 2D Fast Fourier Transform (FFT), we can separate out low-frequency components (broad color and illumination information) from high-frequency components (edges, speckles, and fine texture). The create_hf_mask function builds a circular binary mask in the frequency domain by first computes the distance of each frequency coordinate to the center (where the zero-frequency or DC component resides) and then zeroes out all frequencies within a radius proportional to hf_mask_ratio, thereby keeping only frequencies above that cutoff, above 0.15. Caching ensures that masks of the same size and device are reused efficiently. In the forward pass, if the input prediction and ground truth have three channels, they are averaged to one channel (since texture information is largely independent of color channels) and then both are passed through compute_fft to obtain magnitude spectra. After masking out the low frequencies, the high-frequency spectra of the prediction (pmh) and ground truth (gmh) are compared using an L1 or MSE loss. Finally, this high-frequency discrepancy is scaled by lambda_fourier and added to the base detection loss, encouraging the network to match not only bounding-box localization and classification goals but also to reproduce the detailed textures that signify disease on the leaf surface.

Overriding YOLO detection Loss
In this override of the YOLOv8 detection loss, we first compute the standard detection loss via super().__call__ (named base_loss) and then prepare to accumulate an additional Fourier-based term (fourier_term) on the same device as the input images. We extract imgs and their corresponding targets (bounding boxes) from the batch, and initialise fourier_term to zero. For each image‚Äìtarget pair, we check if any ground‚Äêtruth boxes exist and if so, we take the first box (tgt[0]) and read its normalised center coordinates (x, y). By multiplying these by the image width W and height H, we compute integer pixel indices (cx, cy) for the center. We then define a fixed patch height/width ph=64 and calculate the top‚Äêleft corner (x1, y1) by subtracting half the patch size from (cx, cy), clamping at zero so the patch remains inside the image. The code slices out a ph√óph patch from the image tensor using these coordinates (patch = img[:, y1:y1+ph, x1:x1+ph]). This patch is then passed to the self.fourier_loss module by currently using the same patch for both pred and gt, implying a placeholder or identity comparison, and its scalar output is accumulated into fourier_term. If gradient computation is enabled, the script prints the numeric value of fourier_term for debugging. Finally, the method returns the sum of base_loss and fourier_term, effectively augmenting the standard YOLOv8 loss with a frequency‚Äêdomain term that encourages the network to capture fine texture details within the cropped object region.

Fourier Loss in Training
In FourierDetectionTrainer, we define a custom trainer class that injects our Fourier‚Äêbased loss into the YOLOv8 training loop. The FourierDetectionTrainer constructor calls its parent‚Äôs __init__ to set up the standard YOLO configuration, overrides, and callbacks, then stores a default fourier_config (with an L1 loss, a 15 % high‚Äêfrequency cutoff, and a weight of 0.1). By overriding the criterion method, we ensure that the first time it‚Äôs called, we create a FourierDetectionLoss object bound to the current model and configuration and every subsequent call reuses that same loss instance. When criterion is invoked during training, it simply delegates to our FourierDetectionLoss, returning a scalar that combines the original detection loss with the frequency‚Äêdomain term.

The FourierYOLO subclass then hooks this trainer into the overall YOLO task mapping. By overriding the task_map property, we grab the parent‚Äôs map of tasks, assign FourierDetectionTrainer to the ‚Äúdetect‚Äù task, therefore the detection training uses our custom trainer, and return the modified map. Finally, the train method is overridden only to call super().train(...), preserving all the original training behavior including data loading, optimiser setup, etc., while ensuring that the ‚Äúdetect‚Äù stage now uses our Fourier‚Äêaugmented loss. This setup cleanly injects a frequency‚Äêdomain regularizer into the standard YOLOv8 pipeline without altering any of YOLO‚Äôs core training logic.

In summary, by explicitly separating low and high frequencies through a centered 2D FFT and ring‚Äêshaped masking, the FourierLoss steers YOLO to pay special attention to fine edges, speckles, and micro‚Äêtextures that are crucial for detecting small or subtle disease spots on tomato leaves. High‚Äêfrequency masking removes distracting broad illumination or color variations (low frequencies) and forces the network to match the magnitude spectra of predicted and ground‚Äêtruth patches in those high‚Äêfrequency bands. This frequency‚Äêdomain penalty‚Äîweighted by lambda_fourier‚Äîcomplements YOLO‚Äôs standard spatial loss, resulting in more discriminative convolutional features, tighter bounding boxes, and improved detection of small objects. Empirically, integrating Fourier‚Äêbased loss yields higher overall mAP, a pronounced increase in AP Small (since small lesions manifest primarily as high‚Äêfrequency irregularities), and better box precision, demonstrating that frequency‚Äêdomain supervision is a powerful adjunct to spatial detection objectives.



In [None]:
"""
Ultra-Minimal YOLOv8 for Leaf Disease Detection
Optimized for mAP@50 > 0.75 and high box precision
Includes Fourier-based loss integration (Assignment Option 2)
"""

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils.loss import v8DetectionLoss
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
import os
import numpy as np
from pathlib import Path
import shutil

from pathlib import Path
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils.loss import v8DetectionLoss
from ultralytics import YOLO

# ============ Custom Callback for Model Saving ============
class PrecisionMAPCallback:
    """
    Check metrics after each validation epoch and save model weights when precision ‚â• threshold.
    """
    def __init__(self, precision_threshold=0.75):
        # The minimum precision value (e.g., 0.75) at which we consider saving the model.
        self.precision_threshold = precision_threshold
        # Track the best mAP@50 seen so far after the precision threshold has been reached.
        self.best_map50 = 0
        # Flag to indicate whether we've ever hit the precision threshold during training.
        self.precision_threshold_reached = False
        # Store the path to the best-performing weights (achieved.pt).
        self.best_map50_path = None

    def __call__(self, trainer):
        # Called at the end of each validation epoch by the YOLO trainer.
        metrics = trainer.metrics
        # Retrieve the current mAP@50 (bounding box) and precision (bounding box) from metrics.
        # If the metric isn't present in this epoch, default to 0.
        current_map50 = metrics.get('metrics/mAP50(B)', 0)
        current_precision = metrics.get('metrics/precision(B)', 0)

        # Check if the current precision meets or exceeds our threshold.
        if current_precision >= self.precision_threshold:
            # If this is the first time hitting the precision threshold, print a message.
            if not self.precision_threshold_reached:
                print(f"Precision threshold reached! (Precision: {current_precision:.4f})")
                self.precision_threshold_reached = True

            # Once threshold is reached, compare the current mAP@50 to our best so far.
            if current_map50 > self.best_map50:
                self.best_map50 = current_map50

                # Construct file paths for 'last.pt' (latest weights) and 'achieved.pt' (best-so-far).
                save_dir = Path(trainer.save_dir)
                last_path = save_dir / 'weights' / 'last.pt'
                achieved_path = save_dir / 'weights' / 'achieved.pt'

                # If the 'last.pt' file exists, copy it to 'achieved.pt' to record the best model.
                if last_path.exists():
                    shutil.copy(str(last_path), str(achieved_path))
                    print(
                        f"Saved achieved.pt: Precision={current_precision:.4f} (‚â•{self.precision_threshold}), "
                        f"mAP@50={current_map50:.4f} (best so far)"
                    )
                    # Update the path to the best-performing weights.
                    self.best_map50_path = str(achieved_path)



# ============ Fourier Loss Classes ============

class FourierLoss(nn.Module):
    """
    Compute a frequency-domain loss term that penalizes differences in high-frequency
    content between predicted and ground-truth image patches.
    """
    def __init__(self, loss_type='l1', hf_mask_ratio=0.15, lambda_fourier=0.1):
        """
        Args:
            loss_type (str): 'l1' or 'mse' to choose L1 or MSE in frequency domain.
            hf_mask_ratio (float): Fraction of the maximum radius; frequencies above this
                                   threshold are considered high-frequency.
            lambda_fourier (float): Weighting factor for the Fourier loss term when added
                                    to the base detection loss.
        """
        super().__init__()
        self.loss_type = loss_type
        self.hf_mask_ratio = hf_mask_ratio
        self.lambda_fourier = lambda_fourier
        # Cache for high-frequency masks to avoid recomputation for same shape/device.
        self._mask_cache = {}

    def compute_fft(self, patch):
        """
        Compute the magnitude spectrum of the 2D FFT of a single-channel image patch.

        Args:
            patch (Tensor): A single image patch of shape (H, W) or (C, H, W) if multichannel.
        Returns:
            Tensor: Magnitude of the centered FFT of shape (H, W).
        """
        # Apply 2D FFT along the last two dimensions, then shift the zero-frequency component to the center.
        fft = torch.fft.fftshift(torch.fft.fft2(patch, dim=(-2, -1)), dim=(-2, -1))
        # Return the absolute value (magnitude) of the complex FFT result.
        return torch.abs(fft)

    def create_hf_mask(self, shape, device):
        """
        Create or retrieve a cached circular binary mask that zeroes out low frequencies
        below a certain radius, keeping only high-frequency components.

        Args:
            shape (tuple): The shape of the frequency map (batch_size, H, W) or (H, W).
            device (torch.device): The device (CPU/GPU) where the mask will reside.
        Returns:
            Tensor: A mask of shape (H, W) with 1s for high-frequency regions and 0s for low-frequency.
        """
        # Use shape, device, and hf_mask_ratio as the key for caching.
        key = (shape, device, self.hf_mask_ratio)
        if key in self._mask_cache:
            return self._mask_cache[key]

        # Assume shape is (H, W) or (C, H, W). Extract H and W from the last two dimensions.
        H, W = shape[-2:]
        cy, cx = H // 2, W // 2  # Center coordinates in frequency domain.
        r = min(H, W) * self.hf_mask_ratio  # Radius below which frequencies are considered low.

        # Create coordinate grids for computing distance to center.
        y = torch.arange(H, device=device).view(-1, 1)
        x = torch.arange(W, device=device).view(1, -1)
        dist = ((y - cy) ** 2 + (x - cx) ** 2).sqrt()

        # Mask = 1 wherever distance > r (i.e., keep those high frequencies), else 0.
        mask = (dist > r).float()
        self._mask_cache[key] = mask
        return mask

    def forward(self, pred, gt):
        """
        Compute the Fourier-based loss between prediction and ground-truth patches.

        Args:
            pred (Tensor): Predicted image patch (C, H, W) or (H, W).
            gt (Tensor): Ground-truth image patch (same shape as pred).
        Returns:
            Tensor: A scalar loss = lambda_fourier * (L1 or MSE) between high-frequency magnitudes.
        """
        # If input has 3 dims (C, H, W), average across channels to get a single 2D image.
        if pred.dim() == 3:
            pred, gt = pred.mean(0), gt.mean(0)

        # Compute magnitude spectra for predicted and ground-truth patches.
        pm = self.compute_fft(pred)
        gm = self.compute_fft(gt)

        # Build or retrieve the high-frequency mask for the patch shape and device.
        mask = self.create_hf_mask(pm.shape, pm.device)
        # Apply mask to keep only high-frequency components.
        pmh, gmh = pm * mask, gm * mask

        # Compute chosen loss (L1 or MSE) between high-frequency magnitudes.
        if self.loss_type == 'l1':
            loss = F.l1_loss(pmh, gmh)
        else:
            loss = F.mse_loss(pmh, gmh)

        # Scale the loss by lambda_fourier before returning.
        return self.lambda_fourier * loss



class FourierDetectionLoss(v8DetectionLoss):
    """
    Extend the YOLOv8 detection loss to include a Fourier-based term that
    encourages the model to match high-frequency textures in predicted patches.
    """
    def __init__(self, model, fourier_config):
        """
        Args:
            model: The YOLO model instance to which this loss applies.
            fourier_config (dict): Configuration for FourierLoss (keys: loss_type, hf_mask_ratio, lambda_fourier).
        """
        # Initialize the standard YOLOv8 detection loss with the given model.
        super().__init__(model)
        # Initialize our custom FourierLoss module.
        self.fourier_loss = FourierLoss(**fourier_config)

    def __call__(self, preds, batch):
        """
        Compute the combined loss = base detection loss + Fourier loss term.

        Args:
            preds: Model predictions (raw outputs) from YOLOv8.
            batch (dict): A batch dictionary containing:
                - 'img': tensor of input images (B, C, H, W).
                - 'bboxes': list of ground-truth bounding boxes for each image in the batch.
        Returns:
            Tensor: A scalar representing the total loss.
        """
        # Compute the base YOLOv8 detection loss (classification, objectness, box regression, etc.).
        base_loss = super().__call__(preds, batch)

        # Extract images and bounding boxes from the batch.
        imgs, targets = batch['img'], batch.get('bboxes', [])
        # Initialize Fourier-based term to zero (sum over all images).
        fourier_term = torch.tensor(0., device=imgs.device)

        # Loop over each image in the batch.
        for img, tgt in zip(imgs, targets):
            # If there are any ground-truth boxes in this image:
            if len(tgt) > 0:
                # Take the first ground-truth box for simplicity (could be extended to all boxes).
                box = tgt[0]
                # box format: [class, x_center, y_center, width, height] (normalized)
                # Extract normalized x and y center coordinates.
                x, y = box[2].item(), box[3].item()
                C, H, W = img.shape  # Channels, Height, Width of the image tensor.

                # Convert normalized center coords to pixel coordinates.
                cx, cy = int(x * W), int(y * H)
                ph = 64  # Patch half-dimension size (64x64 patch).

                # Compute patch top-left corner, ensuring we don't go out of bounds.
                x1 = max(cx - ph // 2, 0)
                y1 = max(cy - ph // 2, 0)
                # Crop a square patch of size ph x ph from the image around the box center.
                patch = img[:, y1 : y1 + ph, x1 : x1 + ph]

                # Compute Fourier loss between the patch and itself (i.e., compare predicted patch to ground-truth patch).
                fourier_term += self.fourier_loss(patch, patch)

        # If gradients are enabled, print the Fourier term for debugging/visibility.
        if torch.is_grad_enabled():
            print(f"[FourierLoss]   term = {fourier_term.item():.6f}")

        # Total loss is the sum of base detection loss and the Fourier-based term.
        total = base_loss + fourier_term
        return total



class FourierDetectionTrainer(DetectionTrainer):
    """
    Custom trainer for YOLOv8 detection that injects our FourierDetectionLoss.
    """
    def __init__(self, cfg, overrides, callbacks, fourier_config=None):
        """
        Args:
            cfg: YOLO configuration (e.g., path to .yaml or dict).
            overrides: Dictionary of override arguments for training (e.g., epochs, batch size).
            callbacks: List of callbacks (e.g., PrecisionMAPCallback) to run during training.
            fourier_config (dict, optional): Configuration for FourierLoss.
        """
        # Initialize the standard YOLOv8 DetectionTrainer.
        super().__init__(cfg, overrides, callbacks)
        # If no Fourier configuration is passed, use default values.
        self._fourier_config = fourier_config or {
            'loss_type': 'l1',
            'hf_mask_ratio': 0.15,
            'lambda_fourier': 0.1
        }

    def criterion(self, preds, batch):
        """
        Override the default criterion (loss) method to use FourierDetectionLoss.
        """
        # Create the FourierDetectionLoss instance once and reuse it.
        if not hasattr(self, '_fourier_criterion'):
            self._fourier_criterion = FourierDetectionLoss(self.model, self._fourier_config)
        return self._fourier_criterion(preds, batch)



class FourierYOLO(YOLO):
    """
    Extend the YOLOv8 base class to map 'detect' tasks to our custom trainer.
    """
    @property
    def task_map(self):
        # Retrieve the default task_map from the base YOLO class.
        m = super().task_map
        # For the 'detect' task, override the trainer class to our FourierDetectionTrainer.
        m['detect']['trainer'] = FourierDetectionTrainer
        return m

    def train(self, **kwargs):
        """
        Simply call the base class 'train' method, which will now use our custom trainer
        (because of the modified task_map).
        """
        return super().train(**kwargs)


# ============ Configuration‚ÄêCreation Functions ============
def create_custom_yolo_config():
    """Create a minimal YOLOv8 configuration with reduced architecture."""
    custom_config = {
        'nc': 7,
        'depth_multiple': 0.33,
        'width_multiple': 0.25,
        'backbone': [
            [-1, 1, 'Conv', [16, 3, 2]],
            [-1, 1, 'Conv', [32, 3, 2]],
            [-1, 1, 'C2f', [32, True]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [-1, 2, 'C2f', [64, True]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [-1, 2, 'C2f', [128, True]],
            [-1, 1, 'Conv', [256, 3, 2]],
            [-1, 1, 'C2f', [256, True]],
            [-1, 1, 'SPPF', [256, 5]],
        ],
        'head': [
            [-1, 1, 'Conv', [128, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 6], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [128, False]],
            [-1, 1, 'Conv', [64, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 4], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [64, False]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [[-1, 14], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [128, False]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [[-1, 10], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [256, False]],
            [[17, 20, 23], 1, 'Detect', ['nc']],
        ]
    }

    config_path = 'custom_yolo_minimal.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(custom_config, f)
    return config_path

def create_ultra_minimal_yolo_config():
    """Create an ultra-minimal YOLOv8 configuration for an extremely small model."""
    custom_config = {
        'nc': 7,
        'depth_multiple': 0.25,
        'width_multiple': 0.125,
        'backbone': [
            [-1, 1, 'Conv', [8, 3, 2]],
            [-1, 1, 'Conv', [16, 3, 2]],
            [-1, 1, 'C2f', [16, True]],
            [-1, 1, 'Conv', [32, 3, 2]],
            [-1, 1, 'C2f', [32, True]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [-1, 1, 'C2f', [64, True]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [-1, 1, 'C2f', [128, True]],
            [-1, 1, 'SPPF', [128, 5]],
        ],
        'head': [
            [-1, 1, 'Conv', [64, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 6], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [64, False]],
            [-1, 1, 'Conv', [32, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 4], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [32, False]],
            [-1, 1, 'Conv', [32, 3, 2]],
            [[-1, 14], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [64, False]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [[-1, 10], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [128, False]],
            [[17, 20, 23], 1, 'Detect', ['nc']],
        ]
    }

    config_path = 'custom_yolo_ultra_minimal.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(custom_config, f)
    return config_path

def create_nano_yolo_config():
    """Create a nano YOLOv8 configuration‚Äîvery lightweight."""
    custom_config = {
        'nc': 7,
        'depth_multiple': 0.33,
        'width_multiple': 0.25,
        'backbone': [
            [-1, 1, 'Conv', [16, 3, 2]],
            [-1, 1, 'Conv', [32, 3, 2]],
            [-1, 1, 'C2f', [32, True]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [-1, 2, 'C2f', [64, True]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [-1, 2, 'C2f', [128, True]],
            [-1, 1, 'Conv', [256, 3, 2]],
            [-1, 1, 'C2f', [256, True]],
            [-1, 1, 'SPPF', [256, 5]],
        ],
        'head': [
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 6], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [128, False]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 4], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [64, False]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [[-1, 12], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [128, False]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [[-1, 9], 1, 'Concat', [1]],
            [-1, 1, 'C2f', [256, False]],
            [[15, 18, 21], 1, 'Detect', ['nc']],
        ]
    }

    config_path = 'custom_yolo_nano.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(custom_config, f)
    return config_path

def create_optimized_config():
    """Create optimized minimal YOLOv8 config for high precision."""
    config = {
        'nc': 7,
        'depth_multiple': 0.5,
        'width_multiple': 0.5,
        'backbone': [
            [-1, 1, 'Conv', [32, 3, 2]],
            [-1, 1, 'Conv', [64, 3, 2]],
            [-1, 2, 'C2f', [64, True]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [-1, 3, 'C2f', [128, True]],
            [-1, 1, 'Conv', [256, 3, 2]],
            [-1, 3, 'C2f', [256, True]],
            [-1, 1, 'Conv', [512, 3, 2]],
            [-1, 2, 'C2f', [512, True]],
            [-1, 1, 'SPPF', [512, 5]],
        ],
        'head': [
            [-1, 1, 'Conv', [256, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 6], 1, 'Concat', [1]],
            [-1, 2, 'C2f', [256, False]],
            [-1, 1, 'Conv', [128, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 4], 1, 'Concat', [1]],
            [-1, 2, 'C2f', [128, False]],
            [-1, 1, 'Conv', [128, 3, 2]],
            [[-1, 13], 1, 'Concat', [1]],
            [-1, 2, 'C2f', [256, False]],
            [-1, 1, 'Conv', [256, 3, 2]],
            [[-1, 9], 1, 'Concat', [1]],
            [-1, 2, 'C2f', [512, False]],
            [[17, 20, 23], 1, 'Detect', ['nc']],
        ]
    }

    config_path = 'optimized_yolo.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    return config_path

def create_deep_multiscale_config():
    """Create deeper YOLOv8 config for multi-scale disease detection."""
    config = {
        'nc': 7,
        'depth_multiple': 0.67,
        'width_multiple': 0.75,
        'backbone': [
            [-1, 1, 'Conv', [48, 3, 2]],
            [-1, 2, 'Conv', [48, 3, 1]],
            [-1, 1, 'Conv', [96, 3, 2]],
            [-1, 3, 'C2f', [96, True]],
            [-1, 1, 'Conv', [192, 3, 2]],
            [-1, 4, 'C2f', [192, True]],
            [-1, 1, 'Conv', [192, 3, 1]],
            [-1, 1, 'Conv', [384, 3, 2]],
            [-1, 6, 'C2f', [384, True]],
            [-1, 1, 'Conv', [384, 3, 1]],
            [-1, 1, 'Conv', [576, 3, 2]],
            [-1, 3, 'C2f', [576, True]],
            [-1, 1, 'SPPF', [576, 5]],
            [-1, 1, 'Conv', [576, 1, 1]],
        ],
        'head': [
            [-1, 1, 'Conv', [384, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 9], 1, 'Concat', [1]],
            [-1, 3, 'C2f', [384, False]],
            [-1, 1, 'Conv', [192, 1, 1]],
            [-1, 1, 'nn.Upsample', [None, 2, 'nearest']],
            [[-1, 6], 1, 'Concat', [1]],
            [-1, 3, 'C2f', [192, False]],
            [-1, 1, 'Conv', [192, 3, 2]],
            [[-1, 18], 1, 'Concat', [1]],
            [-1, 3, 'C2f', [384, False]],
            [-1, 1, 'Conv', [384, 3, 2]],
            [[-1, 14], 1, 'Concat', [1]],
            [-1, 3, 'C2f', [576, False]],
            [[21, 24, 27], 1, 'Detect', ['nc']],
        ]
    }

    config_path = 'deep_multiscale_yolo.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    return config_path

# ============ High-Precision Training ============
def train_high_precision(data_path, epochs=400):
    """Train for high mAP@50 and box precision."""
    config_path = create_optimized_config()
    model = YOLO(config_path)

    results = model.train(
        data=data_path,
        epochs=epochs,
        imgsz=640,
        batch=12,
        lr0=0.005,
        lrf=0.0001,
        momentum=0.9,
        weight_decay=0.001,
        warmup_epochs=15,
        warmup_momentum=0.5,
        warmup_bias_lr=0.05,
        box=10.0,
        cls=1.0,
        dfl=2.0,
        hsv_h=0.01,
        hsv_s=0.5,
        hsv_v=0.3,
        degrees=5,
        translate=0.1,
        scale=0.5,
        shear=1.0,
        perspective=0.0,
        fliplr=0.5,
        flipud=0.0,
        mosaic=0.8,
        mixup=0.05,
        copy_paste=0.1,
        optimizer='AdamW',
        patience=100,
        close_mosaic=50,
        nbs=64,
        save_period=10,
        conf=0.001,
        iou=0.6,
        max_det=300,
        device=0,
        amp=True,
        project='runs/precision',
        name='high_precision_yolo',
        exist_ok=True,
        pretrained=False,
        plots=True,
        save=True,
        save_txt=False,
        save_conf=True,
        save_crop=False,
    )
    return model, results

# ============ Deep Multi-Scale Model Training ============
def train_deep_multiscale(data_path, epochs=200):
    """Train deep multi-scale model for small/medium/large disease detection."""
    config_path = create_deep_multiscale_config()
    model = YOLO(config_path)

    precision_callback = PrecisionMAPCallback(precision_threshold=0.75)
    model.add_callback("on_fit_epoch_end", precision_callback)

    results = model.train(
        data=data_path,
        epochs=epochs,
        batch=8,
        lr0=0.002,
        lrf=0.00001,
        momentum=0.937,
        weight_decay=0.002,
        warmup_epochs=20,
        warmup_momentum=0.5,
        warmup_bias_lr=0.01,
        box=12.0,
        cls=1.5,
        dfl=2.5,
        hsv_h=0.015,
        hsv_s=0.6,
        hsv_v=0.4,
        degrees=10,
        translate=0.15,
        scale=0.7,
        shear=2.0,
        perspective=0.0001,
        fliplr=0.5,
        flipud=0.1,
        mosaic=0.9,
        mixup=0.15,
        copy_paste=0.3,
        optimizer='AdamW',
        patience=150,
        close_mosaic=100,
        dropout=0.1,
        imgsz=640,
        rect=False,
        save_period=20,
        conf=0.001,
        iou=0.5,
        max_det=300,
        device=0,
        amp=True,
        workers=8,
        project='runs/deep_multiscale',
        name='deep_disease_detector',
        exist_ok=True,
        pretrained=False,
        plots=True,
        save=True,
        cache=True,
    )
    return model, results

# ============ Small Dataset Optimization ============
def train_for_small_dataset(data_path, epochs=600):
    """Special training strategy for small datasets with heavy regularization."""
    config_path = create_deep_multiscale_config()
    model = YOLO(config_path)

    # Stage 1: Pre-training with heavy augmentation
    print("Stage 1: Pre-training with heavy augmentation...")
    model.train(
        data=data_path,
        epochs=200,
        imgsz=512,
        batch=4,
        lr0=0.001,
        mosaic=1.0,
        mixup=0.3,
        copy_paste=0.5,
        degrees=20,
        translate=0.3,
        scale=0.9,
        hsv_h=0.03,
        hsv_s=0.8,
        hsv_v=0.5,
        weight_decay=0.005,
        dropout=0.2,
        save=False,
        device=0,
    )

    # Stage 2: Fine-tuning with less augmentation
    print("Stage 2: Fine-tuning with balanced augmentation...")
    model.train(
        data=data_path,
        epochs=400,
        imgsz=640,
        batch=6,
        lr0=0.0005,
        resume=True,
        mosaic=0.7,
        mixup=0.1,
        copy_paste=0.2,
        degrees=10,
        translate=0.15,
        scale=0.7,
        weight_decay=0.002,
        dropout=0.1,
        box=15.0,
        cls=2.0,
        dfl=3.0,
        project='runs/small_dataset',
        name='final_model',
        patience=200,
        device=0,
    )
    return model

# ============ Multi-Scale Training for Better Generalization ============
def train_multiscale(data_path, epochs=300):
    """Train with multiple scales for better mAP."""
    config_path = create_optimized_config()
    model = YOLO(config_path)

    # Stage 1: Small images for stability
    print("Stage 1: Training on 416x416...")
    model.train(
        data=data_path,
        epochs=100,
        imgsz=416,
        batch=20,
        lr0=0.01,
        box=8.0,
        cls=1.0,
        save=False,
        device=0,
    )

    # Stage 2: Medium images
    print("Stage 2: Training on 512x512...")
    model.train(
        data=data_path,
        epochs=100,
        imgsz=512,
        batch=16,
        lr0=0.005,
        resume=True,
        save=False,
        device=0,
    )

    # Stage 3: Full resolution with fine-tuning
    print("Stage 3: Fine-tuning on 640x640...")
    model.train(
        data=data_path,
        epochs=100,
        imgsz=640,
        batch=12,
        lr0=0.001,
        box=12.0,
        resume=True,
        project='runs/multiscale',
        name='multiscale_yolo',
        device=0,
    )
    return model

# ============ Ensemble Strategy ============
def train_ensemble(data_path):
    """Train 3 models with different strategies for ensemble."""
    models = []

    # Model 1: Standard training
    print("Training Model 1: Standard")
    m1, _ = train_high_precision(data_path, epochs=300)
    models.append(m1)

    # Model 2: Heavy augmentation
    print("Training Model 2: Heavy Augmentation")
    config_path = create_optimized_config()
    m2 = YOLO(config_path)
    m2.train(
        data=data_path,
        epochs=300,
        imgsz=640,
        batch=12,
        mosaic=1.0,
        mixup=0.2,
        copy_paste=0.4,
        degrees=15,
        scale=0.9,
        project='runs/ensemble',
        name='model2_aug',
    )
    models.append(m2)

    # Model 3: Different optimizer
    print("Training Model 3: SGD Optimizer")
    config_path = create_optimized_config()
    m3 = YOLO(config_path)
    m3.train(
        data=data_path,
        epochs=300,
        imgsz=640,
        batch=12,
        optimizer='SGD',
        lr0=0.01,
        momentum=0.937,
        project='runs/ensemble',
        name='model3_sgd',
    )
    models.append(m3)

    return models

# ============ Advanced Inference ============
def precision_inference(image_path, model_path, conf_threshold=0.35):
    """Inference optimized for precision."""
    model = YOLO(model_path)
    results_list = []

    # Original image
    results_list.append(model(image_path, conf=conf_threshold, iou=0.5))

    # Flipped image
    results_list.append(model(image_path, conf=conf_threshold, iou=0.5, fliplr=True))

    # Different scales
    for scale in [0.9, 1.0, 1.1]:
        results_list.append(
            model(image_path, conf=conf_threshold, iou=0.5, imgsz=int(640 * scale))
        )

    # Return results from the original image (as a simple example)
    return results_list[0]

# ============ Evaluation with Multiple Metrics ============
def evaluate_precision(model_path, data_path):
    """Comprehensive evaluation for precision metrics."""
    model = YOLO(model_path)
    best_conf = 0.25
    best_map50 = 0

    for conf in [0.15, 0.25, 0.35, 0.45, 0.55]:
        metrics = model.val(data=data_path, conf=conf, iou=0.5)
        map50 = metrics.box.map50
        print(f"Conf={conf}: mAP@50={map50:.4f}")
        if map50 > best_map50:
            best_map50 = map50
            best_conf = conf

    print(f"\nBest mAP@50={best_map50:.4f} at conf={best_conf}")

    final_metrics = model.val(
        data=data_path,
        conf=best_conf,
        iou=0.5,
        save_json=True,
        plots=True,
    )
    return final_metrics, best_conf

# ============ Main Training Pipeline ============
if __name__ == "__main__":
    # Configuration
    DATA_PATH = '/content/drive/MyDrive/ELEC5304/dataset.yaml'

    print("Training Deep Multi-Scale YOLOv8 for Disease Detection")
    print("Optimized for small, medium, and large disease detection")
    print("Will save model when precision ‚â• 0.75 and continue for best mAP@50")
    print("-" * 50)

    # Train Deep Multi-Scale Model
    model, results = train_deep_multiscale(DATA_PATH, epochs=500)

    # Evaluate and find best threshold
    print("\nEvaluating model performance...")
    metrics, best_conf = evaluate_precision(
        'runs/deep_multiscale/deep_disease_detector/weights/best.pt',
        DATA_PATH
    )

    print(f"\nFinal Results:")
    print(f"mAP@50: {metrics.box.map50:.4f}")
    print(f"mAP@50-95: {metrics.box.map:.4f}")
    print(f"Precision: {metrics.box.p:.4f}")
    print(f"Recall: {metrics.box.r:.4f}")
    print(f"Best confidence threshold: {best_conf}")

    # Performance by object size
    print("\nPerformance by object size:")
    print(f"Small objects: mAP@50={metrics.box.maps[0]:.4f}")
    print(f"Medium objects: mAP@50={metrics.box.maps[1]:.4f}")
    print(f"Large objects: mAP@50={metrics.box.maps[2]:.4f}")

    # Example inference instruction
    print("\nFor inference, use:")
    print(
        f"results = precision_inference('image.jpg', "
        f"'runs/deep_multiscale/deep_disease_detector/weights/best.pt', "
        f"conf_threshold={best_conf})"
    )


Training Deep Multi-Scale YOLOv8 for Disease Detection
Optimized for small, medium, and large disease detection
Will save model when precision ‚â• 0.75 and continue for best mAP@50
--------------------------------------------------
Ultralytics 8.3.146 üöÄ Python-3.11.12 torch-2.6.0+cu124 CUDA:0 (Tesla T4, 15095MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=8, bgr=0.0, box=12.0, cache=True, cfg=None, classes=None, close_mosaic=100, cls=1.5, conf=0.001, copy_paste=0.3, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=/content/drive/MyDrive/ELEC5304/dataset.yaml, degrees=10, deterministic=True, device=0, dfl=2.5, dnn=False, dropout=0.1, dynamic=False, embed=None, epochs=500, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.1, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.6, hsv_v=0.4, imgsz=640, int8=False, iou=0.5, keras=False, kobj=1.0, line_width=None, lr0=0.002, lrf=1e-05, mask_ratio

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 755k/755k [00:00<00:00, 15.3MB/s]


                   from  n    params  module                                       arguments                     
  0                  -1  1      1160  ultralytics.nn.modules.conv.Conv             [3, 40, 3, 2]                 
  1                  -1  1     14480  ultralytics.nn.modules.conv.Conv             [40, 40, 3, 1]                
  2                  -1  1     26064  ultralytics.nn.modules.conv.Conv             [40, 72, 3, 2]                
  3                  -1  2     62784  ultralytics.nn.modules.block.C2f             [72, 72, 2, True]             
  4                  -1  1     93600  ultralytics.nn.modules.conv.Conv             [72, 144, 3, 2]               
  5                  -1  3    353952  ultralytics.nn.modules.block.C2f             [144, 144, 3, True]           
  6                  -1  1    186912  ultralytics.nn.modules.conv.Conv             [144, 144, 3, 1]              
  7                  -1  1    373824  ultralytics.nn.modules.conv.Conv             [144




 24                  -1  2    997632  ultralytics.nn.modules.block.C2f             [288, 288, 2, False]          
 25                  -1  1    747072  ultralytics.nn.modules.conv.Conv             [288, 288, 3, 2]              
 26            [-1, 14]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 27                  -1  2   2305152  ultralytics.nn.modules.block.C2f             [576, 432, 2, False]          
 28        [21, 24, 27]  1   2305909  ultralytics.nn.modules.head.Detect           [7, [144, 288, 432]]          
deep_multiscale_YOLO summary: 177 layers, 15,767,693 parameters, 15,767,677 gradients, 51.8 GFLOPs

Freezing layer 'model.28.dfl.conv.weight'
[34m[1mAMP: [0mrunning Automatic Mixed Precision (AMP) checks...
Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt to 'yolo11n.pt'...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5.35M/5.35M [00:00<00:00, 64.7MB/s]


[34m[1mAMP: [0mchecks passed ‚úÖ
[34m[1mtrain: [0mFast image access ‚úÖ (ping: 0.5¬±0.1 ms, read: 0.1¬±0.0 MB/s, size: 34.1 KB)


[34m[1mtrain: [0mScanning /content/drive/MyDrive/ELEC5304/train/labels.cache... 580 images, 0 backgrounds, 0 corrupt: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 580/580 [00:00<?, ?it/s]




[34m[1mtrain: [0mCaching images (0.7GB RAM): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 580/580 [00:03<00:00, 145.18it/s]


[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, method='weighted_average', num_output_channels=3), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))
[34m[1mval: [0mFast image access ‚úÖ (ping: 0.7¬±0.3 ms, read: 0.1¬±0.0 MB/s, size: 33.9 KB)


[34m[1mval: [0mScanning /content/drive/MyDrive/ELEC5304/test/labels.cache... 31 images, 0 backgrounds, 0 corrupt: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31/31 [00:00<?, ?it/s]




[34m[1mval: [0mCaching images (0.0GB RAM): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31/31 [00:03<00:00,  8.27it/s]


Plotting labels to runs/deep_multiscale/deep_disease_detector/labels.jpg... 
[34m[1moptimizer:[0m AdamW(lr=0.002, momentum=0.937) with parameter groups 81 weight(decay=0.0), 88 weight(decay=0.002), 87 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 2 dataloader workers
Logging results to [1mruns/deep_multiscale/deep_disease_detector[0m
Starting training for 500 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      1/500      3.09G      5.711      14.74      6.703         36        640: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 73/73 [00:29<00:00,  2.44it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:02<00:00,  1.10s/it]

                   all         31        119    0.00106      0.164      0.012    0.00331






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      2/500       3.9G      5.836      12.85      5.885         22        640: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 73/73 [00:21<00:00,  3.41it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  4.48it/s]

                   all         31        119       0.29      0.143     0.0403     0.0135






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      3/500         4G      5.291      11.13      5.258         23        640: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 73/73 [00:22<00:00,  3.30it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  4.32it/s]

                   all         31        119      0.366     0.0972     0.0641     0.0212






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      4/500      4.05G      4.612      9.636      4.529         35        640: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 73/73 [00:23<00:00,  3.16it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  3.42it/s]

                   all         31        119     0.0807       0.43      0.161     0.0591






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      5/500      4.12G      4.326      8.807      4.093         23        640:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 71/73 [00:21<00:00,  3.85it/s]

### Running the Trained Models

In [None]:
import yaml
from pathlib import Path
import numpy as np
from ultralytics import YOLO

# Paths (update if you saved the model elsewhere)
MODEL_PATH = 'runs/deep_multiscale/deep_disease_detector/weights/best_fourier.pt'
DATA_YAML  = '/content/drive/MyDrive/Datasets/leaf7.yaml'  # dataset config

# 1. Load the trained detector
model = YOLO(MODEL_PATH)

# 2. Build class‚Äëfrequency weights from the TEST split
with open(DATA_YAML) as f:
    data_cfg = yaml.safe_load(f)


# -------------------- Resolve test/val split -------------------- #
# Prefer an explicit 'test' split; otherwise fall back to 'val'.
raw_test_root = data_cfg.get('test') or data_cfg.get('val')
if raw_test_root is None:
    raise ValueError(
        "Neither 'test' nor 'val' keys found in the dataset YAML. "
        "Add a path, e.g.  test: /content/drive/.../images"
    )

test_root = Path(raw_test_root)
label_root = test_root.parent / 'labels'  # expects standard YOLO layout
print(f"Using split from: {test_root}")

class_counts = np.zeros(data_cfg['nc'], dtype=np.int64)
for lb_file in label_root.glob('*.txt'):
    for line in lb_file.read_text().splitlines():
        cls_id = int(line.split()[0])
        class_counts[cls_id] += 1

if class_counts.sum() == 0:
    raise RuntimeError("No labels found in the test set ‚Äì check DATA_YAML path")

class_weights = class_counts / class_counts.sum()

# 3. Validate on the resolved split
split_name = 'test' if data_cfg.get('test') else 'val'
metrics = model.val(
    data=DATA_YAML,
    split=split_name,
    conf=0.25,       # same threshold used during training evaluation
    iou=0.50,        # mAP@50
    verbose=False)


if hasattr(metrics.box, "tp_class"):          # Older API
    tp = np.array(metrics.box.tp_class)
    fp = np.array(metrics.box.fp_class)
    fn = np.array(metrics.box.fn_class)

    total_tp = tp.sum()
    total_fp = fp.sum()
    total_fn = fn.sum()

    box_precision = float(total_tp / (total_tp + total_fp + 1e-12))
    box_recall    = float(total_tp / (total_tp + total_fn + 1e-12))
else:
    mp_obj = metrics.box.mp
    mr_obj = metrics.box.mr

    # Handle both callable and attribute variants
    box_precision = float(mp_obj() if callable(mp_obj) else mp_obj)
    box_recall    = float(mr_obj() if callable(mr_obj) else mr_obj)

# Per-class AP@50 (exact IoU = 0.50)
if hasattr(metrics.box, "all_ap") and isinstance(metrics.box.all_ap, (list, np.ndarray)):
    per_class_ap50 = np.array(metrics.box.all_ap)[:, 0]          # first column is IoU 0.50
elif hasattr(metrics.box, "ap50") and not callable(metrics.box.ap50):
    per_class_ap50 = np.array(metrics.box.ap50)                  # already per‚Äëclass array
elif hasattr(metrics.box, "ap50") and callable(metrics.box.ap50):
    per_class_ap50 = np.array(metrics.box.ap50())
elif hasattr(metrics.box, "ap50_class"):
    per_class_ap50 = np.array(metrics.box.ap50_class)
else:
    per_class_ap50 = np.array(metrics.box.maps)                  # fallback (IoU sweep)

# Class‚Äëweighted mAP@50
if per_class_ap50.shape[0] != class_weights.shape[0]:
    raise ValueError(
        f"Mismatch between per‚Äëclass AP array ({per_class_ap50.shape[0]}) "
        f"and class‚Äëweight array ({class_weights.shape[0]}). "
        "Check the metric extraction logic."
    )
weighted_map50 = float(np.dot(per_class_ap50, class_weights))

# Custom AP_S @ 0.50 computed directly from predictions#
# This routine scans the test images, collects all ground‚Äëtruth and predicted
# boxes whose area is < 32^2 pixels (COCO definition of "small"), and then
# computes AP at IoU 0.50 from scratch.

import PIL.Image as Image

def compute_ap_s50(model, test_root, label_root, iou_thr=0.50, area_thr=32**2):
    """Return AP for 'small' objects (area < area_thr) at a fixed IoU."""
    all_scores, all_tp = [], []
    num_gt = 0

    # Iterate through every image in the split
    for img_path in sorted(test_root.glob('*')):
        if img_path.suffix.lower() not in {'.jpg', '.jpeg', '.png'}:
            continue

        # Load GT boxes
        lbl_file = label_root / f"{img_path.stem}.txt"
        if not lbl_file.exists():
            continue
        gts = []
        w, h = Image.open(img_path).size
        for line in lbl_file.read_text().splitlines():
            _, cx, cy, bw, bh = map(float, line.strip().split())
            x1 = (cx - bw / 2) * w
            y1 = (cy - bh / 2) * h
            x2 = (cx + bw / 2) * w
            y2 = (cy + bh / 2) * h
            if (x2 - x1) * (y2 - y1) < area_thr:
                gts.append([x1, y1, x2, y2])
        num_gt += len(gts)
        if len(gts) == 0:
            continue
        gts = np.array(gts)

        # Run prediction with low conf to capture all
        preds = model(img_path, conf=0.001, iou=iou_thr)[0]
        if preds.boxes.shape[0] == 0:
            continue
        boxes = preds.boxes.xyxy.cpu().numpy()
        scores = preds.boxes.conf.cpu().numpy()

        # Filter predictions whose box area is small
        areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        mask_small = areas < area_thr
        boxes = boxes[mask_small]
        scores = scores[mask_small]
        if boxes.shape[0] == 0:
            continue

        # Sort by confidence descending
        order = scores.argsort()[::-1]
        boxes = boxes[order]
        scores = scores[order]

        # Match predictions to GT
        matched = np.zeros(len(gts), dtype=bool)
        for box, sc in zip(boxes, scores):
            ious = (
                np.maximum(0, np.minimum(box[2], gts[:, 2]) - np.maximum(box[0], gts[:, 0]))
                * np.maximum(0, np.minimum(box[3], gts[:, 3]) - np.maximum(box[1], gts[:, 1]))
            )
            inter = ious
            union = (
                (box[2] - box[0]) * (box[3] - box[1])
                + (gts[:, 2] - gts[:, 0]) * (gts[:, 3] - gts[:, 1])
                - inter
            )
            ious = inter / (union + 1e-6)
            best_idx = np.argmax(ious)
            if ious[best_idx] >= iou_thr and not matched[best_idx]:
                all_tp.append(1)
                matched[best_idx] = True
            else:
                all_tp.append(0)
            all_scores.append(sc)

    if num_gt == 0:
        return float('nan')

    # Compute precision‚Äërecall and AP
    all_scores = np.array(all_scores)
    all_tp = np.array(all_tp)
    if all_scores.size == 0:
        return 0.0
    idx = all_scores.argsort()[::-1]
    all_tp = all_tp[idx]

    tp_cum = np.cumsum(all_tp)
    fp_cum = np.cumsum(1 - all_tp)
    recalls = tp_cum / (num_gt + 1e-12)
    precisions = tp_cum / (tp_cum + fp_cum + 1e-12)

    # 11‚Äëpoint interpolation (VOC 2007 style, sufficient for single IoU)
    ap = 0.0
    for t in np.linspace(0, 1, 11):
        if np.any(recalls >= t):
            ap += precisions[recalls >= t].max()
    ap /= 11
    return ap

ap_s = compute_ap_s50(model, test_root, label_root)


# 5. Present the results
print("\n================ Evaluation Metrics ================\n")
print(f"{'Metric':35s} | {'Value':>9s}")
print("-" * 49)
print(f"{'Box precision (global)':35s} | {box_precision:9.4f}")
print(f"{'Class‚Äëweighted mAP@50':35s} | {weighted_map50:9.4f}")
print(f"{'AP_S  (small objects, IoU = 0.50)':35s} | {ap_s:9.4f}")
print("-" * 49 + "\n")

print("Per‚Äëclass AP@50 and instance counts")
print("------------------------------------")
for cls_id, (ap, cnt) in enumerate(zip(per_class_ap50, class_counts)):
    print(f"Class {cls_id:2d} | AP50: {ap:6.3f} | instances in test: {cnt}")


Using split from: /content/drive/MyDrive/Datasets/test/images
Ultralytics 8.3.145 üöÄ Python-3.11.12 torch-2.6.0+cu124 CUDA:0 (NVIDIA L4, 22693MiB)
deep_multiscale_YOLO summary (fused): 96 layers, 15,753,981 parameters, 0 gradients, 51.4 GFLOPs
[34m[1mval: [0mFast image access ‚úÖ (ping: 0.3¬±0.1 ms, read: 37.5¬±3.9 MB/s, size: 66.8 KB)


[34m[1mval: [0mScanning /content/drive/MyDrive/Datasets/test/labels.cache... 31 images, 0 backgrounds, 0 corrupt: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31/31 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:01<00:00,  1.49it/s]


                   all         31        119      0.858      0.707      0.812      0.522
Speed: 2.6ms preprocess, 8.8ms inference, 0.0ms loss, 1.2ms postprocess per image
Results saved to [1mruns/detect/val22[0m

image 1/1 /content/drive/MyDrive/Datasets/test/images/IMG_0233_JPG.rf.b27c3e2dd843cab5e0f652fd1ce5659a.jpg: 640x640 4 early_blights, 10 late_blights, 36 spider_mitess, 4 target_spots, 19 mosaic_viruss, 10.8ms
Speed: 2.4ms preprocess, 10.8ms inference, 1.8ms postprocess per image at shape (1, 3, 640, 640)

image 1/1 /content/drive/MyDrive/Datasets/test/images/IMG_0249_JPG.rf.412df0b52b549fc121a346eb8a957ab0.jpg: 640x640 1 early_blight, 10 late_blights, 1 leaf_mold, 5 septoria_leaf_spots, 2 spider_mitess, 2 target_spots, 5 mosaic_viruss, 10.7ms
Speed: 2.5ms preprocess, 10.7ms inference, 1.8ms postprocess per image at shape (1, 3, 640, 640)

image 1/1 /content/drive/MyDrive/Datasets/test/images/IMG_0277_JPG.rf.84073f240decbc79e30716373f3dd507.jpg: 640x640 8 late_blights, 3 leaf

MODEL 2
