In [None]:
#The following code was used to convert images from three .fits images with different filters (i,r,g) into one .jpg rgb image
#It uses as input the unique.data file created by 01_gettables.ipynb
#This code loads the three .fits bands for the same image, aligns them and then creates 5 .jpgs for the same image
#Five function filters are applied to the raw count matrices before merging them in a rgb image
#There are two sinh filters (they have different brightness and zero threshold), one sqrt filter and two lupton filters (the second lupton filter (filter 5) has the same parameters but brightness and contrast are enhanced with PIL ImageEnhance)
#You can specify different folders for each filter. For the nth filter, modify the fjpgn (e.g. fjpg1) variable, changing the path before '/%s.jpg'

In [None]:
from astropy.io import fits
import numpy as np
from PIL import Image, ImageEnhance
from astropy.table import Table
from astropy.visualization import make_lupton_rgb
import os.path
from reproject import reproject_interp

In [None]:
istart=0
iend=13000

nskip=0
xuniq = Table.read('unique.data',format='ascii') ###.txt with the .fits data and names created by 01_gettables.ipynb

#Each band path assumes the .fits files have a name.band.fits format (e.g photo1.r.fits)
#The .fits saved by the script galaxy_detection/traning/02_downloadfits.py already have this format.


#Set the paths of the .fits for each filter. r, g and b band are the channels of the .jpg which the .fits will be put into.
path_rband = 'images'
path_gband = 'images'
path_bband = 'images'


#Output folder path. Since every .jpg has the same output name, use different folders for each filter.
path_filter1 = 'filtro1'
path_filter2 = 'filtro2'
path_filter3 = 'filtro3'
path_filter4 = 'filtro4'
path_filter5 = 'filtro5'

In [None]:
# r_band,g_band and b_band are NxNx1 numpy arrays with the data from the .fits
# stretch_type is the function you want to use for the stretch. You can choose between log, sqrt, lupton, sinh, power and exp
# If you want to use a power function, write it as 'power_n', where n is the power you want to use (non-integers can be used)
# factor increases the brightness of the image by multiplying the output matrix for said factor (before converting to image)
# centered uses the mean of the image matrix as the zero point (below zero the image sets the value to black). Will be ignored if zero has an input
# normalized, if true, scales the matrix so the data is in the [0,1] range before applying the stretch. Will be ignored when using a lupton stretch
# min_zero substracts the minimum value of the input matrix so every value is greater or equal than zero
# zero sets the percentile of the pixels which are set to black (zero=1 will graph the lower 1% of the pixels to black)
# trunc sets the percentile of the pixels where high values are truncated. (trunc=50 makes the higher 50% of the pixels to the 50 percentile value)

def filter_rgb(r_band,g_band,b_band,stretch_type='log',factor=1,centered=False,normalized=False,min_zero=False,zero=-1,trunc=0):
    
  
    if stretch_type == 'lupton':
        stretch = factor
        minimum = zero
        
        if factor == 1:
            stretch = 5
            
        if minimum <0:
            minimum = 0
            
        image = make_lupton_rgb(r_band , g_band , b_band , stretch = stretch , minimum = minimum)        
        ret = Image.fromarray(image)        
        return ret 
    
    

    if min_zero or stretch_type == 'log' or stretch_type == 'sqrt':


        r = r_band[:] - np.nanmin([r_band,g_band,b_band])
        g = g_band[:] - np.nanmin([r_band,g_band,b_band])
        b = b_band[:] - np.nanmin([r_band,g_band,b_band])

    else:
        r = r_band[:]
        g = g_band[:]
        b = b_band[:]        
    
    if trunc!=0:
        
        m=np.nanstd([r,g,b])
        n=np.nanmedian([r,g,b])
        cut= n + trunc * m        
        r = np.minimum(r,cut)
        g = np.minimum(g,cut)
        b = np.minimum(b,cut)

    rgb_min = np.nanmin([r,g,b])
    rgb_max = np.nanmax([r,g,b])
    norm_max =  float(rgb_max) - rgb_min

    if normalized:
        
        r = r/norm_max
        g = g/norm_max
        b = b/norm_max
    
    if stretch_type == 'log':
        
        ir = np.log(r+0.000001)
        ig = np.log(g+0.000001)
        ib = np.log(b+0.000001)
        
    elif stretch_type == 'sqrt':
        
        ir = np.sqrt(r)
        ig = np.sqrt(g)
        ib = np.sqrt(b*1.3) 
        
    elif stretch_type == 'sinh':
        
        ir = np.sinh(r)
        ig = np.sinh(g)
        ib = np.sinh(b)
        

    elif stretch_type[0:5] == 'power':
        power = float(stretch_type.split('_')[1])
        ir = np.power(r,power)
        ig = np.power(g,power)
        ib = np.power(b,power)
        

    elif stretch_type == 'exp':
        
        ir = np.exp(r)
        ig = np.exp(g)
        ib = np.exp(b)
        
    else: 
        
        return None
        
    if zero > 0:
        
        z = vmin = np.nanpercentile([ir,ig,ib],zero)
        ir -= z
        ig -= z
        ib -= z
        
    else:

        if centered and np.nanmax([ir,ig,ib])!=float('inf') and np.nanmin([ir,ig,ib])!=-float("inf"):
            
            mean = np.nanmean([ir,ig,ib])
            ir = ir - mean
            ig = ig - mean
            ib = ib - mean

        else:
            minim = np.nanmin([ir,ig,ib])
            ir = ir - minim
            ig = ig - minim
            ib = ib - minim
            
    ir*=factor
    ig*=factor
    ib*=factor
    im_r=Image.fromarray(ir).convert('L')
    im_g=Image.fromarray(ig).convert('L')
    im_b=Image.fromarray(ib).convert('L')
    image = Image.merge("RGB", (im_r , im_g , im_b))

    return image

