In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
from pathlib import Path

import sys

sys.path.append("../src")
from metrics.heatmaps import (
    calculate_relevance_mass_accuracy,
    calculate_relevance_rank_accuracy,
)

In [None]:
classes = [
    "Enlarged Cardiomediastinum",
    "Cardiomegaly",
    "Lung Opacity",
    "Lung Lesion",
    "Edema",
    "Consolidation",
    "Atelectasis",
    "Pneumothorax",
    "Pleural Effusion",
]
xai_methods = ["integrated_gradients", "gradient", "lrp", "smoothgrad"]
training_types = ["pretrained", "from_scratch"]
models = ["vit", "swinvit", "densenet"]
save_folder = Path("../results")
dataset_folder = Path("../dataset/chexlocalize/CheXpert/")

# Example images

In [None]:
images = list((dataset_folder / "val").glob("*/*/view1_frontal*.jpg"))

In [None]:
import random

In [None]:
number_of_images_to_plot = 6

In [None]:
images_to_plot = list(random.sample(images, k=number_of_images_to_plot))

In [None]:
images_to_plot = list(random.sample(images, k=number_of_images_to_plot))
masks_to_plot = [
    random.choice(list(image.parent.glob("view1_frontal*.npy")))
    for image in images_to_plot
]


In [None]:
masks_to_plot

In [None]:
" ".join(masks_to_plot[0].name.split("_")[2:-1])


In [None]:
from matplotlib.colors import ListedColormap
from itertools import product

In [None]:
cmap = ListedColormap(["#FFFFFF00", "yellow"])


