In [None]:
from glob import glob
import h5py
import os
import numpy as np
from sklearn.decomposition import PCA, TruncatedSVD, KernelPCA, IncrementalPCA
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd

In [None]:
from matplotlib import rcParams
import matplotlib as mpl
rcParams['figure.figsize'] = [15,10]
rcParams['figure.dpi'] = 80
rcParams['savefig.dpi'] = 80

COLOR = 'white'
mpl.rcParams['text.color'] = COLOR
mpl.rcParams['axes.labelcolor'] = COLOR
mpl.rcParams['xtick.color'] = COLOR
mpl.rcParams['ytick.color'] = COLOR

In [None]:
def pca_func(flatten_final,n_comp,pca,ncols=10):
    
    if pca == 'pca':
        dimr = PCA(n_components=n_comp)
        
    elif pca == 'tsvd':
        
        dimr = TruncatedSVD(n_components=n_comp)
        
    elif pca == 'incpca':
        
        dimr = IncrementalPCA(n_components=n_comp)
        
        
    lower_dimensional_data = dimr.fit_transform(flatten_final)
    approximation = dimr.inverse_transform(lower_dimensional_data)
    #print(pca.n_components)
    
    return approximation,lower_dimensional_data
    

In [None]:
def visualize_PCAs(flatten_final, approximation_pca, ncols=5):
    residue_pca = np.abs(flatten_final - approximation_pca)
    
    nsample = flatten_final.shape[0]
    #subs_mean = np.mean(substract,axis=0)
    
    _, axes = plt.subplots(nrows=3,ncols=ncols,figsize=(15,10))
    
    for idx in range(ncols):
        
        rand_num = np.random.randint(0,nsample)
        #print(rand_num)
        
        
        axes[0][idx].imshow(flatten_final[rand_num].reshape(320, 320),
              cmap = plt.cm.gray,
              clim=(0, 255));
        axes[1][idx].imshow(approximation_pca[rand_num].reshape(320, 320),
              cmap = plt.cm.gray,
              clim=(0, 255));
        axes[2][idx].imshow(np.sinh(residue_pca[rand_num]).reshape(320, 320),
              cmap = plt.cm.gray,
              clim=(0, 255));
        
        
        if idx == 0:
            
            axes[0][idx].set_ylabel('Original',fontsize=12,fontweight='bold',color='black')
            axes[1][idx].set_ylabel('PCA',fontsize=12,fontweight='bold',color='black')
            axes[2][idx].set_ylabel('Residuals',fontsize=12,fontweight='bold',color='black')

            
            
        axes[0][idx].set_yticks([])
        axes[0][idx].set_xticks([])
        axes[1][idx].set_yticks([])
        axes[1][idx].set_xticks([])
        axes[2][idx].set_yticks([])
        axes[2][idx].set_xticks([])

        
        
    plt.subplots_adjust(wspace=0,hspace=0)
    #_.patch.set_facecolor('#423f3b')
    #plt.savefig(f'pca_1441_nircam.png', dpi=300,bbox_inches='tight',pad_inches=0);
    plt.show()
    

In [None]:
def visualize_latent(low_dim_data,dimension):
    
    
    if dimension == 2:
    

        fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(15,15))

        ax.scatter(*low_dim_data.T)
        ax.set_xlabel('$pc_1$')
        ax.set_ylabel('$pc_2$')
        fig.patch.set_facecolor('#423f3b')

        plt.show()
        
    if dimension == 3:
                
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(*low_dim_data.T)    
        ax.set_xlabel('$pc_1$')
        ax.set_ylabel('$pc_2$')
        ax.set_zlabel('$pc_3$')
        fig.patch.set_facecolor('#423f3b')
        plt.show()


## Data loading

In [None]:
directory = f'/home/sarperyn/sarperyurtseven/ProjectFiles/dataset/NIRCAM/'

In [None]:
h5_files = glob(os.path.join(directory,'**/*.h5'))

In [None]:
data_1441 = h5py.File(h5_files[0],'r')
data_1386 = h5py.File(h5_files[1],'r')

## 1386

In [None]:
keys_1386 = [x for x in data_1386.keys()]
final_1386 = np.concatenate((np.array(data_1386[keys_1386[0]]),np.array(data_1386[keys_1386[1]])))

for i in range(len(keys_1386)-2):
    
    final_1386 = np.concatenate((final_1386,np.array(data_1386[keys_1386[i+2]])))
    
final_1386.shape

In [None]:
#approx_pca_1386,lowdim_pca_1386             = pca_func(flatten_final_1386,2,pca='pca')
#approx_tsvd_1386,lowdim_tsvd_1386           = pca_func(flatten_final_1386,2,pca='tsvd')
#approx_incpca_1386,lowdim_incpca_1386       = pca_func(flatten_final_1386,2,pca='incpca')

In [None]:
#visualize_PCAs(flatten_final_1386,approx_pca_1386,approx_tsvd_1386,approx_incpca_1386)

In [None]:
flatten_final_1386 = np.resize(final_1386, (final_1386.shape[0],320*320))
center_function = lambda x: x - x.mean()
centered_data_1386 = center_function(flatten_final_1386)

In [None]:
approx_pca_1386, lowdim_pca_1386 = pca_func(centered_data_1386,2,pca='pca')

