In [None]:
#%pylab inline

import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds

# Load Datasets

In [None]:
from sfh.datasets import setup_environment, tng100, eagle
setup_environment()

In [None]:
dset_eagle = tfds.load('eagle', split='train')
dset_tng = tfds.load('tng100', split='train')

# Plot some examples 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
print("Train",len(dset_eagle))

fig, axs = plt.subplots(1, 1)
for example in dset_eagle.take(10):
    #print((example['wl_sort'])[example['inds_valid']])
    time_vec = example['time']
    inds_valid = example['inds_valid']
    axs.scatter((example['wl_sort'])[example['inds_valid']],np.log10(example['sed']))
    axs.set_xscale('log')

fig, axs = plt.subplots(1, 1)
for example in dset_eagle.take(10):
    #print(wl[example['inds_valid']])
    axs.plot(example['time'],example['SFR_Max'])
    #axs.set_xscale('log')    
    #sed = (tf.gather(example['sed'],inds, axis=1) + 20.70243)/2.0466275

In [None]:
import matplotlib.pyplot as plt
print("Train",len(dset_tng))

fig, axs = plt.subplots(1, 1)
for example_tng, example_eagle in zip(dset_tng.take(10),dset_eagle.take(10)):
    #print(wl[example['inds_valid']])
    #print(wl_tng)
    #print(example['sed'])
    print((example_tng['wl_sort'])[example_eagle['inds_valid']])
    axs.scatter((example_tng['wl_sort'])[example_eagle['inds_valid']],np.log10((example_tng['sed'])[example_eagle['inds_valid']]))
    axs.set_xscale('log')

fig, axs = plt.subplots(1, 1)
for example in dset_tng.take(10):
    #print(wl[example['inds_valid']])
    axs.plot(example['time'],example['SFR_Max'])
    #axs.set_xscale('log')    
    #sed = (tf.gather(example['sed'],inds, axis=1) + 20.70243)/2.0466275

# Define datasets for TNG and EAGLE

In [None]:
def preprocessing(example):
    return tf.reshape(example['SFR_Max'],(-1,100,1)), \
           tf.reshape(example['SFR_Max'],(-1,100,1))

def preprocessing_wmass(example):
    mass = example['Mstar'][:,0]
    mass_half = example['Mstar_Half'][:,0]
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.add(tf.reshape(example['SFR_Max'],(-1,100,1)), 1e-5)
    res = tf.concat([sfr, mass, mass_half], axis=2)
    return res, res

def preprocessing_wmass_atan(example):
    mass = example['Mstar'][:,0]
    #mass_half = example['Mstar_Half'][:,0]
    #sed = (tf.gather(example['sed'],inds, axis=1) + 20.70243)/2.0466275
    sed = example['sed']
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    #mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.tanh(tf.math.asinh(tf.reshape(example['SFR_Max'],(-1,100,1))/40) + 1e-3 + 0.005*tf.math.softplus(tf.random.normal(shape=[64,100,1])))
    res = tf.concat([sfr], axis=2) #  mass, mass_half
    return (res, sed), res

def input_fn(mode='train', batch_size=64, 
             dataset_name='tng100', data_dir=None,
             include_mass=True, arctan=True):
    """
    mode: 'train' or 'test'
    """
    keys = ['sed','Mstar', 'SFR_Max', 'mass_quantiles', 'sed', 'time']
    if mode == 'train':
        dataset = tfds.load(dataset_name, split='train[:90%]')
        dataset = dataset.map(lambda x: {k:x[k] for k in keys})
        dataset = dataset.repeat()
        dataset = dataset.shuffle(10000)
    else:
        dataset = tfds.load(dataset_name, split='train[90%:]')
        dataset = dataset.map(lambda x: {k:x[k] for k in keys}) #dataset = dataset.repeat()
        
    dataset = dataset.batch(batch_size, drop_remainder=True)
    if include_mass and arctan:
        dataset = dataset.map(preprocessing_wmass_atan) # Apply data preprocessing
    elif include_mass:
        dataset = dataset.map(preprocessing_wmass)
    else : 
        dataset = dataset.map(preprocessing)
    dataset = dataset.prefetch(-1)       # fetch next batches while training current one (-1 for autotune)
    return dataset

