# Data operations for PI-CAI operations

*José Guilherme de Almeida*

Here I develop data operations that retrieve data from the PI-CAI dataset. The outcome is a MONAI data generator that yields 1) the MRI data, 2) the segmentation mask, 3) the biological data/labels associated each scan.

In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm,trange
from glob import glob
from lib.dataoperations.lib.data_functions import *

## Implementing a CSV dataset and labels class

In [2]:
clinical_info_path = "/home/jose_almeida/data/PI-CAI/labels/clinical_information/marksheet.csv"

def retrieve_clinical_data(path:str)->dict:
    with open(path) as o:
        lines = [x.strip() for x in o.readlines()]

    header = lines[0].split(',')
    data_dict = {k:[] for k in header}
    for line in lines[1:]:
        cur_word = ''
        comma_count = 0
        quotes = 0
        for char in line:
            if char == '"':
                quotes = 1-quotes
            elif char == ',' and quotes == 0:
                data_dict[header[comma_count]].append(cur_word)
                comma_count += 1
                cur_word = ''
            elif char == ',' and quotes == 1:
                cur_word += char
            else:
                cur_word += char
        data_dict[header[comma_count]].append(cur_word)

    relevant_fields = [
        'study_id','patient_age','psa',
        'prostate_volume','histopath_type','lesion_GS']
    dataset_dict = {}
    for s,a,p,v,c,l in zip(*[data_dict[k] for k in relevant_fields]):
        l_out = 0
        for ll in l.split(','):
            ll = ll.strip()
            if ll != "N/A" and ll != "":
                tmp = sum([int(lll) for lll in ll.split('+')])
                if tmp > l_out:
                    l_out = tmp
        if c.strip() == "":
            c = 0
        else:
            if l_out <= 6:
                c = 0
            else:
                c = 1
        dataset_dict[s] = {"age":a,"psa":p,"volume":v,"gs":l_out,"label":c}

    return dataset_dict

clinical_data = retrieve_clinical_data(clinical_info_path)
        
pd.DataFrame.from_dict(clinical_data)

Unnamed: 0,1000000,1000001,1000002,1000003,1000004,1000005,1000006,1000007,1000008,1000009,...,1001490,1001491,1001492,1001493,1001494,1001495,1001496,1001497,1001498,1001499
age,74.0,64.0,58.0,72.0,67,64.0,73.0,68.0,81.0,68,...,77,61.0,64.0,49.0,62.0,71.0,81.0,56.0,75,56
psa,7.7,8.7,4.2,13.0,8,12.1,6.2,3.83,11.1,24,...,11,5.4,4.3,4.3,9.1,12.5,5.28,29.6,12,15
volume,55.0,102.0,74.0,71.5,78,51.0,27.0,41.0,56.0,120,...,65,103.0,23.0,34.0,47.0,62.0,44.0,87.0,83,33
gs,0.0,0.0,0.0,0.0,0,7.0,6.0,6.0,7.0,0,...,0,6.0,7.0,0.0,0.0,7.0,7.0,0.0,6,9
label,0.0,0.0,0.0,0.0,0,1.0,0.0,0.0,1.0,0,...,0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0,1


In [3]:
class StrToOneHot:
    def __init__(self,sl:list[str]):
        self.sl = sorted(sl)
        self.match = {k:i for i,k in enumerate(self.sl)}
        self.n = len(sl)

    def __call__(self,X:str):
        output = np.zeros([self.n])
        output[self.match[X]] = 1
        return torch.Tensor(output)

class TabularData(monai.data.Dataset):
    def __init__(self,data_dict:dict,
                 data_types:dict=None)->monai.data.Dataset:
        self.data_dict = data_dict
        self.data_types = data_types
        self.keys = list(data_dict.keys())

        if self.data_types is None:
            self.data_types = {
                k:"numerical" for k in self.data_dict[self.keys[0]]}
        self.setup_transforms()

    def setup_transforms(self):
        self.transforms = {}
        for k in self.data_types:
            if self.data_types[k] == "categorical":
                all_options = [self.data_dict[kk][k] for kk in self.data_dict]
                all_options = list(set(all_options))
                all_options = [x for x in all_options if x != ""]
                self.transforms[k] = StrToOneHot(all_options)
            else:
                self.transforms[k] = self.to_float
    
    def remove_incomplete_keys(self):
        o = []
        for i in self.keys:
            try:
                self[i]
                o.append(i)
            except:
                pass
        self.keys = o

    def to_float(self,X):
        return torch.Tensor([float(X)])
    
    def __len__(self):
        return len(self.data_dict)
        
    def __getitem__(self,i):
        if isinstance(i,int):
            i = self.keys[i]

        return {
            k:self.transforms[k](self.data_dict[i][k])
            for k in self.transforms}

clinical_dataset = TabularData(
    clinical_data,
    {'age':'numerical','psa':'numerical',
    'volume':'numerical','gs':'categorical','label':'numerical'})

clinical_dataset.remove_incomplete_keys()

import time
times = []
for i in clinical_dataset.keys:
    a = time.time()
    clinical_dataset[i]
    b = time.time()
    times.append(b-a)

## Using MONAIDataset to create an MRI scan dataset

