In [1]:
import glob
import os
import re
import numpy as np

import torch
from torch.utils.data import Dataset
from tqdm import tqdm

import config
from utils import load_image

In [4]:
import pandas as pd

In [7]:
dlt = []
empty_fld = [109, 123, 709]
df = pd.read_csv("../input/train_labels.csv")
X = df['BraTS21ID'].values
Y = df['MGMT_value'].values

for i in empty_fld:
    j = np.where(X == i)
    dlt.append(j)
    X = np.delete(X, j)
    
Y = np.delete(Y,dlt)
X

array([   0,    2,    3,    5,    6,    8,    9,   11,   12,   14,   17,
         18,   19,   20,   21,   22,   24,   25,   26,   28,   30,   31,
         32,   33,   35,   36,   43,   44,   45,   46,   48,   49,   52,
         53,   54,   56,   58,   59,   60,   61,   62,   63,   64,   66,
         68,   70,   71,   72,   74,   77,   78,   81,   84,   85,   87,
         88,   89,   90,   94,   95,   96,   97,   98,   99,  100,  102,
        104,  105,  106,  107,  108,  110,  111,  112,  113,  116,  117,
        120,  121,  122,  124,  128,  130,  132,  133,  134,  136,  137,
        138,  139,  140,  142,  143,  144,  146,  147,  148,  149,  150,
        151,  154,  155,  156,  157,  158,  159,  160,  162,  165,  166,
        167,  169,  170,  171,  172,  176,  177,  178,  183,  184,  185,
        186,  187,  188,  191,  192,  193,  194,  195,  196,  197,  199,
        201,  203,  204,  206,  209,  210,  211,  212,  214,  216,  217,
        218,  219,  220,  221,  222,  227,  228,  2

In [64]:
path = f"../input/reduced_dataset/00000/FLAIR/*"
files = sorted(
    glob.glob(path),
    key=lambda var: [
        int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
    ],
)

In [65]:
len(files)

202

In [66]:
image_numbers = []
for image_path in files:

    match = re.search(r"Image-(\d+).png", image_path)
    if match:
        image_numbers.append(int(match.group(1)))

# Sort the image numbers and find the median.
image_numbers.sort()
start = image_numbers[0]
middle = image_numbers[len(image_numbers) // 2]

In [67]:
middle

190

In [68]:
# # middle = len(files) // 2
num_imgs2 = 64 // 2
p1 = max(start, middle - num_imgs2)
p2 = min(len(files), middle + num_imgs2)
image_stack = [load_image(f) for f in files[p1:p2]]

In [69]:
p1

158

In [70]:
p2

202

In [71]:
image_stack

[array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],


In [72]:
img3d = np.stack(image_stack).T
if img3d.shape[-1] < 32:
    n_zero = np.zeros((112, 112, 32 - img3d.shape[-1]))
    img3d = np.concatenate((img3d, n_zero), axis=-1)

In [73]:
np = np.expand_dims(img3d, 0)

In [75]:
np.shape

(1, 112, 112, 44)

In [8]:
class BrainRSNADataset(Dataset):
    def __init__(
        self, patient_path, paths, targets, transform=None, mri_type="FLAIR", is_train=True, ds_type="forgot", do_load=True
    ):
        
        self.patient_path = patient_path
        self.paths = paths   
        self.targets = targets
        self.type = mri_type

        self.transform = transform
        self.is_train = is_train
        self.folder = "train" if self.is_train else "test"
        self.do_load = do_load
        self.ds_type = ds_type  

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        _id = self.paths[index]
        target = self.targets[index]
        _3d_images = self.load_images_3d(_id)
        _3d_images = torch.tensor(_3d_images).float()
        if self.is_train:
            return {"image": _3d_images, "target": target}
        else:
            return {"image": _3d_images, "target": target}
        

    '''
    def _prepare_biggest_images(self):
        big_image_indexes = {}
        if (f"big_image_indexes_{self.ds_type}.pkl" in os.listdir("../input/"))\
            and (self.do_load) :
            print("Loading the best images indexes for all the cases...")
            big_image_indexes = joblib.load(f"../input/big_image_indexes_{self.ds_type}.pkl")
            return big_image_indexes
        else:
            
            print("Caulculating the best scans for every case...")
            for row in tqdm(self.data.iterrows(), total=len(self.data)):
                case_id = str(int(row[1].BraTS21ID)).zfill(5)
                path = f"../input/{self.folder}/{case_id}/{self.type}/*.dcm"
                files = sorted(
                    glob.glob(path),
                    key=lambda var: [
                        int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
                    ],
                )
                resolutions = [utils.extract_cropped_image_size(f) for f in files]
                middle = np.array(resolutions).argmax()
                big_image_indexes[case_id] = middle

            joblib.dump(big_image_indexes, f"../input/big_image_indexes_{self.ds_type}.pkl")
            return big_image_indexes'''



    def get_middle(self, files):
        image_numbers = []
        for image_path in files:
            match = re.search(r"Image-(\d+).png", image_path)
            if match:
                image_numbers.append(int(match.group(1)))

        # Sort the image numbers and find the median.
        image_numbers.sort()
        return image_numbers[len(image_numbers) // 2]


    
    def load_images_3d(
        self,
        case_id,
        num_imgs=config.NUM_IMAGES_3D,
        img_size=config.IMAGE_SIZE,
        rotate=0,
    ):
        case_id = str(case_id).zfill(5)

        path = f"../input/reduced_dataset/{case_id}/{self.type}/*.png"
        files = sorted(
            glob.glob(path),
            key=lambda var: [
                int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
            ],
        )

        middle = len(files) // 2
        if len(files) <= 64:
            image_stack = [load_image(f) for f in files]
        else:
            p1 = middle - 32 #max(0, middle - num_imgs2)
            p2 = middle + 32 #min(len(files), middle + num_imgs2)
            image_stack = [load_image(f) for f in files[p1:p2]]
            
            
            
        '''num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)
        image_stack = [load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]'''
        
        img3d = np.stack(image_stack).T
        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d, n_zero), axis=-1)

        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)

        return np.expand_dims(img3d, 0)

In [11]:
dataset = BrainRSNADataset(
                            patient_path='../input/reduced_dataset/',
                            paths=X, 
                            targets= Y,
                            mri_type='FLAIR',
                            ds_type=f"train_FLAIR"
                            )


In [12]:
batch = dataset[0]
batch['image'].shape

torch.Size([1, 112, 112, 64])