In [None]:
import glob
import os
from matplotlib import pyplot as plt
from rx import Rx
from utils import natural_key
from datasets.calcium_dataset import CalciumDataset

import numpy as np
import skimage
import torch
import torchvision
import torchxrayvision as xrv

ds = CalciumDataset()

In [None]:
data_path = '/data/calcium_processed'
folders = glob.glob(f'{data_path}/*/rx/')
folders.sort(key=natural_key)

patients_ids = [os.path.basename(os.path.normpath(folder[:-4])) for folder in folders
                if os.path.basename(os.path.normpath(folder[:-4]))]

In [None]:
for patient in patients_ids:
    img_files = glob.glob(f'{data_path}/{patient}/rx/*.dcm')
    img_files.sort(key=natural_key)
    
    for i in range(0, len(img_files)):
        rx = Rx(img_files[i])
        try:
            print(patient, img_files[i], rx.metadata["Pixel Spacing"])
        except:
            print(patient, img_files[i], rx.metadata["Imager Pixel Spacing"])
        
        print(rx.img.min(), rx.img.max())
        
        try:
            print(rx.metadata["Series Description"])
        except:
            print('-')
        try:
            print(rx.metadata["Acquisition Device Processing Description"])
        except:
            print('-')
        try:
            print(rx.metadata["Protocol Name"])
        except:
            print('-')
        print('')

In [None]:
rx = Rx('/data/calcium_processed/CAC_439/rx', 'IM-0104*.dcm')

In [None]:
plt.imshow(rx.img.squeeze(0), cmap='gray')
rx.img.shape
rx.metadata

In [None]:
shapes = []

for i in range(0, len(ds)):
    patient = ds.patients[i]
    print(patient["id"])
    rx, label = ds[i]
    
    fig = plt.figure(figsize=(2,2))
    subplot = fig.add_subplot(1, 1, 1)
    plt.imshow(rx.squeeze(0), cmap='gray')
    plt.show()
    
    print(f'{patient["id"]} {rx.shape} {label}')
    shapes.append(rx.shape)


In [None]:
import statistics as s

heights = [s[1] for s in shapes]
widths = [s[2] for s in shapes]

print(len(widths), min(widths), s.mean(widths), s.median(widths), max(widths))
print(len(heights), min(heights), s.mean(heights), s.median(heights), max(heights))

In [None]:
# see https://github.com/mlmed/torchxrayvision/blob/master/scripts/segmentation.ipynb
import numpy as np
import skimage
import torch
import torchvision
import matplotlib.pyplot as plt
import torchxrayvision as xrv

rx, label = ds[1]
rx = rx * 1024
img = rx.detach().numpy()

transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])

img = transform(img)
img = torch.from_numpy(img)

model = xrv.baseline_models.chestx_det.PSPNet()
with torch.no_grad():
    pred = model(img)

# ['Left Clavicle', 'Right Clavicle', 'Left Scapula', 'Right Scapula',
#  'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis',
#  'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum',  'Weasand', 'Spine']

In [None]:
plt.figure(figsize = (26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')
for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')
plt.tight_layout()

In [None]:
pred = 1 / (1 + np.exp(-pred))  # sigmoid
pred[pred < 0.5] = 0
pred[pred > 0.5] = 1

In [None]:
plt.figure(figsize = (26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')
for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')
plt.tight_layout()

In [None]:
model = xrv.baseline_models.chestx_det.PSPNet()

for i in range(0, len(ds)):
    patient = ds.patients[i]
    rx, label = ds[i]

    rx = rx * 1024
    img = rx.detach().numpy()
    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])
    img = transform(img)
    img = torch.from_numpy(img)
    
    with torch.no_grad():
        pred = model(img)
        
    plt.figure(figsize = (5,3))
    plt.subplot(1, 3, 1)
    plt.imshow(img[0], cmap='gray')
    plt.title(patient['id'])
    
    plt.subplot(1, 3, 2)
    plt.imshow(pred[0, 8])
    plt.title(model.targets[8])
    plt.axis('off')
    
    pred = 1 / (1 + np.exp(-pred))  # sigmoid
    pred[pred < 0.5] = 0
    pred[pred > 0.5] = 1
    
    plt.subplot(1, 3, 3)
    plt.imshow(pred[0, 8])
    plt.title(model.targets[8])
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
search_patient = 'CAC_001'

model = xrv.baseline_models.chestx_det.PSPNet()

for i in range(0, len(ds)):
    patient = ds.patients[i]
    if search_patient == patient['id']:
        break
        
rx, label = ds[i]
rx = rx * 1024
img = rx.detach().numpy()
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])
img = transform(img)
img = torch.from_numpy(img)

with torch.no_grad():
    pred = model(img)

plt.figure(figsize = (5,3))
plt.subplot(1, 3, 1)
plt.imshow(img[0], cmap='gray')
plt.title(patient['id'])

plt.subplot(1, 3, 2)
plt.imshow(pred[0, 8])
plt.title(model.targets[8])
plt.axis('off')

pred = 1 / (1 + np.exp(-pred))  # sigmoid
pred[pred < 0.5] = 0
pred[pred > 0.5] = 1

plt.subplot(1, 3, 3)
plt.imshow(pred[0, 8])
plt.title(model.targets[8])
plt.axis('off')

plt.tight_layout()
plt.show()