In [1]:
import os
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset,DataLoader
import torch
import pandas as pd
import numpy as np
import nibabel as nib
from nibabel import ecat

In [2]:
class ScanDataSet(Dataset):
    def __init__(self, image_root,label_root,filetype):
        self.image_root = image_root
        self.label_root = label_root
        self.filetype = filetype #indicates which kind of scan, use ""
        self.samples = []
        
        self.disease = LabelEncoder()
        self._init_dataset()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        scan, disease = self.samples[idx]
        return scan,disease

    def _init_dataset(self):
        
        #Read the disaese
        tmp_df = pd.read_csv(self.label_root)
        
        labels = tmp_df.Label.astype(np.int64) #Integer labels
        one_hot_encode = list()
        for value in labels:
            letter = [0 for _ in range(0,6)]
            letter[value] = 1
            one_hot_encode.append(letter)
        diseases = np.array(one_hot_encode)
        self.disease = diseases
        
        #Reads the scans
        listFilesECAT = [] #create an empty list
        for dirName, subdirList, fileList in os.walk(self.image_root):
            for filename in fileList:
                if self.filetype in filename.lower(): #check wheter the file's ECAT
                    listFilesECAT.append(os.path.join(dirName, filename))
        
        listFilesECAT.sort()

        refImg = ecat.load(listFilesECAT[0]).get_frame(0)

        #Create an array to store the scans of all the patients
        images = np.zeros((np.shape(refImg)[0],np.shape(refImg)[1],np.shape(refImg)[2]))
        images = images[...,np.newaxis]
        for nr in range(np.size(listFilesECAT)):
            images = ecat.load(listFilesECAT[nr]).get_frame(0)
            images = images[...,np.newaxis]
            images = torch.from_numpy(images)
            self.samples.append((images,diseases[nr,:]))
                
    

In [3]:
def show_scan(sample,sliceNr=63):
    image = sample[0]
    label = sample[1]
    
    fig, axs = plt.subplots(2, 2, figsize=(10,10))

    axs[0, 0].imshow(image[:,:,sliceNr]) #SUVr
    axs[0, 0].set_title(['Patient 1: SUVr (PET) slice ', sliceNr])
    axs[0, 1].imshow(images[:,:,sliceNr, (patient1-1)*2+1]) #rCBF
    axs[0, 1].set_title(['Patient 1: rCBF (SPECT) slice ', sliceNr])
    axs[1, 0].imshow(images[:,:,sliceNr, (patient2-1)*2]) #SUVr
    axs[1, 0].set_title(['Patient 2: SUVr (PET) slice ', sliceNr])
    axs[1, 1].imshow(images[:,:,sliceNr, (patient2-1)*2+1]) #rCBF
    axs[1, 1].set_title(['Patient 2: rCBF (SPECT) slice ', sliceNr])

In [4]:
image_root = 'projectfiles_PE2I/scans/ecat_scans/'
label_root = 'projectfiles_PE2I/patientlist.csv'
filetype = "1.v"
dataset = ScanDataSet(image_root,label_root,filetype)

dataloader = DataLoader(dataset, batch_size=4,
                        shuffle=True, num_workers=4)

In [68]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched[0].size(),
          sample_batched[1].size())

0 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
1 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
2 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
3 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
4 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
5 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])
6 torch.Size([4, 128, 128, 128, 1]) torch.Size([4, 6])


In [23]:
sample = dataset.__getitem__(7)
print(torch.mean(sample[0]))


tensor(18.6005, dtype=torch.float64)


In [55]:
import torchvision
normData =torchvision.transforms.Normalize(Datasest,std=1)

In [56]:
sam

TypeError: 'Normalize' object does not support indexing