In [None]:
def preprocessing(example):
    return tf.reshape(example['SFR_Max'],(-1,100,1)), \
           tf.reshape(example['SFR_Max'],(-1,100,1))

def preprocessing_wmass(example):
    mass = example['Mstar'][:,0]
    mass_half = example['Mstar_Half'][:,0]
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.add(tf.reshape(example['SFR_Max'],(-1,100,1)), 1e-5)
    res = tf.concat([sfr, mass, mass_half], axis=2)
    return res, res

def preprocessing_wmass_atan_tng(example):
    mass = example['Mstar'][:,0]
    #mass_half = example['Mstar_Half'][:,0]
    sed = (tf.gather(example['sed'],np.squeeze(np.where(inds_valid)), axis=1))
    #sed = example['sed']
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    #mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.tanh(tf.math.asinh(tf.reshape(example['SFR_Max'],(-1,100,1))/40) + 1e-3 + 0.005*tf.math.softplus(tf.random.normal(shape=[64,100,1])))
    res = tf.concat([sfr], axis=2) #  mass, mass_half
    return (res, sed), res

def input_fn_tng(mode='train', batch_size=64, 
             dataset_name='tng100', data_dir=None,
             include_mass=True, arctan=True):
    """
    mode: 'train' or 'test'
    """
    keys = ['sed','Mstar', 'SFR_Max', 'mass_quantiles', 'sed', 'time']
    if mode == 'train':
        dataset = tfds.load(dataset_name, split='train[:90%]')
        dataset = dataset.map(lambda x: {k:x[k] for k in keys})
        dataset = dataset.repeat()
        dataset = dataset.shuffle(10000)
    else:
        dataset = tfds.load(dataset_name, split='train[90%:]')
        dataset = dataset.map(lambda x: {k:x[k] for k in keys}) #dataset = dataset.repeat()
        
    dataset = dataset.batch(batch_size, drop_remainder=True)
    if include_mass and arctan:
        dataset = dataset.map(preprocessing_wmass_atan_tng) # Apply data preprocessing
    elif include_mass:
        dataset = dataset.map(preprocessing_wmass)
    else : 
        dataset = dataset.map(preprocessing)
    dataset = dataset.prefetch(-1)       # fetch next batches while training current one (-1 for autotune)
    return dataset

In [None]:
batch_size = 64
epochs = 10

dtrain_eagle = input_fn(mode='train', batch_size=batch_size, dataset_name='eagle')
dval_eagle = input_fn(mode='val', batch_size=batch_size, dataset_name='eagle')

In [None]:
batch_size = 64
epochs = 10

dtrain_tng = input_fn_tng(mode='train', batch_size=batch_size, dataset_name='tng100')
dval_tng = input_fn_tng(mode='val', batch_size=batch_size, dataset_name='tng100')

# Train on TNG100

In [None]:
""""Keras model implementing PixelCNN."""

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow import keras
from tensorflow.keras import layers
import sys
import time
tfd = tfp.distributions
tfb = tfp.bijectors
tfkl=keras.layers