In [4]:
all_files = sorted(glob("/home/jose_almeida/data/PI-CAI/dataset/*/*/*mha"))

path_dictionary = {}
reader = sitk.ImageFileReader()
modalities = ['adc','hbv','t2w']
metadata_dictionary = {"shape":{k:[] for k in modalities},
                       "spacing":{k:[] for k in modalities}}
for f in all_files:
    study_id = f.split('_')[-2]
    mod = f.split('_')[-1].split('.')[0]
    if mod in modalities:
        if study_id in path_dictionary:
            path_dictionary[study_id][mod] = f
        else:
            path_dictionary[study_id] = {mod:f}
    
    reader.SetFileName(f)
    reader.LoadPrivateTagsOn()
    reader.ReadImageInformation()
    sh = reader.GetSize()
    sp = reader.GetSpacing()
    
    if mod in modalities:
        metadata_dictionary['shape'][mod].append(sh)
        metadata_dictionary['spacing'][mod].append(sp)

In [37]:
shape_dict = {
    k:tuple(np.int32(np.median(np.array(metadata_dictionary['shape'][k]),axis=0)))
    for k in modalities}

spacing_dict = {
    k:tuple(np.int32(np.median(np.array(metadata_dictionary['spacing'][k]),axis=0)))
    for k in modalities}

transforms = []
for k in modalities:
    transforms.extend(
        [monai.transforms.Spacingd([k],tuple(spacing_dict["t2w"])),
         monai.transforms.Resized([k],tuple(shape_dict["t2w"]))])

In [38]:
dataset = MONAIDataset(
    path_dictionary,data_type="mha",padding=False,orientation="RAS",
    image_keys=modalities,transforms=transforms)

In [39]:
for d in dataset:
    print({k:d[k].shape for k in modalities})

{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2w': torch.Size([1, 384, 384, 21])}
{'adc': torch.Size([1, 384, 384, 21]), 'hbv': torch.Size([1, 384, 384, 21]), 't2

KeyboardInterrupt: 

## Using MONAIDataset to create a segmentation map dataset

In [None]:
labels_path = "/home/jose_almeida/data/PI-CAI/labels/csPCa_lesion_delineations/human_expert/resampled/*"
all_label_paths = glob(os.path.join(labels_path,"*"))
labels_path_original = "/home/jose_almeida/data/PI-CAI/labels/csPCa_lesion_delineations/human_expert/original/*"
all_label_paths_original = glob(os.path.join(labels_path_original,"*"))

path_dictionary_labels = {}
for label_path in all_label_paths_original:
    study_id = label_path.split('_')[-1].split('.')[0]
    path_dictionary_labels[study_id] = {"segmentation_labels":label_path}

segmentation_map_dataset = MONAIDataset(
    path_dictionary_labels,data_type="nifti",padding=False,
    orientation="RAS",image_keys=["segmentation_labels"])

## The MultiDataset class

In [64]:
class MultiDataset(monai.data.Dataset):
    def __init__(self,
                 *datasets:List[monai.data.Dataset],
                 data_keys:list,
                 transforms:monai.transforms.Transform=None):
        self.datasets = datasets
        self.data_keys = data_keys
        self.transforms = transforms

        self.get_common_keys()

    def get_common_keys(self):
        self.keys = []
        all_keys = {}
        for dataset in self.datasets:
            for k in dataset.keys:
                if k not in all_keys:
                    all_keys[k] = 1
                else:
                    all_keys[k] += 1
        for k in all_keys:
            if all_keys[k] == len(self.datasets):
                self.keys.append(k)

    def __len__(self):
        return len(self.keys)
        
    def __getitem__(self,i):
        if isinstance(i,int):
            i = self.keys[i]

        output = {}
        for dataset in self.datasets:
            output.update(
                {k:dataset[i][k] for k in dataset[i]
                 if k in self.data_keys})
        
        if self.transforms is not None:
            output = self.transforms(output)
            
        return output
        
comb_dataset = MultiDataset(
    dataset,clinical_dataset,segmentation_map_dataset,
    data_keys=["adc","hbv","t2w","segmentation_labels","gs","volume","psa"])

comb_dataloader = torch.utils.data.DataLoader(
    comb_dataset,num_workers=4,batch_size=2,prefetch_factor=10)

print("Number of common keys:",len(comb_dataset.keys))
"""
for i in range(len(comb_dataset)):
    dat = comb_dataset[i]
    print({k:dat[k].shape for k in dat})
"""
for element in tqdm(comb_dataloader):
    pass

Number of common keys: 1237
{'adc': torch.Size([1, 116, 114, 31])}
{'adc': torch.Size([1, 120, 128, 21])}
{'adc': torch.Size([1, 120, 128, 22])}
{'adc': torch.Size([1, 120, 128, 23])}
{'adc': torch.Size([1, 120, 128, 21])}
{'adc': torch.Size([1, 84, 128, 19])}
{'adc': torch.Size([1, 120, 128, 21])}
{'adc': torch.Size([1, 120, 128, 19])}
{'adc': torch.Size([1, 256, 256, 27])}
{'adc': torch.Size([1, 120, 128, 19])}
{'adc': torch.Size([1, 224, 224, 16])}
{'adc': torch.Size([1, 84, 128, 19])}


KeyboardInterrupt: 