In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchxrayvision as xrv

from epsutils.dicom import dicom_utils
from epsutils.image import image_utils

Load the segmentation model

In [None]:
model = xrv.baseline_models.chestx_det.PSPNet()
model.eval()

Load image

In [None]:
IMAGE_PATH = "./images/cardiomegaly_1_front.dcm"

img = dicom_utils.get_dicom_image(IMAGE_PATH, custom_windowing_parameters={"window_center": 0, "window_width": 0})
img = image_utils.numpy_array_to_pil_image(img, convert_to_uint8=True, convert_to_rgb=True)
img = np.array(img)

img = xrv.datasets.normalize(img, 255)  # Convert 8-bit image to [-1024, 1024] range.
img = img.mean(2)[None, ...]  # Make single color channel.

transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])

img = transform(img)
img = torch.from_numpy(img)

Run prediction

In [None]:
with torch.no_grad():
    pred = model(img)

Show raw results

In [None]:
plt.figure(figsize=(26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')

for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')

plt.tight_layout()

Get probabilities

In [None]:
pred = 1 / (1 + np.exp(-pred))  # Sigmoid.
pred[pred < 0.5] = 0
pred[pred > 0.5] = 1

Show segmentation results

In [None]:
plt.figure(figsize = (26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')

for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')

plt.tight_layout()

Show segmented heart

In [None]:
plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(img[0], cmap='gray')
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(img[0], cmap='gray')
plt.imshow(pred[0, 8], cmap='jet', alpha=0.5)
plt.title("Original Image with Segmentation")
plt.axis('off')

plt.tight_layout()
plt.show()