In [None]:
# read the predictions from the algorithm.
import os
import pandas as pd
import numpy as np
import json
from PIL import Image

In [None]:
    
model_name = "moco"
tune_mode = "linear_eval"
model_df = pd.read_csv(
        os.path.join(
            "detection_results", f"{model_name}_{tune_mode}_detection_preds.csv"
        )
    )

In [None]:
reflacx_df = pd.read_csv('spreadsheets/reflacx_clinical.csv')

In [None]:
all_healthy_dicom_id = set(list(reflacx_df[
    (reflacx_df["No Finding_chexpert"] == 1.0)
    & (reflacx_df["No Finding_negbio"] == 1.0)
]['dicom_id']))

In [None]:
def is_in_heath_ids(image_path, healthy_ids):
    dicom_id = os.path.basename(image_path).split(".")[0]
    return dicom_id in healthy_ids

In [None]:
healthy_df = model_df[(model_df['gt_boxes']=='[]')&(model_df['image_path'].apply(lambda x: is_in_heath_ids(x, all_healthy_dicom_id)))]

In [None]:
# plot for each image.
all_image_paths =set(list(healthy_df['image_path']))

In [None]:
model_df[model_df['gt_boxes'].apply(lambda x: len(x)< 3)]

In [None]:
limited_lesion = True
linear_eval = True
image_size = 128
batch_size = 4
top_k_score = 5
tune_mode = "linear_eval" if linear_eval else "fine_tuned"

In [None]:
model_names = [
    "supervised",
    "simsiam",
    "byol",
    "twins",
    "moco",
    "our_simclr",
    "swav",
    "our_improved_v4",
    "our_improved_v4_without_auto",
    "our_improved_v8",
]

In [None]:
from ds.reflacx.lesion_detection import REFLACXLesionDetectionDataset
test_dataset = REFLACXLesionDetectionDataset(
    image_size=image_size,
    split_str="test",
)

In [None]:
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def plot_gt_on_img(
    img,
    boxes,
    labels,
    idx_to_lesion_fn,
    img_size,
    cmap={
        "Enlarged cardiac silhouette": "yellow",
        "Atelectasis": "red",
        "Pleural abnormality": "orange",
        "Consolidation": "lightgreen",
        "Pulmonary edema": "dodgerblue",
    },
    limited_lesion=None,
):
    fig, ax = plt.subplots(
        dpi=512,
    )

    plt.imshow(img, cmap="grey")
    width, height = img.size
    width_factor = width / img_size
    height_factor = height / img_size

    for bbox, label in zip(
        boxes,
        labels,
    ):
        if limited_lesion and label != limited_lesion:
            continue

        disease = idx_to_lesion_fn(label)
        c = cmap[disease]
        ax.add_patch(
            Rectangle(
                (bbox[0] * width_factor, bbox[1] * height_factor),
                (bbox[2] - bbox[0]) * width_factor,
                (bbox[3] - bbox[1]) * height_factor,
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        ax.text(
            bbox[0] * width_factor,
            bbox[1] * height_factor,
            f"{disease}",
            color="black",
            backgroundcolor=c,
        )
    return fig


def plot_bboxes_on_img(
    img: torch.tensor,
    boxes,
    labels,
    scores,
    idx_to_lesion_fn,
    img_size,
    cmap={
        "Enlarged cardiac silhouette": "yellow",
        "Atelectasis": "red",
        "Pleural abnormality": "orange",
        "Consolidation": "lightgreen",
        "Pulmonary edema": "dodgerblue",
    },
):
    fig, ax = plt.subplots(
        dpi=512,
    )

    plt.imshow(img, cmap="grey")
    width, height = img.size
    width_factor = width / img_size
    height_factor = height / img_size

    for bbox, label, score in zip(
        boxes,
        labels,
        scores,
    ):
        # bbox = box_cxcywh_to_xyxy(torch.tensor(bbox * img_size)).numpy()
        disease = idx_to_lesion_fn(label)
        c = cmap[disease]
        ax.add_patch(
            Rectangle(
                (bbox[0] * width_factor, bbox[1] * height_factor),
                (bbox[2] - bbox[0]) * width_factor,
                (bbox[3] - bbox[1]) * height_factor,
                fill=False,
                color=c,
                linewidth=2,
            )
        )
        ax.text(
            bbox[0] * width_factor,
            bbox[1] * height_factor,
            f"{disease}({score:.2f})",
            color="black",
            backgroundcolor=c,
        )
    return fig

In [None]:
model_name = model_names[0]

model_df = pd.read_csv(
        os.path.join(
            "detection_results", f"{model_name}_{tune_mode}_detection_preds.csv"
        )
    )

In [None]:
for model_name in model_names:
    model_df = pd.read_csv(
        os.path.join(
            "detection_results", f"{model_name}_{tune_mode}_detection_preds.csv"
        )
    )
    for l in range(1, 6):
        sorted_lesion_df = model_df[model_df["label"] == l].sort_values(
            by="score", ascending=False
        )
        if top_k_score:
           sorted_lesion_df = sorted_lesion_df[:top_k_score]

        l_name = test_dataset.idx_to_lesion(l)
        saving_folder = os.path.join("detection_results", model_name)
        os.makedirs(saving_folder, exist_ok=True)
        idx = 0
        for _, instance in sorted_lesion_df.iterrows():
            dicom_id = os.path.basename(instance['image_path']).split(".")[0]
            idx += 1
            x1, y1, x2, y2 = (
                instance["x1"],
                instance["y1"],
                instance["x2"],
                instance["y2"],
            )
            img = Image.open(instance["image_path"])
            fig = plot_bboxes_on_img(
                img=img,
                boxes=np.array([[x1, y1, x2, y2]]),
                labels=np.array([instance["label"]]),
                scores=np.array([instance["score"]]),
                img_size=image_size,
                idx_to_lesion_fn=test_dataset.idx_to_lesion,
            )
            fig.savefig(os.path.join(saving_folder, f"{l_name}-{idx} [{dicom_id}].png"))
            gt_fig = plot_gt_on_img(
                img=img,
                boxes=np.array(json.loads(instance["gt_boxes"])),
                labels=np.array(json.loads(instance["gt_labels"])),
                img_size=image_size,
                idx_to_lesion_fn=test_dataset.idx_to_lesion,
                limited_lesion=l if limited_lesion else None,
            )
            gt_fig.savefig(os.path.join(saving_folder, f"{l_name}-{idx} [{dicom_id}] (GT).png"))
            plt.cla()
            plt.clf()
            plt.close("all")