In [4]:
import os
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

from dipy.io.image import load_nifti
from dipy.align.reslice import reslice
from scipy.ndimage import affine_transform

strategy = tf.distribute.MirroredStrategy()
num_gpus = strategy.num_replicas_in_sync
print(f'Number of devices: {num_gpus}')
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
print(tf.config.list_logical_devices())


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
[]
[LogicalDevice(name='/device:CPU:0', device_type='CPU')]


In [3]:
def get_dataset_list(dataset):
    dataset_list = []
    
    if dataset == 'CC':
        files = ['./dataset_cc359.txt']
    elif dataset == 'NFBS':
        files = ['./dataset.txt']
    elif dataset == 'both':
        files = ['./dataset.txt']
        files.append('./dataset_cc359.txt')

    for file in files:
        with open(file, 'r') as f:
            lin = f.readline()
            while(lin):
                dataset_list.append(lin[:-1])
                lin = f.readline()
    print('Total Images in dataset: ', len(dataset_list))
    return dataset_list

def datasetHelperFunc(path):
    transform_vol, mask = None, None
    if isinstance(path, bytes):
        path = str(path.decode('utf-8'))

    if 'CC' in path:
        vol, affine, voxsize = load_nifti(
            '/N/project/grg_data/data/skullstripping_datasets/CC359/Original/'+path, return_voxsize=True)
        mask, _ = load_nifti(
            '/N/project/grg_data/data/skullstripping_datasets/CC359/STAPLE/'+path[:-7]+'_staple.nii.gz')
        mask[mask < 1] = 0  # Values <1 in the mask is background
        vol = vol*mask
    else:
        #vol, affine, voxsize = load_nifti(str(path.decode('utf-8')), return_voxsize=True)
        vol, affine, voxsize = load_nifti(path, return_voxsize=True)
        mask, _ = load_nifti(path[:-7]+'mask.nii.gz')
        mask[mask < 1] = 0  # Values <1 in the mask is background
        vol = vol*mask

    if mask is not None:
        mask, _ = transform_img(mask, affine, voxsize)
        # Handling negative pixels, occurred as a result of preprocessing
        mask[mask < 0] *= -1
        mask = np.expand_dims(mask, -1)
    transform_vol, _ = transform_img(vol, affine, voxsize)
    # Handling negative pixels, occurred as a result of preprocessing
    transform_vol[transform_vol < 0] *= -1
    transform_vol = (transform_vol-np.min(transform_vol)) / \
        (np.max(transform_vol)-np.min(transform_vol))
    transform_vol = np.expand_dims(transform_vol, -1)
    return tf.convert_to_tensor(transform_vol, tf.float32), tf.convert_to_tensor(mask, tf.float32)


dataset_list = get_dataset_list('both')

