This notebook performs data augmentation in which no cropping takes place. This is qualtitatively different to the other data augmentation notebook, which should be executed first.

All FFT descriptors are calculated within the notebook as well.

# Import libraries

In [None]:
import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import numpy.matlib
from scipy import ndimage
from collections import Counter
import itertools
from scipy.signal import get_window
import cv2
from collections import defaultdict
from scipy.stats import mode
from scipy import stats

In [None]:
# specify path in which data will be saved
save_path = '.'

# Load all data

In [None]:
# Specify three folders where h5 files with simulated images can be found
folders = ['.']


In [None]:
# load all folders:
files = []
images = []
for folder in folders:
    files.extend( [i for i in os.listdir(folder) if 'augmented_images_rotation&shear' in i ] )
    images.extend( [os.path.join(folder, i) for i in [i for i in os.listdir(folder) if 'augmented_images_rotation&shear' in i ] ] )

In [None]:
files

Exemplarily load one of the images:

In [None]:
file = h5py.File(images[0],'r')
print(file.keys())
print(np.array(file.get('Image_rotation_and_shear').get('Rotated_and_sheared_images')).shape)
file.close()

# Load rotated and sheared images and calculate FFT

In [None]:
def calc_fft(img, padding=(0, 0), power=2,
             sigma=None, r_cut=None,
             thresholding=False, apply_window=True, output_size=None,
             output_shape=(64, 64)):
    """Given HAADF image, calculate HAADF-FFT descriptor
    
    Parameters: 
    
    img: np.array
        HAADF input image
    padding: tuple
        zero padding employed to bring image size to power of 2
    power: int
        Number by which FFT amplitude is exponentiated
        in order to supress small fluctuations and
        emphasize peaks
    sigma: int
        Width of gaussian window employed to cut out central
        part of the FFT. In the standard setting (sigma=None),
        no cutting employed.
    r_cut: int
        Size of rectangular window
        that is used to cut the center of the FFT.
        In the standard setting (sigma=None),
        no cutting employed.
    thresholding: bool
        [incompletely implemented] If True, apply thresholding
        procedure to mitigate influence of central peak
    output_size: tuple
        Output size of fft, if None, fft size will be given
        by img.shape[0] and img.shape[1], if output size
        larger than image size, crop image, if smaller, apply 
        zero padding 
    """

    # First step: normalize image
    img = cv2.normalize(img, None,
                       alpha=0, beta=1,
                       norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

    if apply_window:
        # windowing
        bw2d = np.outer(get_window('hanning',img.shape[0]), 
                        np.ones(img.shape[1]))
        bw2d_1 = np.transpose(np.outer(get_window('hanning',img.shape[1]), 
                                       np.ones(img.shape[0])))
        w = np.sqrt(bw2d * bw2d_1)
        img_windowed = img * w
    else:
        img_windowed = img
    
    # Calculate FFT
    f = np.fft.fft2(img_windowed, s=output_size)
    
    # Calculate power spectrum (or higher order exponential)
    fshift = np.fft.fftshift(np.power(np.abs(f), power))
    
    # Normalization
    fshift = cv2.normalize(fshift, None,
                           alpha=0, beta=1,
                           norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    
    # Remove central part of image, several options:
    # Spherical cut:
    if not r_cut == None:

        xc = (fshift.shape[0] - 1.0) / 2.0
        yc = (fshift.shape[1] - 1.0) / 2.0
        # spherical mask
        a, b = xc, yc
        x, y = np.ogrid[-a:fshift.shape[0] - a, -b:fshift.shape[1] - b]

        mask_out = x * x + y * y <= r_cut * r_cut

        for i in range(fshift.shape[0]):
            for j in range(fshift.shape[1]):
                if mask_out[i, j]:
                    fshift[i, j] = 0.0
   
    # cut using gaussian window: 
    if not sigma == None:
        bw2d = np.outer(get_window(('gaussian', sigma), fshift.shape[0]), 
                    np.ones(fshift.shape[1]))
        bw2d_1 = np.transpose(np.outer(get_window(('gaussian', sigma), fshift.shape[0]), 
                                       np.ones(fshift.shape[0])))
        w = np.sqrt(bw2d * bw2d_1)
        fshift = fshift * (1-w)

    if thresholding:
        # print("Threshold FFT spectrum")
        # Previous procedure employed by Byungchul
        """
        intfft = np.sort(fshift.ravel())[::-1]
        thresh = intfft[1]

        output = fshift / thresh
        #output[np.where(output[:]<0)] = 0 Neccessary?
        output[np.where(output[:]>thresh)] = 1
        
        fshift = output
        """
        # Chris:
        fshift = cv2.normalize(fshift, None, 
                               alpha=0, beta=1, 
                               norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        fshift = fshift/.1
        fshift[fshift>1] = 1
        fshift = cv2.normalize(fshift, None, 
                               alpha=0, beta=1, 
                               norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        
    
    # Cut out 64x64 window around center of FFT
    output = fshift
    #output2 = np.zeros((64,64))
    #for i in range(0,64):
    #    for j in range(0,64):
    #        output2[i,j] = output[int(float(output.shape[0])/float(2.0))-32+i,int(float(output.shape[1])/float(2.0))-32+j]

    output2 = np.zeros(output_shape)
    for i in range(0, output_shape[0]):
        for j in range(0, output_shape[1]):
            output2[i,j] = output[int(float(output.shape[0])/2.) - int(output_shape[0]/2.) + i,
                                  int(float(output.shape[1])/2.0) - int(output_shape[1]/2.) + j]

    
    output2 = cv2.normalize(output2, None, 
                            alpha=0, beta=1, 
                            norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    return output2

In [None]:
all_images = []
all_images_fft = []
all_labels = []

# FFT parameters
thresholding = True
r_cut = None
sigma = None

for idx in range(len(images)):
    
    print(images[idx])

    file = h5py.File(images[idx],'r')
    
    # get keys 
    file_ending_current_image = '_'.join(images[idx].split('/')[-1].split('_')[-4:])

    #Extract datasets for each group
    file_data = file.get("Image_rotation_and_shear")
    img = file_data.get("Rotated_and_sheared_images")
    
    current_structure = '_'.join(images[idx].split('/')[-1].split('_')[:-4])
    
    img_shape = img.shape
    
    for rot in range(img_shape[-2]):
        for shear in range(img_shape[-1]):

            all_images.append(img[:, :, rot, shear])
            all_labels.append(current_structure + '_rot_{}_shear_{}'.format(rot, shear))

            fft_desc = calc_fft(img[:, :, rot, shear],
                                r_cut=r_cut, thresholding=thresholding,
                                sigma=sigma)
            all_images_fft.append(fft_desc)

In [None]:
len(all_images_fft)

In [None]:
# Save HAADF images
np.save(os.path.join(save_path, 'X_haadf.npy'), np.array(all_images))

# Save labels
np.save(os.path.join(save_path, 'y_fulllabels.npy'), all_labels)

In [None]:
print(np.array(all_images).shape)

In [None]:
# Save HAADF FFTs
np.save(os.path.join(save_path, 'X_fft.npy'), np.array(all_images_fft))

### Define relation between labels and int labels

In [None]:
a = [_.split('_')[:3] for _ in all_labels]
b = ['_'.join(_) for _ in a]

In [None]:
unique_labels = np.unique(b)
print(unique_labels)

In [None]:
numerical_to_text_labels = dict(zip(range(len(unique_labels)), unique_labels))
text_to_numerical_labels = dict(zip(unique_labels, range(len(unique_labels))))
print(numerical_to_text_labels, text_to_numerical_labels)

In [None]:
import json

with open(os.path.join(save_path, 'text_to_numerical_labels.json'), 'w') as f:
    json.dump(text_to_numerical_labels, f)
    
with open(os.path.join(save_path, 'numerical_to_text_labels.json'), 'w') as f:
    json.dump(numerical_to_text_labels, f)

In [None]:
converted_labels = [text_to_numerical_labels[_] for _ in b]

In [None]:
np.save(os.path.join(save_path,
                     'y.npy'), np.array(converted_labels))

# Add noise

In [None]:
from scipy.ndimage import gaussian_filter
from skimage.util import random_noise

In [None]:
# Extract subselection, otherwise may run into memory problems - at least
# if not run on high-performance computing cluster

raw_images = []
raw_labels = []

for idx in range(len(images)):
    
    print(images[idx])

    file = h5py.File(images[idx],'r')
    
    # get keys 
    file_ending_current_image = '_'.join(images[idx].split('/')[-1].split('_')[-4:])

    #Extract datasets for each group
    file_data = file.get("Image_rotation_and_shear")
    img = file_data.get("Rotated_and_sheared_images")
    
    current_structure = '_'.join(images[idx].split('/')[-1].split('_')[:-4])
    
    img_shape = img.shape
    
    for rot in range(img_shape[-2])[::2]:
        for shear in range(img_shape[-1])[::2]:

            raw_images.append(img[:, :, rot, shear])
            raw_labels.append(current_structure + '_rot_{}_shear_{}'.format(rot, shear))
print(len(raw_images))

## Poisson noise

In [None]:
images_w_poisson = []
labels_w_poisson = []
iterations = 2

for img, current_structure in zip(raw_images, raw_labels):
               
    # NORMALIZE
    current_image = cv2.normalize(img, None,
                                  alpha=0, beta=1,
                                  norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

    images_w_poisson.append(current_image)
    labels_w_poisson.append(current_structure)

    for it in range(iterations):
        distorted_image = random_noise(current_image, mode='poisson')
        images_w_poisson.append(distorted_image)
        labels_w_poisson.append('{}_pois_it_{}'.format(current_structure, it))

In [None]:
print(len(images_w_poisson))

In [None]:
np.unique(list(Counter(labels_w_poisson).values()))

In [None]:
labels_w_poisson

# Add Blurring

In [None]:
widths = [2, 4]
images_w_poisson_w_gaussian = []
labels_w_poisson_w_gaussian = []

for current_image, current_label in zip(images_w_poisson, labels_w_poisson):

    images_w_poisson_w_gaussian.append(current_image)
    labels_w_poisson_w_gaussian.append(current_label)
    
    for width in widths:

        distorted_image = gaussian_filter(current_image, sigma=width)

        images_w_poisson_w_gaussian.append(distorted_image)
        labels_w_poisson_w_gaussian.append(current_label + '_gwidth_' + str(width))

In [None]:
print(len(images_w_poisson_w_gaussian))

In [None]:
np.unique(list(Counter(labels_w_poisson_w_gaussian).values()))

In [None]:
labels_w_poisson_w_gaussian

# Add Gaussian noise

In [None]:
var_list = [0.005, 0.01]
images_w_poisson_w_gaussian_w_gnoise = []
labels_w_poisson_w_gaussian_w_gnoise = []

for current_image, current_label in zip(images_w_poisson_w_gaussian, labels_w_poisson_w_gaussian):
    
    if 'pois' in current_label or 'gwidth' in current_label:
        # only keep distorted, not pristine or those with scan noise
        images_w_poisson_w_gaussian_w_gnoise.append(current_image)
        labels_w_poisson_w_gaussian_w_gnoise.append(current_label)
        
    if 'pois' in current_label:
        # don't add gaussian noise AND poisson noise
        continue
    
    for var in var_list:
        distorted_image = random_noise(current_image, mode='gaussian', var=var)
        
        images_w_poisson_w_gaussian_w_gnoise.append(distorted_image)
        labels_w_poisson_w_gaussian_w_gnoise.append(current_label + '_gnoisevar_' + str(var))

In [None]:
np.unique(list(Counter(labels_w_poisson_w_gaussian_w_gnoise).values()))

In [None]:
print(len(labels_w_poisson_w_gaussian_w_gnoise))

In [None]:
labels_w_poisson_w_gaussian_w_gnoise

# Calculate FFT for distorted images

In [None]:
# images_distorted = []
ffts_distorted = []
labels_distorted = []
r_cut = None
sigma = None
thresholding = True

for img, label in zip(images_w_poisson_w_gaussian_w_gnoise, 
                     labels_w_poisson_w_gaussian_w_gnoise):
    
    fft_desc = calc_fft(img, r_cut=r_cut,
                    thresholding=thresholding,
                    sigma=sigma)
    
    ffts_distorted.append(fft_desc)
    labels_distorted.append(label)

In [None]:
len(ffts_distorted)

In [None]:
# Save
np.save(os.path.join(save_path, 'X_distorted_HAADF.npy'), np.asarray(images_w_poisson_w_gaussian_w_gnoise))

np.save(os.path.join(save_path, 'X_fft_distorted.npy'), np.asarray(ffts_distorted))

np.save(os.path.join(save_path, 'y_fft_distorted.npy'), np.asarray(labels_distorted))

In [None]:
a = [_.split('_')[:3] for _ in labels_distorted]
b = ['_'.join(_) for _ in a]
converted_labels = [text_to_numerical_labels[_] for _ in b]

np.save(os.path.join(save_path, 'y_fft_distorted_int.npy'), np.asarray(converted_labels))