def generate_model(n_timesteps, n_filters, *, n_channels=1, n_components=2, kernel_size=3,
                   n_dilations=5, list_of_dilation_rates=None,
                   list_of_filters=None):
    """Generate the PixelCNN Keras model.

    Parameters
    ----------
    n_timesteps : int
        Number of time steps.
    n_filters : int
        Number of filters.
    n_channels : int, default 1
        Number of channels in the dataset
    n_components : int, default 2
        Number of components in the Gaussian mixture distribution.
    kernel_size : int, default 3
        Size of the convolution kernel.
    n_dilations : int, default 5
        Number of dilated convolutions to do. For each convolution, the
        dilation rate is 2**idx+1 and the number of filters is 2**idx+4.
    list_of_dilation_rates : list of int or None, default None
        List of the dilation rates to use in the dilated convolutions. If not
        None, the n_dilations is not used and filters must be given with the
        same size.
    list_of_filters : list of int or None, default None
        List of the filter number for each of the dilated convolutions. Must be
        of the same size as list_of_dilation_rates

    Returns
    -------
    Keras model

    """
    # Shape of the distribution
    event_shape = [1]
    # Compute how many parameters this distribution requires
    params_size = 2
    #print(params_size)

    
    input_sfh = keras.layers.Input(shape=(n_timesteps,1))
    input_sed = keras.layers.Input(shape=(n_filters,1))
    
    # Compress the SED and return some channels
    sed_net = tf.keras.Sequential([
        tfkl.Input(shape=(125, 1)),
        tfkl.Conv1D(16, 3, strides=2, padding='same', activation='relu'),
        tfkl.Conv1D(32, 3, strides=2, padding='same', activation='relu'),
        tfkl.Conv1D(64, 3, strides=2, padding='same', activation='relu'),
        tfkl.Conv1D(64, 3, strides=1, padding='same', activation='relu'),
        tfkl.Flatten(),
        tfkl.Dense(128, activation='relu'),
        tfkl.Dense(8, activation='softplus'),
        tfkl.Lambda(lambda x: tf.tile(tf.reshape(x,[-1,1,8]), [1,100,1]))
        ])
    
    merged = keras.layers.Concatenate(axis=-1)([input_sfh, 
                                                sed_net(input_sed)])
    
    
    # Shift and cut
    net = keras.layers.Lambda(
            lambda x: tf.pad(x, paddings=tf.constant([[0, 0], [1, 0], [0, 0]]))
        )(merged)
    
    net=keras.layers.Lambda(
            lambda x: x[:, :-1, :]
        )(net)
    

    net=keras.layers.Conv1D(
            filters=16,
            kernel_size=kernel_size,
            dilation_rate=1,
            padding='causal',
            activation='relu'
        )(net)

    if list_of_dilation_rates is None:
        list_of_dilation_rates = [2**(i+1) for i in range(n_dilations)]
        list_of_filters = [2**(i+4) for i in range(n_dilations)]
    elif len(list_of_filters) != len(list_of_dilation_rates):
        raise ValueError(
            "filters and list_of_dilation_rates must have the same length")

    for dilation_rate, nb_filters in zip(list_of_dilation_rates,
                                         list_of_filters):
        net = keras.layers.Conv1D(
                filters=nb_filters,
                kernel_size=kernel_size,
                dilation_rate=dilation_rate,
                padding='causal',
                activation='relu')(net)
    
    net = keras.layers.Dense(2)(net)
    
    net = tfp.layers.DistributionLambda(
                    make_distribution_fn=lambda t: tfd.Beta(
                          concentration1=tf.math.softplus(t[..., 0])+1e-3,
                          concentration0=tf.math.softplus(t[..., 1])+1e-3)
                    )(net)
    
    pixel_cnn = keras.models.Model(inputs=[input_sfh, input_sed],
                                  outputs=net)

    # Use the negative log-likelihood as loss function.
    def negloglik(y, q):
        return tf.reduce_sum(-q.log_prob(y[...,0]), -1)
    
    opt = tf.keras.optimizers.Adam(learning_rate=0.0002)
    pixel_cnn.compile(loss=negloglik, optimizer=opt)

    return pixel_cnn

In [None]:
pixel_cnn = generate_model(100,125)

pixel_cnn.summary()

In [None]:
hist = pixel_cnn.fit(dtrain_tng, 
                     epochs=epochs,
                     steps_per_epoch=1000,validation_data=dval_tng)

# Test on TNG100

In [None]:
dset_test = dval_tng.as_numpy_iterator()

In [None]:
data = next(dset_test)

In [None]:
ind=15
sample = np.zeros([64,100,1])
true = data[0][0][ind,:,0]
sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
# init at the 
sample[:,0,0] = true[0]

In [None]:
for i in range(99):
    tmp = pixel_cnn((sample, sed)).sample()
    sample[:,i+1,0] = tmp[:,i+1]

In [None]:
plt.plot(true,label='true SFH')
for i in range(64):
    plt.plot(sample[i,:,0],color='C1',alpha=0.1)
