# Variational Autoencoders for galaxy images

In [1]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import pandas as pd

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from torch import tensor as tt
import torch.nn.functional as F
from torchvision.utils import make_grid
from astropy.stats import sigma_clip

# import required libraries for orientation fitting
from photutils.isophote import Ellipse
import warnings
from astropy.stats import sigma_clip
from astropy.stats import sigma_clipped_stats
from photutils.morphology import data_properties
from photutils.background import Background2D, MedianBackground
from astropy.convolution import convolve
from photutils.segmentation import make_2dgaussian_kernel
from photutils.segmentation import SourceFinder
from photutils.segmentation import deblend_sources
from photutils.segmentation import SourceCatalog

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Convert Galaxy10 dataset into Pytorch format

See more information about the data set here: https://astronn.readthedocs.io/en/latest/galaxy10.html

In [3]:
Gama = np.load("Gama_16k_specz.npz")
Zou = np.load("Zou_10k_specz.npz")

images1 = Gama['x']
images2 = Zou['x']
images = np.concatenate([images1, images2])

#get rid of galaxy_ID -- though we might need it later 
labels1 = Gama['y'][:, 1:6]
labels1[:, 2] += 9 #convert from Msol/yr to Msol/Gyr
labels2 = Zou['y'][:, [1, 2, 3, -2, -1]]
labels = np.concatenate([labels1, labels2])

In [4]:
#normalization scaling taken from https://arxiv.org/pdf/2005.12039.pdf
meanmax_g = np.nanmean([np.nanmax(x) for x in images[:, :, :, 0]])
meanmax_r = np.nanmean([np.nanmax(x) for x in images[:, :, :, 1]])
meanmax_z = np.nanmean([np.nanmax(x) for x in images[:, :, :, 2]])

images_rescaled = []
for img in images:    
    beta = 10.
    
    test_g = np.tanh(np.arcsinh(beta*img[:, :, 0]/meanmax_g))
    test_r = np.tanh(np.arcsinh(beta*img[:, :, 1]/meanmax_r))
    test_z = np.tanh(np.arcsinh(beta*img[:, :, 2]/meanmax_z))

    newImg = np.moveaxis(np.array([test_z, test_r, test_g]), 0, -1)
    
    images_rescaled.append(newImg)

In [5]:
images_rescaled = np.array(images_rescaled).astype(np.float32)
labels = np.array(labels).astype(np.float32)

In [175]:
warnings.filterwarnings("ignore")

pas = []
pa_errs = []
idxs_toDrop = []
for j in np.arange(len(images_rescaled)):

    
    img = images_rescaled[j].copy()
    
    #flag these arrays to drop
    if (not np.any(img[:, :, 0])) or (not np.any(img[:, :, 1])) or (not np.any(img[:, :, 2])):
        print("Found an empty image!")
        idxs_toDrop.append(j)
        continue
        
    imgTot = img[:, :, 0] + img[:, :, 1] + img[:, :, 2] #combining all bands - There were some galaxies missing r-band photometry!

    bkg_estimator = MedianBackground()
    bkg = Background2D(imgTot, (10, 10), filter_size=(3, 3),
                       bkg_estimator=bkg_estimator)
    
    #figure out why this is wrong!
    imgTot -= bkg.background  # subtract the background
    
    threshold = 1.1 * bkg.background_rms

    kernel = make_2dgaussian_kernel(3.0, size=5)  # FWHM = 3.0
    convolved_data = convolve(imgTot, kernel)

    finder = SourceFinder(npixels=10, progress_bar=False)
    segment_map = finder(convolved_data, threshold)
    
    segm_deblend = deblend_sources(convolved_data, segment_map,
                                   npixels=10, nlevels=32, contrast=0.001,
                                   progress_bar=False)

    cat = SourceCatalog(imgTot, segm_deblend, convolved_data=convolved_data)
    tbl = cat.to_table().to_pandas()
    tbl['dist'] = np.sqrt((tbl['xcentroid'] - 50)**2 + (tbl['ycentroid'] - 50)**2)
    tbl_cut = tbl[tbl['dist'] == np.nanmin(tbl['dist'])]
    if len(tbl_cut) < 1:
        plt.imshow(imgTot)
        print(tbl)
        plt.show()
        print("ERROR!")
        pas.append(np.nan)
    else:
        #save position angle in radians
        pas.append(tbl_cut['orientation'].values[0]* np.pi / 180.)        
    if j % 100 == 0:
        print("Finished %i"%j)

Finished 0
Finished 100
Finished 200
Finished 300
Finished 400
Finished 500
Finished 600
Finished 700
Finished 800
Finished 900
Finished 1000
Finished 1100
Finished 1200
Finished 1300
Finished 1400
Finished 1500
Finished 1600
Finished 1700
Finished 1800
Finished 1900
Finished 2000
Finished 2100
Finished 2200
Finished 2300
Finished 2400
Found an empty image!
Finished 2500
Finished 2600
Finished 2700
Finished 2800
Finished 2900
Finished 3000
Finished 3100
Finished 3200
Finished 3300
Finished 3400
Finished 3500
Finished 3600
Finished 3700
Finished 3800
Finished 3900
Finished 4000
Finished 4100
Finished 4200
Found an empty image!
Finished 4300
Finished 4400
Finished 4500
Finished 4600
Finished 4700
Finished 4800
Finished 4900
Finished 5000
Finished 5100
Finished 5200
Finished 5300
Finished 5400
Finished 5500
Finished 5600
Finished 5700
Finished 5800
Finished 5900
Finished 6000
Finished 6100
Finished 6200
Finished 6300
Finished 6400
Finished 6500
Finished 6600
Finished 6700
Finished 6800
Fi

In [None]:
#how well did we do? visualize a few orientations:
i = 650
print(pas[i])
pas_rad= pas[i]
plt.imshow(images_rescaled[i], origin='lower')
plt.plot([75, 50], [50, 50], c='w', zorder=100)
plt.plot([50, 50+25*np.cos(pas_rad)], [50, 50+25*np.sin(pas_rad)], ls='--',  c='w', zorder=100)
plt.show()

In [228]:
pas = np.array(pas)

# for how many did we fail?
print(np.nansum(pas != pas))

0


In [229]:
pa_errs = 0.1 #some nominal uncertainty

In [242]:
idxs_toDrop = [2464, 4233, 12127, 12472, 14432]
idx_keep = np.array(list(set(np.arange(len(images_rescaled))) - set(idxs_toDrop)))

In [243]:
images_rescaled_keep = images_rescaled[idx_keep]
labels_keep = labels[idx_keep]

labels_wPas = np.vstack([labels_keep.T, [pas]]).T
labels_keep_pa =  np.vstack([labels_wPas.T, [pa_errs]]).T

np.savez("FullSample_26k_photoz_wOrientation.npz", x=images_rescaled_keep, y=labels_keep_pa)