bs = 48
suffix = 'B48-both'
test_size = len(dataset_list) - (len(dataset_list)//bs)*bs

lis = dataset_list[-test_size:]
dataset = tf.data.Dataset.from_tensor_slices(lis)

dataset = dataset.map(lambda x: tf.numpy_function(func=datasetHelperFunc, inp=[x], Tout=[tf.float32, tf.float32]),
                      num_parallel_calls=tf.data.experimental.AUTOTUNE)

test_dataset = dataset.batch(bs).prefetch(tf.data.experimental.AUTOTUNE).take(1)



Total Images in dataset:  484


In [4]:
model = VQVAE(
            in_channels=1,
            out_channels=1,
            num_channels=(32, 64, 128),
            num_res_channels=(32, 64, 128),
            num_res_layers=3,
            downsample_parameters=((2, 4, 1, 'same'), (2, 4, 1, 'same'), (2, 4, 1, 'same')),
            upsample_parameters=((2, 4, 1, 'same', 0), (2, 4, 1, 'same', 0), (2, 4, 1, 'same')),
            num_embeddings=256,
            embedding_dim=32,
            num_gpus=float(num_gpus),
            kernel_resize=False)

test_epoch = 64
model.load_weights(os.path.join('./checkpoints-vqvae-monai-scaled-128', suffix, str(test_epoch)+'.ckpt'))

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x2b34ebe2a740>

In [None]:
for x, _ in test_dataset:
    np.save(f'./reconst_scaled_vqvae3d_monai/original-{suffix}.npy', x.numpy())
    reconst = model(x)
    loss = tf.reduce_mean((reconst-x)**2)
    print(f'Test Loss is {loss}')
    np.save(f'./reconst_scaled_vqvae3d_monai/reconst3d-{suffix}-epoch{test_epoch}.npy', reconst.numpy())
    break

In [None]:
for x, _ in test_dataset:
    print(x.numpy().shape)
    break

# Test BraTS

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from dipy.io.image import load_nifti
from dipy.align.reslice import reslice
from scipy.ndimage import affine_transform
from fury.actor import slicer

In [7]:
vol, affine, voxsize = load_nifti(r'D:\DiPY\SyntheticMRI\BraTS2021_00495_t1.nii.gz', return_voxsize=True)
mask, affine, voxsize = load_nifti(r'D:\DiPY\SyntheticMRI\BraTS2021_00495_seg.nii.gz', return_voxsize=True)
vol = vol.astype(np.float32)
mask = mask.astype(np.float32)
mask[mask < 1] = 0  # Values <1 in the mask is background
vol = vol*mask # zero out the background or non-region of interest areas.

In [8]:
def transform_img_brats(image, affine, voxsize, final_shape = (128, 128, 128)):
    temp_image, affine_temp = reslice(image, affine, voxsize, (2, 2, 2))
    temp_image = slicer(temp_image, affine_temp).resliced_array()
    print(temp_image.shape)
    
    current_shape = temp_image.shape

    pad_x = (final_shape[0] - current_shape[0]) // 2
    pad_y = (final_shape[1] - current_shape[1]) // 2
    pad_z = (final_shape[2] - current_shape[2]) // 2

    # Ensure the padding is equally distributed
    pad_width = ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z))

    transformed_img = np.pad(temp_image, pad_width, mode='constant', constant_values=0)
    print("Padded image shape:", transformed_img.shape)

    return transformed_img, affine

In [9]:
mask, _ = transform_img_brats(mask, affine, voxsize)
transform_vol, _ = transform_img_brats(vol, affine, voxsize)

mask[mask < 0] *= -1 # Handling negative pixels, occurred as a result of preprocessing
mask = np.expand_dims(mask, -1)

transform_vol[transform_vol < 0] *= -1 # Handling negative pixels, occurred as a result of preprocessing

transform_vol = (transform_vol-np.min(transform_vol)) / \
    (np.max(transform_vol)-np.min(transform_vol))
transform_vol = np.expand_dims(transform_vol, -1)
print(transform_vol.shape)

(120, 120, 78)
Padded image shape: (128, 128, 128)
(120, 120, 78)
Padded image shape: (128, 128, 128)
(128, 128, 128, 1)


In [10]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

import os
# brain_images = transform_vol
brain_images = vol
print(brain_images.shape)
def plot_brain_slices(slice_index):
    n_images = 1
    cols = int(np.ceil(np.sqrt(n_images)))
    rows = int(np.ceil(n_images / cols))

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))

    axes.imshow(brain_images[:, :, slice_index], cmap='gray')  # Adjust index based on your data's shape
    axes.set_title(f'Model Epoch, Slice {slice_index}')
    axes.axis('off')

    plt.show()

slice_slider = widgets.IntSlider(min=0, max=brain_images.shape[2]-1, step=1, value=brain_images.shape[2]//2, description='Slice Index')
widgets.interactive(plot_brain_slices, slice_index=slice_slider)

(240, 240, 155)


interactive(children=(IntSlider(value=77, description='Slice Index', max=154), Output()), _dom_classes=('widge…

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

import os
brain_images = transform_vol
print(brain_images.shape)
def plot_brain_slices(slice_index):
    n_images = 1
    cols = int(np.ceil(np.sqrt(n_images)))
    rows = int(np.ceil(n_images / cols))

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
    print(brain_images[:, :, slice_index].shape)
    axes.imshow(brain_images[:, :, slice_index], cmap='gray')  # Adjust index based on your data's shape
    axes.set_title(f'Model Epoch, Slice {slice_index}')
    axes.axis('off')

    plt.show()

slice_slider = widgets.IntSlider(min=0, max=brain_images.shape[2]-1, step=1, value=brain_images.shape[2]//2, description='Slice Index')
widgets.interactive(plot_brain_slices, slice_index=slice_slider)

(128, 128, 128, 1)


interactive(children=(IntSlider(value=64, description='Slice Index', max=127), Output()), _dom_classes=('widge…