### Import packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

import os
from random import shuffle
from PIL import Image
import h5py

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler

### Parameters

In [3]:
MEDICAL_DATA = '/net/archive/groups/plggneurony/swilczyn/medical/datasets/medical-4channels-cropped.h5'

BATCH_SIZE = 1024

### Dataset

In [4]:
class MedicalDataset(Dataset):
    def __init__(self, images, masks, patch_size=32):
        self.images = images
        self.masks = masks
        
        self.patch_size = patch_size
        
        self.images_count, self.image_size, _, _ = self.images.shape
        self.patches_per_side = self.image_size - self.patch_size + 1
        self.patches_per_image = self.patches_per_side ** 2
    
    def __len__(self):
        return self.images_count * self.patches_per_image
    
    def __getitem__(self, idx):
        image_idx = idx // self.patches_per_image
        patch_idx = idx % self.patches_per_image
        
        image = self.images[image_idx]
        mask = self.masks[image_idx]
        
        y = patch_idx // self.patches_per_side
        x = patch_idx % self.patches_per_side
        
        patch = image[y: y + self.patch_size, x: x + self.patch_size].transpose(2, 0, 1)
        label = mask.item((y + self.patch_size // 2, x + self.patch_size // 2, 0))
        
        return (patch, label)

### Patches preprocessing

In [5]:
def preprocess_patches(dest, patch_size):
    with h5py.File(MEDICAL_DATA, 'r') as f:
        images = f[dest]['images'][...]
        masks = f[dest]['masks'][...]
        
    dataset = MedicalDataset(images, masks, patch_size=patch_size)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=8, sampler=SequentialSampler(dataset))
    
    patches_labels = torch.zeros(len(dataset), dtype=torch.int64)

    progress = widgets.IntProgress(max=len(loader))
    display(progress)

    for i, data in enumerate(loader):
        _, labels = data

        patches_labels[BATCH_SIZE * i: BATCH_SIZE * i + len(labels)] = labels

        progress.value += 1
        progress.description = '{:.3f} %'.format(100. * progress.value / progress.max)
        
    targets = patches_labels.numpy()

    positive_indices = targets.nonzero()[0]
    negative_indices = np.where(targets == 0)[0]
        
    with h5py.File(MEDICAL_DATA, 'a') as f:
        positives_save_dest = '{}/patches/{}/positives'.format(dest, patch_size)
        negatives_save_dest = '{}/patches/{}/negatives'.format(dest, patch_size)
        
        if positives_save_dest in f:
            del f[positives_save_dest]
            
        if negatives_save_dest in f:
            del f[negatives_save_dest]
        
        f.create_dataset(positives_save_dest, data=positive_indices, compression="gzip")
        f.create_dataset(negatives_save_dest, data=negative_indices, compression="gzip")

In [6]:
for patch_size in [22]:
    for dest in ['train', 'test']:
        print('Processing {} set with patch size {}'.format(dest, patch_size))
        preprocess_patches(dest=dest, patch_size=patch_size)

Processing train set with patch size 22


IntProgress(value=0, max=52888)

Processing test set with patch size 22


IntProgress(value=0, max=21246)

### Prepare validation set

In [5]:
def create_validation_set(patch_size, test_slice=0.05):
    with h5py.File(MEDICAL_DATA, 'r') as f:
        test_positive_indices = f['test']['patches'][str(patch_size)]['positives'][...]
        test_negative_indices = f['test']['patches'][str(patch_size)]['negatives'][...]

        test_images = f['test']['images'][...]
        test_masks = f['test']['masks'][...]
        
    test_dataset = MedicalDataset(test_images, test_masks, patch_size=patch_size)
        
    samples = 2**int(np.log2(test_positive_indices.shape[0] * test_slice))

    val_positive_indices = np.random.choice(test_positive_indices, size=samples, replace=False)
    val_negative_indices = np.random.choice(test_negative_indices, size=samples, replace=False)

    print('Validation positive samples: {}'.format(val_positive_indices.shape[0]))
    print('Validation negative samples: {}'.format(val_negative_indices.shape[0]))
    
    val_indices = np.hstack([val_positive_indices, val_negative_indices])
    np.random.shuffle(val_indices)
    
    val_patches, val_labels = [], []

    for i in val_indices:
        patch, label = test_dataset[i]

        val_patches.append(patch.transpose(1, 2, 0))
        val_labels.append(label)

    val_patches = np.array(val_patches)
    val_labels = np.array(val_labels)

    print('Validation patches: {}, {}'.format(val_patches.shape, val_patches.dtype))
    print('Validation labels: {}, {}'.format(val_labels.shape, val_labels.dtype))
    print('Validation indices: {}, {}'.format(val_indices.shape, val_indices.dtype))
    
    with h5py.File(MEDICAL_DATA, 'a') as f:
        val_indices_save_dest = 'validation/patches/{}/indices'.format(patch_size)
        val_patches_save_dest = 'validation/patches/{}/patches'.format(patch_size)
        val_labels_save_dest = 'validation/patches/{}/labels'.format(patch_size)
        
        if val_indices_save_dest in f:
            del f[val_indices_save_dest]
        
        if val_patches_save_dest in f:
            del f[val_patches_save_dest]
            
        if val_labels_save_dest in f:
            del f[val_labels_save_dest]
        
        f.create_dataset(val_indices_save_dest, data=val_indices, compression="gzip")
        f.create_dataset(val_patches_save_dest, data=val_patches, compression="gzip")
        f.create_dataset(val_labels_save_dest, data=val_labels, compression="gzip")

In [6]:
for patch_size in [28]:
    print('Creating validation set for patch size {}'.format(patch_size))
    create_validation_set(patch_size, test_slice=0.05)

Creating validation set for patch size 28
Validation positive samples: 32768
Validation negative samples: 32768
Validation patches: (65536, 28, 28, 3), float64
Validation labels: (65536,), int64
Validation indices: (65536,), int64