In [None]:
interval_min = 0
interval_max = 1
scaled_lowdimpca = (lowdim_pca_1386 - np.min(lowdim_pca_1386)) / (np.max(lowdim_pca_1386) - np.min(lowdim_pca_1386)) * (interval_max - interval_min)  + interval_min

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(scaled_lowdimpca[:,:1],scaled_lowdimpca[:,1:2])
plt.xticks(color='black')
plt.yticks(color='black')
#plt.title('PCA',color='black')
plt.ylabel('$pc_2$',color='black')
plt.xlabel('$pc_1$',color='black')
plt.savefig(f'pca_1386_latent_nircam.png', dpi=300,bbox_inches='tight',pad_inches=0);

In [None]:
visualize_PCAs(flatten_final_1386,approx_pca_1386)

## 1441

In [None]:
keys_1441 = [x for x in data_1441.keys()]

In [None]:
keys_1441 = [x for x in data_1441.keys()]
final_1441 = np.concatenate((np.array(data_1441[keys_1441[0]]),np.array(data_1441[keys_1441[1]])))

for i in range(len(keys_1441)-2):
    
    final_1441 = np.concatenate((final_1441,np.array(data_1441[keys_1441[i+2]])))

In [None]:
flatten_final_1441 = np.resize(final_1441, (final_1441.shape[0],320*320))
flatten_final_1441.shape

In [None]:
#approx_tsvd_1441,lowdim_tsvd_1441           = pca_func(flatten_final_1441,3,pca='tsvd')
#approx_incpca_1441,lowdim_incpca_1441       = pca_func(flatten_final_1441,3,pca='incpca')

In [None]:
#visualize_PCAs(flatten_final_1441,approx_pca_1441,approx_tsvd_1441,approx_incpca_1441)

In [None]:
center_function = lambda x: x - x.mean()
centered_data_1441 = center_function(flatten_final_1441)

In [None]:
approx_pca_1441,lowdim_pca_1441 = pca_func(centered_data_1441,2,pca='pca')

In [None]:
interval_min = 0
interval_max = 1
scaled_lowdimpca_1441 = (lowdim_pca_1441 - np.min(lowdim_pca_1441)) / (np.max(lowdim_pca_1441) - np.min(lowdim_pca_1441)) * (interval_max - interval_min)  + interval_min

In [None]:
plt.figure(figsize=(8,8))
plt.scatter(scaled_lowdimpca_1441[:,:1],scaled_lowdimpca_1441[:,1:2])
plt.xticks(color='black')
plt.yticks(color='black')
#plt.title('PCA',color='black')
plt.ylabel('$pc_2$',color='black')
plt.xlabel('$pc_1$',color='black');
plt.savefig(f'pca_1441_latent_nircam.png', dpi=300,bbox_inches='tight',pad_inches=0);

In [None]:
#visualize_latent(scaled_lowdimpca_1441,3)

In [None]:
visualize_PCAs(flatten_final_1441,approx_pca_1441)

In [None]:
#visualize_latent(lowdim_incpca_1441,3)

Now we can try pca for every psfstack in 1441. We will calculate one by one.

In [None]:
batch1 = np.array(data_1441[keys_1441[0]])
batch2 = np.array(data_1441[keys_1441[1]])
batch3 = np.array(data_1441[keys_1441[2]])

In [None]:
flatten_batch1 = np.resize(batch1, (batch1.shape[0],320*320))
flatten_batch2 = np.resize(batch2, (batch2.shape[0],320*320))
flatten_batch3 = np.resize(batch3, (batch3.shape[0],320*320))

In [None]:
centered_data_1441_1 = center_function(flatten_batch1)
centered_data_1441_2 = center_function(flatten_batch2)
centered_data_1441_3 = center_function(flatten_batch3)

In [None]:
approx1_pca_1441,low_dim_pca_1441_1 =  pca_func(centered_data_1441_1,2,pca='pca')
approx2_pca_1441,low_dim_pca_1441_2 =  pca_func(centered_data_1441_2,2,pca='pca')
approx3_pca_1441,low_dim_pca_1441_3 =  pca_func(centered_data_1441_3,2,pca='pca')

In [None]:
visualize_PCAs(flatten_batch1, approx1_pca_1441)

In [None]:
interval_min = 0
interval_max = 1
scaled_lowdimpca_1441_1 = (low_dim_pca_1441_1 - np.min(low_dim_pca_1441_1)) / (np.max(low_dim_pca_1441_1) - np.min(low_dim_pca_1441_1)) * (interval_max - interval_min)  + interval_min

plt.figure(figsize=(8,8))
plt.scatter(scaled_lowdimpca_1441_1[:,:1],scaled_lowdimpca_1441_1[:,1:2])
plt.xticks(color='black')
plt.yticks(color='black')
plt.title('PCA',color='black')
plt.ylabel('$pc_2$',color='black')
plt.xlabel('$pc_1$',color='black');
#plt.savefig(f'pca_1441_latent_nircam.png', dpi=300,bbox_inches='tight',pad_inches=0);

In [None]:
visualize_PCAs(flatten_batch2, approx2_pca_1441,approx2_tsvd_1441,approx2_incpca_1441)

In [None]:
visualize_latent(low_dim_tsvd_1441_2,3)

In [None]:
visualize_PCAs(flatten_batch3, approx3_pca_1441,approx3_tsvd_1441,approx3_incpca_1441)

In [None]:
visualize_latent(low_dim_tsvd_1441_3,3)