In [1]:
import cv2
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import sys
import shutil
from pathlib import Path
from typing import Union, List

sys.path.append("..")

from plotting import window_image
from models.pix2pix import Generator, Discriminator
from run_management import get_path_of_directory_with_id
from preprocessing import min_max_normalization, ct_denormalization

tf.get_logger().setLevel("ERROR")

2023-07-24 19:31:39.598134: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-24 19:31:41.157571: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Utils

In [2]:
def ones_normalization(x):
    return (x - 0.5) * 2.0


def ones_back_normalization(x):
    return (x / 2.0) + 0.5


def min_max_back_normalization(x, x_min, x_max):
    return (x * (x_max - x_min)) + x_min


def compute_metrics(real_image, fake_image):
    return {
        "psnr": tf.image.psnr(real_image, fake_image, max_val=1.0).numpy(),
        "ssim": tf.image.ssim(real_image, fake_image, max_val=1.0).numpy(),
        "ssim_multiscale": tf.image.ssim_multiscale(
            real_image, fake_image, max_val=1.0
        ).numpy(),
    }


def load_models(weights_dir, out_channels, out_activation):
    generator = Generator(out_channels, out_activation)
    discriminator = Discriminator()

    generator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)

    ckpt = tf.train.Checkpoint(
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer,
        generator=generator,
        discriminator=discriminator,
    )
    ckpt.restore(tf.train.latest_checkpoint(weights_dir))
    return generator, discriminator


