In [2]:
import os
import sys
import random
import numpy as np
from tqdm import tqdm
import nibabel as nib
import raster_geometry as rg
import cv2
import matplotlib.pyplot as plt

sys.path.append('/workspace/MRI-inpainting-project')

from data_scripts.visualization_utils import ImageSliceViewer3D
from data_scripts.datasets import TrainPatchesDataset, HealthyMRIDataset, PathologicalMRIDataset

## 3D Patches Preparation

In [7]:
healthy_dataset = HealthyMRIDataset("../../data/healthy_mri/healthy_mri", 
                                    "../../data/healthy_mri/healthy_masks/simple-sphere-masks")

In [27]:
patches = []

patch_size = 40
for sample in tqdm(healthy_dataset):
    mri, mask = sample['mri'], sample['mask']
    xs, ys, zs = np.where(mask == 1)
    x_min, x_max = np.min(xs), np.max(xs)
    y_min, y_max = np.min(ys), np.max(ys)
    z_min, z_max = np.min(zs), np.max(zs)
    
    x_size = x_max - x_min
    y_size = y_max - y_min
    z_size = z_max - z_min
    
    x_patch_min, x_patch_max = x_min - (patch_size - x_size) // 2, x_max + (patch_size - x_size) // 2
    y_patch_min, y_patch_max = y_min - (patch_size - y_size) // 2, y_max + (patch_size - y_size) // 2
    z_patch_min, z_patch_max = z_min - (patch_size - z_size) // 2, z_max + (patch_size - z_size) // 2

    patch_mask = np.zeros_like(mask)
    patch_mask[x_patch_min:x_patch_max+1, y_patch_min:y_patch_max+1, z_patch_min:z_patch_max+1] = 1

    patch = mri[patch_mask.astype(bool)].reshape(patch_size, patch_size, patch_size) 
    patch_pathology_mask = mask[patch_mask.astype(bool)].reshape(patch_size, patch_size, patch_size) 

    assert patch.shape == patch_pathology_mask.shape 

    patches.append((patch, patch_mask, patch_pathology_mask, sample['filename']))

100%|██████████| 60/60 [00:17<00:00,  3.41it/s]


In [32]:
for patch, patch_mask, patch_pathology_mask, filename in patches:
    nib_patch = nib.Nifti1Image(patch, affine=np.eye(4))
    nib_patch_mask = nib.Nifti1Image(patch_mask, affine=np.eye(4))
    nib_patch_pathology_mask = nib.Nifti1Image(patch_pathology_mask, affine=np.eye(4))

    nib.save(nib_patch, f"../data/raw/patches/simple-sphere-masks/patches/{filename}")
    nib.save(nib_patch_mask, f"../data/raw/patches/simple-sphere-masks/patch_masks/{filename}")
    nib.save(nib_patch_pathology_mask, f"../data/raw/patches/simple-sphere-masks/patch_pathology_masks/{filename}")

## 