In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import torch
from dotenv import load_dotenv
from monai.transforms import LoadImaged

# Add the project source to the Python path
# This allows us to import modules from the 'src' directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.data.transforms import RidgeletTransformd

# Set notebook style
plt.style.use('seaborn-v0_8-whitegrid')
print("Setup complete.")

In [None]:
# Load environment variables from a .env file at the project root
load_dotenv(dotenv_path='../.env')

IMAGE_ROOT_DIR = os.getenv("MIMIC_CXR_P_FOLDERS_PATH")
PROJECT_DATA_FOLDER_PATH = os.getenv("PROJECT_DATA_FOLDER_PATH")

if not all([IMAGE_ROOT_DIR, PROJECT_DATA_FOLDER_PATH]):
    raise ValueError(
        "Please ensure MIMIC_CXR_P_FOLDERS_PATH, PROJECT_DATA_FOLDER_PATH are set in your .env file."
    )

print(f"Image root directory loaded.")
print(f"Metadata CSV path loaded.")

In [None]:
# --- 1. Define paths and load the validation split ---
split_folder_name = "split_2000"
val_csv_path = os.path.join(PROJECT_DATA_FOLDER_PATH, "splits", split_folder_name, "validation.csv")
df_val = pd.read_csv(val_csv_path)

# --- 2. Construct the full path to the metadata CSV file ---
metadata_dir = os.getenv("MIMIC_CXR_METADATA_PATH")
metadata_filename = "mimic-cxr-2.0.0-metadata.csv"
full_metadata_path = os.path.join(metadata_dir, metadata_filename)

# --- 3. Load the metadata and merge with the validation split ---
print(f"Loading metadata from: {full_metadata_path}")
if not os.path.exists(full_metadata_path):
    raise FileNotFoundError(f"Error: The metadata file was not found at the expected path: {full_metadata_path}")

df_meta = pd.read_csv(full_metadata_path)
df_merged = pd.merge(df_val, df_meta[['dicom_id', 'subject_id', 'study_id']], on='dicom_id', how='left')

# --- 4. Select a sample that specifically has a fracture ---
# Filter the DataFrame for records where 'fracture_present' is 1.
df_fractures = df_merged[df_merged['fracture'] == 1]

# Check if any fracture cases exist in the validation set.
if df_fractures.empty:
    raise ValueError("No fracture cases were found in the validation split. Cannot select a sample.")

# Select the first record from the filtered list of fractures.
sample_record = df_fractures.iloc[0].to_dict()

subject_id = str(int(sample_record['subject_id_x']))
study_id = str(int(sample_record['study_id_x']))
dicom_id = str(sample_record['dicom_id'])

# Construct the relative path based on the standard MIMIC-CXR-JPG directory structure
image_relative_path = os.path.join(
    f"p{subject_id[:2]}",
    f"p{subject_id}",
    f"s{study_id}",
    f"{dicom_id}.jpg"
)
full_image_path = os.path.join(IMAGE_ROOT_DIR, image_relative_path)

# --- 5. Prepare the dictionary for MONAI ---
sample_dict = {"image": full_image_path}

print(f"\nSuccessfully located image with a fracture for visualization:\n{full_image_path}")

In [None]:
from monai.transforms import Compose, LoadImaged, Resized, EnsureChannelFirstd
import torch

# Define a square size for the transform.
# The FRT is most efficient with sizes that are powers of 2.
SQUARE_SIZE = 512

# 1. Create a corrected, robust processing pipeline.
preprocess_pipeline = Compose([
    # Loads the grayscale JPG image.
    LoadImaged(keys=["image"]),
    # Safely ensures the image has a channel dimension, e.g., shape becomes [1, H, W].
    # This is the standard way to prepare images and prevents the corruption error.
    EnsureChannelFirstd(keys=["image"]),
    # Resize the image to a square.
    Resized(keys=["image"], spatial_size=(SQUARE_SIZE, SQUARE_SIZE))
])

# Apply the preprocessing steps
loaded_and_resized_dict = preprocess_pipeline(sample_dict)

# 2. Create an instance of the Ridgelet transform
ridgelet_transformer = RidgeletTransformd(keys=["image"], threshold_ratio=0.2)

# 3. Apply the transform to the now-square, single-channel image data
transformed_dict = ridgelet_transformer(loaded_and_resized_dict)

# The transformed_dict now contains the reconstructed image
# This will now be the correctly preprocessed X-ray image.
original_image_tensor = loaded_and_resized_dict['image']
transformed_image_tensor = transformed_dict['image']

print("Original (resized) image shape:", original_image_tensor.shape)
print("Transformed image shape:", transformed_image_tensor.shape)

In [None]:
# Squeeze the channel dimension (C, H, W) -> (H, W) for grayscale plotting
original_image_np = original_image_tensor.squeeze().numpy()
transformed_image_np = transformed_image_tensor.squeeze().numpy()

# Create a figure to display the images
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Plot Original Image - Add .T to transpose the array
axes[0].imshow(original_image_np.T, cmap='gray')
axes[0].set_title('Original Image', fontsize=16)
axes[0].axis('off')

# Plot Reconstructed Image after Ridgelet Transform - Add .T here as well
axes[1].imshow(transformed_image_np.T, cmap='gray')
axes[1].set_title('Reconstructed Image (after Ridgelet)', fontsize=16)
axes[1].axis('off')

plt.tight_layout()
plt.show()