In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import mannwhitneyu, spearmanr, ttest_ind
from sklearn.ensemble import IsolationForest

In [None]:
train_agg_explanations = (
    pd.read_csv("data/tsv2_train_aggregated_sg_explanations_with_dice.csv")
    .dropna()
    .reset_index()
)

In [None]:
ID2LABELS = {
    0: "background",
    1: "aorta",
    2: "lung_upper_lobe_left",
    3: "lung_lower_lobe_left",
    4: "lung_upper_lobe_right",
    5: "lung_middle_lobe_right",
    6: "lung_lower_lobe_right",
    7: "trachea",
    8: "heart",
    9: "pulmonary_vein",
    10: "thyroid_gland",
    11: "ribs",
    12: "vertebraes",
    13: "autochthon_left",
    14: "autochthon_right",
    15: "sternum",
    16: "costal_cartilages",
}

labels = list(ID2LABELS.values())

MASS_COLUMNS_FORMAT = "{label_1}_explanation_in_{label_2}_mass"

outliers_analysis_datasets = {}
columns = {}
for label in labels:
    mass_columns = [
        MASS_COLUMNS_FORMAT.format(label_1=label, label_2=label_2) for label_2 in labels
    ]
    columns[label] = mass_columns
    get_only_first_record = False

    # drop explanations with NA
    aggregated_explanations = train_agg_explanations.dropna().reset_index(drop=True)

    mass_explanations = aggregated_explanations[mass_columns].copy()
    outliers_analysis_datasets[label] = mass_explanations

label_mass_in_label_columns = [
    f"{label}_explanation_in_{label}_mass" for label in labels
]
iou_columns = [f"{label}_iou" for label in labels]
dice_columns = [f"{label}_dice" for label in labels]

In [None]:
train_context = (
    aggregated_explanations[label_mass_in_label_columns]
    .describe()
    .T.sort_values(by="mean", ascending=False)
    .drop(columns=["count", "std"])
)
train_context.index = [
    idx.split("_explanation")[0].replace("_", " ").title()
    for idx in train_context.index
]
train_context = train_context.reset_index().rename(columns={"index": "Label"})
print(train_context.to_markdown(index=False))

In [None]:
mass_ifs = {label: IsolationForest(random_state=123) for label in labels}
for label, mass_if in mass_ifs.items():
    mass_if.fit(outliers_analysis_datasets[label])

In [None]:
test_agg_explanations = pd.read_csv(
    "data/tsv2_test_aggregated_sg_explanations_with_dice.csv"
)

b50_agg_explanations = (
    pd.read_csv("data/b50_aggregated_sg_explanations.csv")
    .dropna()
    .reset_index(drop=True)
)

In [None]:
tsv2_test_dataset = {label: test_agg_explanations[columns[label]] for label in labels}

b50_datasets = {label: b50_agg_explanations[columns[label]] for label in labels}

In [None]:
# How contextual are tsv2_test and b50

tsv2_test_context = (
    test_agg_explanations[label_mass_in_label_columns]
    .describe()
    .T.sort_values(by="mean", ascending=False)
    .drop(columns=["count", "std"])
)
tsv2_test_context.index = [
    idx.split("_explanation")[0].replace("_", " ").title()
    for idx in tsv2_test_context.index
]
tsv2_test_context = tsv2_test_context.reset_index().rename(columns={"index": "Label"})
print(tsv2_test_context.to_markdown(index=False))

In [None]:
b50_context = (
    b50_agg_explanations[label_mass_in_label_columns]
    .describe()
    .T.sort_values(by="mean", ascending=False)
    .drop(columns=["count", "std"])
)
b50_context.index = [
    idx.split("_explanation")[0].replace("_", " ").title() for idx in b50_context.index
]
b50_context = b50_context.reset_index().rename(columns={"index": "Label"})
print(b50_context.to_markdown(index=False))

In [None]:
b50_outliers_per_label = {
    label: (
        mass_if.predict(b50_datasets[label]),
        mass_if.decision_function(b50_datasets[label]),
    )
    for label, mass_if in mass_ifs.items()
}

