In [1]:
import glob
import os
import json
import time
import datetime

import torch
import fastai
import pandas as pd
import numpy as np
import nibabel as nib

from fastai.vision.all import *
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.utils.class_weight import compute_sample_weight

from fastMONAI.vision_all import *
from monai.networks.nets import UNet
from monai.losses import DiceCELoss

import scipy.ndimage
from sklearn.model_selection import train_test_split
from skimage.measure import label, regionprops

from useful_functions import *

  warn(


In [2]:
bids_dir = "bids-new"

In [3]:
session_dirs = []
for json_path in sorted(glob.glob(os.path.join(bids_dir, "sub*", "ses*", "anat", "*echo-01*mag*json"))):
    with open(json_path, 'r') as json_file:
        json_data = json.load(json_file)
        if json_data['ProtocolName'] == "t2starME_qsm_tra_Iso1.4mm_INPHASE_bipolar_RUN_THIS_ONE":
            session_dirs.append(os.sep.join(os.path.split(json_path)[0].split(os.sep)[:-1]))
print(f"{len(session_dirs)} sessions found")
session_dirs

18 sessions found


['bids-new/sub-z0034542/ses-20220715',
 'bids-new/sub-z0186251/ses-20221107',
 'bids-new/sub-z0237546/ses-20230508',
 'bids-new/sub-z0445614/ses-20230510',
 'bids-new/sub-z0705200/ses-20230104',
 'bids-new/sub-z0755228/ses-20211108',
 'bids-new/sub-z1167038/ses-20220315',
 'bids-new/sub-z1181657/ses-20220315',
 'bids-new/sub-z1262112/ses-20220314',
 'bids-new/sub-z1472355/ses-20221222',
 'bids-new/sub-z1568577/ses-20230510',
 'bids-new/sub-z1728751/ses-20220328',
 'bids-new/sub-z1778013/ses-20220715',
 'bids-new/sub-z1818796/ses-20230313',
 'bids-new/sub-z2007565/ses-20220715',
 'bids-new/sub-z2904752/ses-20220826',
 'bids-new/sub-z3171177/ses-20230313',
 'bids-new/sub-z3278008/ses-20211109']

In [4]:
qsm_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "*qsm_echo2-and-echo4.*")) for session_dir in session_dirs), []))
seg_clean_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "sub*ses*segmentation_clean.*")) for session_dir in session_dirs), []))
t1_resampled_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "*t1_tra*_resampled.nii*")) for session_dir in session_dirs), []))
t2s_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "*t2starmap.nii*")) for session_dir in session_dirs), []))
swi_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "*swi.nii*")) for session_dir in session_dirs), []))
mag_files = sorted(sum((glob.glob(os.path.join(session_dir, "extra_data", "magnitude_combined.nii")) for session_dir in session_dirs), []))
t1_files = [t1_file.replace("_resampled", "") for t1_file in t1_resampled_files]
seg_files = [seg_clean_file.replace("_clean", "") for seg_clean_file in seg_clean_files]

extra_files = sum((glob.glob(os.path.join(session_dir, "extra_data", "*.nii*")) for session_dir in session_dirs), [])
ct_files = [extra_file for extra_file in extra_files if any(pattern in extra_file for pattern in ['_na_', '_Pelvis_']) and not any(pattern in extra_file for pattern in ['_t1_tra_', 'ATX', 'AXT', 'ROI', 'resliced', 'segmentation'])]

ct_seg_files = sum((glob.glob(ct_file.replace(".nii", "_segmentation_clean.nii")) for ct_file in ct_files), [])
ct_resliced_files = sum((glob.glob(ct_file.replace(".nii", "_resliced.nii")) for ct_file in ct_files), [])
ct_resliced_seg_files = sum((glob.glob(ct_file.replace(".nii", "_segmentation_clean.nii")) for ct_file in ct_resliced_files), [])

ct_files = [ct_file for ct_file in ct_files if 'z0237546' not in ct_file]
ct_seg_files = [ct_file for ct_file in ct_seg_files if 'z0237546' not in ct_file]
ct_resliced_files = [ct_file for ct_file in ct_resliced_files if 'z0237546' not in ct_file]
ct_resliced_seg_files = [ct_file for ct_file in ct_resliced_seg_files if 'z0237546' not in ct_file]