def save_translations(save_dir: Union[str, Path], generator: tf.keras.Model):
    save_dir = Path(save_dir) if isinstance(save_dir, str) else save_dir
    figures_dir = save_dir / "figures"
    figures_dir.mkdir(exist_ok=True)
    niftis_dir = save_dir / "niftis"
    niftis_dir.mkdir(exist_ok=True)

    examples_dict = {
        "train_026": 45,
        "train_035": 45,
        "train_044": 99,
        "train_058": 104,
        "test_022": 69,
        "test_037": 10
    }
    for i, (patient_id, lesion_idx) in enumerate(examples_dict.items()):
        fig, axs = plt.subplots(figsize=(20, 20))
        examples_dir = Path().absolute().parent / "examples"
        adc_path = examples_dir / f"{patient_id}/{patient_id}_adc.nii.gz"
        ncct_path = examples_dir / f"{patient_id}/{patient_id}_ncct.nii.gz"
        mask_path = examples_dir / f"{patient_id}/masks/{patient_id}_r1_mask.nii.gz"

        ncct = nib.load(ncct_path)
        ncct_arr = ncct.get_fdata()
        ncct_arr_win = window_image(ncct_arr, 40, 80)
        ncct_min, ncct_max = np.min(ncct_arr), np.max(ncct_arr)
        ncct_arr = min_max_normalization(ncct_arr)
        ncct_arr_res = tf.image.resize(ncct_arr, [256, 256]).numpy()
        ncct_arr_win = min_max_normalization(ncct_arr_win)
        ncct_arr_win_res = tf.image.resize(ncct_arr_win, [256, 256]).numpy()

        adc = nib.load(adc_path)
        adc_arr = adc.get_fdata()
        adc_min, adc_max = np.min(adc_arr), np.max(adc_arr)
        adc_arr = min_max_normalization(adc_arr)
        adc_arr_res = tf.image.resize(adc_arr, [256, 256]).numpy()

        # Compute the fake image.
        adc_fake_arr = generator.predict(
            ncct_arr_res.transpose(2, 0, 1)[..., np.newaxis], verbose=0
        )
        adc_fake_arr = adc_fake_arr.transpose(1, 2, 0, 3)[..., 0]

        # Create the plot.
        idxs = np.array([lesion_idx + (i * 6) for i in range(-1, 2)])
        metrics_slices = []
        ncct_slices, adc_fake_slices, adc_slices = [], [], []
        for idx in idxs:
            ncct_slices.append(np.rot90(ncct_arr_win_res[..., idx]))
            adc_fake_slices.append(np.rot90(adc_fake_arr[..., idx]))
            adc_slices.append(np.rot90(adc_arr_res[..., idx]))
            metrics = compute_metrics(
                adc_arr_res[..., idx : idx + 1], adc_fake_arr[..., idx : idx + 1]
            )
            metrics_slices.append(
                "PSNR: {:.3f}, SSIM: {:.3f}".format(metrics["psnr"], metrics["ssim"])
            )
        ncct_plot = np.vstack(ncct_slices)
        adc_fake_plot = np.vstack(adc_fake_slices)
        adc_plot = np.vstack(adc_slices)
        final_plot = np.hstack([ncct_plot, adc_fake_plot, adc_plot])
        axs.imshow(final_plot, cmap="gray")
        axs.annotate(
            "NCCT", (128, 25), ha="center", va="center", fontsize=32, color="white"
        )
        axs.annotate(
            "ADC$_{fake}$",
            (128 + 256 * 1, 25),
            ha="center",
            va="center",
            fontsize=32,
            color="white",
        )
        axs.annotate(
            "ADC",
            (128 + 256 * 2, 25),
            ha="center",
            va="center",
            fontsize=32,
            color="white",
        )
        axs.annotate(
            metrics_slices[0],
            (128 + 256 * 1, 50),
            ha="center",
            va="center",
            fontsize=24,
            color="white",
        )
        axs.annotate(
            metrics_slices[1],
            (128 + 256 * 1, 50 + 256),
            ha="center",
            va="center",
            fontsize=24,
            color="white",
        )
        axs.annotate(
            metrics_slices[2],
            (128 + 256 * 1, 50 + 512),
            ha="center",
            va="center",
            fontsize=24,
            color="white",
        )
        axs.axis("off")
        plt.savefig(
            figures_dir / f"{patient_id}.png", bbox_inches="tight", pad_inches=0
        )
        plt.close()

        # Save the niftis.
        shutil.copy(adc_path, niftis_dir)
        shutil.copy(ncct_path, niftis_dir)
        adc_fake_arr = tf.image.resize(adc_fake_arr, adc_arr.shape[:2]).numpy()
        adc_fake_arr = min_max_back_normalization(adc_fake_arr, adc_min, adc_max)
        adc_fake = nib.Nifti1Image(
            adc_fake_arr.astype(np.float32), adc.affine, adc.header
        )
        nib.save(adc_fake, niftis_dir / f"{patient_id}_adc_fake.nii.gz")


