In [1]:
import random
import PIL
from matplotlib import pyplot as plt
import numpy as np
import albumentations as A
import os
import glob
import torch
from astropy.io import fits
import umap
import einops
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import transforms
from itertools import product
import numpy as np
from tqdm import tqdm
import pandas as pd
import glob
from PIL import Image
import os


  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


In [2]:
injections = glob.glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/injections/*.png')[500:600]
augmentations = glob.glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/*')[500:600]

In [None]:
def plot_results(imgs):


    _, axes = plt.subplots(nrows=10,ncols=10,figsize=(20,20))

    for i, (row,col) in enumerate(product(range(10),range(10))):


        axes[row][col].imshow(imgs[i])
    
        axes[row][col].set_yticks([])
        axes[row][col].set_xticks([])

    plt.subplots_adjust(wspace=0,hspace=0)
    plt.savefig(f'fig_.jpg',format='jpg',dpi=1000)
    plt.show() 

In [None]:
imgs = []

for i in injections:

    img = PIL.Image.open(i).convert('L')
    img = np.array(img)
    imgs.append(img)

imgs = np.concatenate(np.expand_dims(imgs, axis=0),axis=0)

In [None]:
plot_results(imgs)

In [3]:
def hex_to_RGB(hex_str):
    """ #FFFFFF -> [255,255,255]"""
    #Pass 16 to the integer function for change of base
    return [int(hex_str[i:i+2], 16) for i in range(1,6,2)]

def get_color_gradient(c1, c2, n):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    assert n > 1
    c1_rgb = np.array(hex_to_RGB(c1))/255
    c2_rgb = np.array(hex_to_RGB(c2))/255
    mix_pcts = [x/(n-1) for x in range(n)]
    rgb_colors = [((1-mix)*c1_rgb + (mix*c2_rgb)) for mix in mix_pcts]
    return ["#" + "".join([format(int(round(val*255)), "02x") for val in item]) for item in rgb_colors]



def plot_pca_comps(pca_comps):

    color1 = "#D4CC47"
    color2 = "#7C4D8B"
    num_points = 200
    plt.figure(figsize=(15,15))

    plt.scatter(pca_comps[:,0],pca_comps[:,1],
            color=get_color_gradient(color1, color2, num_points))
    plt.colorbar()
    plt.title("Gradient Scatter")
    #plt.savefig(f'chair_{index}.png',format='png',dpi=100)
    plt.show()
    plt.close()

In [4]:
def visualize(image):
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(image)

In [5]:
def get_stage3_products(suffix,directory):
    return glob.glob(os.path.join(directory, f'*{suffix}.fits'))

In [6]:

class Exonet(nn.Module):
    
    def __init__(self, convdim_enc_outputs:list, convdim_dec_outputs:list, kernels_enc:list, strides_enc:list, kernels_dec:list, strides_dec:list):
        
        super(Exonet,self).__init__()
        
        self.convdim_enc = convdim_enc_outputs
        self.convdim_dec = convdim_dec_outputs
        self.kernels_enc = kernels_enc
        self.strides_enc = strides_enc
        self.kernels_dec = kernels_dec
        self.strides_dec = strides_dec
        self.C       = 8 
        
        self.encoder  = nn.Sequential(
                        
            nn.Conv2d(in_channels=1, out_channels=self.C, stride=self.strides_enc[0], kernel_size=self.kernels_enc[0]), #1
            nn.BatchNorm2d(self.C),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C, out_channels=self.C*2, stride=self.strides_enc[1], kernel_size=self.kernels_enc[1]), #2
            nn.BatchNorm2d(self.C*2),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_enc[2], kernel_size=self.kernels_enc[2]), #3
            nn.BatchNorm2d(self.C*2),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_enc[3], kernel_size=self.kernels_enc[3]), #4 
            nn.BatchNorm2d(self.C*2),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C*2, out_channels=self.C*4, stride=self.strides_enc[4], kernel_size=self.kernels_enc[4]), #5
            nn.BatchNorm2d(self.C*4),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C*4, out_channels=self.C*8, stride=self.strides_enc[5], kernel_size=self.kernels_enc[5]), #6
            nn.BatchNorm2d(self.C*8),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=self.C*8, out_channels=self.C*16, stride=self.strides_enc[6], kernel_size=self.kernels_enc[6]), #7
            nn.BatchNorm2d(self.C*16),
            nn.LeakyReLU(),
            
        
        ) 
        
        self.fc1 = nn.Sequential(
        
                nn.Linear((self.C*16)*convdim_outputs[-1]**2,4096),
                nn.SiLU(),
                nn.Linear(4096,2048),
                nn.SiLU(),
                nn.Linear(2048,1024),
                nn.SiLU(),
        )

        self.latent = nn.Linear(1024,1024)

        self.fc2   = nn.Sequential(

                nn.Linear(1024,2048),
                nn.SiLU(),
                nn.Linear(2048,4096),
                nn.SiLU(),
                nn.Linear(4096,(self.C*16)*convdim_outputs[-1]**2),
                nn.SiLU(),

        )

        self.decoder = nn.Sequential(

                        
            nn.ConvTranspose2d(in_channels=self.C*16, out_channels=self.C*8, stride=self.strides_dec[0], kernel_size=self.kernels_dec[0]), #1
            nn.BatchNorm2d(self.C*8),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C*8, out_channels=self.C*4, stride=self.strides_dec[1], kernel_size=self.kernels_dec[1]), #2
            nn.BatchNorm2d(self.C*4),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C*4, out_channels=self.C*2, stride=self.strides_dec[2], kernel_size=self.kernels_dec[2]), #3
            nn.BatchNorm2d(self.C*2),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C*2, stride=self.strides_dec[3], kernel_size=self.kernels_dec[3]), #4 
            nn.BatchNorm2d(self.C*2),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C*2, out_channels=self.C, stride=self.strides_dec[4], kernel_size=self.kernels_dec[4]), #5
            nn.BatchNorm2d(self.C),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C, out_channels=self.C, stride=self.strides_dec[5], kernel_size=self.kernels_dec[5]), #6
            nn.BatchNorm2d(self.C),
            nn.SiLU(),
            
            nn.ConvTranspose2d(in_channels=self.C, out_channels=1, stride=self.strides_dec[6], kernel_size=self.kernels_dec[6]), #7
            nn.BatchNorm2d(1),
            nn.SiLU(),
            
        ) 
        
    def forward(self,x):
        
        bs       = x.size(0)

        x       = self.encoder(x)
        x       = x.view(x.size(0),-1)

        x       = self.fc1(x)
        latents = self.latent(x)
        x       = self.fc2(latents)

        x       = x.view(bs,self.C*16,convdim_outputs[-1],convdim_outputs[-1])
        x       = self.decoder(x)
        
        return x

In [7]:
model =torch.load('/home/sarperyn/sarperyurtseven/ProjectFiles/models/model_exp-aug0_epoch-139.pt')

In [8]:
class SynDataset(Dataset):

    def __init__(self, image_paths):

        self.image_paths = image_paths
        self.transform  = transforms.Compose([
        transforms.ToTensor(),
        ])


    def __len__(self,):

        return len(self.image_paths)

    def __getitem__(self, index):
        
        image_path = self.image_paths[index]
        image      = Image.open(image_path).convert('L')
        image = self.transform(image)

        if torch.isnan(image).any().item():
            torch.nan_to_num(image)
            
        return image
    

In [None]:
# injections = glob.glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/injections/*.png')[:100]
# augmentations = glob.glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/*')[:100]
# test = injections + augmentations
# random.shuffle(test)

In [15]:
test = []

with open('/home/sarperyn/sarperyurtseven/ProjectFiles/notebooks/text_dirs.txt','r') as file:

    img = file.readlines()
    for i in img:

        test.append(i)

In [16]:
test

['/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f360m-maskrnd-sub320a335r_psfstack_sci_262_horflip_rot_shift.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f360m-maskrnd-sub320a335r_psfstack_sci_72_horflip_rot.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f300m-maskrnd-sub320a335r_psfstack_sci_436_verflip_horflip_rot.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f300m-maskrnd-sub320a335r_psfstack_sci_278.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f360m-maskrnd-sub320a335r_psfstack_sci_109_horflip.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f410m-maskrnd-sub320a335r_psfstack_sci_16_shift_rot_ver.jpg\n',
 '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t004_nircam_f300m-maskrnd-sub320a335r_psf

In [17]:
syndata        = SynDataset(image_paths=test)
syndata_loader = DataLoader(dataset=syndata, batch_size=200, shuffle=True)
batch = next(iter(syndata_loader)).to('cuda:2')

FileNotFoundError: [Errno 2] No such file or directory: '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/jw01386-a3001_t002_nircam_f300m-maskrnd-sub320a335r_psfstack_sci_16_horflip_shift.jpg\n'

In [None]:
x       = model.encoder(batch)
x       = x.view(x.size(0),-1)
x       = model.fc1(x)
latents = model.latent(x)

In [None]:
umap_comps = umap.UMAP().fit(latents.detach().cpu().numpy())

In [None]:
plot_pca_comps(umap_comps.embedding_)

In [None]:
def create_batch_numpy(img_paths):
    img_numpy = []
    for img_dir in img_paths:
        image   = PIL.Image.open(img_dir)
        image   = np.expand_dims(np.array(image),axis=0)
        img_numpy.append(image)
    img_numpy = np.concatenate(img_numpy)

    return img_numpy

In [None]:
def get_dict(files_dict,img_dirs):
    
    for file_name, dirs in files_dict.items():

        for img_dir in img_dirs:
            
            f_name = '_'.join(img_dir.split('/')[-1].split('.')[0].split('_')[:5])

            if file_name == f_name:

                dirs.append(img_dir)

    return files_dict
        

In [None]:
def create_psf_dict(file_names):

    files_dict = {}

    for file in file_names:
    
        name = file.split('/')[-1].split('.')[0]
        files_dict[name] = []

    return files_dict

In [None]:
def get_batch_dict(final_dict):

    batch_dict = {}

    for file_name, dirs in final_dict.items():

        batch = create_batch_numpy(dirs)

        batch_dict[file_name] = batch

    return batch_dict


In [None]:
def save_arrays_to_fits(psfstacks_nircam_1386, batch_dict):

    for idx, fits_file in enumerate(psfstacks_nircam_1386):

        with fits.open(fits_file, mode='update') as hdul:
            
            hdul[1].data = batch_dict[psfstacks_nircam_1386[idx].split('.')[0].split('/')[-1]]
            hdul.flush()  # changes are written back to original.fits

In [None]:
img_dirs = sorted(glob.glob('/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/sci_imgs/*'))

In [None]:
len(img_dirs)

In [None]:
directory_1386_nircam = f'/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/mastDownload/JWST'
psfstacks_nircam_1386 = get_stage3_products(suffix='psfstack',directory=directory_1386_nircam)

In [None]:
file_names = create_psf_dict(psfstacks_nircam_1386)

In [None]:
final_dict = get_dict(file_names, img_dirs)

In [None]:
batch_dict = get_batch_dict(final_dict)

In [None]:
save_arrays_to_fits(psfstacks_nircam_1386, batch_dict)

In [None]:
f1 = fits.open(psfstacks_nircam_1386[0])