In [None]:
import matplotlib.pyplot as plt
import numpy as np
import util
import nibabel as nib
import numpy as np
import torch
import torch.utils.data as data
import torchio as tio
import os
import sklearn.model_selection as skm
from imbalanced_regression.qsm.datasets import QSM
from imbalanced_regression.utils import get_lds_kernel_window
import logging
from scipy.ndimage import convolve1d
from torch.utils import data
import torchio.transforms as transforms

In [None]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','CORNELL ID')
# Find cases with all required scores
subs,pre_imp,post_imp,pre_updrs_off = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo')
# Find overlap between scored subjects and nii
ids = np.asarray(case_id).astype(int)
ids = ids[ids != 62]
cases_idx = np.in1d(subs,ids)
ccases = subs[cases_idx]
per_change = post_imp[cases_idx]


In [None]:
nii_paths = []
seg_nii_paths = []
qsm_dir = '/home/ali/RadDBS-QSM/data/nii/qsm/'
seg_dir = '/home/ali/RadDBS-QSM/data/nii/seg/'
qsm_niis = sorted(os.listdir(qsm_dir))
seg_niis = sorted(os.listdir(seg_dir))
for k in np.arange(len(ccases)):
    for file in qsm_niis:
        if int(ccases[k]) == int(file[18:20]):
            nii_paths.append(qsm_dir+file)
            seg_nii_paths.append(seg_dir+'labels_2iMag'+file[18:20]+'.nii.gz')

train_dir, test_dir, train_seg, test_seg, y_train, y_test = skm.train_test_split(nii_paths, seg_nii_paths, per_change, test_size=0.2, random_state=1)
train_dir, val_dir, train_seg, val_seg, y_train, y_val = skm.train_test_split(train_dir, train_seg, y_train, test_size=0.2, random_state=1)

In [None]:
train_dir, test_dir, y_train, y_test = skm.train_test_split(nii_paths, per_change, test_size=0.33, random_state=1)
train_dir, val_dir, y_train, y_val = skm.train_test_split(train_dir, y_train, test_size=0.33, random_state=1)

In [None]:
import numpy as np
import nibabel as nib
import numpy as np
import torch
import torch.utils.data as data
import torchio as tio
from imbalanced_regression.utils import get_lds_kernel_window
import logging
from scipy.ndimage import convolve1d
from torch.utils import data
import torchio.transforms as transforms

class QSM(data.Dataset):
    def __init__(self, data_dir, mask_dir, targets, nz, nx, split='train', reweight='none',
                 lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        self.images_list = [nib.load(image_path) for image_path in data_dir]
        self.masks_list = [nib.load(mask_path) for mask_path in mask_dir]
        self.data_dir = data_dir
        self.mask_dir = mask_dir
        self.targets = targets
        self.nz = nz
        self.nx = nx
        self.split = split
        self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, index):
        case_dir = self.data_dir[index]
        nx = self.nx
        nz = self.nz
        #print('nz is ',nz, ' and nx is ',nx)
        nii_image = self.images_list[index]
        nii_mask = self.masks_list[index]
        data = np.asarray(nii_image.dataobj)
        mask = np.asarray(nii_mask.dataobj)
        print('Applying mask of shape ',str(mask.shape),' to image of size ',str(data.shape),' for ',case_dir)#,' with size ',str(self.img_size)+' before transform')
        img = torch.from_numpy(data[:,:,~(mask==0).all((0,1))])
        self.img_size = img.shape
        target = self.targets[index]
        transform = self.get_transform(img,nx,nz)
        img = torch.squeeze(transform(torch.unsqueeze(img,axis=0)))
        #print(case_dir+' has size ',str(img.shape)+' after transform')
        label = target
        weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])

        return img, label, weight

    def get_transform(self,img,nx,nz):
        self.img = img
        self.img_size = (self.img).shape
        if self.img_size[2]>nz:
            self.img = self.img[:,:,(self.img_size[2]//2)-(nz//2):(self.img_size[2]//2)+(nz//2)]
            tpad = transforms.Pad((0,0,0))
        else:
            if (nz-self.img_size[2])/2 == (nz-self.img_size[2])//2:
                tpad = transforms.Pad((0,0,(nz-self.img_size[2])//2))
            else:
                #print('Padding an odd number of slices with ',str((nz-self.img_size[2])//2),' and ',str(((nz-self.img_size[2])//2)+1))                      
                tpad = transforms.Pad((0,0,0,0,
                                    (nz-self.img_size[2])//2,
                                    ((nz-self.img_size[2])//2)+1))
        if self.split == 'train':
            transform = transforms.Compose([
                transforms.Crop((nx,nx,nx,nx,0,0)),
                tpad,
                transforms.RandomFlip(axes=['LR', 'AP', 'IS']),
                transforms.RescaleIntensity(out_min_max=(0, 1)),
            ])
        else:
            transform = transforms.Compose([
                transforms.Crop((nx,nx,nx,nx,0,0)),
                tpad,
                transforms.RescaleIntensity(out_min_max=(0, 1)),
            ])
        return transform

    def _prepare_weights(self, reweight, max_target=1, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        assert reweight in {'none', 'inverse', 'sqrt_inv'}
        assert reweight != 'none' if lds else True, \
            "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"

        value_dict = {x: 0 for x in range(max_target)}
        labels = self.targets
        for label in labels:
            value_dict[min(max_target - 1, int(label))] += 1
        if reweight == 'sqrt_inv':
            value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
        elif reweight == 'inverse':
            value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()}  # clip weights for inverse re-weight
        num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
        if not len(num_per_label) or reweight == 'none':
            return None
        print(f"Using re-weighting: [{reweight.upper()}]")

        if lds:
            lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
            print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
            smoothed_value = convolve1d(
                np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
            num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]

        weights = [np.float32(1 / x) for x in num_per_label]
        scaling = len(weights) / np.sum(weights)
        weights = [scaling * x for x in weights]
        return weights


In [None]:
train_dataset = QSM(data_dir=train_dir, mask_dir=train_seg, targets=y_train, nz=128, nx=128, split='train')

In [None]:
train_loader = data.DataLoader(train_dataset, batch_size=5, shuffle=True,
                              num_workers=1, pin_memory=True, drop_last=False)

In [None]:
for idx, (inputs, targets, weights) in enumerate(train_loader):
    print(idx)

Applying mask of shape  (512, 512, 352)  to image of size  (512, 512, 352)  for  /home/ali/RadDBS-QSM/data/nii/qsm/QSM_e10_imaginary_44.nii.gz