tsv2_test_outliers_per_label = {
    label: (
        mass_if.predict(tsv2_test_dataset[label]),
        mass_if.decision_function(tsv2_test_dataset[label]),
    )
    for label, mass_if in mass_ifs.items()
}

In [None]:
tsv2_test_outliers_per_label["heart"]

In [None]:
test_dice = test_agg_explanations[dice_columns].copy()
test_iou = test_agg_explanations[iou_columns].copy()

In [None]:
np.mean(
    [outlier_scores for _, outlier_scores in tsv2_test_outliers_per_label.values()],
    axis=0,
)

In [None]:
test_data = []

for label, (_, outlier_scores) in tsv2_test_outliers_per_label.items():
    test_score = spearmanr(test_dice[f"{label}_dice"], outlier_scores)
    test_data.append(
        {
            "label": label,
            "p-value": test_score.pvalue,
            "spearmanr": test_score.correlation,
        }
    )
    print(f"p-val for {label}: {test_score.pvalue:,.3g}")
mean_outlier_score = np.mean(
    [outlier_scores for _, outlier_scores in tsv2_test_outliers_per_label.values()],
    axis=0,
)
mean_dice = test_dice.mean(axis=1)
test_score = spearmanr(mean_dice, mean_outlier_score)
print(
    f"p-val for mean outlier score: {test_score.pvalue:,.3g}, spearmanr: {test_score.correlation}"
)

In [None]:
test_data = pd.DataFrame(test_data)
print_data = test_data.copy()
print_data.label = print_data.label.apply(lambda x: x.replace("_", " "))
print_data = print_data.set_index("label")
print_data["p-value"] = print_data["p-value"].map("{:,.3g}".format)
print_data["spearmanr"] = print_data["spearmanr"].map("{:,.3g}".format)
print(print_data.sort_values("p-value").to_latex())

In [None]:
sns.set_theme()

In [None]:
sns.scatterplot(x=mean_outlier_score, y=mean_dice)
plt.xlabel("Mean Outlier Score")
plt.ylabel("Mean Dice")
plt.title("Mean Outlier Score vs Mean Dice")
plt.show()

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12, 12))

for i, (ax, label) in enumerate(zip(axs.flatten(), labels[1:]), start=1):
    sns.scatterplot(
        x=tsv2_test_outliers_per_label[label][1], y=test_dice[f"{label}_dice"], ax=ax
    )
    label_for_print = label.replace("_", " ").title()
    ax.set_title(label_for_print)
    ax.set_xlabel("Outlier score")
    ax.set_ylabel(f"{label_for_print} Dice")

plt.tight_layout()
plt.show()

In [None]:
# get patients selected as outliers in b50
b50_outliers = []
for label, (outliers, _) in b50_outliers_per_label.items():
    b50_outliers.append(
        np.unique(
            b50_agg_explanations.patient.apply(lambda x: int(x.split("/")[0]))[
                outliers == -1
            ]
        )
    )

b50_outliers = np.unique(np.concatenate(b50_outliers))
b50_outliers, b50_outliers.shape

In [None]:
outlier_score_lable = "lung_lower_lobe_right"
b50_with_outlier_score = b50_agg_explanations.copy()
# b50_with_outlier_score["outlier_score"] = np.mean(
#     [outlier_scores for _, outlier_scores in b50_outliers_per_label.values()], axis=0
# )
b50_with_outlier_score["outlier_score"] = b50_outliers_per_label[outlier_score_lable][1]
b50_with_outlier_score["patient_id"] = b50_with_outlier_score.patient.apply(
    lambda x: int(x.split("/")[0])
)

In [None]:
meta_data = pd.read_csv("b50/patient_data/patient_metadata_1.0/csv_1.0.csv")

with open("b50/patient_data/descriptions_1.0/json_1.0.json") as f:
    patient_descriptions = json.load(f)

patient_descriptions = pd.DataFrame(patient_descriptions)

In [None]:
b50_outliers_metadata = meta_data[meta_data["patient_id"].isin(b50_outliers)]
b50_inliers_metadata = meta_data[~meta_data["patient_id"].isin(b50_outliers)]