def compute_metrics_for_patients(
    patients_dirs: List[Union[str, Path]],
    experiment_ids: List[Union[str, int]],
    results_dir: Union[str, Path] = "results",
):  
    metrics = ["psnr", "ssim", "ssim_multiscale"]
    for experiment_id in experiment_ids:
        experiment_dir = Path(
            get_path_of_directory_with_id(experiment_id, results_dir=results_dir)
        )
        weights_dir = experiment_dir / "weights"
        generator, discriminator = load_models(weights_dir, 1, "sigmoid")

        metrics_dict = {"patient_id": [], "slice_index": [], "num_lesion_pixels": []}
        metrics_dict = {**metrics_dict, **{metric: [] for metric in metrics}}
        metrics_dict = {
            **metrics_dict,
            **{f"{metric}_lesion": [] for metric in metrics},
        }
        for patient_dir in tqdm(
            patients_dirs, desc=f"Predicting for patients of experiment {experiment_id}"
        ):
            if isinstance(patient_dir, str):
                patient_dir = Path(patient_dir)

            patient_id = patient_dir.name
            adc_path = patient_dir / f"{patient_id}_adc.nii.gz"
            ncct_path = patient_dir / f"{patient_id}_ncct.nii.gz"
            mask_path = patient_dir / f"masks/{patient_id}_r1_mask.nii.gz"

            ncct = nib.load(ncct_path)
            ncct_arr = ncct.get_fdata()
            ncct_arr_win = window_image(ncct_arr, 40, 80)
            ncct_min, ncct_max = np.min(ncct_arr), np.max(ncct_arr)
            ncct_arr = min_max_normalization(ncct_arr)
            ncct_arr_res = tf.image.resize(ncct_arr, [256, 256]).numpy()
            ncct_arr_win = min_max_normalization(ncct_arr_win)
            ncct_arr_win_res = tf.image.resize(ncct_arr_win, [256, 256]).numpy()

            adc = nib.load(adc_path)
            adc_arr = adc.get_fdata()
            adc_min, adc_max = np.min(adc_arr), np.max(adc_arr)
            adc_arr = min_max_normalization(adc_arr)
            adc_arr_res = tf.image.resize(adc_arr, [256, 256]).numpy()

            # Compute the fake image.
            adc_fake_arr = generator.predict(
                ncct_arr_res.transpose(2, 0, 1)[..., np.newaxis], verbose=0
            )
            adc_fake_arr = adc_fake_arr.transpose(1, 2, 0, 3)[..., 0]

            # Compute metrics.
            assert (
                adc_arr_res.shape == adc_fake_arr.shape
            ), "Real and Fake ADC shapes do not match."

            if mask_path.exists():
                mask = nib.load(mask_path)
                mask_arr = mask.get_fdata()
                mask_arr_res = tf.image.resize(
                    mask_arr, [256, 256], method="nearest"
                ).numpy()

            for i in range(adc_fake_arr.shape[-1]):
                slice_metrics = compute_metrics(
                    adc_arr_res[..., i : i + 1], adc_fake_arr[..., i : i + 1]
                )
                metrics_dict["patient_id"].append(patient_id)
                metrics_dict["slice_index"].append(i)
                if mask_path.exists():
                    metrics_dict["num_lesion_pixels"].append(
                        np.count_nonzero(mask_arr_res[..., i : i + 1])
                    )
                else:
                    metrics_dict["num_lesion_pixels"].append(0)
                for metric in metrics:
                    metrics_dict[metric].append(slice_metrics[metric])
                    metrics_dict[f"{metric}_lesion"].append(slice_metrics[metric])

        pd.DataFrame(metrics_dict).to_csv(
            str(experiment_dir / "evaluation/test_metrics.csv"), index=False
        )


# Compute metrics

In [3]:
pd.read_csv("/home/sangohe/projects/lesion-aware-translation/examples/examples_desc.csv")[["patient_id", "slice_index"]].set_index("patient_id").transpose().to_dict()

{
    'train_026': 45,
    'train_035': 45,
    'train_044': 99,
    'train_058': 104,
    'test_022': 69,
    'test_037': 10
}

{'train_026': 45,
 'train_035': 45,
 'train_044': 99,
 'train_058': 104,
 'test_022': 69,
 'test_037': 10}

In [4]:
patients_txt = "/home/sangohe/projects/lesion-aware-translation/data/APIS_synth-1_0_3_0_10_0_shuffled/test_patients.txt"
patients_dirs = [Path(line.strip()) for line in open(patients_txt, "r").readlines()]

results_dir = "/home/sangohe/projects/lesion-aware-translation/results"
experiment_ids = [7, 8]  # 2 -> dilated weights

compute_metrics_for_patients(patients_dirs, experiment_ids, results_dir)


2023-07-24 19:31:57.902935: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
2023-07-24 19:31:57.903003: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: 1660b6c49a51
2023-07-24 19:31:57.903019: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: 1660b6c49a51
2023-07-24 19:31:57.903132: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: NOT_FOUND: was unable to find libcuda.so DSO loaded into this program
2023-07-24 19:31:57.903177: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 525.105.17


Predicting for patients of experiment 7:   0%|          | 0/38 [00:00<?, ?it/s]

