In [1]:
"""
Improved visualization script for MONAI transformations applied to medical images.
Includes robust error handling for different image types.
"""

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import monai
from monai.data import PILReader
from monai.transforms import (
    LoadImaged,
    AsChannelFirstd,
    AddChanneld,
    ScaleIntensityd,
    SpatialPadd,
    RandSpatialCropd,
    RandAxisFlipd,
    RandRotate90d,
    RandGaussianNoised,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandHistogramShiftd,
    RandZoomd,
    EnsureTyped,
    Compose,
)
from skimage import io, img_as_ubyte
import traceback
import pickle
from tqdm import tqdm
from monai.data import Dataset, DataLoader
from PIL import Image


join = os.path.join

In [2]:
def main():

    input_size = 256
    
    img_path = "../data/preprocessing_outputs/normalized_data/images"
    gt_path = "../data/preprocessing_outputs/normalized_data/labels"

    target_path = "../data/preprocessing_outputs/transformed_images_labels"

    os.makedirs(os.path.join(target_path, "images"), exist_ok=True)
    os.makedirs(os.path.join(target_path, "labels"), exist_ok=True)

    train_transforms = Compose([
        LoadImaged(keys=["img", "label"], reader=PILReader, dtype=np.uint8),
        AddChanneld(keys=["label"], allow_missing_keys=True),
        AsChannelFirstd(keys=["img"], channel_dim=-1, allow_missing_keys=True),
        ScaleIntensityd(keys=["img"], allow_missing_keys=True),
        SpatialPadd(keys=["img", "label"], spatial_size=input_size),
        RandSpatialCropd(keys=["img", "label"], roi_size=input_size, random_size=False),
        RandAxisFlipd(keys=["img", "label"], prob=0.5),
        RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
        RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
        RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
        RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
        RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
        RandZoomd(keys=["img", "label"], prob=0.15, min_zoom=0.8, max_zoom=1.5, mode=["area", "nearest"]),
        EnsureTyped(keys=["img", "label"]),
    ])
    
    image_files = sorted([f for f in os.listdir(img_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    label_files = sorted([f for f in os.listdir(gt_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    data_dicts = [{"img": os.path.join(img_path, img), "label": os.path.join(gt_path, lbl), "name": os.path.splitext(img)[0]} for img, lbl in zip(image_files, label_files)]
    
    dataset = Dataset(data=data_dicts, transform=train_transforms)
    loader = DataLoader(dataset, batch_size=1, num_workers=4)
    
    for batch in loader:
        img_name = batch["name"][0]
        transformed_img = batch["img"].squeeze().numpy().transpose(1, 2, 0)
        transformed_label = batch["label"].squeeze().numpy()
        
        Image.fromarray((transformed_img * 255).astype(np.uint8)).save(os.path.join(target_path, "images", f"{img_name}.png"))
        Image.fromarray((transformed_label * 255).astype(np.uint8)).save(os.path.join(target_path, "labels", f"{img_name}.png"))
        print(f"Saved {img_name}.png and corresponding label.")


In [3]:
main()



Saved cell_00001.png and corresponding label.
Saved cell_00002.png and corresponding label.
Saved cell_00003.png and corresponding label.
Saved cell_00004.png and corresponding label.
Saved cell_00005.png and corresponding label.
Saved cell_00006.png and corresponding label.
Saved cell_00007.png and corresponding label.
Saved cell_00008.png and corresponding label.
Saved cell_00009.png and corresponding label.
Saved cell_00010.png and corresponding label.
Saved cell_00011.png and corresponding label.
Saved cell_00012.png and corresponding label.
Saved cell_00013.png and corresponding label.
Saved cell_00014.png and corresponding label.
Saved cell_00015.png and corresponding label.
Saved cell_00016.png and corresponding label.
Saved cell_00017.png and corresponding label.
Saved cell_00018.png and corresponding label.
Saved cell_00019.png and corresponding label.
Saved cell_00020.png and corresponding label.
Saved cell_00021.png and corresponding label.
Saved cell_00022.png and correspon