In [1]:
import numpy as np
import pandas as pd

import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import timm

import random
import os
from astropy.io import fits
from astropy.table import Table

import matplotlib.pyplot as plt
from astropy.visualization import make_lupton_rgb
plt.style.use('dark_background')

In [2]:
def make_plot_all(objects, data, Q, stretch, minimum):
    for i in range(len(objects)):
        if(i%8 == 0):
            plt.figure(figsize=(16,9))
            for j in range(8):
                if(i+j > len(objects)-1): break
                plt.subplot(1,8,j+1)
                #print(data['new_id'].iloc[i+j])
                title = data['new_id'].iloc[i+j]
                plt.title(title)
                rgb = make_lupton_rgb(objects[i+j][2], objects[i+j][1], objects[i+j][0], 
                                    Q=Q, stretch=stretch, minimum=minimum)
                plt.imshow(rgb, aspect='equal')
                plt.xticks([], [])
                plt.yticks([], []) 
            plt.show()
            
def make_one_plot(objects, title, data, Q, stretch, minimum):
    plt.figure(figsize=(4,4))
    rgb = make_lupton_rgb(objects[2], objects[1], objects[0], Q=Q, stretch=stretch, minimum=minimum)
    plt.imshow(rgb, aspect='equal')
    plt.xticks([], [])
    plt.yticks([], []) 
    plt.savefig(png_path + title, bbox_inches='tight')
    #plt.show()
    plt.close()

In [3]:
def write_fit_file(name, x, data):
    my_types = {'new_id': int, 'class': str}
    data = data.astype(my_types)
    
    primary = fits.PrimaryHDU()
    image = fits.ImageHDU(x, name="IMAGE")
    table_data = Table.from_pandas(data)
    table = fits.BinTableHDU(data = table_data)
    hdu_list = fits.HDUList([primary, image, table])
    hdu_list.writeto(name + '.fits', overwrite=True)   
    hdu_list.close()

In [4]:
# Saving all training images into one file (not including known strong lenses)
names = ['Single', 'Ring', 'Smooth', 'Companions', 'SDSS_Spirals', 'DES_Spirals', 'Crowded',
        'Most_negatives'] #'Double'
file_path = '/Users/jimenagonzalez/research/DSPL/SpaceWarps_Inspection/training_sample/coadds/'
numpix = 45
images_save = np.zeros((1,3,numpix,numpix))
data_save = pd.DataFrame(columns = ['new_id', 'class'])

for k in range(len(names)):
    if(names[k] == 'DES_Spirals' or names[k] == 'SDSS_Spirals'): 
        num = 165
    else:
        num = 330
        
    hdu_list = fits.open(file_path + 'new_ids_' + names[k] + '.fits')
    images = hdu_list[1].data
    data = pd.DataFrame(hdu_list[2].data)
    data = data.astype({'new_id': int})
    hdu_list.close()
    
    remove = pd.read_csv(file_path + 'remove_' + names[k] + '.csv', dtype = {'id': int})
    data = data.drop(remove['id'])
    
    images = images[data.index][0:num]
    data = data.iloc[0:num]
    group_list = [names[k]]*num
    data['class'] = group_list
    
    images_save = np.append(images_save, images, axis = 0)
    data_save = data_save.append(data)
    
    #make_plot_all(images, data, 9, 42, (0, 0, 0))

images_save = np.delete(images_save, 0, axis = 0)
data_save = data_save.reset_index(drop=True)
print(len(images_save), len(data_save))

2310 2310


In [5]:
data_save.head()

Unnamed: 0,new_id,class
0,1,Single
1,2,Single
2,3,Single
3,4,Single
4,6,Single


In [6]:
png_path = 'sample/pngs/'

# Saving each subset (manuscript, pngs) - not including real lenses:
for k in range(len(names)):
    data_class = pd.DataFrame(columns=['#subject_name','#class','#id','#image_name1','#image_name2','#image_name3'])
    data_tmp = data_save[data_save['class'] == names[k]]
    images_tmp = images_save[data_tmp.index]
    
    data_class['#subject_name'] = data_tmp['new_id']
    data_class['#class'] = data_tmp['class'].apply(lambda x: f'train_{x}')
    data_class['#id'] = data_tmp['new_id']
    data_class['#image_name1'] = data_tmp['new_id'].apply(lambda x: f'Train_{names[k]}_{x}_1.png')
    data_class['#image_name2'] = data_tmp['new_id'].apply(lambda x: f'Train_{names[k]}_{x}_2.png')
    data_class['#image_name3'] = data_tmp['new_id'].apply(lambda x: f'Train_{names[k]}_{x}_3.png')
    
    data_class.to_csv('sample/train_' + names[k] + '.csv')
    
    for i in range(len(data_class)):
        title1 = data_class['#image_name1'].iloc[i]
        make_one_plot(images_tmp[i], title1, data_class, 9, 42, (0, 0, 0))
        title2 = data_class['#image_name2'].iloc[i]
        make_one_plot(images_tmp[i], title2, data_class, 8, 35, (14, 2, 0))
        title3 = data_class['#image_name3'].iloc[i]
        make_one_plot(images_tmp[i], title3, data_class, 10.5, 37, (14, 8, 0))
    
