In [None]:
import sys
from io import StringIO

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import clear_output

from epsutils.gcs import gcs_utils
from xrv_segmentor import XrvSegmentor, BodyPart

INPUT_FILE_NAME = "gs://gradient-crs/archive/training/individual-labels/cardiomegaly/gradient-crs-all-batches-chest-images-with-standard-cardiomegaly-label-training.jsonl"
IMAGE_PATH_COLUMN_NAME = "image_path"
LABELS_COLUMN_NAME = "labels"
PATH_SUBSTITUTIONS = {"GRADIENT-DATABASE/CR/": "/mnt/efs/all-cxr/gradient/"}
FIGURE_SIZE = 15

# Display matplotlib backend.
print(f"Using the following matplotlib backend: {matplotlib.get_backend()}")

# Download or load input file.
if gcs_utils.is_gcs_uri(INPUT_FILE_NAME):
    print(f"Downloading input file {INPUT_FILE_NAME}")
    gcs_data = gcs_utils.split_gcs_uri(INPUT_FILE_NAME)
    content = gcs_utils.download_file_as_string(gcs_bucket_name=gcs_data["gcs_bucket_name"], gcs_file_name=gcs_data["gcs_path"])
else:
    print(f"Loading input file {INPUT_FILE_NAME}")
    with open(INPUT_FILE_NAME, "r") as file:
        content = file.read()

# Convert to Pandas dataset.
print("Converting input file to Pandas dataset")
if INPUT_FILE_NAME.endswith(".csv"):
    df = pd.read_csv(StringIO(content), low_memory=False)
elif INPUT_FILE_NAME.endswith(".jsonl"):
    df = pd.read_json(StringIO(content), lines=True)
else:
    raise ValueError("Input file type not supported")

# Instantiate segmentor.
print("Instantiating segmentor")
segmentor = XrvSegmentor()

# Print command.
print("Commands:")
print("- row number (zero-based indexing) - go to the row number")
print("- '+' - go to the next row")
print("- '-' - go to the previous row")
print("- 'q' - quit")
sys.stdout.flush()

row_number = -1

while True:
    selection = input("Enter command: ")
    clear_output(wait=True)

    try:
        selected_row_number = int(selection)
    except:
        selected_row_number = None

    if selected_row_number is not None:
        row_number = selected_row_number
    elif selection == "+":
        row_number += 1
    elif selection == "-":
        row_number -= 1
    elif selection == "q":
        break
    else:
        print("Invalid command")
        continue

    # Get image path and labels.
    row_number = max(0, min(row_number, len(df) - 1))
    row = df.iloc[row_number]
    image_path = row[IMAGE_PATH_COLUMN_NAME]
    labels = row[LABELS_COLUMN_NAME]

    for key, value in PATH_SUBSTITUTIONS.items():
        image_path = image_path.replace(key, value)

    print(f"Row number: {row_number}")
    print(f"Image path: {image_path}")
    print(f"Labels: {labels}")

    # Run segmentation.
    print("Running segmentation")
    sys.stdout.flush()
    result = segmentor.segment(image=image_path, body_parts=[BodyPart.LUNGS, BodyPart.HEART])
    image = result["image"]
    masks = result["segmentation_masks"]

    print("Segmentation complete, rendering images")
    sys.stdout.flush()

    # Plot results.
    fig, axs = plt.subplots(1, 3, figsize=(FIGURE_SIZE, FIGURE_SIZE))

    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap="gray")
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(image, cmap="gray")
    mask = (masks[0] * 255).astype(np.uint8)
    plt.imshow(mask, cmap="jet", alpha=0.2)
    plt.title("Lungs Segmentation")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(image, cmap="gray")
    mask = (masks[1] * 255).astype(np.uint8)
    plt.imshow(mask, cmap="jet", alpha=0.2)
    plt.title("Heart Segmentation")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

print("Finished")