In [1]:
import torch
import torch.nn as nn
import numpy as np
import nibabel as nib

In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("awsaf49/brats20-dataset-training-validation")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: /Users/enesdemir/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1


In [7]:
import os 
from torch.utils.data import Dataset, DataLoader


class BraTSDataset(Dataset):
    def __init__(self, folders, transform=None):
        self.transform = transform
        self.folders = folders
    
    def __len__(self):
        return len(self.folders)
    
    def __getitem__(self, idx):
        temp_data = []
        label = None
        for nii_files in os.listdir(self.folders[idx]):
            img_path = os.path.join(self.folders[idx], nii_files)
            tensor = torch.from_numpy(nib.load(img_path).get_fdata()).float()
            if "seg" in nii_files:
                label = tensor
            else:
                temp_data.append(tensor)
        input_tensor = torch.stack(temp_data, dim=0)
        
        if self.transform:
            input_tensor = self.transform(input_tensor)
            label = self.transform(label)
        
        return input_tensor, label


In [8]:
from sklearn.model_selection import train_test_split

path = '/Users/enesdemir/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1'
real_path = os.path.join(path, 'BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData')

all_folders = [os.path.join(real_path,f) for f in os.listdir(real_path) if not f.endswith('.csv')]

train_folders, val_folders = train_test_split(all_folders, test_size=0.2, random_state=42)

training_set = BraTSDataset(train_folders)
training_loader = DataLoader(training_set,
                             batch_size=4,
                             shuffle=True,
                             num_workers=0)

validation_set = BraTSDataset(val_folders)
validation_loader = DataLoader(validation_set,
                               batch_size=4,
                               shuffle=False,
                               num_workers=0)

for images, labels in training_loader:
    print(images.shape)
    print(labels.shape)
    break

for images, labels in validation_loader:
    print(images.shape)
    print(labels.shape)
    break

torch.Size([4, 4, 240, 240, 155])
torch.Size([4, 240, 240, 155])
torch.Size([4, 4, 240, 240, 155])
torch.Size([4, 240, 240, 155])
