In [13]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
import pathlib
from random import sample
import torch
import pandas as pd

In [17]:
DATASET_PATH = './data/brainMRI.h5'
source_image_tensors = []
source_mask_tensors = []

In [15]:
transform = A.Compose([
    A.Resize(256, 256, p = 1.),
])

In [16]:
with h5py.File(DATASET_PATH) as dataset:
    images = dataset['images'][:]
    masks = dataset['masks'][:]

In [18]:
for idx, image in enumerate(images):
    image = image.astype(np.float32)
    mask = masks[idx].squeeze().astype(np.float32)

    image = image.transpose(1, 2, 0)
    augmented = transform(image=image, mask=mask)
    image = augmented['image']
    mask = augmented['mask']
    source_image_tensors.append(image.transpose(2, 0, 1))
    source_mask_tensors.append(mask)

In [19]:
print(len(source_image_tensors))

3676


In [20]:
print(len(source_mask_tensors))

3676


In [21]:
source_image_tensors_train, source_image_tensors_test, source_mask_tensors_train, source_mask_tensors_test = train_test_split(source_image_tensors, source_mask_tensors, test_size=0.2, random_state = 42)

In [22]:
source_image_tensors_train, source_image_tensors_val, source_mask_tensors_train, source_mask_tensors_val = train_test_split(source_image_tensors_train, source_mask_tensors_train, test_size=0.2, random_state = 42)

In [23]:
len(source_image_tensors_train)

2352

In [24]:
len(source_image_tensors_val)

588

In [25]:
len(source_image_tensors_test)

736

In [26]:
has_stroke_masks = []
has_stroke_images = []

for idx, mask in enumerate(source_mask_tensors_train):
    if torch.argmax(torch.Tensor(mask)).item() > 0:
        has_stroke_images.append(source_image_tensors_train[idx])
        has_stroke_masks.append(source_mask_tensors_train[idx])



In [27]:
print(len(has_stroke_masks))
print(len(has_stroke_images))

481
481


In [28]:
(len(source_image_tensors_train) - len(has_stroke_masks)) / len(has_stroke_masks)

3.8898128898128896

In [29]:
for _ in range(3):
    source_image_tensors_train.extend(has_stroke_images)
    source_mask_tensors_train.extend(has_stroke_masks)

In [30]:
has_stroke_masks = []
has_stroke_images = []

for idx, mask in enumerate(source_mask_tensors_train):
    if np.argmax(mask) > 0:
        has_stroke_images.append(source_image_tensors_train[idx])
        has_stroke_masks.append(source_mask_tensors_train[idx])


In [31]:
len(has_stroke_masks)

1924

In [32]:
len(has_stroke_images)

1924

In [33]:
len(source_image_tensors_train)

3795

In [34]:
len(source_mask_tensors_train)

3795

In [35]:
assert len(source_image_tensors_train) == len(source_mask_tensors_train)

In [37]:
assert len(has_stroke_images) == len(has_stroke_masks)

In [38]:
(len(source_mask_tensors_train) - len(has_stroke_images)) / len(source_mask_tensors_train)

0.4930171277997365

In [39]:
with h5py.File('./data/train_dataset.h5', 'w') as dataset:
    dataset.create_dataset('images', data = source_image_tensors_train)
    dataset.create_dataset('masks', data = source_mask_tensors_train)

In [40]:
with h5py.File('./data/test_dataset_.h5', 'w') as dataset:
    dataset.create_dataset('images', data = source_image_tensors_test)
    dataset.create_dataset('masks', data = source_mask_tensors_test)

In [41]:
with h5py.File('./data/val_dataset_.h5', 'w') as dataset:
    dataset.create_dataset('images', data = source_image_tensors_val)
    dataset.create_dataset('masks', data = source_mask_tensors_val)

In [42]:
with h5py.File('./data/train_dataset.h5', 'r') as dataset:
    images = dataset['images'][:]
    masks = dataset['masks'][:]
    
has_stroke_masks = []
has_stroke_images = []

for idx, mask in enumerate(masks):
    if torch.argmax(torch.Tensor(mask)).item() > 0:
        has_stroke_images.append(images[idx])
        has_stroke_masks.append(masks[idx])
        