data_class.head()

Unnamed: 0,#subject_name,#class,#id,#image_name1,#image_name2,#image_name3
1980,0,train_Most_negatives,0,Train_Most_negatives_0_1.png,Train_Most_negatives_0_2.png,Train_Most_negatives_0_3.png
1981,1,train_Most_negatives,1,Train_Most_negatives_1_1.png,Train_Most_negatives_1_2.png,Train_Most_negatives_1_3.png
1982,2,train_Most_negatives,2,Train_Most_negatives_2_1.png,Train_Most_negatives_2_2.png,Train_Most_negatives_2_3.png
1983,3,train_Most_negatives,3,Train_Most_negatives_3_1.png,Train_Most_negatives_3_2.png,Train_Most_negatives_3_3.png
1984,4,train_Most_negatives,4,Train_Most_negatives_4_1.png,Train_Most_negatives_4_2.png,Train_Most_negatives_4_3.png


In [7]:
def manuscript_pngs_real(images, data, catalog):
    data_class = pd.DataFrame(columns=['#subject_name','#class','#id','#image_name1','#image_name2','#image_name3'])
    
    data_class['#subject_name'] = data['COADD_OBJECT_ID']
    data_class['#class'] = ['train_' + catalog]*len(data)
    data_class['#id'] = data['COADD_OBJECT_ID']
    data_class['#image_name1'] = data['COADD_OBJECT_ID'].apply(lambda x: f'Train_{catalog}_{x}_1.png')
    data_class['#image_name2'] = data['COADD_OBJECT_ID'].apply(lambda x: f'Train_{catalog}_{x}_2.png')
    data_class['#image_name3'] = data['COADD_OBJECT_ID'].apply(lambda x: f'Train_{catalog}_{x}_3.png')
    
    data_class.to_csv('sample/train_' + catalog + '.csv')
    for i in range(len(data_class)):
        title1 = data_class['#image_name1'].iloc[i]
        make_one_plot(images[i], title1, data_class, 9, 42, (0, 0, 0))
        title2 = data_class['#image_name2'].iloc[i]
        make_one_plot(images[i], title2, data_class, 8, 35, (14, 2, 0))
        title3 = data_class['#image_name3'].iloc[i]
        make_one_plot(images[i], title3, data_class, 10.5, 37, (14, 8, 0))


In [8]:
# Adding the train sample of known catalogs: Jacobs
hdu_list = fits.open('real_lenses/Jacobs_train.fits')
images_jacobs = hdu_list[1].data
data_jacobs = hdu_list[2].data
data_jacobs = Table(data_jacobs)
data_jacobs = data_jacobs.to_pandas()
hdu_list.close()

data_tmp = pd.DataFrame(columns = data_save.columns)
data_tmp['new_id'] = range(0, len(images_jacobs))
data_tmp['class'] = ['Known_jacobs']*len(images_jacobs)

images_save = np.append(images_save, images_jacobs, axis = 0)
data_save = data_save.append(data_tmp)

print(len(images_jacobs))
print(len(images_save), len(data_save))

#Saving the manuscript and the PNGs
manuscript_pngs_real(images_jacobs, data_jacobs, 'Jacobs')

152
2462 2462


In [9]:
# Adding the train sample of known catalogs: O'Donnell
hdu_list = fits.open('real_lenses/Jack_train.fits')
images_jacobs = hdu_list[1].data
data_jacobs = hdu_list[2].data
data_jacobs = Table(data_jacobs)
data_jacobs = data_jacobs.to_pandas()
hdu_list.close()

data_tmp = pd.DataFrame(columns = data_save.columns)
data_tmp['new_id'] = range(0, len(images_jacobs))
data_tmp['class'] = ['Known_jack']*len(images_jacobs)

images_save = np.append(images_save, images_jacobs, axis = 0)
data_save = data_save.append(data_tmp)

print(len(images_jacobs))
print(len(images_save), len(data_save))

#Saving the manuscript and the PNGs
manuscript_pngs_real(images_jacobs, data_jacobs, 'Jack')

48
2510 2510


In [10]:
write_fit_file('complete_training_sample', images_save, data_save)