In [None]:
import ast
import inspect
import math
import os
import sys
from io import StringIO

import matplotlib.pyplot as plt
import pandas as pd
import torch
from IPython.display import clear_output

from epsclassifiers.cr_chest_classifier import CrChestClassifier
from epsclassifiers.cr_projection_classifier import CrProjectionClassifier
from epsutils.dicom import dicom_utils
from epsutils.gcs import gcs_utils
from epsutils.image import image_utils

INPUT_FILE_NAME = "/mnt/efs/all-cxr/simonmed/batch1/simonmed_batch_1_reports_with_image_paths_filtered_standardized_with_dicom_data_mapped_modalities_mapped_body_parts_with_uncertain_labels_cleaned_unflagged.csv"
BASE_IMAGES_PATH = "/mnt/efs/all-cxr/simonmed/images/422ca224-a9f2-4c64-bf7c-bb122ae2a7bb/"
IMAGE_PATH_COLUMN_NAME = "filtered_image_paths"
CSV_SEPARATOR = ","
CSV_HEADER = True
RUN_PROJECTION_CLASSIFICATION = True
RUN_CHEST_NON_CHEST_CLASSIFICATION = True
FIGURE_SIZE = 15

# Download or load input file.
if gcs_utils.is_gcs_uri(INPUT_FILE_NAME):
    print(f"Downloading input file {INPUT_FILE_NAME}")
    gcs_data = gcs_utils.split_gcs_uri(INPUT_FILE_NAME)
    content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])
else:
    print(f"Loading input file {INPUT_FILE_NAME}")
    with open(INPUT_FILE_NAME, "r") as file:
        content = file.read()

# Convert to Pandas dataset.
print("Converting input file to Pandas dataset")
if INPUT_FILE_NAME.endswith(".csv"):
    df = pd.read_csv(StringIO(content), low_memory=False, sep=CSV_SEPARATOR, header=0 if CSV_HEADER else None)
elif INPUT_FILE_NAME.endswith(".jsonl"):
    df = pd.read_json(StringIO(content), lines=True)
else:
    raise ValueError("Input file type not supported")

# Instantiate projection classifier.
if RUN_PROJECTION_CLASSIFICATION:
    print("Loading the projection classification model")
    MODEL_PATH = os.path.join(os.path.dirname(inspect.getmodule(CrProjectionClassifier).__file__), "models/cr_projection_classifier_trained_on_500k_gradient_samples.pt")
    projection_classifier = CrProjectionClassifier()
    projection_classifier.load_state_dict(torch.load(MODEL_PATH))

# Instantiate chest/non-chest classifier.
if RUN_CHEST_NON_CHEST_CLASSIFICATION:
    print("Loading the chest/non-chest classification model")
    MODEL_PATH = os.path.join(os.path.dirname(inspect.getmodule(CrChestClassifier).__file__), "models/cr_chest_classifier_trained_on_600k_gradient_samples.pt")
    chest_classifier = CrChestClassifier()
    chest_classifier.load_state_dict(torch.load(MODEL_PATH))

# Print command.
print("Commands:")
print("- row number (zero-based indexing) - go to the row number")
print("- '+' - go to the next row")
print("- '-' - go to the previous row")
print("- 'q' - quit")

row_number = -1

while True:
    selection = input("Enter command: ")
    clear_output(wait=True)

    try:
        selected_row_number = int(selection)
    except:
        selected_row_number = None

    if selected_row_number is not None:
        row_number = selected_row_number
    elif selection == "+":
        row_number += 1
    elif selection == "-":
        row_number -= 1
    elif selection == "q":
        break
    else:
        print("Invalid command")
        continue

    # Get row.
    row_number = max(0, min(row_number, len(df) - 1))
    row = df.iloc[row_number]

    # Load image data.
    print("Loading image data...")
    sys.stdout.flush()

    try:
        image_paths = ast.literal_eval(row[IMAGE_PATH_COLUMN_NAME])
    except:
        image_paths = row[IMAGE_PATH_COLUMN_NAME]

    image_paths = image_paths if isinstance(image_paths, list) else [image_paths]

    if BASE_IMAGES_PATH is not None:
        image_paths = [os.path.join(BASE_IMAGES_PATH, image_path) for image_path in image_paths]

    numpy_images = [dicom_utils.get_dicom_image_fail_safe(image_path, custom_windowing_parameters={"window_center": 0, "window_width": 0}) for image_path in image_paths]

    if RUN_PROJECTION_CLASSIFICATION:
        projections = projection_classifier.predict(images=numpy_images, device="cuda")
    else:
        projections = None

    if RUN_CHEST_NON_CHEST_CLASSIFICATION:
        chest_classification_results = chest_classifier.predict(images=numpy_images, device="cuda")
    else:
        chest_classification_results = None

    images = [image_utils.numpy_array_to_pil_image(numpy_image) for numpy_image in numpy_images]

    # Show info.
    print(f"Row number: {row_number}")

    for column, value in row.items():
        if isinstance(value, str):
            value = value.replace("\n", " ")
        print(f"{column}: {value}")

    if projections:
        print(f"Projections: {projections}")

    if chest_classification_results:
        print(f"Chest/non-chest classification: {chest_classification_results}")

    # Show image data.
    if len(images) == 1:
        plt.imshow(images[0])
    else:
        num_images = len(images)
        cols = 2
        rows = math.ceil(num_images / cols)
        fig, axs = plt.subplots(rows, cols, figsize=(FIGURE_SIZE, FIGURE_SIZE))
        for i in range(num_images):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(images[i])
            plt.axis("off")
    plt.tight_layout()
    plt.show()

print("Finished")