print(f"{len(ct_files)} CT images found.")
print(f"{len(ct_seg_files)} CT segmentations found.")
print(f"{len(ct_resliced_files)} resliced CT images found.")
print(f"{len(ct_resliced_seg_files)} resliced CT segmentations found.")
print(f"{len(qsm_files)} QSM images found.")
print(f"{len(mag_files)} magnitude images found.")
print(f"{len(t2s_files)} T2* maps found.")
print(f"{len(swi_files)} SWI images found.")
print(f"{len(t1_files)} T1w files found.")
print(f"{len(t1_resampled_files)} resampled T1w files found.")
print(f"{len(seg_files)} GRE segmentations found.")
print(f"{len(seg_clean_files)} cleaned GRE segmentations found.")

17 CT images found.
17 CT segmentations found.
17 resliced CT images found.
17 resliced CT segmentations found.
18 QSM images found.
18 magnitude images found.
18 T2* maps found.
18 SWI images found.
18 T1w files found.
18 resampled T1w files found.
18 GRE segmentations found.
18 cleaned GRE segmentations found.


In [5]:
assert(len(qsm_files) == len(seg_clean_files))
assert(len(qsm_files) == len(t2s_files))
assert(len(qsm_files) == len(swi_files))
assert(len(qsm_files) == len(mag_files))
assert(len(qsm_files) == len(t1_resampled_files))
assert(len(ct_files) == len(ct_seg_files))
assert(len(ct_resliced_files) == len(ct_resliced_seg_files))

In [6]:
def get_center_slices(mask):
    labeled_mask = label(mask)
    regions = regionprops(labeled_mask)
    center_slices = [[round(coord) for coord in region.centroid] for region in regions]
    return center_slices

class SetVrange(DisplayedTransform):
    def __init__(self, vmin, vmax):
        self.vmin = vmin
        self.vmax = vmax

    def encodes(self, o:MedImage):
        o[o > self.vmax] = 0
        o[o < self.vmin] = 0
        return o
    
def show_images(x, y, figsize=None, fig_out=None):
    n_samples = x.shape[0]
    n_masks = y.shape[0]
    #assert(n_samples == n_masks)
    n_samples = max(1, n_masks)

    if y.shape[1] > 1:
        mask = torch.argmax(y, dim=1).unsqueeze(1).cpu().numpy()
    else:
        mask = y.cpu().numpy()
    mask = np.array(np.round(mask), dtype=int)
    data = x.cpu().numpy()

    max_sources = 1
    for i in range(n_samples):
        center_slices = get_center_slices(np.array(mask[i][0] == 1, dtype=int))
        n_sources = len(center_slices)
        max_sources = max(n_sources, max_sources)
    max_sources = min(7, max_sources)

    img_width = 2
    img_height = 2
    wspace = 0.05
    hspace = 0.05
    n_cols = max_sources
    n_rows = n_samples
    fig_width = img_width * n_cols + wspace * (n_cols - 1)
    fig_height = img_height * n_rows + hspace * (n_rows - 1)

    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(fig_width, fig_height), squeeze=False)
    
    for ax in axes.flat:
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_axis_off()
    
    for i in range(n_rows):
        center_slices = get_center_slices(np.array(mask[i][0] == 1, dtype=int))
        n_sources = len(center_slices)

        for j in range(min(n_sources, n_cols)):
            axes[i,j].imshow(data[i][0][:,center_slices[j][1],:], cmap='gray', interpolation='nearest') 
            axes[i,j].imshow(mask[i][0][:,center_slices[j][1],:], cmap='Set1', vmin=1, vmax=9, alpha=np.array(mask[i][0][:,center_slices[j][1],:] != 0, dtype=int) * 0.6, interpolation='nearest')
    plt.tight_layout()
    if fig_out: plt.savefig(fig_out)
    plt.show()
    plt.close()
    
