## This file is used to download and process the Brain MRI dataset. Do not run it again if you don't need to modify the data set.
## If you want to load the data set in your program, please use the `load_data` function in the utils.py file

In [2]:
import kaggle
import numpy as np
import torch
from matplotlib import pyplot as plt
import os
from shutil import copyfile
from PIL import Image
from torch.utils.data import DataLoader
import torch.nn as nn
from collections import OrderedDict
from tqdm import tqdm
import torch.optim as optim
import pickle
from utils import load_data
from config import DATA_HOME
from config import DATA_SET

device = torch.device("cpu" if not torch.cuda.is_available() else 'cuda')

# Initialize the directory to store sorted imgs and masks
path = f'{DATA_HOME}/kaggle_3m'
path_img = f'{DATA_HOME}/img'
path_mask = f'{DATA_HOME}/mask'
kaggle_path = 'mateuszbuda/lgg-mri-segmentation'

In [2]:
# Download the dataset and arrange the directories in a convient manner
# Do not run this cell again, time-cosuming
# You should setup your kaggle autnenticaion information first.
kaggle.api.authenticate()
kaggle.api.dataset_download_files(kaggle_path, path=DATA_HOME, unzip=True)

patients = os.listdir(path)

folds = []
for patient in patients:
    patient = os.path.join(path, patient)
    if os.path.isdir(patient):
        folds.append(patient)
        
all_images = []
for fold in folds:
    images = os.listdir(fold)
    for i in range(len(images)):
        images[i] = os.path.join(fold, images[i])
    all_images = all_images + images
    
for img in all_images:
    old_name = os.path.basename(img)
    if 'mask' in img:
        new_name = old_name[:-9] + old_name[-4:]
        new_name = os.path.join(path_mask, new_name)
    else:
        new_name = os.path.join(path_img, old_name)
    copyfile(img, new_name)



In [8]:
# This cell load the data into X and Y as numpy arrays for training
# Get the file pathes
imgs = []
masks = []
for file in os.listdir(path_img):
    imgs.append(os.path.join(path_img, file))
for file in os.listdir(path_mask):
    masks.append(os.path.join(path_mask, file))
imgs.sort()
masks.sort()

for i in range(len(imgs)):
    assert os.path.basename(imgs[i]) == os.path.basename(masks[i])

# Read files into numpy array
X = np.empty((len(imgs), 3, 256, 256), dtype='float32')
Y = np.empty((len(masks), 1, 256, 256), dtype='float32')

for i in range(len(imgs)):
    X[i, :, :, :] = np.moveaxis(np.asarray(Image.open(imgs[i]), dtype='float32'), -1, 0) / 255
for i in range(len(masks)):
    Y[i, :, :, :] = np.asarray(Image.open(masks[i])).reshape(1, 256, 256) / 255
    
print("Shape of imgs: ", X.shape)
print("Shape of masks: ", Y.shape)

Shape of imgs:  (3929, 3, 256, 256)
Shape of masks:  (3929, 1, 256, 256)


In [9]:
data_set = list(zip(X, Y))
with open(DATA_SET, 'wb') as f:
    pickle.dump(data_set, f)


In [10]:
data_set[0][0].shape

(3, 256, 256)

In [11]:
type(data_set)

list