Predicting for patients of experiment 8:   0%|          | 0/38 [00:00<?, ?it/s]

# Create figures

In [5]:
results_dir = "/home/sangohe/projects/lesion-aware-translation/results"
experiment_ids = [7, 8]
for experiment_id in experiment_ids:
    experiment_dir = Path(get_path_of_directory_with_id(experiment_id, results_dir=results_dir))
    weights_dir = experiment_dir / "weights"
    generator, discriminator = load_models(weights_dir, 1, "sigmoid")
    save_translations(experiment_dir / "evaluation", generator)

In [None]:
results_dir = "/home/sangohe/projects/lesion-aware-translation/results"
experiment_ids = [2] # 2 -> dilated weights
for experiment_id in experiment_ids:
    experiment_dir = Path(get_path_of_directory_with_id(experiment_id, results_dir=results_dir))
    weights_dir = experiment_dir / "weights"
    generator, discriminator = load_models(weights_dir, 1, "sigmoid")

# Todos

- [ ] create a plot with the validation samples
- [ ] compute the psnr and ssim for the test samples (and test patches)

# Read the images from different evaluation folders and plot them side by side

In [6]:
results_dir = "/home/sangohe/projects/lesion-aware-translation/results"
experiment_ids = [7, 8]
patient_ids = ["train_026", "train_035", "train_044", "train_058"]

comparison_plots = []
for patient_id in patient_ids:
    figures = []
    for experiment_id in experiment_ids:
        experiment_dir = Path(get_path_of_directory_with_id(experiment_id, results_dir=results_dir))
        figures_dir = experiment_dir / "evaluation" / "figures"
        figure = cv2.imread(str(figures_dir / f"{patient_id}.png"))
        figure = cv2.cvtColor(figure, cv2.COLOR_BGR2RGB)
        figure = np.vstack([np.zeros((100, figure.shape[1], 3)).astype(np.uint8), figure])
        figures.append(figure)
        figures.append(np.ones((figure.shape[0], 10, 3)).astype(np.uint8) * 255)
    figures.pop()
    comparison_plots.append(np.hstack(figures))

figures_dir = Path("/home/sangohe/projects/lesion-aware-translation/figures")
for i, (experiment_id, patient_id, comparison_plot) in enumerate(zip(experiment_ids, patient_ids, comparison_plots)):
    plt.imshow(comparison_plot)
    annotation = experiment_dir.name.split("-")[3]
    annotation = "normal" if not "weights" in annotation else annotation
    plt.annotate("Weights", (256*3, 45), ha="center", va="center", color="white", fontsize=5)
    plt.annotate("Dilated Weights", (256*9, 45), ha="center", va="center", color="white", fontsize=5)
    plt.annotate("None", (256*15, 45), ha="center", va="center", color="white", fontsize=5)
    plt.axis("off")
    plt.savefig(figures_dir / f"{patient_id}_models_comparison_new.png", bbox_inches="tight", pad_inches=0, dpi=500)
    plt.close()

In [7]:
figures_dir = Path("/home/sangohe/projects/lesion-aware-translation/figures")

for patient_id in patient_ids[:-1]:
    figure_path = figures_dir / f"{patient_id}_models_comparison.png"
    figure_path_new = figures_dir / f"{patient_id}_models_comparison_new.png"
    figure = cv2.imread(str(figure_path))
    figure_new = cv2.imread(str(figure_path_new))

    new_figure = np.vstack([figure, (np.ones((5, figure_new.shape[1], 3)) * 255).astype(np.uint8), figure_new])
    cv2.imwrite(str(figures_dir / f"{patient_id}_full_comparison.png"), new_figure)


[ WARN:0@2372.208] global loadsave.cpp:244 findDecoder imread_('/home/sangohe/projects/lesion-aware-translation/figures/train_026_models_comparison.png'): can't open/read file: check file path/integrity


ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)

In [8]:
np.ones((100, figure_new.shape[1])).shape

(100, 2480)