In [None]:
for k in range(iend-istart):
    i=k+istart
    print (i)
    
    fnamei='%s/%s.i.fits'%(path_gband,xuniq['imagename'][i])
    fnamer='%s/%s.r.fits'%(path_rband,xuniq['imagename'][i])    
    fnameg='%s/%s.g.fits'%(path_bband,xuniq['imagename'][i])    
    
    
    
    fjpg1 = '%s/%s.jpg'%(path_filter1,xuniq['imagename'][i])
    fjpg2 = '%s/%s.jpg'%(path_filter2,xuniq['imagename'][i]) 
    fjpg3 = '%s/%s.jpg'%(path_filter3,xuniq['imagename'][i])
    fjpg4 = '%s/%s.jpg'%(path_filter4,xuniq['imagename'][i])
    fjpg5 = '%s/%s.jpg'%(path_filter5,xuniq['imagename'][i])
    
    
    #print('i=%d current attempt %s '%(i,fnamer))
    if not os.path.isfile(fnamer):    
        print('i=%d file does not exist %s skipping...'%(i,fnamer))
        nskip=nskip+5
    if (os.path.isfile(fnamer) and not (os.path.isfile(fjpg1) and os.path.isfile(fjpg2) and os.path.isfile(fjpg3) and os.path.isfile(fjpg4) and os.path.isfile(fjpg5)))::




        

        imrr=fits.open(fnamer)
        imgg=fits.open(fnameg)
        imii=fits.open(fnamei)
        imr=imrr[0]
        img=imgg[0]
        imi=imii[0]
        ir = imr.data
        w=ir.shape[1]
        h=ir.shape[0]
        ig, footprint1 = reproject_interp(img, imr.header)
        ii, footprint2 = reproject_interp(imi, imr.header)
        
    


            
        print('i=%d saved %s w=%d h=%d'%(i,fjpg1,w,h))
            
        if not os.path.isfile(fjpg1):
            image1 =  filter_rgb(ii,ir,ig,stretch_type='sinh',factor=400,zero=0.1)
            image1.save(fjpg1,'jpeg',quality=73)
            nskip=nskip+1

        if not os.path.isfile(fjpg2):
            image2 =  filter_rgb(ii,ir,ig,stretch_type='sinh',factor=350,zero=0.03)
            image2.save(fjpg2,'jpeg',quality=73)
            nskip=nskip+1

        if not os.path.isfile(fjpg3):
            image3 =  filter_rgb(ii,ir,ig,stretch_type='sqrt',factor=500,trunc=6,zero=0.01)
            image3.save(fjpg3,'jpeg',quality=73)
            nskip=nskip+1

        if not os.path.isfile(fjpg4):
            image4 =  filter_rgb(ii,ir,ig,stretch_type='lupton')
            image4.save(fjpg4,'jpeg',quality=95)
            nskip=nskip+1

        
        if not os.path.isfile(fjpg5):
            image5 =  filter_rgb(ii,ir,ig,stretch_type='lupton')
            brightness = ImageEnhance.Brightness(image5)
            img2 = brightness.enhance(2)
            contrast= ImageEnhance.Contrast(img2)
            img3 = contrast.enhance(2)
            img3.save(fjpg5,'jpeg',quality=94)
            nskip=nskip+1

        
        




print('skipped files = %d'%(nskip))