<a href="https://colab.research.google.com/github/Achille1912/Achille1912/blob/main/Progetto.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kagglehub torchmetrics
!pip install albumentations

!pip install -U openmim
!mim install mmdet
!mim install mmcv-full
!pip install torchvision


Looking in links: https://download.openmmlab.com/mmcv/dist/cu124/torch2.6.0/index.html
Collecting mmdet
  Using cached mmdet-3.3.0-py3-none-any.whl.metadata (29 kB)
Ignoring mmcv: markers 'extra == "mim"' don't match your environment
Ignoring mmengine: markers 'extra == "mim"' don't match your environment
Collecting terminaltables (from mmdet)
  Using cached terminaltables-3.1.10-py2.py3-none-any.whl.metadata (3.5 kB)
Using cached mmdet-3.3.0-py3-none-any.whl (2.2 MB)
Using cached terminaltables-3.1.10-py2.py3-none-any.whl (15 kB)
Installing collected packages: terminaltables, mmdet
Successfully installed mmdet-3.3.0 terminaltables-3.1.10
Looking in links: https://download.openmmlab.com/mmcv/dist/cu124/torch2.6.0/index.html
Collecting mmcv-full
  Downloading mmcv-full-1.7.2.tar.gz (607 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m607.9/607.9 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting addict (

## IMPORTS

In [None]:
import os
import json
import torch
import torch.nn as nn
import torchvision
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

from torchmetrics.detection import MeanAveragePrecision
import kagglehub

import os
import sys
from datetime import datetime
from sklearn.model_selection import train_test_split

torch.cuda.empty_cache()


## CONFIG

In [None]:
# ==== CONFIG ====
class Config:
    path = kagglehub.dataset_download("orvile/p-vivax-malaria-infected-human-blood-smears")
    BASE_DIR = path + "/malaria/"
    IMG_DIR = os.path.join(BASE_DIR, "images")
    TRAIN_JSON = os.path.join(BASE_DIR, "training.json")
    TEST_JSON = os.path.join(BASE_DIR, "test.json")

    NUM_CLASSES = 8  # 7 + background
    MODEL_SAVE_PATH = "malaria_detection.pth"

    #seed
    RANDOM_SEED = 42
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

    BATCH_SIZE = 4 if torch.cuda.is_available() else 2
    NUM_EPOCHS = 5
    LEARNING_RATE = 0.0003
    WEIGHT_DECAY = 0.001
    PATIENCE = 5

    IMG_SIZE = (1024, 1024)

    MODEL = "REPPOINTS_RESNET50" # FASTER_RCNN_RESNET50, FASTER_RCNN_RESNET101, FASTER_RCNN_MOBILENET, RETINANET_RESNET50
    Loss = "FocalLoss" # Standard, FocalLosss

    Optim = "AdamW" # Adam, AdamW
    Scheduler = "ReduceLROnPlateau" # ReduceLROnPlateau, CosineAnnealingLR
    val_size = 0.2
    NUM_WORKERS = 2
    USE_AMP = torch.cuda.is_available()

### SETUP DEVICE AND CLASSIFICATION

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CLASSES = [
    "red blood cell", "trophozoite", "ring", "difficult",
    "schizont", "gametocyte", "leukocyte"
]
CLASS_TO_IDX = {cls: i+1 for i, cls in enumerate(CLASSES)}

## DATASET

In [None]:
class MalariaDataset(Dataset):
    def __init__(self, json_path, img_dir, transform=None, resize=True):
        with open(json_path) as f:
            self.data = json.load(f)
        self.img_dir = img_dir
        self.transform = transform
        self.resize = resize

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]

        img_name = os.path.basename(entry['image'].get('pathname',
                          entry['image'].get('filename',
                          entry['image'].get('id', ''))))
        img_path = os.path.join(self.img_dir, img_name)

        # Carica e migliora il contrasto dell'immagine
        img = Image.open(img_path).convert("RGB")
        img = ImageEnhance.Contrast(img).enhance(1.5)
        orig_width, orig_height = img.size

        # Bounding box e label
        boxes, labels = [], []

        for obj in entry.get('objects', []):
            if obj['category'] not in CLASS_TO_IDX:
                continue
            bbox = obj['bounding_box']
            xmin = bbox['minimum']['c']
            ymin = bbox['minimum']['r']
            xmax = bbox['maximum']['c']
            ymax = bbox['maximum']['r']
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(CLASS_TO_IDX[obj['category']])

        boxes = np.array(boxes)
        labels = np.array(labels)

        # Converti PIL in NumPy per Albumentations
        img_np = np.array(img)

        # Se serve, ridimensiona manualmente prima della transform (opzionale, può anche stare dentro le tfm)
        if self.resize and self.transform is None:
            scale_x = Config.IMG_SIZE[0] / orig_width
            scale_y = Config.IMG_SIZE[1] / orig_height
            img_np = np.array(img.resize(Config.IMG_SIZE))
            boxes = [[xmin * scale_x, ymin * scale_y, xmax * scale_x, ymax * scale_y]
                     for xmin, ymin, xmax, ymax in boxes]

        # Applica le transformazioni (Albumentations)
        if self.transform:
            transformed = self.transform(image=img_np, bboxes=boxes, labels=labels)
            img_tensor = transformed['image']
            boxes = torch.tensor(transformed['bboxes'], dtype=torch.float32)
            labels = torch.tensor(transformed['labels'], dtype=torch.int64)
        else:
            # fallback se non c'è albumentations
            img_tensor = transforms.ToTensor()(img)
            img_tensor = transforms.Normalize(mean=[0.7205, 0.7203, 0.7649],
                                              std=[0.2195, 0.2277, 0.1588])(img_tensor)
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}
        return img_tensor, target