@typedispatch
def show_batch(x:MedImage, y:MedMask, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    show_images(x, y)

@typedispatch
def show_results(x:MedImage, y:MedMask, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, fig_out='out.png', **kwargs):
    outs = torch.stack([outs[i][0] for i in range(len(outs))], dim=0)
    show_images(x, y, fig_out=f"{fig_out.split('.')[0]}_targ.png")
    show_images(x, outs, fig_out=f"{fig_out.split('.')[0]}_pred.png")
    

In [7]:
class MarkersIdentified(fastai.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.targ_marker_count = 0
        self.pred_marker_count = 0
        self.overlap_count = 0
    
    def reset(self):
        self.targ_marker_count = 0
        self.pred_marker_count = 0
        self.overlap_count = 0
    
    def accumulate(self, learn=None, pred=None, targ=None):
        if pred is None or targ is None:
            pred = learn.pred.argmax(dim=1).cpu().numpy()
            targ = learn.y.cpu().numpy()
        
        pred = np.array(np.round(pred) == 1, dtype=int)
        targ = np.array(np.round(targ) == 1, dtype=int)

        pred = scipy.ndimage.binary_dilation(pred)
        targ = scipy.ndimage.binary_dilation(targ)

        for i in range(targ.shape[0]):
            _, pred_nlabels = scipy.ndimage.label(pred[i])
            _, targ_nlabels = scipy.ndimage.label(targ[i])
            
            overlap = np.array(np.logical_and(pred[i] == targ[i], pred[i] == 1), dtype=int)
            _, n_overlaps = scipy.ndimage.label(overlap)
            
            self.pred_marker_count += pred_nlabels
            self.targ_marker_count += targ_nlabels
            self.overlap_count += n_overlaps

    @property
    def value(self):
        return float(self.overlap_count) / max(1., float(self.targ_marker_count))

class SuperfluousMarkers(fastai.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.targ_marker_count = 0
        self.pred_marker_count = 0
        self.overlap_count = 0
    
    def reset(self):
        self.targ_marker_count = 0
        self.pred_marker_count = 0
        self.overlap_count = 0
    
    def accumulate(self, learn=None, pred=None, targ=None):
        if pred is None or targ is None:
            pred = learn.pred.argmax(dim=1).cpu().numpy()
            targ = learn.y.cpu().numpy()
        
        pred = np.array(np.round(pred), dtype=int)
        targ = np.array(np.round(targ), dtype=int)

        pred = scipy.ndimage.binary_dilation(pred)
        targ = scipy.ndimage.binary_dilation(targ)

        for i in range(targ.shape[0]):
            _, pred_nlabels = scipy.ndimage.label(pred[i] == 1)
            _, targ_nlabels = scipy.ndimage.label(targ[i] == 1)
            overlap = np.array(np.logical_and(pred[i] == targ[i], pred[i] == 1), dtype=int)
            _, n_overlaps = scipy.ndimage.label(overlap)
            
            self.pred_marker_count += pred_nlabels
            self.targ_marker_count += targ_nlabels
            self.overlap_count += n_overlaps

    @property
    def value(self):
        return float(self.pred_marker_count - self.overlap_count) / max(1., float(self.pred_marker_count))

In [8]:
model_stats = pd.DataFrame()

In [9]:
models = {
    'CT' : ["CT-Resliced-20230526-114140-best", [ct_resliced_files], ct_resliced_seg_files, '#a6cee3'],
    'QSM+T1' : ["QSM-T1-NOWEIGHT-20230526-152731-best", [qsm_files, t1_resampled_files], seg_clean_files, '#1f78b4'],
    'QSM' : ["QSM-NOWEIGHT-20230526-145608-best", [qsm_files], seg_clean_files, '#b2df8a'],
    'T1' : ["T1-Resampled-20230526-120944-best", [t1_resampled_files], seg_clean_files, '#33a02c'],
    'SWI' : ["SWI-20230529-152018-best", [swi_files], seg_clean_files, '#e31a1c'],
    'GRE' : ["GRE-Magnitude-20230526-130737-best", [mag_files], seg_clean_files, '#fb9a99'],
}

In [10]:
random_state = 42
test_size = 3
training_epochs = 700
lr = 0.003
crop_size = [80, 80, 80] # 128, 160, 160

In [11]:
plt.figure()

for model in models.keys():
    print(f"=== {model} ===")
    num_data_channels = len(models[model][1])
    df = pd.DataFrame({'in_files' : [';'.join(pair) for pair in zip(*models[model][1])], 'seg_files' : models[model][2]})

    print("Splitting...")
    train_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)

    if model == 'CT':
        print("Getting resampling suggestions...")
        med_dataset = MedDataset(
            img_list=train_df.seg_files.tolist(),
            dtype=MedMask
        )
        suggested_voxelsize, requires_resampling = med_dataset.suggestion()
        largest_imagesize = med_dataset.get_largest_img_size(resample=suggested_voxelsize)
        print(f"Suggested voxel size: {suggested_voxelsize}")
        print(f"Requires resampling: {requires_resampling}")
        print(f"Largest image size: {largest_imagesize}")
    
    print("Creating datablock...")
    dblock = MedDataBlock(
        blocks=(ImageBlock(cls=MedImage), MedMaskBlock),
        splitter=RandomSplitter(),#seed=42),
        get_x=ColReader('in_files'),
        get_y=ColReader('seg_files'),
        item_tfms=[
            PadOrCrop(crop_size),
            RandomFlip(axes=("LR",)),
            RandomFlip(axes=("AP",)),
            RandomAffine(degrees=(90, 90, 90)),
            ZNormalization(),
        ],
        reorder=requires_resampling,
        resample=suggested_voxelsize
    )

    print("Creating dataloaders...")
    dls = DataLoaders.from_dblock(dblock, train_df, bs=4)

    print("Creating learner...")
    learn = Learner(
        dls,
        model=UNet(
            spatial_dims=3,
            in_channels=num_data_channels,  # qsm
            out_channels=3, # background, marker, calcification
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2
        ),
        loss_func=DiceCELoss(
            to_onehot_y=True,
            include_background=True,
            softmax=True,
            ce_weight=torch.Tensor([1, 1, 1])
            #ce_weight=torch.Tensor([0, 0.9998, 0.0002])
        ),
        opt_func=ranger,
        metrics=[multi_dice_score, MarkersIdentified(), SuperfluousMarkers()]#.to_fp16()
    )

    print("Loading model...")
    learn = learn.load(models[model][0])

    print("Transferring to CUDA...")
    #learn.model.cuda()

    # Compute metrics on the entire training dataset
    print("Computing metrics...")
    correct_markers = MarkersIdentified()
    
    dblock_train_eval = MedDataBlock(
        blocks=(ImageBlock(cls=MedImage), MedMaskBlock),
        splitter=IndexSplitter([]),
        get_x=ColReader('in_files'),
        get_y=ColReader('seg_files'),
        item_tfms=[
            PadOrCrop(crop_size),
            ZNormalization(),
        ],
        reorder=requires_resampling,
        resample=suggested_voxelsize
    )
    dls_train_eval = DataLoaders.from_dblock(dblock_train_eval, train_df, bs=1, sampler=SequentialSampler)
    for x, y in dls_train_eval.train:
        pred = torch.argmax(learn.model(x), dim=1).unsqueeze(1).to(dtype=torch.float)
        correct_markers.accumulate(pred=pred.cpu(), targ=y.cpu())

    print(correct_markers.value)
    print(f"Predicted markers: {correct_markers.pred_marker_count}")
    print(f"Correct markers: {correct_markers.overlap_count}")
    print(f"Incorrect markers: {correct_markers.pred_marker_count - correct_markers.overlap_count}")
    print(f"Target markers: {correct_markers.targ_marker_count}")

    correct_markers.reset()

    dblock_test_eval = MedDataBlock(
        blocks=(ImageBlock(cls=MedImage), MedMaskBlock),
        splitter=IndexSplitter([]),
        get_x=ColReader('in_files'),
        get_y=ColReader('seg_files'),
        item_tfms=[
            PadOrCrop(crop_size),
            ZNormalization(),
        ],
        reorder=requires_resampling,
        resample=suggested_voxelsize
    )

    dls_test_eval = DataLoaders.from_dblock(dblock_test_eval, test_df, bs=1, sampler=SequentialSampler)
    for x, y in dls_test_eval.train:
        pred = torch.argmax(learn.model(x), dim=1).unsqueeze(1).to(dtype=torch.float)
        correct_markers.accumulate(pred=pred.cpu(), targ=y.cpu())

    print(correct_markers.value)
    print(f"Predicted markers: {correct_markers.pred_marker_count}")
    print(f"Correct markers: {correct_markers.overlap_count}")
    print(f"Incorrect markers: {correct_markers.pred_marker_count - correct_markers.overlap_count}")
    print(f"Target markers: {correct_markers.targ_marker_count}")

    loss, *metrics = learn.validate(ds_idx=0, dl=dls_train_eval.train)
    print("TRAINING SET METRICS")
    print(f"Dice score: {metrics[0]}; Markers identified: {metrics[1]}; Superfluous markers: {metrics[2]}")
    #learn.show_results(anatomical_plane=0, dl=dls_train_eval.train, fig_out='seg_results_train.png')

    loss, *metrics = learn.validate(ds_idx=0, dl=dls_test_eval.train)
    print("TEST SET METRICS")
    print(f"Dice score: {metrics[0]}; Markers identified: {metrics[1]}; Superfluous markers: {metrics[2]}")
    #learn.show_results(anatomical_plane=0, dl=dls_test_eval.train, fig_out='seg_results_test.png')

    # get predictions
    dls_train_eval = DataLoaders.from_dblock(dblock_train_eval, train_df, bs=len(dls_train_eval.train_ds), sampler=SequentialSampler)
    dls_test_eval = DataLoaders.from_dblock(dblock_test_eval, test_df, bs=len(dls_test_eval.train_ds), sampler=SequentialSampler)

    train_x, train_y = dls_train_eval.train.one_batch()
    test_x, test_y = dls_test_eval.train.one_batch()

    def calc_stuff(x, y):
        pred = learn.model(x)[:,1,:,:,:].unsqueeze(1).cpu().detach().numpy()
        pred -= np.min(pred)
        pred /= np.max(pred)
        pred = pred.flatten()
        targ = (y.cpu() == 1).to(dtype=torch.int).detach().numpy().flatten()

        # calculate AUC
        sample_weight = compute_sample_weight(class_weight="balanced", y=targ, indices=None)
        fpr, tpr, thresholds = roc_curve(targ, pred, sample_weight=sample_weight)
        roc_auc = auc(fpr, tpr)

        # calculate precision-recall curve
        precision, recall, _ = precision_recall_curve(targ, pred)
        average_precision = average_precision_score(targ, pred)

        return fpr, tpr, precision, recall, thresholds, average_precision, roc_auc
    
    fpr, tpr, precision, recall, thresholds, average_precision, roc_auc = calc_stuff(train_x, train_y)

    # Plotting ROC Curve
    #plt.figure()
    #plt.plot(fpr, tpr, color=models[model][3], label=f'{model} (AUC = {round(roc_auc, 2)})')

    # Plotting Precision-Recall Curve
    #plt.figure()
    plt.plot(recall, precision, color=models[model][3], label=f'{model}')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve\n(Training set)')
plt.legend(loc="lower right")
plt.savefig("out.png", dpi=400)
plt.show()


#plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
#plt.xlim([0.0, 1.0])
#plt.ylim([0.0, 1.05])
#plt.xlabel('False Positive Rate')
#plt.ylabel('True Positive Rate')
#plt.title('Receiver Operating Characteristic (ROC) Curve\n(Test set)')
#plt.legend(loc="lower right")
#plt.savefig("out.png", dpi=400)
#plt.show()

=== CT ===
Splitting...
Getting resampling suggestions...
Suggested voxel size: [1.4, 1.4, 1.4]
Requires resampling: False
Largest image size: [146.0, 160.0, 72.0]
Creating datablock...
Creating dataloaders...
Creating learner...
Loading model...
Transferring to CUDA...


  elif with_opt: warn("Saved filed doesn't contain an optimizer state.")


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

<Figure size 640x480 with 0 Axes>