In [None]:
from glob import glob
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
from PIL import Image
import copy

In [None]:
def visualize_noise(flatten_final, noise, ncols=10):
    residuals = np.abs(flatten_final - noise)
    nsample = flatten_final.shape[0]
    
    _, axes = plt.subplots(nrows=3,ncols=ncols,figsize=(20,8))
    
    for idx in range(ncols):
        
        rand_num = np.random.randint(0,nsample)
        #print(rand_num)
        
        
        axes[0][idx].imshow(flatten_final[idx],
              cmap = plt.cm.gray,
              clim=(0, 255));
        axes[1][idx].imshow(noise[idx],
              cmap = plt.cm.gray,
              clim=(0, 255));
    
        axes[2][idx].imshow(np.sinh(residuals[idx]),
              cmap = plt.cm.gray, interpolation='nearest',
              clim=(0, 255));
        
        if idx == 0:
            
            axes[0][idx].set_ylabel('Original',fontsize=10,fontweight='bold')
            axes[1][idx].set_ylabel('Noisy',fontsize=10,fontweight='bold')
            axes[2][idx].set_ylabel('Residuals',fontsize=10,fontweight='bold')
            
        axes[0][idx].set_yticks([])
        axes[0][idx].set_xticks([])
        
        axes[1][idx].set_yticks([])
        axes[1][idx].set_xticks([])
        
        axes[2][idx].set_yticks([])
        axes[2][idx].set_xticks([])
        
    plt.subplots_adjust(wspace=0,hspace=0)
    _.patch.set_facecolor('#423f3b')
    plt.show()
    

In [None]:
directory = f'/home/sarperyn/sarperyurtseven/ProjectFiles/dataset/NIRCAM/'
h5_files = glob(os.path.join(directory,'**/*sci.h5'))

In [None]:
data_1441 = h5py.File(h5_files[0],'r')
data_1386 = h5py.File(h5_files[1],'r')

In [None]:
def visualize(original,sample1, sample2, sample3, sample4, sample5, ncols=10):
    
    nsample = original.shape[0]
    _, axes = plt.subplots(nrows=6,ncols=ncols,figsize=(20,10))
    
    for idx in range(ncols):
                
        rand_num = np.random.randint(150,nsample)    
        
        axes[0][idx].imshow(original[rand_num],cmap='gray',clim=(0,255))
        axes[1][idx].imshow(sample1[rand_num],cmap='gray',clim=(0,255))
        axes[2][idx].imshow(sample2[rand_num],cmap='gray',clim=(0,255))
        axes[3][idx].imshow(sample3[rand_num],cmap='gray',clim=(0,255))
        axes[4][idx].imshow(sample2[rand_num],cmap='gray',clim=(0,255))
        axes[5][idx].imshow(sample3[rand_num],cmap='gray',clim=(0,255))
        
        if idx == 0:
            
            axes[0][idx].set_ylabel('Original',fontsize=10,fontweight='bold')
            axes[1][idx].set_ylabel('Rotated 45˚',fontsize=10,fontweight='bold')
            axes[2][idx].set_ylabel('Rotated 90˚',fontsize=10,fontweight='bold')
            axes[3][idx].set_ylabel('Rotated 120˚',fontsize=10,fontweight='bold')
            axes[4][idx].set_ylabel('Flipped',fontsize=10,fontweight='bold')
            axes[5][idx].set_ylabel('F + R 210˚',fontsize=10,fontweight='bold')
            
            
            
        axes[0][idx].set_yticks([])
        axes[0][idx].set_xticks([])
        
        axes[1][idx].set_yticks([])
        axes[1][idx].set_xticks([])
        
        axes[2][idx].set_yticks([])
        axes[2][idx].set_xticks([])
        
        axes[3][idx].set_yticks([])
        axes[3][idx].set_xticks([])
        
        axes[4][idx].set_yticks([])
        axes[4][idx].set_xticks([])
        
        axes[5][idx].set_yticks([])
        axes[5][idx].set_xticks([])
        
           
    plt.subplots_adjust(wspace=0,hspace=0)
    #_.patch.set_facecolor('#423f3b')
    plt.show()
    _.savefig(f'Augmentation.png', dpi=300,bbox_inches='tight',pad_inches=0);
    

# 1386

In [None]:
np.array(data_1441[keys_1441[0]]).shape

In [None]:
keys_1386 = [x for x in data_1386.keys()]
final_1386 = np.concatenate((np.array(data_1386[keys_1386[0]]),np.array(data_1386[keys_1386[1]])))

