# Full code to save rois and labels

In [3]:
import os
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from scipy.ndimage import label
from scipy.spatial.distance import euclidean
from statistics import median
from PIL import Image

# Function to process a single slice
def process_slice(im, net):
    og_y, og_x = im.shape
    # Resize image for inference
    im_resized = cv2.resize(im, (256, 256), interpolation=cv2.INTER_LINEAR)
    # Normalize and prepare for inference
    im_normalized = im_resized / im_resized.max()
    inputd = torch.from_numpy(im_normalized.astype(np.float32))[None, None, :, :]  # Add batch and channel dimensions
    # Run the model
    with torch.no_grad():
        pred = net(inputd)
    pred = torch.softmax(pred[0], dim=0)
    seg = torch.argmax(pred, dim=0).cpu().numpy()
    # Process segmentation to find ROI and center
    roi_mask = (seg == 1) | (seg == 2)  # Mask for values 1 and 2
    labeled_array, num_features = label(roi_mask)
    largest_component = max([(np.sum(labeled_array == i), i) for i in range(1, num_features + 1)], default=(0, None))[1]
    if largest_component is not None:
        roi_coords = np.argwhere(labeled_array == largest_component)
        center_y, center_x = roi_coords.mean(axis=0)
        scale_y, scale_x = og_y / im_resized.shape[0], og_x / im_resized.shape[1]
        return center_x * scale_x, center_y * scale_y
    return None

# Function to draw ROI and corresponding label masks
def draw_roi_with_labels(image, labels, center, roi_size=128):
    half_size = roi_size // 2
    start_y, start_x = int(center[1] - half_size), int(center[0] - half_size)
    end_y, end_x = start_y + roi_size, start_x + roi_size
    start_y, start_x = max(0, start_y), max(0, start_x)
    end_y, end_x = min(image.shape[0], end_y), min(image.shape[1], end_x)
    cropped_image = image[start_y:end_y, start_x:end_x]
    cropped_labels = labels[start_y:end_y, start_x:end_x]
    return cropped_image, cropped_labels

# Main function to process all cases
def process_cases(input_dir, output_rois_dir, output_labels_dir, net, output_visual_dir):
    # Ensure output directories exist
    os.makedirs(output_rois_dir, exist_ok=True)
    os.makedirs(output_labels_dir, exist_ok=True)

    # Group files by case
    # Group files by case
    files = [f for f in os.listdir(input_dir) if f.endswith(".npy")]
    case_groups = {}
    for file in files:
        case_id = "_".join(file.split("Case_")[1].split("_")[:1])  # Extract case identifier
        case_groups.setdefault(case_id, []).append(file)

    # Process each case
    for case_id, case_files in case_groups.items():
        print(f"Processing case: {case_id}")
        
        # Robust sorting by slice number
        case_files.sort(key=lambda x: int(x.split("_slice_")[1].split("_")[0]))

        centers = []
        images, label_masks = [], []

        # Step 1: Process slices and calculate centers
        for file in case_files:
            filepath = os.path.join(input_dir, file)
            data = np.load(filepath)
            cmr_image, labels = data[:, :, 0], data[:, :, 1]
            images.append(cmr_image)
            label_masks.append(labels)
            center = process_slice(cmr_image, net)
            centers.append(center)



        # Step 2: Median centerization
        valid_centers = [c for c in centers if c is not None]
        if valid_centers:
            # Calculate the median center
            median_x = median([c[0] for c in valid_centers])
            median_y = median([c[1] for c in valid_centers])
            median_center = (median_x, median_y)

            # Calculate distances from the median center
            # distances = [euclidean(c, median_center) for c in valid_centers]

            # Determine the threshold (e.g., based on standard deviation or fixed value)
            # threshold = 2 * np.std(distances)  # Example: 2 * standard deviation
            threshold = 0
            print(f"Outlier detection threshold: {threshold:.2f}")

            # Correct centers based on the threshold
            corrected_centers = [
                median_center if (c is None or euclidean(c, median_center) > threshold) else c
                for c in centers
            ]
        else:
            corrected_centers = centers  # No valid centers, no correction


        # Step 3: Crop ROIs and save outputs
        for file, cmr_image, labels, corrected_center in zip(case_files, images, label_masks, corrected_centers):
            if corrected_center is None:
                print(f"Skipping slice {file} due to missing corrected center.")
                continue

            # Crop ROI and label mask
            cropped_roi, cropped_labels = draw_roi_with_labels(cmr_image, labels, corrected_center)

            # Save cropped ROI and labels
            np.save(os.path.join(output_rois_dir, file), cropped_roi)
            np.save(os.path.join(output_labels_dir, file), cropped_labels)

            # Visualize and save the overlay
            fig, axs = plt.subplots(1, 2, figsize=(12, 6))
            axs[0].imshow(cropped_roi, cmap="gray", interpolation="none")
            axs[0].set_title("Cropped CMR Image")
            axs[0].axis("off")
            axs[1].imshow(cropped_roi, cmap="gray", interpolation="none")
            axs[1].imshow(cropped_labels, cmap="jet", alpha=0.5, interpolation="none")
            axs[1].set_title("CMR with Labels Overlay")
            axs[1].axis("off")
            overlay_path = os.path.join(output_visual_dir, f"{file}_overlay.png")
            plt.savefig(overlay_path, bbox_inches="tight")
            plt.close()

            print(f"Processed slice: {file}")


    


In [5]:
# Input and output directories
input_dir = r"Set this to the directory where your slices are stored"
output_rois_dir = r"Set this to the directory where you want to store the localized slices "
output_labels_dir = r"Set this to the directory where you want to store the localized labels "
output_visual_dir =  r"Set this to the directory where you want to store the the visualisations of the localized data "

# Load the MONAI model
import monai
parser = monai.bundle.load_bundle_config("..", "train.json")
net = parser.get_parsed_content("network_def")
net.load_state_dict(torch.load("../models/model.pt"))
net.eval()
print("Model loaded and ready for inference.")

# Process all cases
process_cases(input_dir, output_rois_dir, output_labels_dir, net, output_visual_dir)

  net.load_state_dict(torch.load("../models/model.pt"))


Model loaded and ready for inference.
Processing case: N006
Outlier detection threshold: 0.00
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_1_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_2_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_3_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_4_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_5_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_6_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_7_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_8_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N006_slice_9_NoInf_NoReflowN.npy
Processing case: N012
Outlier detection threshold: 0.00
Processed slice: EMIDEC_EMIDEC_Case_N012_slice_1_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N012_slice_2_NoInf_NoReflowN.npy
Processed slice: EMIDEC_EMIDEC_Case_N012_slice_3_NoInf_NoReflowN.npy
Processed slice: EMIDE