plt.plot(sample[1,:,0],color='C1',alpha=1.,label='individual sample')    
plt.plot(sample.mean(axis=0)[:,0],'--',color='red',label='mean posterior')
plt.legend(loc='upper left')

# check summaries

In [None]:
import pdb
def find_summaries(mass, time, percentiles=np.linspace(0.1, 0.9, 9)):

    ''' compute the half mass and the half time of a galaxy 
          Input: 
                - mass: array. The mass history of the galaxy.
                - time: array. The corresponding time for the galaxy history.
                - percentiles: array. The summaries you want to predict by default 0.1, 0.2,..., 0.9. 
          Output: the time of the summaries, the corresponding masses, and the index of the mass/time summary.
    '''

    summary_masses = []
    summary_times = []
    summary_indices = []
    for percentile in percentiles:
        summary_mass = min(mass, key=lambda x: abs(x-mass[0]*percentile))  # find mass closest to the half mass
        #pdb.set_trace()
        summary_masses.append(summary_mass)
        summary_mass_indices = np.where(mass == summary_mass)[0]  # find the corresponding indices
        summary_mass_index = summary_mass_indices[0]  # chose the first index for the half mass
        summary_indices.append(summary_mass_index)
        summary_time = time[summary_mass_index]  # find the corresponding half time
        summary_times.append(summary_time)

    return np.array(summary_times).astype('float32')

In [None]:
#time = (tbins[1:] + tbins[:-1] )/2.
deltat=time_vec[1:] - time_vec[:-1]
#print(deltat)
mgrowth_true = np.cumsum(deltat*np.flip(true[1:]))
mgrowth_pred = np.cumsum(deltat*np.flip(sample[18,1:,0]))

plt.plot(np.flip(time_vec[1:]),mgrowth_true,color='blue',label="true")
plt.plot(np.flip(time_vec[1:]),mgrowth_pred,color='red',label='pred')
plt.legend()

In [None]:
t50_pred=[]
t50_true=[]
for j in range(1):
    try:
        data = next(dset_test)
        for ind in range(64):
            sample = np.zeros([64,100,1])
            true = data[0][0][ind,:,0]
            sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
            # init at the 
            sample[:,0,0] = true[0]
            for i in range(99):
                tmp = pixel_cnn((sample, sed)).sample()
                sample[:,i+1,0] = tmp[:,i+1]
            mgrowth_true = np.cumsum(deltat*np.flip(true[1:]))
            mgrowth_pred = np.cumsum(deltat*np.flip(sample[18,1:,0]))
            #pdb.set_trace()
            summ_true = find_summaries(np.flip(mgrowth_true), np.flip(time_vec))
            summ_pred = find_summaries(np.flip(mgrowth_pred), np.flip(time_vec))
            t50_true.append(summ_true[4])
            t50_pred.append(summ_pred[4])
    except:
        break
        
            

In [None]:
plt.scatter(t50_true,t50_pred)
plt.plot(np.linspace(40,90,100),np.linspace(40,90,100))

# Test on EAGLE

In [None]:
dset_eagle = dval_eagle.as_numpy_iterator()

In [None]:
data = next(dset_eagle)

In [None]:
ind=19
sample = np.zeros([64,100,1])
true = data[0][0][ind,:,0]
sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
# init at the 
sample[:,0,0] = true[0]

In [None]:
for i in range(99):
    tmp = pixel_cnn((sample, sed)).sample()
    sample[:,i+1,0] = tmp[:,i+1]

In [None]:
plt.plot(true,label='true SFH')
for i in range(64):
    plt.plot(sample[i,:,0],color='C1',alpha=0.1)
plt.plot(sample[1,:,0],color='C1',alpha=1.,label='individual sample')    
plt.plot(sample.mean(axis=0)[:,0],'--',color='red',label='mean posterior')
plt.legend(loc='upper left')

# check summaries

In [None]:
t = (time_vec[1:] + time_vec[:-1] )/2.
print(t.shape)
print(t)
print(true.shape)
deltat=time_vec[1:] - time_vec[:-1]
print(deltat)

