In [None]:
import sys
sys.path.insert(0, '../')
import numpy as np
import torch
import matplotlib.pyplot as plt
from dataio.transformation.imageTransformations import RandomFlipTransform, RandomElasticTransform, RandomAffineTransform, RandomNoiseTransform
from dataio.transformation.imageTransformations import StandardizeImage
from gsprep.visual_tools.visual import display, display_4D

In [None]:
dataset_path = "D:/GitHub/StrokeLesionPredict-BIO503/data/working_data/pct_unet_all_2016_2017/rescaled_data_set.npz"
channel = 0
subj_id = None
subj = 0
ids = np.load(dataset_path, allow_pickle=True)['ids']
if subj_id is not None:
    subj = np.argwhere(ids==subj_id)[0, 0]

raw_images = np.load(dataset_path, allow_pickle=True)['ct_inputs'][subj][..., 0:4].astype(np.float64)
try:
    raw_labels = np.load(dataset_path, allow_pickle=True)['ct_lesion_GT'][subj].astype(np.float64)
except:
    raw_labels = np.load(dataset_path, allow_pickle=True)['lesion_GT'][subj].astype(np.float64)
raw_mask = np.load(dataset_path, allow_pickle=True)['brain_masks'][subj]

raw_images.shape

In [None]:
display_4D(raw_images)

In [None]:
display(raw_images[..., 0], raw_mask)
display(raw_images[..., 1], raw_mask)
display(raw_images[..., 2], raw_mask)
display(raw_images[..., 3], raw_mask)

In [None]:
# images, labels = torch.from_numpy(np.expand_dims(raw_images, axis=-1)), \
#                  torch.from_numpy(np.expand_dims(raw_labels, axis=-1))
images, labels = torch.from_numpy(raw_images), \
                 torch.from_numpy(np.expand_dims(raw_labels, axis=-1))


seed = 7533
max_output_channels = 2
print(images.shape)
labels.shape

In [None]:
flip_axis = (1)
random_flip_prob = 1
flip = RandomFlipTransform(axes=flip_axis, flip_probability=1, p=random_flip_prob, seed=seed, max_output_channels=max_output_channels)
print(images.numpy().shape)
print(labels.numpy().shape)
flipped_image, flipped_label = flip(images, labels)

display(images.numpy()[..., 0], mask=labels.numpy())

print(flipped_image.numpy().shape)
print(flipped_label.numpy().shape)
display(flipped_image.numpy()[..., 0], mask=flipped_label.numpy())
display(flipped_image.numpy()[..., 1], mask=flipped_label.numpy())
display(flipped_image.numpy()[..., 2], mask=flipped_label.numpy())
display(flipped_image.numpy()[..., 3], mask=flipped_label.numpy())

In [None]:
elastic = RandomElasticTransform(max_displacement=[12, 12, 0],
                                   num_control_points=(7, 7, 7),
                                   image_interpolation='bspline',
                                   seed=seed, p=1,
                                   max_output_channels=max_output_channels, verbose=True)

elastic_image, elastic_label = elastic(images, labels)

plt.imshow(elastic_image.numpy()[..., 45, 0], cmap='gray')
plt.imshow(elastic_label.numpy()[..., 45, 0], cmap='Blues', alpha=0.4)

display(images.numpy()[..., 0], mask=labels.numpy())
display(elastic_image.numpy()[..., 0], mask=elastic_label.numpy())

In [None]:
shift_val = (0, 0)  # translation range
rotate_val = 0  # rotation range
scale_val = (1.6, 1.6) # scaling range

affine = RandomAffineTransform(scales=scale_val, degrees=rotate_val, translation=shift_val,
                                  isotropic=True, default_pad_value=0,
                                  image_interpolation='bspline', seed=seed, p=1,
                                  max_output_channels=max_output_channels, verbose=True)

affine_image, affine_label = affine(images, labels)

