In [122]:
import onnxruntime as rt
import pydicom
import numpy as np
import torch
from torchvision import transforms
import os
from PIL import Image, ImageDraw, ImageOps



In [3]:
# Load the model from the ONNX file
model = rt.InferenceSession('model.onnx')
input_name = model.get_inputs()[0].name

In [4]:
t = torch.nn.Sequential(
    transforms.Resize((512, 512)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
)
scripted_transforms = torch.jit.script(t)

In [5]:
input_name

'input.1'

In [6]:
def get_first_of_dicom_field_as_int(x):
    if type(x) == pydicom.multival.MultiValue:
        return int(x[0])
    return int(x)

In [7]:
def window_image(img, window_center, window_width, intercept, slope):
    """
    Get windowed image from dicom

    Inputs:
        - original image
        - window_center
        - window_width
        - intercept
        - slope
    """
    img = img * slope + intercept
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img[img < img_min] = img_min
    img[img > img_max] = img_max
    return img


In [160]:
def overlay_text_on_dicom(image_path, text):
    # Load the DICOM image
    ds = pydicom.dcmread(image_path)

    # Apply VOI LUT to get the actual pixel data
    pixel_data = np.zeros((512,512))

    # Create a PIL image from the pixel data
    img = Image.fromarray(pixel_data)

    # Draw text on the image
    draw = ImageDraw.Draw(img)
    draw.text((10, 10), text, fill="red")

    # Save the modified image back to DICOM format
    img_array = np.array(img)
    ds.PixelData = img_array.tobytes()
    # ds.save_as(image_path.replace('.dcm', '_overlay.dcm'))

    return Image.fromarray(ds.pixel_array).convert("RGB")
    

In [161]:
def preprocess(dcm):

    metadata = {
        "intercept": dcm.RescaleIntercept,
        "slope": dcm.RescaleSlope,
    }

    images = []
    for window_center, window_width in [[40, 80], [80, 200], [600, 2800]]:
        metadata["window_center"] = window_center
        metadata["window_width"] = window_width
        metadata = {k: get_first_of_dicom_field_as_int(v) for k, v in metadata.items()}
        # print("Shp:", img_dicom.pixel_array.shape)
        img = window_image(dcm.pixel_array, **metadata)
        images.append(img)

    stacked = np.stack(images).astype(np.float32)
    stacked = stacked[:,:,:,np.newaxis]
    stacked = np.transpose(stacked, (3, 0, 1, 2))
    return torch.from_numpy(stacked)


In [162]:
def process_image(file):
    dcm = pydicom.dcmread(file)
    input = preprocess(dcm)
    input = scripted_transforms(input).cpu().numpy()

    # Execute the inference via ONNX Runtime
    outputs = model.run(None, {input_name: input})

    outputs = np.argmax(outputs[0])

    text = "Positive" if outputs == 1 else "Negative"

    # Normalize the background (input) image
    dcmpixel = dcm.pixel_array
    background = 255 * ( 1.0 / dcmpixel.max() * (dcmpixel - dcmpixel.min()) )
    background = background.astype(np.ubyte)
    background_image = Image.fromarray(background).convert("RGB")

    overlay_image = overlay_text_on_dicom(file, text)
    # overlay_image = ImageOps.colorize(overlay_image, black="black", white='yellow')


    # Blend the two images
    final_image = Image.blend(overlay_image, background_image, 0.75)
    final_array = np.array(final_image).astype(np.uint8) 

    print(final_array.shape)

    return final_array

    # # Write the final image back to a new DICOM (color) image 
    # dcm.SeriesInstanceUID = series_uid
    # dcm.SOPInstanceUID = generate_uid()
    # dcm.SeriesNumber = dcm.SeriesNumber + settings["series_offset"]
    # dcm.file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
    # dcm.Rows = final_image.height
    # dcm.Columns = final_image.width
    # dcm.PhotometricInterpretation = "RGB"
    # dcm.SamplesPerPixel = 3
    # dcm.BitsStored = 8
    # dcm.BitsAllocated = 8
    # dcm.HighBit = 7
    # dcm.PixelRepresentation = 0
    # dcm.PixelData = final_array.tobytes()
    # dcm.SeriesDescription = "SEG(" + dcm.SeriesDescription + ")"
    # dcm.save_as(dcm_file_out)  

    # return outputs
    



In [163]:
# process_image('/dataNAS/people/arogya/projects/ich-evaluation/outputs/3dq-test-data/any/ID_27fe0a13c6/ID_459b77eea.dcm')

In [164]:
DIR = '/dataNAS/people/arogya/projects/ich-evaluation/outputs/3dq-test-data/any/ID_27fe0a13c6'
images = os.listdir(DIR)
for image in images:
    out = process_image(f'{DIR}/{image}')
    # out[0].save_as("overlayed.dcm")
    # print(out)

(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
(512, 512, 3)