In [None]:
# Mann-Whitney U test
u_stat, u_p_val = mannwhitneyu(
    b50_outliers_metadata["age"], b50_inliers_metadata["age"]
)
print(f"Mann-Whitney U test P-value: {u_p_val}")

In [None]:
t_stat, t_p_val = ttest_ind(b50_outliers_metadata["age"], b50_inliers_metadata["age"])
print(f"t-test P-value: {t_p_val}")

In [None]:
b50_outliers_metadata["sex"].value_counts(True)

In [None]:
b50_inliers_metadata["sex"].value_counts(True)

In [None]:
outliers_descriptions = patient_descriptions[
    patient_descriptions["patient_id"].isin(b50_outliers)
]

In [None]:
outliers_descriptions.head()

In [None]:
sns.set_theme(style="whitegrid")

In [None]:
top_3_b50_outliers = (
    b50_with_outlier_score.drop_duplicates(subset=["patient_id"])[
        ["patient", "outlier_score"]
    ]
    .sort_values(by="outlier_score")
    .iloc[:3]
)
top_3_b50_inliers = (
    b50_with_outlier_score.drop_duplicates(subset=["patient_id"])[
        ["patient", "outlier_score"]
    ]
    .sort_values(by="outlier_score")
    .iloc[-3:]
)

In [None]:
top_3_b50_outliers, top_3_b50_inliers

In [None]:
b50_slices_outliers = [[181, 153, 85], [93, 140, 143], [174, 150, 133]]
b50_slices_inliers = [[92, 95, 113], [100, 98, 81], [81, 89, 147]]

In [None]:
from generation_utils import B50_FOLDER, get_transform, read_path

In [None]:
transform = get_transform()


def load_image(path):
    return transform(read_path(path)).squeeze().numpy()

In [None]:
outliers_imgs = [
    load_image(str(B50_FOLDER / patient_id))
    for patient_id in top_3_b50_outliers.patient
]
inliers_imgs = [
    load_image(str(B50_FOLDER / patient_id)) for patient_id in top_3_b50_inliers.patient
]

In [None]:
fig = plt.figure(figsize=(10, 5))
subfigs = fig.subfigures(1, 2, wspace=0.1)

subfigs[0].suptitle("Top 3 Anomaly Scores", y=1.01)
subfigs[1].suptitle("Bottom 3 Anomaly Scores", y=1.01)
ax = subfigs[0].subplots(3, 3)
plt.subplots_adjust(hspace=0.01, wspace=0.01)
for i, (img, slices) in enumerate(zip(outliers_imgs, b50_slices_outliers)):
    ax[i, 0].imshow(np.rot90(img[slices[0]]), cmap="gray")
    ax[i, 0].axis("off")
    ax[i, 0].set_aspect("auto")
    ax[i, 1].imshow(np.rot90(img[:, slices[1]]), cmap="gray")
    ax[i, 1].axis("off")
    ax[i, 1].set_aspect("auto")
    ax[i, 2].imshow(np.rot90(img[:, :, slices[2]]), cmap="gray")
    ax[i, 2].axis("off")
    ax[i, 2].set_aspect("auto")

ax = subfigs[1].subplots(3, 3)

for i, (img, slices) in enumerate(zip(inliers_imgs, b50_slices_inliers)):
    ax[i, 0].imshow(np.rot90(img[slices[0]]), cmap="gray")
    ax[i, 0].axis("off")
    ax[i, 0].set_aspect("auto")
    ax[i, 1].imshow(np.rot90(img[:, slices[1]]), cmap="gray")
    ax[i, 1].axis("off")
    ax[i, 1].set_aspect("auto")
    ax[i, 2].imshow(np.rot90(img[:, :, slices[2]]), cmap="gray")
    ax[i, 2].axis("off")
    ax[i, 2].set_aspect("auto")  # Adjust the layout of the subplots

plt.tight_layout()
plt.savefig("top3_outliers_inliers.pdf", bbox_inches="tight")