In [None]:
_, ax = plt.subplots(2, number_of_images_to_plot // 2, figsize=(15, 10))
for num, (i, j) in enumerate(
    product(
        range(2),
        range(number_of_images_to_plot // 2),
    )
):
    img = plt.imread(images_to_plot[num])
    mask = np.load(masks_to_plot[num])
    ax[i, j].imshow(img, cmap="gray")
    ax[i, j].imshow(mask, cmap=cmap, alpha=0.3)
    ax[i, j].axis("off")
    ax[i, j].set_title(
        " ".join(masks_to_plot[num].name.split("_")[2:-1]),
        loc="center",
        y=-0.1,
    )
plt.savefig(
    "example_images.png",
    bbox_inches="tight",
    pad_inches=0,
)


In [None]:
plt.imshow(img, cmap="gray")
plt.imshow(
    mask_ == 1,
    cmap=ListedColormap(["black", "green"]),
    alpha=0.5,
    vmin=0,
    vmax=1,
)
plt.axis("off")
plt.show()


# Attributions

In [None]:
def tensor_to_float(x):
    return float(x.replace("tensor(", "").replace(")", ""))


def plot_heatmap(
    save_folder,
    dataset_folder,
    model,
    training_type,
    class_name,
    xai_method,
    images_save_folder=Path("images"),
    selected_patient=None,
):
    inverse_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "inverse_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )
    normal_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "normal_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )
    inverse_data = pd.read_csv(inverse_path)
    normal_data = pd.read_csv(normal_path)
    inverse_data.mass_accuracy = inverse_data.mass_accuracy.apply(
        tensor_to_float
    )
    normal_data.mass_accuracy = normal_data.mass_accuracy.apply(
        tensor_to_float
    )
    inverse_data.rank_accuracy = inverse_data.rank_accuracy.apply(
        tensor_to_float
    )
    normal_data.rank_accuracy = normal_data.rank_accuracy.apply(
        tensor_to_float
    )

    normal_data.sort_values(by="mass_accuracy", inplace=True)
    inverse_data.sort_values(by="mass_accuracy", inplace=True)
    if selected_patient is not None:
        best_mass_patient = selected_patient
    else:
        best_mass_patient = normal_data.iloc[-1].path
    inverse_data.set_index("path", inplace=True)
    normal_data.set_index("path", inplace=True)
    image = plt.imread(dataset_folder / best_mass_patient)
    gt_mask = np.load(
        dataset_folder
        / best_mass_patient.replace(
            ".jpg", f"_{class_name.replace(' ', '_')}_mask.npy"
        )
    )

    positive_heatmap = np.load(
        normal_path.parent
        / best_mass_patient.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    positive_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(positive_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    negative_heatmap = np.load(
        inverse_path.parent
        / best_mass_patient.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    negative_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(negative_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    normal_heatmap = np.load(
        save_folder
        / f"{model}"
        / training_type
        / best_mass_patient.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    normal_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(normal_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    positive_heatmap = np.maximum(positive_heatmap, 0).sum(axis=0)
    positive_heatmap = (positive_heatmap - positive_heatmap.min()) / (
        positive_heatmap.max() - positive_heatmap.min()
    )
    negative_heatmap = np.maximum(negative_heatmap, 0).sum(axis=0)
    negative_heatmap = (negative_heatmap - negative_heatmap.min()) / (
        negative_heatmap.max() - negative_heatmap.min()
    )
    normal_heatmap = np.maximum(normal_heatmap, 0).sum(axis=0)
    normal_heatmap = (normal_heatmap - normal_heatmap.min()) / (
        normal_heatmap.max() - normal_heatmap.min()
    )
    positive_mass_accuracy = calculate_relevance_mass_accuracy(
        torch.tensor(positive_heatmap), torch.tensor(gt_mask)
    )
    negative_mass_accuracy = calculate_relevance_mass_accuracy(
        torch.tensor(negative_heatmap), torch.tensor(gt_mask)
    )
    normal_mass_accuracy = calculate_relevance_mass_accuracy(
        torch.tensor(normal_heatmap), torch.tensor(gt_mask)
    )
    positive_rank_accuracy = calculate_relevance_rank_accuracy(
        torch.tensor(positive_heatmap), torch.tensor(gt_mask)
    )
    negative_rank_accuracy = calculate_relevance_rank_accuracy(
        torch.tensor(negative_heatmap), torch.tensor(gt_mask)
    )
    normal_rank_accuracy = calculate_relevance_rank_accuracy(
        torch.tensor(normal_heatmap), torch.tensor(gt_mask)
    )

    accuracies = {
        f"{training_type}_{model}_{class_name}_{xai_method}": {
            "positive": {
                "mass_accuracy": positive_mass_accuracy,
                "rank_accuracy": positive_rank_accuracy,
            },
            "negative": {
                "mass_accuracy": negative_mass_accuracy,
                "rank_accuracy": negative_rank_accuracy,
            },
            "normal": {
                "mass_accuracy": normal_mass_accuracy,
                "rank_accuracy": normal_rank_accuracy,
            },
        }
    }
    plt.imshow(image, cmap="gray")
    plt.imshow(positive_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    plt.imshow(gt_mask, alpha=0.3, cmap="gray")
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"aligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(normal_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    plt.imshow(gt_mask, alpha=0.3, cmap="gray")
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"normal_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(negative_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    plt.imshow(gt_mask, alpha=0.3, cmap="gray")
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"misaligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    _, ax = plt.subplots(3, 1, figsize=(15, 10))
    ax[0].imshow(image, cmap="gray")
    ax[0].imshow(positive_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    ax[0].imshow(gt_mask, alpha=0.3, cmap="gray")
    ax[0].set_title("Aligned")
    ax[0].axis("off")
    ax[1].imshow(image, cmap="gray")
    ax[1].imshow(normal_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    ax[1].imshow(gt_mask, alpha=0.3, cmap="gray")
    ax[1].set_title("Normal")
    ax[1].axis("off")
    ax[2].imshow(image, cmap="gray")
    ax[2].imshow(negative_heatmap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
    ax[2].imshow(gt_mask, alpha=0.3, cmap="gray")
    ax[2].set_title("Misaligned")
    ax[2].axis("off")
    plt.savefig(
        images_save_folder
        / f"positive_normal_negative_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    if selected_patient is None:
        return best_mass_patient, accuracies
    return accuracies


In [None]:
from multiprocessing import Pool
from itertools import product
from tqdm import tqdm


def generate_plot(args):
    class_name, xai_method, training_type, model = args
    _, acc = plot_heatmap(
        save_folder,
        dataset_folder,
        model,
        training_type,
        class_name,
        xai_method,
    )
    return acc


args_list = list(product(classes, xai_methods, training_types, models))


def create_plots():
    with Pool(8) as p:
        accs = list(
            tqdm(p.imap(generate_plot, args_list), total=len(args_list))
        )
    return accs


In [None]:
accs = create_plots()

In [None]:
from itertools import product
from tqdm import tqdm

accuracies = {}
for class_name, xai_method, training_type, model in tqdm(
    product(classes, xai_methods, training_types, models),
    total=len(classes) * len(xai_methods) * len(training_types) * len(models),
):
    _, acc = plot_heatmap(
        save_folder,
        dataset_folder,
        model,
        training_type,
        class_name,
        xai_method,
    )


In [None]:
!ls

# ViT + IG example attributions

In [None]:
diff_data = pd.read_csv("../results/metrics_diff_results.csv")


In [None]:
selected_model = "vit"
selected_attribution = "integrated_gradients"

In [None]:
diff_data = diff_data[
    (diff_data.model == selected_model)
    & (diff_data.xai_method == selected_attribution)
]


In [None]:
diff_data.shape

In [None]:
diff_data.columns

In [None]:
def plot_heatmap(
    save_folder,
    model,
    training_type,
    class_name,
    xai_method,
    patient_path,
    mask_alpha=0.5,
    attribution_alpha=0.5,
    attribution_cmap="jet",
    mask_cmap="gray",
    na_attribution_threshold=0.01,
    images_save_folder="images",
):
    inverse_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "inverse_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )
    normal_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "normal_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )

    image = plt.imread(dataset_folder / patient_path)
    gt_mask = np.load(
        dataset_folder
        / patient_path.replace(
            ".jpg", f"_{class_name.replace(' ', '_')}_mask.npy"
        )
    )

    positive_heatmap = np.load(
        normal_path.parent
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    positive_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(positive_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    negative_heatmap = np.load(
        inverse_path.parent
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    negative_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(negative_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    normal_heatmap = np.load(
        save_folder
        / f"{model}"
        / training_type
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    normal_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(normal_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    positive_heatmap = np.maximum(positive_heatmap, 0).sum(axis=0)
    positive_heatmap = (positive_heatmap - positive_heatmap.min()) / (
        positive_heatmap.max() - positive_heatmap.min()
    )
    negative_heatmap = np.maximum(negative_heatmap, 0).sum(axis=0)
    negative_heatmap = (negative_heatmap - negative_heatmap.min()) / (
        negative_heatmap.max() - negative_heatmap.min()
    )
    normal_heatmap = np.maximum(normal_heatmap, 0).sum(axis=0)
    normal_heatmap = (normal_heatmap - normal_heatmap.min()) / (
        normal_heatmap.max() - normal_heatmap.min()
    )
    if na_attribution_threshold:
        positive_heatmap = np.where(
            positive_heatmap < na_attribution_threshold, -1, positive_heatmap
        )
        negative_heatmap = np.where(
            negative_heatmap < na_attribution_threshold, -1, negative_heatmap
        )
        normal_heatmap = np.where(
            normal_heatmap < na_attribution_threshold, -1, normal_heatmap
        )
    plt.imshow(image, cmap="gray")
    plt.imshow(
        positive_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"aligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        positive_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"aligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}_without_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder / f"gt_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        negative_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"misaligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        negative_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"misaligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}_without_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()

In [None]:
from matplotlib.colors import ListedColormap
import seaborn as sns


In [None]:
# attribution_cmap = plt.cm.jet.copy()
attribution_cmap = sns.light_palette("red", as_cmap=True)
attribution_cmap.set_under(color="#FFFFFF00")
mask_cmap = ListedColormap(["#FFFFFF00", "yellow"])
attribution_cmap


In [None]:
save_folder = Path("../results")


In [None]:
patients = [
    571,
    994,
    758,
]
patients_name = [
    "upper_right",
    "middle",
    "lower_left",
]

In [None]:
for patient, patient_name in zip(patients, patients_name):
    cur_data = diff_data.iloc[patient]
    plot_heatmap(
        save_folder,
        selected_model,
        cur_data.pretraining,
        cur_data.label,
        selected_attribution,
        patient_path=cur_data.path,
        attribution_cmap=attribution_cmap,
        mask_cmap=mask_cmap,
        na_attribution_threshold=0.1,
        mask_alpha=0.3,
        images_save_folder=Path(f"vit_ig/{patient_name}"),
    )


In [None]:
def plot_binary_heatmap(
    save_folder,
    model,
    training_type,
    class_name,
    xai_method,
    patient_path,
    mask_alpha=0.5,
    attribution_alpha=0.5,
    attribution_cmap="jet",
    mask_cmap="gray",
    na_attribution_threshold=0.01,
    images_save_folder="images",
):
    inverse_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "inverse_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )
    normal_path = (
        save_folder
        / f"finetuned_{model}"
        / training_type
        / "normal_mask"
        / f"_{class_name}_{xai_method}_output.csv"
    )

    image = plt.imread(dataset_folder / patient_path)
    gt_mask = np.load(
        dataset_folder
        / patient_path.replace(
            ".jpg", f"_{class_name.replace(' ', '_')}_mask.npy"
        )
    )

    positive_heatmap = np.load(
        normal_path.parent
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    positive_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(positive_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    negative_heatmap = np.load(
        inverse_path.parent
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    negative_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(negative_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    normal_heatmap = np.load(
        save_folder
        / f"{model}"
        / training_type
        / patient_path.replace(
            ".jpg", f"{class_name}_{xai_method}_relevance.npy"
        )
    )
    normal_heatmap = (
        torch.nn.functional.interpolate(
            torch.tensor(normal_heatmap), size=image.shape, mode="bilinear"
        )
        .detach()
        .numpy()[0]
    )
    positive_heatmap = np.maximum(positive_heatmap, 0).sum(axis=0)
    positive_heatmap = (positive_heatmap - positive_heatmap.min()) / (
        positive_heatmap.max() - positive_heatmap.min()
    )
    negative_heatmap = np.maximum(negative_heatmap, 0).sum(axis=0)
    negative_heatmap = (negative_heatmap - negative_heatmap.min()) / (
        negative_heatmap.max() - negative_heatmap.min()
    )
    normal_heatmap = np.maximum(normal_heatmap, 0).sum(axis=0)
    normal_heatmap = (normal_heatmap - normal_heatmap.min()) / (
        normal_heatmap.max() - normal_heatmap.min()
    )
    if na_attribution_threshold:
        positive_heatmap = np.where(
            positive_heatmap < na_attribution_threshold, 0, 1
        )
        negative_heatmap = np.where(
            negative_heatmap < na_attribution_threshold, 0, 1
        )
        normal_heatmap = np.where(
            normal_heatmap < na_attribution_threshold, 0, 1
        )
    plt.imshow(image, cmap="gray")
    plt.imshow(
        positive_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"aligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        positive_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"aligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}_without_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder / f"gt_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        negative_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.imshow(gt_mask, alpha=mask_alpha, cmap=mask_cmap)
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"misaligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()
    plt.imshow(image, cmap="gray")
    plt.imshow(
        negative_heatmap,
        alpha=attribution_alpha,
        cmap=attribution_cmap,
        vmin=0,
        vmax=1,
    )
    plt.axis("off")
    plt.savefig(
        images_save_folder
        / f"misaligned_heatmap_{training_type}_{model}_{class_name}_{xai_method}_without_mask.png",
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.clf()
    plt.close()

In [None]:
attribution_cmap = ListedColormap(["#FFFFFF00", "red"])
mask_cmap = ListedColormap(["#FFFFFF00", "yellow"])

In [None]:
patients = [
    571,
    994,
    758,
]
patients_name = [
    "upper_right",
    "middle",
    "lower_left",
]

In [None]:
for patient, patient_name in zip(patients, patients_name):
    cur_data = diff_data.iloc[patient]
    plot_binary_heatmap(
        save_folder,
        selected_model,
        cur_data.pretraining,
        cur_data.label,
        selected_attribution,
        patient_path=cur_data.path,
        attribution_cmap=attribution_cmap,
        mask_cmap=mask_cmap,
        na_attribution_threshold=0.1,
        mask_alpha=0.3,
        images_save_folder=Path(f"vit_ig/{patient_name}"),
    )


In [None]:
patients_data = diff_data.iloc[patients]


In [None]:
patients_data.to_csv("vit_ig/selected_patients.csv", index=False)

# Dataset counts

## CheXpert

In [None]:
import pandas as pd

In [None]:
train = pd.read_csv("../dataset/train_split.csv")
val = pd.read_csv("../dataset/val_split.csv")
test = pd.read_csv("../dataset/test_split.csv")

In [None]:
val = pd.concat([val, test], axis=0).reset_index(drop=True)

In [None]:
classes_renames = {
    "Enlarged Cardiomediastinum": "Enl. Card."
}
columns_names = {
    0: "Negative",
    1: "Positive",
}

In [None]:
def get_val_counts(df, first_class_col):
    df = df.loc[:, first_class_col:].fillna(0).replace(-1, 0).apply(
        pd.Series.value_counts
    ).T.sort_index().rename(index=classes_renames, columns=columns_names)
    return df


In [None]:
train_chexpert = get_val_counts(train, "Enlarged Cardiomediastinum")

In [None]:
val_chexpert = get_val_counts(val, "Enlarged Cardiomediastinum")

## CheXlocalize

In [None]:
df = pd.read_csv("../dataset/chexlocalize/CheXpert/test_labels.csv")

In [None]:
def get_val_counts(df, first_class_col):
    df = df.loc[:, first_class_col:].fillna(0).replace(-1, 0).apply(
        pd.Series.value_counts
    ).T.sort_index().rename(index=classes_renames, columns=columns_names)
    return df


In [None]:
classes_renames = {
    "Enlarged Cardiomediastinum": "Enl. Card."
}
columns_names = {
    0: "Negative",
    1: "Positive",
}

In [None]:
test_chexlocalize = get_val_counts(df, "No Finding")

In [None]:
df = pd.read_csv("../dataset/chexlocalize/CheXpert/val_labels.csv")

In [None]:
val_chexlocalize = get_val_counts(df, "No Finding")

In [None]:
pd.MultiIndex.from_product([["CheXpert", "CheXlocalize"], ["Train", "Test"], ["Negative", "Positive"]])

In [None]:
full_df_chexpert = train_chexpert.join(val_chexpert, lsuffix='_train', rsuffix='_val')
full_df_chexlocalize = test_chexlocalize.join(val_chexlocalize.fillna(0).astype(int), lsuffix='_test', rsuffix='_val')

full_df = full_df_chexpert.join(full_df_chexlocalize, how="left", lsuffix='_chexpert', rsuffix='_chexlocalize')


In [None]:
full_df_chexpert

In [None]:
_cols = full_df.columns

In [None]:
pd.MultiIndex.from_product([["CheXpert", "CheXlocalize"], ["Train", "Test"], ["Negative", "Positive"]])

In [None]:
new_cols = pd.MultiIndex.from_tuples(
    [
        ("Chexpert", "Training", "Negative"),
        ("Chexpert", "Training", "Positive"),
        ("Chexpert", "Validation", "Negative"),
        ("Chexpert", "Validation", "Positive"),
        ("CheXlocalize", "Fine-tuning", "Negative"),
        ("CheXlocalize", "Fine-tuning", "Positive"),
        ("CheXlocalize", "Validation", "Negative"),
        ("CheXlocalize", "Validation", "Positive"),
    ]
)


In [None]:
full_df.columns = new_cols

In [None]:
print(full_df.to_latex())