In [None]:
import os
import sys
from io import StringIO

import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import clear_output

from epsutils.dicom import dicom_utils
from epsutils.gcs import gcs_utils
from epsutils.image import image_utils

INPUT_FILE_NAME = "/workspace/models/intern_vit_classifier-training-on-gradient_cr_airspace_opacity/pxp6bjmf/misclassified_epoch_4_20250227_114821_utc.jsonl"
IMAGE_PATH_COLUMN_NAME = "file_name"
BASE_IMAGE_PATH = "/workspace/CR/22JUL2024"
GCS_REPORTS_FILE = "gs://report_csvs/cleaned/CR/labels_for_binary_classification/GRADIENT_CR_22JUL2024_chest_with_image_paths_with_airspace_opacity_labels.csv"

# Download or load input file.
if gcs_utils.is_gcs_uri(INPUT_FILE_NAME):
    print(f"Downloading input file {INPUT_FILE_NAME}")
    gcs_data = gcs_utils.split_gcs_uri(INPUT_FILE_NAME)
    content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])
else:
    print(f"Loading input file {INPUT_FILE_NAME}")
    with open(INPUT_FILE_NAME, "r") as file:
        content = file.read()

# Convert to Pandas dataset.
print("Converting input file to Pandas dataset")
if INPUT_FILE_NAME.endswith(".csv"):
    df = pd.read_csv(StringIO(content))
elif INPUT_FILE_NAME.endswith(".jsonl"):
    df = pd.read_json(StringIO(content), lines=True)
else:
    raise ValueError("Input file type not supported")

# Download reports file.
print(f"Downloading reports file {GCS_REPORTS_FILE}")
gcs_data = gcs_utils.split_gcs_uri(GCS_REPORTS_FILE)
content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])

# Convert to Pandas dataset.
print("Converting reports file to Pandas dataset")
reports_df = pd.read_csv(StringIO(content))
reports_df = reports_df[["patient_id", "accession_number", "study_uid", "report_text"]]

# Print command.
print("Commands:")
print("- row number (zero-based indexing) - go to the row number")
print("- '+' - go to the next row")
print("- '-' - go to the previous row")
print("- 'q' - quit")

row_number = -1

while True:
    selection = input("Enter command: ")
    clear_output(wait=True)

    try:
        selected_row_number = int(selection)
    except:
        selected_row_number = None

    if selected_row_number is not None:
        row_number = selected_row_number
    elif selection == "+":
        row_number += 1
    elif selection == "-":
        row_number -= 1
    elif selection == "q":
        break
    else:
        print("Invalid command")
        continue

    # Get row.
    row_number = max(0, min(row_number, len(df) - 1))
    row = df.iloc[row_number]

    # Load image.
    print("Loading image...")
    sys.stdout.flush()
    image_path = row[IMAGE_PATH_COLUMN_NAME]
    image = dicom_utils.get_dicom_image(image_path, custom_windowing_parameters={"window_center": 0, "window_width": 0})
    image = image_utils.numpy_array_to_pil_image(image)

    # Get report text.
    print("Looking for report text...")
    relative_path = os.path.relpath(image_path, BASE_IMAGE_PATH)
    segments = relative_path.split("/")
    filtered_df = reports_df[(reports_df["patient_id"] == segments[0]) & (reports_df["accession_number"] == segments[1]) & (reports_df["study_uid"] == segments[3])]
    report_text = filtered_df["report_text"].iloc[0] if not filtered_df.empty else None

    print(f"Row number: {row_number}")
    print(f"Image path: {image_path}")
    print(row)

    plt.imshow(image)
    plt.show()

    print("Report text:")
    print(report_text)

print("Finished")
