In [41]:
!ml cuda/10.2.89.440
import fastai
from torchvision.models.video import r3d_18
import torch
from IPython.display import display, HTML
from glob import glob
import os
import pandas as pd

In [105]:
import numpy as np
import nibabel as nib

class Masked_NIfTI_2D(torch.utils.data.Dataset):
    def __init__(self, seg_files, qsm_files, transform=None, target_transform=None):
        self.seg_files = seg_files
        self.qsm_files = qsm_files
        self.transform = transform
        self.target_transform = target_transform
        assert(len(self.seg_files) == len(self.qsm_files))

        # determine 2D sample locations
        self.sample_dict = {} # e.g. sample_dict[sample_num] == [image_num, slice_num]
        idx = 0
        for image_num in range(len(self.qsm_files)):
            num_slices = len(self.load_data(image_num)[0])
            for slice_num in range(num_slices):
                self.sample_dict[idx] = [image_num, slice_num]
                idx += 1
        self.total_samples = idx

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # convert idx to list if tensor
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # convert idx to image and slice numbers
        image_num, slice_num = self.sample_dict[idx]

        # load data
        qsm_samples, seg_samples = self.load_data(image_num)

        # extract relevant slice
        qsm = qsm_samples[slice_num]
        seg = seg_samples[slice_num]

        # transform if needed
        if self.transform:
            qsm = self.transform(qsm)
        if self.target_transform:
            seg = self.target_transform(seg)

        return qsm, seg

    def __iter__(self):
        for idx in range(self.total_samples):
            yield self.__getitem__(idx)

    def load_data(self, idx):
        seg = nib.load(self.seg_files[idx]).get_fdata()
        qsm = nib.load(self.qsm_files[idx]).get_fdata()
        seg_indices = np.unique(np.where(seg > 0)[2])
        qsm_samples = list(np.swapaxes(qsm[:,:,seg_indices], 0, 2))
        seg_samples = list(np.swapaxes(seg[:,:,seg_indices], 0, 2))
        return qsm_samples, seg_samples


In [106]:
qsm_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*qsm.nii*"))
seg_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*segmentation_clean_seeds.nii*")) # TODO: Add calcifications
t2s_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*t2starmap.nii*"))
mag_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*magnitude_combined.nii*"))
assert(len(qsm_files) == len(seg_files) == len(t2s_files) == len(mag_files))
print(f"{len(qsm_files)} samples found in data/bids.")

10 samples found in data/bids.


In [107]:
# define input data
df = pd.DataFrame({
    'qsm'  : qsm_files,
    'masks': seg_files
})

In [108]:
dataset = Masked_NIfTI_2D(seg_files, qsm_files)

In [109]:
list(dataset)

[(array([[ 0.        , -0.01450469, -0.0035951 , ..., -0.00200724,
           0.00163556,  0.        ],
         [-0.00552831, -0.01002956, -0.0002994 , ..., -0.00092974,
           0.00048357,  0.00502054],
         [-0.00479086, -0.00405344,  0.02738266, ...,  0.00020972,
          -0.00020449,  0.00138315],
         ...,
         [ 0.10386628, -0.01853787,  0.15043744, ..., -0.06075585,
           0.013323  , -0.02298198],
         [ 0.00885139, -0.0038359 , -0.01126584, ..., -0.00113866,
           0.00309037, -0.0033987 ],
         [ 0.        ,  0.0060779 ,  0.04656878, ..., -0.01661353,
          -0.00339992,  0.        ]]),
  array([[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]])),
 (array([[ 0.00000000e+00, -1.17161470e-02,  6.35763537e-03, ...,
          -1.32529635e-03, -5.5823143

In [110]:
len(dataset)

68

In [124]:
# dataloader
#  TODO: Optimisation -> 2D model may be faster and lower memory than 3D model - only include slices with seeds - this may even make more sense, and allow batch size to increase
#                     - ideally keep the same resolution for 2D
dls = fastai.data.load.DataLoader(
    dataset=dataset, # dataframe
    #path='.', # path to root data directory
    bs=2,     # how many samples per batch to load (overridden by batch_size if specified)
    val_bs=1, # batch size for validation (defaults to bs)
    splitter=fastai.data.transforms.RandomSplitter(valid_pct=0.2) # random split 20% / 80%
)

In [125]:
list(dls)

[(tensor([[[ 0.0000e+00, -1.4505e-02, -3.5951e-03,  ..., -2.0072e-03,
             1.6356e-03,  0.0000e+00],
           [-5.5283e-03, -1.0030e-02, -2.9940e-04,  ..., -9.2974e-04,
             4.8357e-04,  5.0205e-03],
           [-4.7909e-03, -4.0534e-03,  2.7383e-02,  ...,  2.0972e-04,
            -2.0449e-04,  1.3831e-03],
           ...,
           [ 1.0387e-01, -1.8538e-02,  1.5044e-01,  ..., -6.0756e-02,
             1.3323e-02, -2.2982e-02],
           [ 8.8514e-03, -3.8359e-03, -1.1266e-02,  ..., -1.1387e-03,
             3.0904e-03, -3.3987e-03],
           [ 0.0000e+00,  6.0779e-03,  4.6569e-02,  ..., -1.6614e-02,
            -3.3999e-03,  0.0000e+00]],
  
          [[ 0.0000e+00, -1.1716e-02,  6.3576e-03,  ..., -1.3253e-03,
            -5.5823e-04,  0.0000e+00],
           [-5.6444e-03,  1.5169e-03,  1.0177e-02,  ..., -3.8274e-04,
             7.1802e-04,  3.2151e-03],
           [-3.9268e-03,  7.1297e-03,  9.8988e-03,  ...,  6.5598e-04,
             6.4347e-04,  1.0717e-03],

In [126]:
def dice(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return ((2. * intersection) / (iflat.sum() + tflat.sum()))

def dice_score(input, target):
    return dice(input.argmax(1), target)

def dice_loss(input, target): 
    return 1 - dice(input.softmax(1)[:, 1], target)

def loss(input, target):
    return dice_loss(input, target) + nn.CrossEntropyLoss()(input, target[:, 0])

In [128]:
# build a unet learner from dls and arch
from locale import normalize


learn = fastai.vision.learner.unet_learner(
    dls=dls,            # data loader
    arch=r3d_18,        # model architecture
    n_out=2,            # number of final filters (by default inferred from dls where possible)
    loss_func=loss,     # loss function for evaluation during training
    lr=0.001,           # learning rate??
    metrics=dice_score, # performance measure for humans (and model selection?)
    model_dir='models', # save directory for trained model
    normalize=False,    # had to disable this to prevent error...
    cbs=[fastai.callback.all.SaveModelCallback(monitor='dice_score', with_opt=True)] # saves the model's best during training and loads it at the end
)
learn = learn.to_fp16() # use half-precision floats for the learner

AttributeError: 'Masked_NIfTI_2D' object has no attribute 'train'