# Exploratory Data Analysis (EDA)

## Imports

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

## Data

In [4]:
raw_data = np.load('../dataset/Task07_Pancreas/raw/imagesTr/image_001.npy')
preprocessed_s0_data = np.load('../dataset/Task07_Pancreas/preprocessed/stage_0/imagesTr/image_001.npy')
preprocessed_s0_pad_data = np.load('../dataset/Task07_Pancreas/preprocessed_pad/stage_0/imagesTr/image_001.npy')
preprocessed_s1_data = np.load('../dataset/Task07_Pancreas/preprocessed/stage_1/imagesTr/image_001.npy')

In [6]:
print('raw_data shape: ', raw_data.shape)
print('preprocessed_s0_data shape: ', preprocessed_s0_data.shape)
print('preprocessed_s0_pad_data shape: ', preprocessed_s0_pad_data.shape)
print('preprocessed_s1_data shape: ', preprocessed_s1_data.shape)

raw_data shape:  (48, 512, 512)
crop_data shape:  (110, 512, 512)
preprocessed_s0_data shape:  (107, 255, 255)
preprocessed_s0_pad_data shape:  (107, 255, 386)
preprocessed_s1_data shape:  (110, 411, 411)


## Visualisation

In [8]:
shapes = []
dir = '../dataset/Task07_Pancreas/preprocessed/stage_0/imagesTr'
for file in os.listdir(dir):
    if file.endswith('.npy'):
        data = np.load(os.path.join(dir, file))
        shapes.append(data.shape[-1])

plt.hist(shapes, bins=20)
plt.title('Histogram of image width')
plt.xlabel('Image width')
plt.ylabel('Frequency')
plt.show()

  plt.show()


In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact
import ipywidgets as widgets

def plot_slice(dataset, category, stage, type, image, slice):
    slice -= 1
    image_number = str(image).zfill(3)
    if category == 'preprocessed':
        path = f'../dataset/Task{dataset}/{category}/{stage}/{type}sTr/{type}_{image_number}.npy'
    elif category == 'preprocessed_pad':
        path = f'../dataset/Task{dataset}/{category}/stage_0/{type}sTr/{type}_{image_number}.npy'
    else:
        path = f'../dataset/Task{dataset}/{category}/{type}sTr/{type}_{image_number}.npy'
    if os.path.exists(path):
        im = np.load(path)
        print("image shape:", im.shape)
    else:
        print("image not found")
        im = np.zeros((130, 320, 320))
    if slice > im.shape[0]-1:
        slice = im.shape[0] - 1
    plt.imshow(im[slice], cmap='gray')
    plt.show()

interact(plot_slice, 
         dataset=widgets.Dropdown(options=['02_Heart', '07_Pancreas', '09_Spleen'], value='02_Heart'),
         category=widgets.Dropdown(options=['raw', 'preprocessed', 'preprocessed_pad'], value='raw'),
         stage=widgets.Dropdown(options=['stage_0', 'stage_1'], value='stage_0'),
         type=widgets.Dropdown(options=['image', 'label'], value='image'),
         image=widgets.IntSlider(min=3, max=30, step=1, value=0), 
         slice=widgets.IntSlider(min=1, max=130, step=1, value=0))

interactive(children=(Dropdown(description='dataset', options=('02_Heart', '07_Pancreas', '09_Spleen'), value=…

<function __main__.plot_slice(dataset, category, stage, type, image, slice)>

## DataLoader Test

In [13]:
import sys
sys.path.append('nnUNet')

from data_loader import load_data


task_folder_path = '../dataset/Task02_Heart'
# task_folder_path = '../dataset/Task07_Pancreas'
# task_folder_path = '../dataset/Task09_Spleen'
train_dataloader, val_dataloader = load_data(task_folder_path, 
                                             dataset_type="preprocessed", 
                                             stage="stage_0", 
                                             val_size=0.2, 
                                             batch_size=5, 
                                             shuffle=True,
                                             resize=(120, 320, 320))

for i, (images, labels) in enumerate(train_dataloader):
    print("image shape:", images.shape)
    print("label shape:", labels.shape)
    break

100%|██████████| 20/20 [00:00<00:00, 56.98it/s]
100%|██████████| 20/20 [00:00<00:00, 49.19it/s]


image shape: torch.Size([5, 1, 120, 320, 320])
label shape: torch.Size([5, 1, 120, 320, 320])


In [14]:
fig, ax = plt.subplots(5, 2, figsize=(10, 20))
for i in range(5):
    ax[i, 0].imshow(images[i][0][80], cmap='gray')
    ax[i, 1].imshow(labels[i][0][80], cmap='gray')
# save the plot as a png file
plt.savefig('images.png')