#display(images.numpy()[..., 0], mask=labels.numpy())
display(affine_image.numpy()[..., 0], mask=affine_label.numpy())
plt.imshow(affine_image.numpy()[..., 35, 0], cmap='gray')
plt.imshow(affine_label.numpy()[..., 35, 0], cmap='Blues', alpha=0.4)

In [None]:
shift_val = (0, 0)  # translation range
rotate_val = (5, 5)  # rotation range
scale_val = (1.8, 1.8) # scaling range

affine = RandomAffineTransform(scales=scale_val, degrees=rotate_val, translation=shift_val,
                                  isotropic=True, default_pad_value=0,
                                  image_interpolation='bspline', seed=seed, p=1,
                                  max_output_channels=max_output_channels, verbose=True)

affine_image, affine_label = affine(images, labels)

#display(images.numpy()[..., 0], mask=labels.numpy())
display(affine_image.numpy()[..., 0], mask=affine_label.numpy())
plt.imshow(affine_image.numpy()[..., 35, 0], cmap='gray')
plt.imshow(affine_label.numpy()[..., 35, 0], cmap='Blues', alpha=0.4)

In [None]:
shift_unique = 10
shift_val = 10 # (shift_unique, shift_unique)  # translation range
rotate_val = (0, 0)  # rotation range
scale_val = (1.0, 1.0) # scaling range

affine = RandomAffineTransform(scales=scale_val, degrees=rotate_val, translation=shift_val,
                                  isotropic=True, default_pad_value=0,
                                  image_interpolation='bspline', seed=54, p=1,
                                  max_output_channels=max_output_channels, verbose=True)

affine_image, affine_label = affine(images, labels)

#display(images.numpy()[..., 0], mask=labels.numpy())
display(affine_image.numpy()[..., 0], mask=affine_label.numpy())
plt.imshow(affine_image.numpy()[..., 35, 0], cmap='gray')
plt.imshow(affine_label.numpy()[..., 35, 0], cmap='Blues', alpha=0.4)

In [None]:
noise_mean = np.mean(raw_images[raw_mask]) # find appropriate mean by taking mean of masked image
print('Mean of input image:', noise_mean)
noise_std = (0.75, 0.75)  # range of noise std

noise = RandomNoiseTransform(mean=noise_mean, std=noise_std, seed=seed, p=1,
                                 max_output_channels=max_output_channels)

noise_image, noise_label = noise(images, labels)

display(images.numpy()[..., 0], mask=labels.numpy())
display(noise_image.numpy()[..., 0], mask=noise_label.numpy())
plt.imshow(noise_image.numpy()[..., 35, 0], cmap='gray')

In [None]:
# Check that the masks are the same at the output of the Noise function. (must be 0)
79*95*79 - torch.eq(labels, noise_label).sum()

In [None]:
from dataio.transformation.imageTransformations import StandardizeImage

norm = StandardizeImage(norm_flag=[True, True, True, False])


norm_image, norm_label = norm(images, labels)

display(images.numpy()[..., 0], mask=labels.numpy())
display(norm_image.numpy()[..., 0], mask=labels.numpy())

In [None]:
# Check that the masks are the same at the output of the NormalizeImage function. (must be 0)
79*95*79 - torch.eq(labels, norm_label).sum()

In [None]:
print("Image Global Mean : \t\t", images.mean().item())
print("Standardized Global Mean : \t", norm_image.mean().item())

print("\nImage Global StD : \t\t", images.std().item())
print("Standardized Global StD : \t", norm_image.std().item())

print("\nChannel 0")
print("Channel 0 Mean : \t\t", images[..., 0].mean().item())
print("Standardized Channel 0 Mean : \t", norm_image[..., 0].mean().item())
print("Channel 0 StD : \t\t", images[..., 0].std().item())
print("Standardized Channel 0 StD : \t", norm_image[..., 0].std().item())

print("\nChannel 1")
print("Channel 1 Mean : \t\t", images[..., 1].mean().item())
print("Standardized Channel 1 Mean : \t", norm_image[..., 1].mean().item())
print("Channel 1 StD : \t\t", images[..., 1].std().item())
print("Standardized Channel 1 StD : \t", norm_image[..., 1].std().item())

