In [1]:
import numpy as np
import pandas as pd
import os

from sklearn.model_selection import train_test_split
import SimpleITK as sitk
import cv2

import torch as tc
import torch.nn as nn
from torchvision import models

In [2]:
def get_files(data_path, folder):
    names = []
    folder_path = os.path.join(data_path, folder)
    for subfolder in os.listdir(folder_path):
        subfolder_path = os.path.join(folder_path, subfolder)
        for file in os.listdir(subfolder_path):
            if file.endswith(".nrrd"):
                names.append(os.path.join(subfolder_path, file))
    return names

def split_data(names_list, pre_test_size = 0.2, val_ratio = 0.5, seed=5):
    train_names, pre_test_names = train_test_split(names_list, test_size = pre_test_size, random_state = seed)
    val_names, test_names = train_test_split(pre_test_names, test_size = val_ratio, random_state = seed)
    return train_names, val_names, test_names
    
Data_path = r"/kaggle/input/mri-data/Data"

Dongyang_data_names = get_files(Data_path, "Dongyang")
KiTS_data_names = get_files(Data_path, "KiTS")
Rider_data_names = get_files(Data_path, "Rider")

D_train_names, D_val_names, D_test_names = split_data(Dongyang_data_names)
K_train_names, K_val_names, K_test_names = split_data(KiTS_data_names)
R_train_names, R_val_names, R_test_names = split_data(Rider_data_names)

train_data_names = D_train_names + K_train_names + R_train_names
val_data_names = D_val_names + K_val_names + R_val_names
test_data_names = D_test_names + K_test_names + R_test_names

In [3]:
class Dataset(tc.utils.data.Dataset):
    def __init__(self, data_path, data_names):
        self.data_path = data_path
        self.images = []
        self.masks = []
        self.images_names = []
        
        progress_indicator = 0
        
        for name in data_names:
            image = sitk.ReadImage(os.path.join(data_path, name))
            image = sitk.GetArrayFromImage(image).T
            if image.shape[1] == 666:
                image = image[:, 77:589, :]
            for i in range(image.shape[2]):
                self.images.append(image[:, :, i])
                self.images_names.append(os.path.basename(name) + "Slice" + str(i))
                
            mask = sitk.ReadImage(os.path.join(data_path, name.replace(".nrrd", ".seg.nrrd")))
            mask = sitk.GetArrayFromImage(mask).T
            if mask.shape[1] == 666:
                mask = mask[:, 77:589, :]
            for i in range(mask.shape[2]):
                self.masks.append(mask[:, :, i])
                
            progress_indicator += 1
            print("Data loading progress: ", (progress_indicator/len(data_names)*100, "%"))
            
    def Preprocess_Data():
        for i in range(len(self.images)):
            image = self.images[i]
            if "R" in self.images_names[i]:
                image[image == -2000] = 0
            image = cv2.resize(image, (224, 224))
            image = cv2.medianBlur(image, kernel_size=3)
            image = cv.equalizeHist(image)
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
            self.images[i] = tc.from_numpy(np.float32(image))
            
        for i in range(len(self.masks)):
            mask = self.masks[i]
            mask = cv2.resize(mask, (224, 224))
            self.masks[i] = tc.from_numpy(np.float32(mask))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.images[index], self.masks[index]
        
    def Get_name(self, index):
        return self.images_name[i]

In [4]:
Testing_files = ['/kaggle/input/mri-data/Data/Dongyang/D6/D6.nrrd', 
                 '/kaggle/input/mri-data/Data/KiTS/K10/K10.nrrd', 
                 '/kaggle/input/mri-data/Data/Rider/R4/R4.nrrd']