In [1]:
import numpy as np
import pandas as pd

import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import timm

import random
import os
from astropy.io import fits
from astropy.table import Table

import matplotlib.pyplot as plt
from astropy.visualization import make_lupton_rgb
plt.style.use('dark_background')

In [2]:
def make_plot_all(objects, data, Q, stretch, minimum):
    for i in range(len(objects)):
        if(i%8 == 0):
            plt.figure(figsize=(16,9))
            for j in range(8):
                if(i+j > len(objects)-1): break
                plt.subplot(1,8,j+1)
                #print(data['new_id'].iloc[i+j])
                title = data['new_id'].iloc[i+j]
                plt.title(title)
                rgb = make_lupton_rgb(objects[i+j][2], objects[i+j][1], objects[i+j][0], 
                                    Q=Q, stretch=stretch, minimum=minimum)
                plt.imshow(rgb, aspect='equal')
                plt.xticks([], [])
                plt.yticks([], []) 
            plt.show()

In [3]:
def write_fit_file(name, x, data):
    my_types = {'new_id': int, 'class': str}
    data = data.astype(my_types)
    
    primary = fits.PrimaryHDU()
    image = fits.ImageHDU(x, name="IMAGE")
    table_data = Table.from_pandas(data)
    table = fits.BinTableHDU(data = table_data)
    hdu_list = fits.HDUList([primary, image, table])
    hdu_list.writeto(name + '.fits', overwrite=True)   
    hdu_list.close()


In [4]:
# Saving all training images into one pdf (not including known strong lenses)
names = ['Single', 'Ring', 'Smooth', 'Companions', 'SDSS_Spirals', 'DES_Spirals', 'Crowded',
        'Most_negatives'] #'Double'
file_path = '/Users/jimenagonzalez/research/DSPL/SpaceWarps_Inspection/training_sample/coadds/'
numpix = 45
images_save = np.zeros((1,3,numpix,numpix))
data_save = pd.DataFrame(columns = ['new_id', 'class'])

for k in range(len(names)):
    if(names[k] == 'DES_Spirals' or names[k] == 'SDSS_Spirals'): 
        num = 165
    else:
        num = 330
        
    hdu_list = fits.open(file_path + 'new_ids_' + names[k] + '.fits')
    images = hdu_list[1].data
    data = pd.DataFrame(hdu_list[2].data)
    data = data.astype({'new_id': int})
    hdu_list.close()
    
    remove = pd.read_csv(file_path + 'remove_' + names[k] + '.csv', dtype = {'id': int})
    data = data.drop(remove['id'])
    
    images = images[data.index][0:num]
    data = data.iloc[0:num]
    group_list = [names[k]]*num
    data['class'] = group_list
    
    images_save = np.append(images_save, images, axis = 0)
    data_save = data_save.append(data)
    
    #make_plot_all(images, data, 9, 42, (0, 0, 0))

images_save = np.delete(images_save, 0, axis = 0)
print(len(images_save), len(data_save))

2310 2310


In [5]:
data_save

Unnamed: 0,new_id,class
1,1,Single
2,2,Single
3,3,Single
4,4,Single
6,6,Single
...,...,...
374,374,Most_negatives
375,375,Most_negatives
376,376,Most_negatives
377,377,Most_negatives


In [6]:
# Adding the train sample of known catalogs: Jacobs
hdu_list = fits.open('real_lenses/Jacobs_train.fits')
images_jacobs = hdu_list[1].data
data_jacobs = hdu_list[2].data
data_jacobs = Table(data_jacobs)
data_jacobs = data_jacobs.to_pandas()
hdu_list.close()

data_tmp = pd.DataFrame(columns = data_save.columns)
data_tmp['new_id'] = range(0, len(images_jacobs))
data_tmp['class'] = ['Known_jacobs']*len(images_jacobs)

images_save = np.append(images_save, images_jacobs, axis = 0)
data_save = data_save.append(data_tmp)

print(len(images_jacobs))
print(len(images_save), len(data_save))

152
2462 2462


In [7]:
# Adding the train sample of known catalogs: O'Donnell
hdu_list = fits.open('real_lenses/Jack_train.fits')
images_jacobs = hdu_list[1].data
data_jacobs = hdu_list[2].data
data_jacobs = Table(data_jacobs)
data_jacobs = data_jacobs.to_pandas()
hdu_list.close()

data_tmp = pd.DataFrame(columns = data_save.columns)
data_tmp['new_id'] = range(0, len(images_jacobs))
data_tmp['class'] = ['Known_jack']*len(images_jacobs)

images_save = np.append(images_save, images_jacobs, axis = 0)
data_save = data_save.append(data_tmp)

print(len(images_jacobs))
print(len(images_save), len(data_save))

48
2510 2510


In [8]:
write_fit_file('complete_training_sample', images_save, data_save)