In [1]:
from pathlib import Path
import shutil

import h5py

import numpy as np
import matplotlib.pyplot as plt
from skimage.draw import polygon, polygon2mask

from sklearn.model_selection import train_test_split

In [None]:
from kedro.extras.datasets.pickle import PickleDataSet
from kedro.config import ConfigLoader

import os, sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.tagseg.data.acdc_dataset import AcdcDataSet
from src.tagseg.data.dmd_dataset import DmdDataSet, DmdTimeDataSet, DmdH5DataSet
from src.tagseg.pipelines.data_splitting.nodes import split_data

In [3]:
filepath_raw = '../data/01_raw/dmd_alex/all/'

In [4]:
ds = DmdH5DataSet(
    filepath='../data/03_primary/dmd_alex_wtv.pt'
)._load_except(filepath_raw)

In [5]:
len(ds)

1250

In [None]:
N = 15

fig, ax = plt.subplots(len(ds) // N + 1, N, figsize=(40, 200))

for i in range(len(ds)):

    m, n = i // N, i % N

    im, la = ds[i]

    ax[m, n].imshow(im[0], cmap='gray')
    ax[m, n].imshow(np.ma.masked_where(la[0] == 0, la[0]), alpha=0.3)
    ax[m, n].axis('off')

    y, x = tuple(map(lambda p: p.mean(), np.where(la == 1)[1:]))

    ax[m, n].set_xlim(x - 40, x + 40)
    ax[m, n].set_ylim(y - 40, y + 40)

Split dataset from /all to /train and /test

In [4]:
filepath_all = Path('../data/01_raw/dmd_alex/all/')

In [5]:
def get_img_path(roi_path: str):
    return roi_path.parent / ('_'.join(roi_path.stem.split('_')[:-1]) + '.h5')

subjects = [(get_img_path(roi_path), roi_path) for roi_path in filepath_all.iterdir() if roi_path.stem.split('_')[-1] == 'roi']

print(f'Total of {len(subjects)} equating to {len(subjects) * 25} 2D images.')

Total of 50 equating to 1250 2D images.


In [6]:
splits_paths = train_test_split(subjects, test_size=0.2)

In [8]:
for split, split_path in zip(['train', 'test'], splits_paths):

    save_path = Path('../data/01_raw/dmd_alex/') / split
    if save_path.exists():
        shutil.rmtree(save_path)

    save_path.mkdir()

    for im_p, la_p in split_path:

        shutil.copy(im_p, save_path / im_p.name)
        shutil.copy(la_p, save_path / la_p.name)

In [None]:
dataset = PickleDataSet(filepath='../data/05_model_input/model_input.pt').load()

In [None]:
len(dataset)

In [None]:
im, la = dataset[0]

In [None]:
im.shape, la.shape