In [None]:
import sys
sys.path.append('/tf/data')

import os
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

from tqdm import tqdm

from general_func import load_dataset

In [None]:
#Choose dataset to load
# positive = False
positive = True
savepath = "/tf/data/augmented_1slice_64_3ch/"

#load dataset
BATCH_SIZE = 16

ds = load_dataset(positive=positive)
if positive == False:
    savepath = savepath+'0'
else:
    savepath = savepath+'1'
print(savepath)

In [None]:
orig_hight = 162
orig_width = 141
num_slices = 46
new_height = 64
new_width = 64
train_scans = list(np.array([x[1] for x,_ in ds]))
train_scans = (np.reshape(train_scans,[len(ds), num_slices, orig_hight, orig_width, 1]))

#pick single slice:
train_scans = [i[23] for i in train_scans]

#Perform data augmentations
#Linearly scales each image in image to have mean 0 and variance 1.
train_scans = [tf.image.per_image_standardization(scan) for scan in train_scans[:]]
print('Scans:',len(train_scans))

In [None]:
for scan in tqdm(train_scans[:],desc='Rotation'):
    for angle in [-0.05, -0.025, 0.025, 0.05]:
        train_scans.append(tfa.image.rotate(images = scan, angles = angle, fill_mode='nearest')) #Nearest as fill, to mimic actual image
print('Scans:',len(train_scans))

In [None]:
for scan in tqdm(train_scans[:], desc='Contrast'):
    for contrast in [0.8, 0.9, 1.1, 1.2]:
        train_scans.append(tf.image.adjust_contrast(scan, contrast_factor = contrast))
print('Scans:',len(train_scans))

In [None]:
# Adjust brightness
for scan in tqdm(train_scans[:], desc='Brightness'):
    for brightness in [-0.1, 0.1]:
        train_scans.append(tf.image.adjust_brightness(scan, delta=brightness)) #could have used tf.image.random_brightness but prefer to use fixes values for reproducibility.
print('Scans:', len(train_scans))

In [None]:
#Flip all images left/right
for scan in tqdm(train_scans[:], desc='Flip images'):
    train_scans.append(tf.image.flip_left_right(scan))
print('Scans:', len(train_scans))

In [None]:
train_scans = [tf.image.resize(scan, [new_height, new_width]) for scan in tqdm(train_scans, desc = 'Resize')]
print('Scans:', len(train_scans))

In [None]:
#Convert to RGB
train_scans = [tf.image.grayscale_to_rgb(scan) for scan in tqdm(train_scans, desc = 'Convert to RGB')]
print('Scans:', len(train_scans))

In [None]:
train_scans_np = [slice.numpy() for slice in train_scans]
print(train_scans_np[0].shape)

In [None]:
os.makedirs(savepath, exist_ok=True)
for i, image in enumerate(train_scans_np):
    np.save(file = f"{savepath}/image_{i:05d}",
            arr = image)

print('Images saved to',savepath)