In [None]:
import math
import os
import sys
from io import StringIO

import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import clear_output

from epsutils.dicom import dicom_utils
from epsutils.gcs import gcs_utils
from epsutils.image import image_utils

INPUT_FILE_NAME = "/mnt/efs/all-cxr/segmed/batch1/chest_non_chest_classificaton_results.csv"
BASE_IMAGES_PATH = "/mnt/efs/all-cxr/segmed/batch1"
IMAGE_PATH_COLUMN_NAME = 0
CSV_SEPARATOR = ";"
CSV_HEADER = False
REPORTS_BASE_IMAGES_PATH = "/workspace/CR"
GCS_REPORTS_FILE = None
FIGURE_SIZE = 15

# Download or load input file.
if gcs_utils.is_gcs_uri(INPUT_FILE_NAME):
    print(f"Downloading input file {INPUT_FILE_NAME}")
    gcs_data = gcs_utils.split_gcs_uri(INPUT_FILE_NAME)
    content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])
else:
    print(f"Loading input file {INPUT_FILE_NAME}")
    with open(INPUT_FILE_NAME, "r") as file:
        content = file.read()

# Convert to Pandas dataset.
print("Converting input file to Pandas dataset")
if INPUT_FILE_NAME.endswith(".csv"):
    df = pd.read_csv(StringIO(content), low_memory=False, sep=CSV_SEPARATOR, header=0 if CSV_HEADER else None)
elif INPUT_FILE_NAME.endswith(".jsonl"):
    df = pd.read_json(StringIO(content), lines=True)
else:
    raise ValueError("Input file type not supported")

# Download reports file.
if GCS_REPORTS_FILE is not None:
    print(f"Downloading reports file {GCS_REPORTS_FILE}")
    gcs_data = gcs_utils.split_gcs_uri(GCS_REPORTS_FILE)
    content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])

    # Convert to Pandas dataset.
    print("Converting reports file to Pandas dataset")
    reports_df = pd.read_csv(StringIO(content), low_memory=False)
    reports_df = reports_df[["patient_id", "accession_number", "study_uid", "report_text"]]
else:
    reports_df = None

# Print command.
print("Commands:")
print("- row number (zero-based indexing) - go to the row number")
print("- '+' - go to the next row")
print("- '-' - go to the previous row")
print("- 'q' - quit")

row_number = -1

while True:
    selection = input("Enter command: ")
    clear_output(wait=True)

    try:
        selected_row_number = int(selection)
    except:
        selected_row_number = None

    if selected_row_number is not None:
        row_number = selected_row_number
    elif selection == "+":
        row_number += 1
    elif selection == "-":
        row_number -= 1
    elif selection == "q":
        break
    else:
        print("Invalid command")
        continue

    # Get row.
    row_number = max(0, min(row_number, len(df) - 1))
    row = df.iloc[row_number]

    # Load image data.
    print("Loading image data...")
    sys.stdout.flush()
    image_paths = row[IMAGE_PATH_COLUMN_NAME]
    image_paths = image_paths if isinstance(image_paths, list) else [image_paths]

    if BASE_IMAGES_PATH is not None:
        image_paths = [os.path.join(BASE_IMAGES_PATH, image_path) for image_path in image_paths]

    images = []
    for image_path in image_paths:
        image = dicom_utils.get_dicom_image(image_path, custom_windowing_parameters={"window_center": 0, "window_width": 0})
        image = image_utils.numpy_array_to_pil_image(image)
        images.append(image)

    # Get report text.
    if reports_df is not None:
        print("Looking for report text...")
        relative_path = os.path.relpath(image_path, REPORTS_BASE_IMAGES_PATH)
        segments = relative_path.split("/")
        filtered_df = reports_df[(reports_df["patient_id"] == segments[1]) & (reports_df["accession_number"] == segments[2]) & (reports_df["study_uid"] == segments[4])]
        report_text = filtered_df["report_text"].iloc[0] if not filtered_df.empty else None
    else:
        report_text = None

    print(f"Row number: {row_number}")
    print(f"Image path: {image_path}")
    print(row)

    # Show image data.
    if len(images) == 1:
        plt.imshow(images[0])
    else:
        num_images = len(images)
        cols = 2
        rows = math.ceil(num_images / cols)
        fig, axs = plt.subplots(rows, cols, figsize=(FIGURE_SIZE, FIGURE_SIZE))
        for i in range(num_images):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(images[i])
            plt.axis("off")
    plt.tight_layout()
    plt.show()

    print("Report text:")
    print(report_text)

print("Finished")