for i in range(len(keys_1386)-2):
    
    final_1386 = np.concatenate((final_1386,np.array(data_1386[keys_1386[i+2]])))

In [None]:
final_1386 = torch.from_numpy(final_1386)

## Resizing

In [None]:
#resized_imgs = [T.CenterCrop(size=size)(final_1386) for size in [160,120]]

In [None]:
#visualize(final_1386,resized_imgs[0],resized_imgs[1])

In [None]:
#resized_160 = resized_imgs[0]

In [None]:
def augment(tensor):
        
    #### ROTATION
    #####################################################################
    angles = [30,45,60,75,90,105,120,135,180,200,210,245,275,310,340]
    rotated_imgs = [T.RandomRotation(degrees=d)(tensor) for d in angles]
    
    rot1 = np.array(rotated_imgs[0])
    rot2 = np.array(rotated_imgs[1])
    rotated_stack = np.concatenate((rot1,rot2),axis=0)   
    
    for i in range(2,len(rotated_imgs)):
    
        rotated_stack = np.concatenate((rotated_stack,np.array(rotated_imgs[i])),axis=0) 
    #####################################################################
    
    
    
    
    #### FLIPPING    
    #####################################################################
    flipped_imgs = [T.RandomHorizontalFlip(p=1)(tensor)]
    #####################################################################
    
    
    
    #### FLIP + ROTATE
    #####################################################################
    flipped_rotated_imgs = [T.RandomRotation(degrees=d)(flipped_imgs[0]) for d in angles]
    
    flip_rot1  = np.array(flipped_rotated_imgs[0])
    flip_rot2  = np.array(flipped_rotated_imgs[1])
    flipped_rotated_stack = np.concatenate((flip_rot1,flip_rot2),axis=0) 
    
    for i in range(2,len(flipped_rotated_imgs)):
    
        flipped_rotated_stack = np.concatenate((flipped_rotated_stack,np.array(flipped_rotated_imgs[i])),axis=0) 
    
    
    augmented_data = np.concatenate((rotated_stack,flipped_rotated_stack,flipped_imgs[0]),axis=0)
    
    return augmented_data, rotated_imgs, flipped_imgs, flipped_rotated_imgs

In [None]:
#np.save('augmented',augmented_data)

In [None]:
augmented_data, rotated_imgs, flipped_imgs, flipped_rotated_imgs = augment(final_1386)

In [None]:
visualize(final_1386,rotated_imgs[0],rotated_imgs[1],rotated_imgs[2],flipped_imgs[0],flipped_rotated_imgs[10])

In [None]:
DIR = f'/home/sarperyn/sarperyurtseven/ProjectFiles/dataset/NIRCAM/1386/1386_psfstack_160.h5'

## Shifting

In [None]:
def random_shift(imgs):
    
    y,x = np.random.randint(0,150,(2,len(imgs)))
    samples = copy.deepcopy(imgs)
    star = copy.deepcopy(samples[:,100:250,75:225])
    
    for i in range(len(y)):

        replace = copy.deepcopy(samples[i:i+1,y[i]:y[i]+150,x[i]:x[i]+150])
        samples[i:i+1,y[i]:y[i]+150,x[i]:x[i]+150] = star[i]
        samples[i:i+1,100:250,75:225] = replace
        
    return samples

In [None]:
flipped_imgs_135 = [T.RandomHorizontalFlip(p=1)(rotated_imgs[0])]
flipped_imgs_270 = [T.RandomHorizontalFlip(p=1)(rotated_imgs[1])]
shifting_135 = random_shift(rotated_imgs[0])
shifting_270 = random_shift(rotated_imgs[1])

In [None]:
shifted_rotated = [T.RandomRotation(degrees=d)(shifted) for d in range(135,270)]
shifted_flipped = [T.RandomHorizontalFlip(p=1)(shifted)]

In [None]:
data_1386.create_dataset('rotated_135',data=rotated_imgs[0])
data_1386.create_dataset('rotated_270',data=rotated_imgs[1])
data_1386.create_dataset('flipped',data=flipped_imgs[0])

In [None]:
data_1386.create_dataset('shifted',data=shifted)

data_1386.create_dataset('rotated_flipped_135',data=flipped_imgs_135[0])
data_1386.create_dataset('rotated_flipped_270',data=flipped_imgs_270[0])
data_1386.create_dataset('rotated_shifted_135',data=shifting_135)
data_1386.create_dataset('rotated_shifted_270',data=shifting_270)