print("\nAnd so on...")

In [None]:
plt.rcParams.update({'figure.figsize':(7,5), 'figure.dpi':100})

exp_names = ["Tmax", "CBF", "MTT", "CBV"]
exp_considered = 0

# Plot Histogram on Initial Image
plt.hist(images[..., exp_considered].numpy().flatten(), bins=100)
plt.gca().set(title='Original Image '+exp_names[exp_considered]+' Histogram', xlabel='Pixel Value', ylabel='Count')

print("Max : ", np.max(images[..., exp_considered].numpy().flatten()))
print("Min : ", np.min(images[..., exp_considered].numpy().flatten()))
print("Mean : ", np.mean(images[..., exp_considered].numpy().flatten()))
print("Std : ", np.std(images[..., exp_considered].numpy().flatten()))

In [None]:
exp_names = ["Tmax", "CBF", "MTT", "CBV"]
exp_considered = 0

# Plot Histogram on Normalized Image
plt.hist(norm_image[..., exp_considered].numpy().flatten(), bins=100)
plt.gca().set(title='Standardized Image '+exp_names[exp_considered]+' Histogram', xlabel='Pixel Value', ylabel='Count')

print("Max : ", np.max(norm_image[..., exp_considered].numpy().flatten()))
print("Min : ", np.min(norm_image[..., exp_considered].numpy().flatten()))
print("Mean : ", np.mean(norm_image[..., exp_considered].numpy().flatten()))
print("Std : ", np.std(norm_image[..., exp_considered].numpy().flatten()))

In [None]:
import torchsample.transforms as ts
chanfirst = ts.ChannelsFirst()

In [None]:
chanf_image, chanf_label = chanfirst(images, labels)

In [None]:
chanf_image.shape

In [None]:
images.shape

In [None]:
pad = ts.Pad(size=[96,96,96,1])
pad_image, pad_label = pad(images, labels)

In [None]:
pad_image.shape

# Combinations

In [None]:
shift_val = 100  # translation range
rotate_val = (5, 5)  # rotation range
scale_val = (1.8, 1.8) # scaling range

affine = RandomAffineTransform(scales=scale_val, degrees=rotate_val, translation=shift_val,
                                  isotropic=True, default_pad_value=0,
                                  image_interpolation='bspline', seed=32, p=1,
                                  max_output_channels=max_output_channels, verbose=True)

affine_image, affine_label = affine(images, labels)

display(images.numpy()[..., 0], mask=labels.numpy())
display(affine_image.numpy()[..., 0], mask=affine_label.numpy())
plt.imshow(affine_image.numpy()[..., 35, 0], cmap='gray')
plt.imshow(affine_label.numpy()[..., 35, 0], cmap='Blues', alpha=0.4)

# Parse input arguments
json_filename = arguments.config

# Load options
json_opts = json_file_to_pyobj(json_filename)
train_opts = json_opts.training

# Architecture type
arch_type = train_opts.arch_type

# Setup Dataset and Augmentation
ds_class = get_dataset(arch_type)
ds_path = get_dataset_path(arch_type, json_opts.data_path)
ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation,
                                          max_output_channels=json_opts.model.output_nc)

# Setup channels
channels = json_opts.data_opts.channels
if len(channels) != json_opts.model.input_nc :
        # or len(channels) != getattr(json_opts.augmentation, arch_type).scale_size[-1]:
    raise Exception('Number of data channels must match number of model channels, and patch and scale size dimensions')

# Setup the NN Model
model = get_model(json_opts.model)
if network_debug:
    print('# of pars: ', model.get_number_parameters())
    print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time()))
    exit()

# Setup Data Loader
split_opts = json_opts.data_split
train_dataset = ds_class(ds_path, split='train',      transform=ds_transform['train'], preload_data=train_opts.preloadData,
                         train_size=split_opts.train_size, test_size=split_opts.test_size,
                         valid_size=split_opts.validation_size, split_seed=split_opts.seed, channels=channels)