In [1]:
import os
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
import tensorflow as tf
from tensorflow import keras

In [2]:
cutout_dir = os.path.expandvars("$SCRATCH") + "/"
image_dir = "/home/anahoban/projects/rrg-kyi/astro/cfis/W3/"

In [4]:
#get tile ids
tile_list = open(image_dir + "tiles.list", "r")

# Only use tiles with all five channels
tile_list = open(image_dir + "tiles.list", "r")
tile_ids = []

for tile in tile_list:
    tile = tile[:-1] # Remove new line character
    channels = tile.split(" ")
    if len(channels) == 5: # Order is u,g,r,i,z
        tile_ids.append(channels[0][5:12]) # XXX.XXX id
tile_list.close()

In [7]:
hf = h5py.File(cutout_dir + "cutouts_filtered.h5", "r")

In [8]:
n_cutouts = 0
for i in range(5):
    print(i)
    n_cutouts += len(hf.get(tile_ids[i] + "/IMAGES"))
print(n_cutouts)


0
1
2
3
4
57399


In [11]:
n_valid = len(hf.get(tile_ids[12] + "/IMAGES"))

In [12]:
# tiles for val and training
train_indices = range(5)
val_indices = [12]

In [13]:
BATCH_SIZE = 256 #128
CUTOUT_SIZE = 128
N_EPOCHS = 8 #12
weights_cfis = np.zeros((BATCH_SIZE, CUTOUT_SIZE, CUTOUT_SIZE, 2))
weights_ps1 = np.zeros((BATCH_SIZE, CUTOUT_SIZE, CUTOUT_SIZE, 3))
weights_all = np.zeros((BATCH_SIZE, CUTOUT_SIZE, CUTOUT_SIZE, 5))

In [14]:
def get_test_cutouts(index, n_cutouts, cutout_size, bands="all", start=0):
    n = 0
    if bands == "all":
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 5))
        band_indices = [0, 1, 2, 3, 4]
    elif bands == "cfis":
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 2))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 2))
        band_indices = [0, 2]
    else: # PS1
        sources = np.zeros((n_cutouts, cutout_size, cutout_size, 3))
        weights = np.zeros((n_cutouts, cutout_size, cutout_size, 3))
        band_indices = [1, 3, 4]
    img_group = hf.get(tile_ids[index] + "/IMAGES")
    wt_group = hf.get(tile_ids[index] + "/WEIGHTS")
    for i in range(start, len(img_group)):
        sources[n,:,:,:] = np.array(img_group.get(f"c{i}"))[:,:,band_indices]
        weights[n,:,:,:] = np.array(wt_group.get(f"c{i}"))[:,:,band_indices]
        n += 1
        if n == n_cutouts:
            return (sources, weights)

In [15]:
test_index = 13
#sources_test_cfis = get_test_cutouts(test_index, 50, cutout_size, "cfis")
#sources_test_ps1 = get_test_cutouts(test_index, 50, cutout_size, "ps1")
(sources_test_all, weights_test_all) = get_test_cutouts(test_index, 50, CUTOUT_SIZE)
sources_test_all.shape

(50, 128, 128, 5)

In [16]:
def get_cutouts(tile_indices, batch_size, cutout_size, bands="all"):
    b = 0 # counter for batch
    if bands == "all":
        sources = np.zeros((batch_size, cutout_size, cutout_size, 5))
        band_indices = [0, 1, 2, 3, 4]
        weights = weights_all
    elif bands == "cfis":
        sources = np.zeros((batch_size, cutout_size, cutout_size, 2))
        band_indices = [0, 2]
        weights = weights_cfis
    else: # PS1
        sources = np.zeros((batch_size, cutout_size, cutout_size, 3))
        band_indices = [1, 3, 4]
        weights = weights_ps1
    while True:
        for i in tile_indices:
            img_group = hf.get(tile_ids[i] + "/IMAGES")
            wt_group = hf.get(tile_ids[i] + "/WEIGHTS")
            n_cutouts = len(img_group)
            for n in range(n_cutouts):
                sources[b,:,:,:] = np.array(img_group.get(f"c{n}"))[:,:,band_indices]
                weights[b,:,:,:] = np.array(wt_group.get(f"c{n}"))[:,:,band_indices]
                b += 1
                if b == batch_size:
                    b = 0
                    yield (sources, sources)

In [17]:
def train_autoencoder(model, train_indices, val_indices, n_epochs, batch_size, cutout_size, bands="all"):
    n_cutouts_train = 0
    for i in train_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_train += len(img_group)
    
    n_cutouts_val = 0    
    for i in val_indices:
        img_group = hf.get(tile_ids[i] + "/IMAGES")        
        n_cutouts_val += len(img_group)
    
    train_steps = n_cutouts_train // batch_size
    val_steps = n_cutouts_val // batch_size
    
    history = model.fit(get_cutouts(train_indices, batch_size, cutout_size, bands), 
                        epochs=n_epochs, steps_per_epoch=train_steps, 
                        validation_data=get_cutouts(val_indices, batch_size, cutout_size, bands), 
                        validation_steps=val_steps)
    return model, history

In [18]:
def masked_MSE_with_uncertainty(y_true, y_pred):
    y_pred_image = tf.gather(y_pred,indices=0,axis=-1)
    y_true_image = tf.reshape(y_true,shape=tf.shape(y_pred_image))
    weight = tf.gather(y_pred, indices=1, axis=-1)
    mask = keras.abs(keras.sign(y_true_image))
    return mask*(keras.square(tf.math.divide_no_nan(y_true_image - y_pred_image, weight)))

In [19]:
# tiles for val and training
train_indices = range(4)
val_indices = [12]
BATCH_SIZE= 128

In [20]:
def create_autoencoder2(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(x)
    x = keras.layers.BatchNormalization()(x)

    y = keras.layers.Conv2D(32, kernel_size=3, activation='relu', padding='same')(input_img)
    y = keras.layers.BatchNormalization()(y)
    encoded = keras.layers.Add()([x,y])
    
    x = keras.layers.Conv2DTranspose(32, kernel_size=4, activation='relu', padding='same')(encoded)
    x = keras.layers.Conv2DTranspose(16, kernel_size=4, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(shape[2], kernel_size=3, activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [22]:
autoencoder_cfis = create_autoencoder2((CUTOUT_SIZE, CUTOUT_SIZE, 2))
autoencoder_cfis.compile(optimizer="adam", loss='mean_squared_error')

In [None]:
(autoencoder_cfis, history_cfis) = train_autoencoder(autoencoder_cfis, train_indices,  val_indices, batch_size=BATCH_SIZE,
                                                     cutout_size=CUTOUT_SIZE, n_epochs= N_EPOCHS, bands="cfis")
   

Epoch 1/8