# Gold marker segmentation with QSM

## Imports

In [None]:
import fastai
from glob import glob
import fastai.vision.learner
import fastai.vision.models
import fastai.data.core
import fastai.callback.all
import fastai.losses
import numpy as np
import nibabel as nib
import torch
import cv2
import skimage.measure
import scipy.ndimage
from useful_functions import *
import fastMONAI.vision_all

## Prepare data
### Locate input data

The files are 3D NIfTI images.

In [None]:
# Input data
qsm_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*qsm.nii*"))
seg_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*segmentation*clean_seeds.nii*"))
t2s_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*t2starmap.nii*"))
mag_files = sorted(glob("data/bids/sub-*/ses-*/extra_data/*magnitude_combined.nii*"))
assert(len(qsm_files) == len(seg_files))
print(f"{len(qsm_files)} NIfTI image sets found in data/bids (QSM, segmentations).")

In [None]:
med_dataset = fastMONAI.vision_all.MedDataset(img_list=[qsm_files[0]], dtype=fastMONAI.vision_all.MedMask, max_workers=12)

### Load samples as a PyTorch dataset and fastai 'dataloaders'

In [None]:
class QSM_3D_With_Seg(torch.utils.data.Dataset):
    def __init__(self, seg_files, qsm_files, transform=None):
        self.seg_files = seg_files
        self.qsm_files = qsm_files
        self.transform = transform
        self.vocab = np.array(['Prostate', 'Gold marker'])

    def __len__(self):
        return len(self.qsm_files)

    def __getitem__(self, idx):
        # convert idx to list if tensor
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # convert idx to image and slice numbers
        seg_path = self.seg_files[idx]
        qsm_path = self.qsm_files[idx]

        # load data
        qsm = nib.load(qsm_path).get_fdata()
        seg = nib.load(seg_path).get_fdata()

        # scale qsm
        qsm = np.interp(np.clip(qsm, -2, +2), (-2, +2), (0, 1))
        qsm = (qsm - qsm.mean()) / qsm.std() * 0.229 + 0.485

        # select slice
        #qsm = qsm[:,:,int(slice_id)]
        #seg = seg[:,:,int(slice_id)]

        # resize images to common size
        #qsm = torch.Tensor(cv2.resize(qsm, dsize=(224, 224, 224)))
        #seg = torch.Tensor(cv2.resize(seg, dsize=(224, 224, 224), interpolation=cv2.INTER_NEAREST))

        qsm = torch.Tensor(scipy.ndimage.zoom(qsm, (224/qsm.shape[0], 224/qsm.shape[1], 224/qsm.shape[2]), mode='nearest'))
        seg = torch.Tensor(scipy.ndimage.zoom(seg, (224/seg.shape[0], 224/seg.shape[1], 224/seg.shape[2]), mode='nearest'))

        seg = seg.to(torch.int64)

        # expand qsm over 3 dimensions
        #qsm = qsm.expand(3, 224, 224)

        # rotate image
        #num_rotations = idx // len(self.sample_details)
        #seg = torch.rot90(seg, num_rotations, [0, 1])
        #qsm = torch.rot90(qsm, num_rotations, [1, 2])

        return fastai.torch_core.TensorImage(qsm), fastai.torch_core.TensorMask(seg)#, codes=['FM', 'Calcification'])

    def __iter__(self):
        for idx in range(len(self.sample_details)):
            yield self.__getitem__(idx)

In [None]:
train_ds = QSM_3D_With_Seg(qsm_files=qsm_files[:10], seg_files=seg_files[:10])
valid_ds = QSM_3D_With_Seg(qsm_files=qsm_files[10:], seg_files=seg_files[10:])
dls = fastai.data.core.DataLoaders.from_dsets(train_ds, valid_ds, batch_size=2, device='cuda:0')
print(f"Training set contains {len(train_ds)} samples.")
print(f"Validation set contains {len(valid_ds)} samples.")

In [None]:
batch = dls.train.one_batch() # batch[type][idx][rgb]

In [None]:
x = batch[0][0].cpu() 
y = batch[1][0].cpu()

In [None]:
y.shape

In [None]:
show_histogram(x, title="Input - After creating dataset", mask=y, dim=2, n_ticks=10)

## Prepare learner

In [None]:
def dice(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return ((2. * intersection) / (iflat.sum() + tflat.sum()))

def dice_score(input, target):
    pred = input.cpu().argmax(1)[0]
    num_seeds_target = np.max(np.unique(skimage.measure.label(np.array(target.cpu()[0]))))
    num_seeds_pred = np.max(np.unique(skimage.measure.label(np.array(pred))))
    print("num_seeds_target", num_seeds_target)
    print("num_seeds_pred", num_seeds_pred)
    #show_image(label)
    #test_ad()
    return dice(input.argmax(1), target)

def dice_loss(input, target): 
    return 1 - dice(input.softmax(1)[:, 1], target)

def loss(input, target):
    return dice_loss(input, target) + nn.CrossEntropyLoss()(input, target[:, 0])

In [None]:
learn = fastai.vision.learner.unet_learner(
    dls=dls,
    arch=fastai.vision.models.resnet34,
    n_out=2,
    loss_func=fastai.losses.CrossEntropyLossFlat(axis=1),
    model_dir='models',
    normalize=False,
    metrics=dice_score#fastai.learner.AvgLoss()
)

In [None]:
learn.summary()

In [None]:
# INPUT RANGE CHANGES AFTER CREATING LEARNER???
batch = dls.one_batch() # batch[type][idx][rgb]
x = batch[0][0][0].cpu() 
y = batch[1][0].cpu()

show_histogram(x, title="Input - After creating learner", mask=y)

In [None]:
# INPUT RANGE CHANGES AFTER CREATING LEARNER???
batch = dls.valid.one_batch() # batch[type][idx][rgb]
x = batch[0][0][0].cpu() 
y = batch[1][0].cpu()

show_histogram(x, title="Input - After creating learner", mask=y)

## Train

In [None]:
#learn.lr_find()

In [None]:
learn.fine_tune(3, base_lr=0.0005)

# Test

In [None]:
learn.show_results()

### Test on training data

In [None]:
batch = dls.train.one_batch() # batch[type][idx][rgb]
x = batch[0][0][0].cpu()
y = batch[1][0].cpu()

show_histogram(x, title="Ground truth (from training set)", mask=y)

_, _, prediction = learn.predict(batch[0][0].unsqueeze(0))
prediction = torch.round(prediction)

show_histogram(x, title="Prediction", mask=prediction[1])

### Test on validation data

In [None]:
batch = dls.valid.one_batch() # batch[type][idx][rgb]]
x = batch[0][4][0].cpu()
y = batch[1][4].cpu()

show_histogram(x, title="Ground truth (from validation set)", mask=y)

_, _, prediction = learn.predict(batch[0][0].unsqueeze(0))
prediction = torch.round(prediction)

show_histogram(x, title="Prediction", mask=prediction[1])

In [None]:
interp = fastai.interpret.ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()