## TRANSFORMATION

In [None]:
# import os
# from PIL import Image
# import numpy as np
# from tqdm import tqdm
# import torch

# from torchvision import transforms

# # === CONFIG ===
# IMG_DIR = Config.IMG_DIR

# # === Trasformazione base ===
# to_tensor = transforms.ToTensor()  # converte a [C, H, W] in [0, 1]

# # === Liste per accumulare pixel per canale ===
# mean = torch.zeros(3)
# std = torch.zeros(3)
# n_images = 0

# print("Inizio calcolo media/std da immagini...")

# for fname in tqdm(os.listdir(IMG_DIR)):
#     if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
#         continue  # ignora file non immagine

#     img_path = os.path.join(IMG_DIR, fname)
#     img = Image.open(img_path).convert("RGB")
#     tensor = to_tensor(img)  # shape: [3, H, W]

#     mean += tensor.mean(dim=(1, 2))
#     std += tensor.std(dim=(1, 2))
#     n_images += 1

# mean /= n_images
# std /= n_images

# print(f"\nMedia RGB: {mean}")
# print(f"Deviazione standard RGB: {std}")


In [None]:
# ==== TRASFORMAZIONI ====
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_transforms(train=True):
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            #A.Rotate(limit=5, border_mode=1, p=0.3),
            #A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
            A.CLAHE(clip_limit=1.5, p=0.3),
            A.Resize(*Config.IMG_SIZE),
            A.Normalize(mean=[0.7205, 0.7203, 0.7649],
                        std=[0.2195, 0.2277, 0.1588]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

    else:
        return A.Compose([
            A.Resize(*Config.IMG_SIZE),
            A.Normalize(mean=[0.7205, 0.7203, 0.7649],
                        std=[0.2195, 0.2277, 0.1588]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))


## MODEL

### Focal Loss

In [None]:
import torch
import torch.nn.functional as F
from torchvision.models.detection.roi_heads import fastrcnn_loss as original_fastrcnn_loss

# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=4.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

# Custom fastrcnn_loss with focal loss
def custom_fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
    # Focal loss for classification
    classification_loss = FocalLoss()(class_logits, labels)

    # Smooth L1 loss for bounding box regression (unchanged)
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
    box_loss = F.smooth_l1_loss(
        box_regression[sampled_pos_inds_subset],
        regression_targets[sampled_pos_inds_subset],
        beta=1.0, reduction='sum'
    ) / max(1, labels.numel())

    return classification_loss, box_loss


In [None]:
from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn,
    fasterrcnn_mobilenet_v3_large_fpn,
    FasterRCNN,
    retinanet_resnet50_fpn
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models import resnet101
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models import resnext101_32x8d



# Focal loss patch (solo per Faster R-CNN)
def replace_fastrcnn_loss_with_focal(model):
    def custom_forward(self, class_logits, box_regression, labels, regression_targets):
        return custom_fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
    model.roi_heads.fastrcnn_loss = custom_forward.__get__(model.roi_heads, type(model.roi_heads))


# Costruisce ResNet101 con FPN
def build_resnet101_backbone():
    backbone = resnet101(weights="DEFAULT")
    return_layers = {
        'layer1': '0',
        'layer2': '1',
        'layer3': '2',
        'layer4': '3',
    }
    in_channels_stage2 = 256
    in_channels_list = [
        in_channels_stage2,
        in_channels_stage2 * 2,
        in_channels_stage2 * 4,
        in_channels_stage2 * 8,
    ]
    out_channels = 256

    body = IntermediateLayerGetter(backbone, return_layers=return_layers)
    backbone_with_fpn = BackboneWithFPN(
        body=body,
        return_layers=return_layers,
        in_channels_list=in_channels_list,
        out_channels=out_channels
    )
    return backbone_with_fpn

def build_resnext101_backbone():
    backbone = resnext101_32x8d(weights="DEFAULT")

    # Mappa dei layer che vogliamo usare per FPN
    return_layers = {
        'layer1': '0',
        'layer2': '1',
        'layer3': '2',
        'layer4': '3',
    }

    in_channels_stage2 = 256  # layer1 output
    in_channels_list = [
        in_channels_stage2,        # layer1
        in_channels_stage2 * 2,    # layer2
        in_channels_stage2 * 4,    # layer3
        in_channels_stage2 * 8,    # layer4
    ]
    out_channels = 256

    # Costruzione corpo troncato
    body = IntermediateLayerGetter(backbone, return_layers=return_layers)

    # Backbone con FPN
    backbone_with_fpn = BackboneWithFPN(
        body,
        return_layers=return_layers,
        in_channels_list=in_channels_list,
        out_channels=out_channels
    )
    return backbone_with_fpn





# Crea modello detection con backbone configurabile
def create_model(num_classes=2):
    if Config.MODEL == "FASTER_RCNN_RESNET50":
        model = fasterrcnn_resnet50_fpn(weights="DEFAULT")

    elif Config.MODEL == "FASTER_RCNN_RESNET101":
        backbone = build_resnet101_backbone()
        model = FasterRCNN(backbone=backbone, num_classes=num_classes)

    elif Config.MODEL == "FASTER_RCNN_MOBILENET":
        model = fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT")

    elif Config.MODEL == "RETINANET_RESNET50":
        model = retinanet_resnet50_fpn(weights="DEFAULT", )
        in_features = model.head.classification_head.cls_logits.in_channels
        num_anchors = model.head.classification_head.num_anchors
        model.head.classification_head.cls_logits = torch.nn.Conv2d(
            in_features, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
        )
        model.head.classification_head.num_classes = num_classes
        return model
    elif Config.MODEL == "FASTER_RCNN_RESNEXT101":
        backbone = build_resnext101_backbone()
        model = FasterRCNN(backbone=backbone, num_classes=num_classes)
    elif Config.MODEL == "REPPOINTS_RESNET50":
        from mmdet.apis import init_detector
        config_url = "https://raw.githubusercontent.com/open-mmlab/mmdetection/master/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py"
        checkpoint_url = "https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_r50_fpn_1x_coco/reppoints_moment_r50_fpn_1x_coco_20200329-4f0f7cf9.pth"
        config_path = "configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py"
        checkpoint_path = "reppoints_r50_fpn_1x.pth"

        if not os.path.exists(config_path):
            os.makedirs(os.path.dirname(config_path), exist_ok=True)
            urllib.request.urlretrieve(config_url, config_path)
        if not os.path.exists(checkpoint_path):
            urllib.request.urlretrieve(checkpoint_url, checkpoint_path)

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = init_detector(config_path, checkpoint_path, device=device)
        return model
    else:
        raise ValueError(f"Backbone '{Config.MODEL}' non supportato.")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    if Config.Loss == "FocalLoss":
        replace_fastrcnn_loss_with_focal(model)
    return model


## TRAIN

In [None]:
def train_model():
    model = create_model().to(device)
    # OPTIMIZER
    if Config.Optim == "AdamW":
      optimizer = torch.optim.AdamW(
          [p for p in model.parameters() if p.requires_grad],
          lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
    elif Config.Optim == "Adam":
      optimizer = torch.optim.Adam(
          [p for p in model.parameters() if p.requires_grad],
          lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
    # LR SCHEDULER
    if Config.Scheduler == "ReduceLROnPlateau":
      scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
          optimizer, mode='min', patience=2, factor=0.5)
    elif Config.Scheduler == "CosineAnnealingLR":
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
          optimizer, T_max=Config.NUM_EPOCHS, eta_min=1e-6)
    dataset = MalariaDataset(Config.TRAIN_JSON, Config.IMG_DIR,
                             transform=get_transforms(train=True), resize=False)

    # Split dataset into training and validation sets
    val_size = int(Config.val_size * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(Config.RANDOM_SEED)
    )

    # Create DataLoaders for both sets
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                              shuffle=True, num_workers=Config.NUM_WORKERS,
                              collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=Config.NUM_WORKERS,
                            collate_fn=lambda x: tuple(zip(*x)))

    best_loss = float('inf')
    patience_counter = 0

    history = {
        'train_loss': [],
        'val_loss': [],
        'lr': []
    }

    for epoch in range(Config.NUM_EPOCHS):
        # --- Training phase ---
        model.train()
        epoch_train_loss = 0.0

        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            for t in targets:
                t['labels'] = torch.ones_like(t['labels'])  # Remove this line if your dataset has real labels

            optimizer.zero_grad()

            with torch.amp.autocast('cuda', enabled=Config.USE_AMP):
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

            losses.backward()
            optimizer.step()

            epoch_train_loss += losses.item()

        avg_train_loss = epoch_train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        # --- Validation phase ---
        print("\nStart Validation... 🔍")
        epoch_val_loss = 0.0

        # Force model.train() to get loss dict, but still disable gradients
        model.train()
        with torch.no_grad():
            for images, targets in val_loader:
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                for t in targets:
                    t['labels'] = torch.ones_like(t['labels'])  # Remove if not needed

                with torch.amp.autocast('cuda', enabled=Config.USE_AMP):
                    loss_dict = model(images, targets)

                    # Ensure it's a dict and not a list of predictions
                    if isinstance(loss_dict, dict):
                        losses = sum(loss for loss in loss_dict.values())
                        epoch_val_loss += losses.item()
                    else:
                        raise ValueError("Model returned predictions instead of a loss dict during validation. Check model mode or target input.")

        avg_val_loss = epoch_val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)

        scheduler.step(avg_val_loss)

        print(f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}")
        print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'best_loss': best_loss,
            }, Config.MODEL_SAVE_PATH)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break

    return history


