##  Enhancing MoE Performance with Test-Time Augmentation (TTA)

In this notebook, we explore the use of **Test-Time Augmentation (TTA)** to further improve the performance of our **Mixture of Experts (MoE)** model for fire detection.

###  Objective:
To boost the robustness and accuracy of the MoE system by applying augmentations during inference and aggregating predictions to form a more reliable output.

###  What is TTA?
**Test-Time Augmentation** involves applying a set of transformations (e.g., horizontal flip, scaling, rotation) to the input image at inference time. The model generates predictions on each augmented version, and these predictions are then combined — typically via averaging or ensembling — to yield a final result.



###  This Notebook Covers:
- Integration of TTA within the MoE pipeline.
- Running inference using multiple augmented views of the input.
- Fusing predictions using **Weighted Box Fusion (WBF)** to obtain the final output.





In [None]:
from ensemble_boxes import weighted_boxes_fusion
from PIL import ImageEnhance
import numpy as np
def horizontal_flip(image):
    return image.transpose(Image.FLIP_LEFT_RIGHT)

def adjust_brightness(image, factor=1.2):
    enhancer = ImageEnhance.Brightness(image)
    return enhancer.enhance(factor)

def adjust_contrast(image, factor=1.2):
    enhancer = ImageEnhance.Contrast(image)
    return enhancer.enhance(factor)

def run_moe_with_tta(image_path, conf_threshold=0.3, iou_threshold=0.1):
    original_img = Image.open(image_path).convert("RGB")
    img_w, img_h = original_img.size
    
    augmentations = [
        ("original", original_img),
        ("hflip", horizontal_flip(original_img)),
        ("bright", adjust_brightness(original_img, factor=1.2)),
        ("contrast", adjust_contrast(original_img, factor=1.2)),
    ]

    all_boxes = []
    all_scores = []
    all_labels = []

    for aug_name, aug_img in augmentations:
        temp_path = f"_temp_{aug_name}.jpg"
        aug_img.save(temp_path)

        aug_boxes = run_moe(temp_path, conf_threshold=conf_threshold, iou_threshold=iou_threshold)
        if aug_boxes:
            boxes = aug_boxes[0].xyxy.cpu().numpy()
            scores = aug_boxes[0].conf.cpu().numpy()
            labels = aug_boxes[0].cls.cpu().numpy()

            # Normalize box coordinates for WBF
            norm_boxes = boxes.copy()
            norm_boxes[:, [0,2]] /= img_w
            norm_boxes[:, [1,3]] /= img_h

            all_boxes.append(norm_boxes.tolist())
            all_scores.append(scores.tolist())
            all_labels.append(labels.tolist())

        os.remove(temp_path)

    if not all_boxes:
        return []

    # Apply Weighted Box Fusion
    boxes_list, scores_list, labels_list = weighted_boxes_fusion(
        all_boxes, all_scores, all_labels,
        iou_thr=iou_threshold,
        skip_box_thr=conf_threshold
    )

    # Rescale back to original image size
    boxes_list = np.array(boxes_list)
    boxes_list[:, [0,2]] *= img_w
    boxes_list[:, [1,3]] *= img_h

    final_tensor = torch.tensor(np.hstack([
        boxes_list,                      # x1,y1,x2,y2
        np.array(scores_list).reshape(-1,1),  # confidence
        np.array(labels_list).reshape(-1,1)   # class (fire=0)
    ]), dtype=torch.float32)

    from ultralytics.engine.results import Boxes
    kept_boxes = Boxes(final_tensor, orig_shape=original_img.size[::-1])  # height, width

    return [kept_boxes]