In [None]:
#time = (tbins[1:] + tbins[:-1] )/2.
deltat=time_vec[1:] - time_vec[:-1]
print(deltat)
mgrowth_true = np.cumsum(deltat*true[1:])
mgrowth_pred = np.cumsum(deltat*sample.mean(axis=0)[1:,0])

plt.plot(np.flip(time_vec[1:]),mgrowth_true,color='blue',label="true")
plt.plot(np.flip(time_vec[1:]),mgrowth_pred,color='red',label='pred')
plt.legend()

In [None]:
summ_true = find_summaries(np.flip(mgrowth_true), np.flip(time_vec))
summ_pred = find_summaries(np.flip(mgrowth_pred), np.flip(time_vec))

In [None]:
print(summ_true)
print(summ_pred)

In [None]:
t50_pred=[]
t50_true=[]
for j in range(1):
    try:
        data = next(dset_eagle)
        for ind in range(64):
            sample = np.zeros([64,100,1])
            true = data[0][0][ind,:,0]
            sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
            # init at the 
            sample[:,0,0] = true[0]
            for i in range(99):
                tmp = pixel_cnn((sample, sed)).sample()
                sample[:,i+1,0] = tmp[:,i+1]
            mgrowth_true = np.cumsum(deltat*np.flip(true[1:]))
            mgrowth_pred = np.cumsum(deltat*np.flip(sample[18,1:,0]))
            #pdb.set_trace()
            summ_true = find_summaries(np.flip(mgrowth_true), np.flip(time_vec))
            summ_pred = find_summaries(np.flip(mgrowth_pred), np.flip(time_vec))
            t50_true.append(summ_true[4])
            t50_pred.append(summ_pred[4])
    except:
        break
        
            

In [None]:
plt.scatter(t50_true,t50_pred)
plt.plot(np.linspace(40,90,100),np.linspace(40,90,100))


# train on Eagle

In [None]:
pixel_cnn = generate_model(100,125)

pixel_cnn.summary()

In [None]:
hist = pixel_cnn.fit(dtrain_eagle, 
                     epochs=epochs,
                     steps_per_epoch=1000,validation_data=dval_eagle)

In [None]:
dset_eagle = dval_eagle.as_numpy_iterator()

In [None]:
data = next(dset_eagle)

In [None]:
ind=19
sample = np.zeros([64,100,1])
true = data[0][0][ind,:,0]
sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
# init at the 
sample[:,0,0] = true[0]

In [None]:
ind=19
sample = np.zeros([64,100,1])
true = data[0][0][ind,:,0]
sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
# init at the 
sample[:,0,0] = true[0]

In [None]:
for i in range(99):
    tmp = pixel_cnn((sample, sed)).sample()
    sample[:,i+1,0] = tmp[:,i+1]

In [None]:
plt.plot(true,label='true SFH')
for i in range(64):
    plt.plot(sample[i,:,0],color='C1',alpha=0.1)
plt.plot(sample[1,:,0],color='C1',alpha=1.,label='individual sample')    
plt.plot(sample.mean(axis=0)[:,0],'--',color='red',label='mean posterior')
plt.legend(loc='upper left')

In [None]:
t50_pred=[]
t50_true=[]
for j in range(1):
    try:
        data = next(dset_eagle)
        for ind in range(64):
            sample = np.zeros([64,100,1])
            true = data[0][0][ind,:,0]
            sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)
            # init at the 
            sample[:,0,0] = true[0]
            for i in range(99):
                tmp = pixel_cnn((sample, sed)).sample()
                sample[:,i+1,0] = tmp[:,i+1]
            mgrowth_true = np.cumsum(deltat*np.flip(true[1:]))
            mgrowth_pred = np.cumsum(deltat*np.flip(sample[18,1:,0]))
            #pdb.set_trace()
            summ_true = find_summaries(np.flip(mgrowth_true), np.flip(time_vec))
            summ_pred = find_summaries(np.flip(mgrowth_pred), np.flip(time_vec))
            t50_true.append(summ_true[4])
            t50_pred.append(summ_pred[4])
    except:
        break
        
            

In [None]:
plt.scatter(t50_true,t50_pred)
plt.plot(np.linspace(40,90,100),np.linspace(40,90,100))