## TEST AND VISUALIZZATION

In [None]:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from sklearn.metrics import average_precision_score
import matplotlib.patches as patches
import json
from datetime import datetime

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer,)):
            return int(obj)
        elif isinstance(obj, (np.floating,)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        return super().default(obj)

def calculate_iou(box1, box2):
    x1, y1 = max(box1[0], box2[0]), max(box1[1], box2[1])
    x2, y2 = min(box1[2], box2[2]), min(box1[3], box2[3])
    inter = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - inter
    return inter / union if union > 0 else 0.0

def compute_mean_iou_per_gt(pred_boxes, gt_boxes, iou_threshold=0.5):
    if len(pred_boxes) == 0 or len(gt_boxes) == 0:
        return 0.0, []

    iou_matrix = [[calculate_iou(pb, gt) for gt in gt_boxes] for pb in pred_boxes]
    matched_pairs = []
    used_preds = set()

    for gt_idx in range(len(gt_boxes)):
        best_iou = 0.0
        best_pred_idx = -1
        for pred_idx in range(len(pred_boxes)):
            if pred_idx in used_preds:
                continue
            iou = iou_matrix[pred_idx][gt_idx]
            if iou > best_iou:
                best_iou = iou
                best_pred_idx = pred_idx
        if best_iou >= iou_threshold:
            matched_pairs.append((best_pred_idx, gt_idx, best_iou))
            used_preds.add(best_pred_idx)

    mean_iou = np.mean([iou for _, _, iou in matched_pairs]) if matched_pairs else 0.0
    return mean_iou, matched_pairs

def evaluate_test_set(model, dataset, score_threshold=0.5, save_path=None):
    model.eval()
    total_iou = 0.0
    num_samples = 0
    all_ious = []
    total_tp = total_fp = total_fn = 0

    for img, target in dataset:
        with torch.no_grad():
            prediction = model([img.to(device)])[0]

        pred_boxes = prediction['boxes'].cpu().numpy()
        pred_scores = prediction['scores'].cpu().numpy()
        gt_boxes = target['boxes'].cpu().numpy()

        keep = pred_scores >= score_threshold
        pred_boxes = pred_boxes[keep]

        mean_iou, matched_pairs = compute_mean_iou_per_gt(pred_boxes, gt_boxes)
        precision, recall, f1_score, tp, fp, fn = compute_precision_recall_f1(pred_boxes, gt_boxes)

        if len(gt_boxes) > 0:
            total_iou += mean_iou
            num_samples += 1
            all_ious.extend([iou for _, _, iou in matched_pairs])
            total_tp += tp
            total_fp += fp
            total_fn += fn

    mean_iou_all = total_iou / num_samples if num_samples > 0 else 0.0
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    recall    = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    f1_score  = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    results = {
        'miou': mean_iou_all,
        'precision': precision,
        'recall': recall,
        'f1': f1_score,
        'tp': total_tp,
        'fp': total_fp,
        'fn': total_fn
    }

    print(f"\n=== [Test Set Evaluation] ===")
    print(f"📊 Metriche globali:")
    print(f"  - mIoU     : {results['miou']:.4f}")
    print(f"  - Precision: {results['precision']:.4f}")
    print(f"  - Recall   : {results['recall']:.4f}")
    print(f"  - F1 Score : {results['f1']:.4f}")

    print(f"\n🔢 Conteggi:")
    print(f"  - TP: {results['tp']} | FP: {results['fp']} | FN: {results['fn']}")

    if all_ious:
        print(f"\n📈 IoU stats:")
        print(f"  - min   : {min(all_ious):.2f}")
        print(f"  - max   : {max(all_ious):.2f}")
        print(f"  - median: {np.median(all_ious):.2f}")
    else:
        print("\n📉 Nessun IoU calcolabile.")


    if save_path:
        with open(save_path, 'w') as f:
            json.dump(results, f, indent=2, cls=NpEncoder)
        print(f"Results saved to {save_path}")

    return results



def compute_precision_recall_f1(pred_boxes, gt_boxes, iou_threshold=0.5):
    """
    Calcola TP, FP, FN, Precision, Recall, F1 Score per una singola immagine.
    """
    matched_gt = set()
    tp = 0

    for pb in pred_boxes:
        match_found = False
        for i, gb in enumerate(gt_boxes):
            if i in matched_gt:
                continue
            iou = calculate_iou(pb, gb)
            if iou >= iou_threshold:
                tp += 1
                matched_gt.add(i)
                match_found = True
                break
    fp = len(pred_boxes) - tp
    fn = len(gt_boxes) - tp

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score  = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    return precision, recall, f1_score, tp, fp, fn


def compute_map(model, dataset, score_threshold=0.5, save_path=None):
    model.eval()
    metric = MeanAveragePrecision(iou_type="bbox")
    preds, targets = [], []

    for img, target in dataset:
        with torch.no_grad():
            pred = model([img.to(device)])[0]

        scores = pred['scores'].cpu()
        boxes = pred['boxes'].cpu()
        labels = pred['labels'].cpu()
        keep = scores >= score_threshold

        preds_dict = {
            "boxes": boxes[keep],
            "scores": scores[keep],
            "labels": labels[keep]
        }

        target_dict = {
            "boxes": target['boxes'].cpu(),
            "labels": target['labels'].cpu()
        }

        preds.append(preds_dict)
        targets.append(target_dict)

    metric.update(preds, targets)
    results = metric.compute()

    print(f"\n=== [mAP Evaluation] ===")
    for k, v in results.items():
        if torch.is_tensor(v):
            v = v.item() if v.numel() == 1 else v.tolist()
        if isinstance(v, float):
            print(f"  - {k}: {v:.4f}")
        else:
            print(f"  - {k}: {v}")




    if save_path:
        with open(save_path, 'w') as f:
            json.dump(results, f, indent=2, cls=NpEncoder)
        print(f"mAP results saved to {save_path}")

    return results


def save_results_to_file(test_results, map_results, out_path="results"):
    os.makedirs(out_path, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filepath = os.path.join(out_path, f"results_{timestamp}.json")

    def convert(o):
        if isinstance(o, torch.Tensor):
            return o.item() if o.numel() == 1 else o.tolist()
        elif isinstance(o, np.ndarray):
            return o.tolist()
        elif isinstance(o, (np.float32, np.float64)):
            return float(o)
        elif isinstance(o, (np.int32, np.int64)):
            return int(o)
        return str(o)

    # Serializza la classe Config
    config_dict = {
        key: convert(val) for key, val in vars(Config).items()
        if not key.startswith("__") and not callable(val)
    }

    combined_results = {
        "config": config_dict,
        "test_metrics": json.loads(json.dumps(test_results, default=convert)),
        "map_metrics": json.loads(json.dumps(map_results, default=convert)),
    }

    with open(filepath, 'w') as f:
        json.dump(combined_results, f, indent=4)

    print(f"📁 Risultati salvati in: {filepath}")

    try:
        from google.colab import files
        files.download(filepath)
        print("⬇️ File pronto per il download!")
    except ImportError:
        print("⚠️ Non sei in Colab, download automatico non disponibile.")





def visualize_results(model_path, dataset, num_samples=3, score_threshold=0.5):
    model = create_model().to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    test_results = evaluate_test_set(model, dataset, score_threshold)
    map_results = compute_map(model, dataset, score_threshold)

    save_results_to_file(test_results, map_results)


    for i in range(num_samples):
        img, target = dataset[i]
        with torch.no_grad():
            pred = model([img.to(device)])[0]

        img_disp = img.cpu().clone()
        img_disp = img_disp * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_disp = img_disp + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img_disp = img_disp.clamp(0, 1).permute(1, 2, 0).numpy()

        plt.figure(figsize=(12, 8))
        plt.imshow(img_disp)
        ax = plt.gca()

        gt_boxes = target['boxes'].cpu().numpy()
        pred_boxes = pred['boxes'].cpu().numpy()
        pred_scores = pred['scores'].cpu().numpy()

        keep = pred_scores >= score_threshold
        pred_boxes = pred_boxes[keep]
        pred_scores = pred_scores[keep]

        mean_iou, matched_pairs = compute_mean_iou_per_gt(pred_boxes, gt_boxes)
        precision, recall, f1, tp, fp, fn = compute_precision_recall_f1(pred_boxes, gt_boxes)

        print(f"\n=== Sample {i+1} ===")
        print(f"GT: {len(gt_boxes)}, Predetti (score ≥ {score_threshold}): {len(pred_boxes)}")
        print(f"mIoU: {mean_iou:.4f}")
        print(f"TP: {tp}, FP: {fp}, FN: {fn}")
        print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        matched_gt_indices = {gt_idx for (_, gt_idx, _) in matched_pairs}

        for gt_idx, box in enumerate(gt_boxes):
            x1, y1, x2, y2 = box
            color = 'green' if gt_idx in matched_gt_indices else 'blue'
            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                       fill=False, color=color, linewidth=2))

        for idx, (box, score) in enumerate(zip(pred_boxes, pred_scores)):
            text = f"{score:.2f}"
            for (pred_idx, gt_idx, iou) in matched_pairs:
                if idx == pred_idx:
                    text += f" (IoU: {iou:.2f})"
                    break
            ax.add_patch(plt.Rectangle((box[0], box[1]),
                                       box[2] - box[0], box[3] - box[1],
                                       fill=False, color='red', linewidth=2))
        plt.axis('off')
        plt.title(f"Sample {i+1} | mIoU: {mean_iou:.2f}")
        plt.show()


## MAIN

In [None]:
# ==== MAIN ====
if __name__ == "__main__":
    print("Starting training... ⏳")

    print({
         "model": Config.MODEL,
         "lr": Config.LEARNING_RATE,
         "batch_size": Config.BATCH_SIZE,
        "num_epochs": Config.NUM_EPOCHS,
        "weight_decay": Config.WEIGHT_DECAY,
        "patience": Config.PATIENCE,
         "optimizer": Config.Optim,
         "scheduler": Config.Scheduler,
        "img_size": Config.IMG_SIZE
    })
    history = train_model() # changed name to history

    # Plot training and validation loss separately
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.legend()
    plt.show()


    print("\nStarting Test and Visualize Samples...💡")
    test_dataset = MalariaDataset(Config.TEST_JSON, Config.IMG_DIR,
                                  transform=get_transforms(train=False), resize=True)
    visualize_results(Config.MODEL_SAVE_PATH, test_dataset, 3, 0.5)