data_1386.create_dataset('shifted_rotated_135',data=shifted_rotated[0])
data_1386.create_dataset('shifted_rotated_270',data=shifted_rotated[1])
data_1386.create_dataset('shifted_flipped_270',data=shifted_flipped[0])

In [None]:
data_1386.close()

# 1441

In [None]:
keys_1441 = [x for x in data_1441.keys()]
final_1441 = np.concatenate((np.array(data_1441[keys_1441[0]]),np.array(data_1441[keys_1441[1]])))

for i in range(len(keys_1441)-2):
    
    final_1441 = np.concatenate((final_1441,np.array(data_1441[keys_1441[i+2]])))

In [None]:
final_1441.shape

In [None]:
final_1441 = torch.from_numpy(final_1441)

In [None]:
augmented_data, rotated_imgs, flipped_imgs, flipped_rotated_imgs = augment(final_1441)

In [None]:
rotated_imgs[0].shape

In [None]:
visualize(final_1441,rotated_imgs[0],rotated_imgs[1],rotated_imgs[2],flipped_imgs[0],flipped_rotated_imgs[10])

In [None]:
#data_1441.create_dataset('stacked_gaussian4_10_noise',data=noisy_gauss_1441)
#data_1441.create_dataset('stacked_exp_3_noise',data=noisy_exp_1441)
#data_1441.create_dataset('stacked_rayleigh_5_noise',data=noisy_rayleigh_1441)

In [None]:
rotated_imgs_1441 = [T.RandomRotation(degrees=d)(final_1441) for d in range(135,270)]

In [None]:
flipped_imgs_1441 = [T.RandomHorizontalFlip(p=1)(final_1441)]

In [None]:
rotated135_flipped = [T.RandomHorizontalFlip(p=0.98)(rotated_imgs_1441[0])]
rotated270_flipped = [T.RandomHorizontalFlip(p=0.98)(rotated_imgs_1441[1])]

In [None]:
np.save('augmented_1441',data_14)

In [None]:
shifted_1441 = random_shift(final_1441)

In [None]:
flipped_imgs_135_1441 = [T.RandomHorizontalFlip(p=1)(rotated_imgs_1441[0])]
flipped_imgs_270_1441 = [T.RandomHorizontalFlip(p=1)(rotated_imgs_1441[1])]
shifting_135_1441 = random_shift(rotated_imgs_1441[0])
shifting_270_1441 = random_shift(rotated_imgs_1441[1])

In [None]:
shifted_rotated_1441 = [T.RandomRotation(degrees=d)(shifted_1441) for d in range(135,270)]
shifted_flipped_1441 = [T.RandomHorizontalFlip(p=1)(shifted_1441)]

In [None]:
#data_1441.create_dataset('rotated_135',data=rotated_imgs_1441[0])
#data_1441.create_dataset('rotated_270',data=rotated_imgs_1441[1])
data_1441.create_dataset('flipped',data=flipped_imgs_1441[0])
data_1441.create_dataset('rotated135_flipped',data=rotated135_flipped[0])
data_1441.create_dataset('rotated270_flipped',data=rotated270_flipped[0])

In [None]:
data_1441.create_dataset('shifted',data=shifted_1441)


data_1441.create_dataset('rotated_flipped_135',data=flipped_imgs_135_1441[0])
data_1441.create_dataset('rotated_flipped_270',data=flipped_imgs_270_1441[0])
data_1441.create_dataset('rotated_shifted_135',data=shifting_135_1441)
data_1441.create_dataset('rotated_shifted_270',data=shifting_270_1441)

data_1441.create_dataset('shifted_rotated_135',data=shifted_rotated_1441[0])
data_1441.create_dataset('shifted_rotated_270',data=shifted_rotated_1441[1])
data_1441.create_dataset('shifted_flipped',data=shifted_flipped_1441[0])

In [None]:
data_1441.close()

In [None]:
data_1386.close()

In [None]:
croped_imgs = [T.CenterCrop(size=size)(final_1441) for size in [160,120]]

In [None]:
visualize(final_1441,croped_imgs[0],croped_imgs[1])

In [None]:
DIR = f'/home/sarperyn/sarperyurtseven/ProjectFiles/dataset/NIRCAM/1441/1441_psfstack_160.h5'

In [None]:
with h5py.File(DIR,'w') as hf:
    
    hf.create_dataset('fullpackage1441_160',data=croped_imgs[0])  