# Prediction file for Yolo ensembling and TTA

In [1]:

import numpy as np
import sys
import pandas as pd
from tqdm import tqdm
import shutil

sys.path.append("../")
from detection import YoloInference,EnsembleYolo
from notebooks.utils_notebook import filter_boxes

In [None]:
itos={0:'B', 1:'BA', 2:'EO', 3:'Er', 4:'LAM3', 5:'LF', 6:'LGL', 7:'LH_lyAct', 8:'LLC', 9:'LM', 10:'LY', 11:'LZMG', 12:'LyB', 13:'Lysee', 14:'M', 15:'MBL', 16:'MM', 17:'MO', 18:'MoB', 19:'PM', 20:'PNN', 21:'SS', 22:'Thromb'}

data_path="../data/Cytologia/images"


test_csv_path="../data/Cytologia/test.csv"
csv_path="../data/Cytologia/predictions.csv"
shutil.copy(test_csv_path,"../data/Cytologia/predictions.csv")

df = pd.read_csv(csv_path)

yolo_engine1=YoloInference("../models/detection/Cytologia_yolo_msk_black/yolo11n/384/curated/train/weights/best.pt",device="cuda")
yolo_engine2=YoloInference("../models/detection/Cytologia_msk_iou/yolo11m/384/curated250/train/weights/best.pt",device="cuda")
yolo_engine3=YoloInference("../models/detection/Cytologia_yolo_msk_black/yolov10n/384/curated/train/weights/best.pt",device="cuda")
yolo_engine4=YoloInference("../models/detection/Cytologia_msk_iou/yolov10m/384/curated250/train/weights/best.pt",device="cuda")
ens_engine=EnsembleYolo([yolo_engine1,yolo_engine2,yolo_engine3,yolo_engine4],use_probs=True,meta_idenfier=None)

# Add missing columns
if not {'x1', 'y1', 'x2', 'y2', 'class'}.issubset(df.columns):
    for col in ['x1', 'y1', 'x2', 'y2', 'class']:
        df[col] = np.nan

# Get unique NAME to infer each image only once
names = df["NAME"].unique()
tqdm_names= tqdm(names)
no_dets=0
for name in tqdm_names:
    img_path=f"{data_path}/{name}"

    occurences = df[df["NAME"]==name]
    trustii_ids = occurences["trustii_id"].tolist()
    boxes,scores,labels=ens_engine.predict(img_path,conf=0.00001,verbose=False) 

    boxes=list(boxes)
    scores=list(scores)
    labels=list(labels)
    
    if len(occurences) < len(boxes):
        boxes,scores,labels=filter_boxes(occurences,boxes,scores,labels)

    for idx,(box,label) in enumerate(zip(boxes,labels)):
        x1,y1,x2,y2 = box
        trustii_id = trustii_ids[idx]
        cls=itos[label]
        df.loc[df["trustii_id"] == trustii_id, ["x1", "y1", "x2", "y2", "class"]] = [x1, y1, x2, y2, cls]

    # case where a wbc is not detected (even with this very low threshold)
    for i in range(len(boxes),len(occurences)):
        trustii_id = trustii_ids[i]
        df.loc[df["trustii_id"] == trustii_id, ["x1", "y1", "x2", "y2", "class"]] = [0, 0, 0, 0, 'PNN']

print("Number of undetected WBCs: ",no_dets)
df.to_csv(csv_path, index=False)
print("CSV mis à jour avec succès.")
