In [2]:
!pip install tensorflow



In [3]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [4]:
dftrain = pd.read_csv("data/plasticc_train_metadata.csv")

dftest = pd.read_csv("data/plasticc_test_metadata.csv")

dftest = dftest[:20000]

dftest['target'] = dftest['true_target']

dftrain = pd.concat([dftrain, dftest], axis=0, ignore_index=True)

model_nums = {90:'SN Ia', 67:'SNIa-91bg', 52:'SNIax', 42:'SNII', 62:'SNIbc', 95:'SLSN-I', 15:'TDE', 64:'KN', 88:'AGN', 92:'RRL', 65:'M-dwarf', 16:'EB', 53:'Mira', 6:'muLens-Single', 991:'muLens-Binary', 992:'ILOT', 993:'CaRT', 994:'PISN', 995:'muLens-String'}

sn_models = [90, 67, 52, 42, 62, 95, 15]

dftrain = dftrain[dftrain['target'].isin(sn_models)]

dftrain['target_names'] = dftrain['target'].map(model_nums)

In [5]:
# print(dftrain)

In [6]:
# print(dftrain[dftrain['object_id'] == 50409])

In [7]:
lcs = pd.read_csv("data/plasticc_train_lightcurves.csv")
lcs_test = pd.read_csv("data/plasticc_test_lightcurves_01.csv")

lcs = pd.concat([lcs, lcs_test[:3000000]], axis=0, ignore_index=True)

# lcs = pd.merge(lcs, dftrain, on='object_id', how='inner')
# lcs['passband']

In [8]:
# #map passband idx to passband name
# lcs['passband_name'] = ["ugrizY"[i] for i in lcs['passband']]
# lcs_detected = lcs[lcs['detected_bool'] == True]
# cols = sns.color_palette("Spectral_r", 6)
# VRO_bands = "ugrizY"

# transients = np.unique(lcs_detected['object_id'])


# fig, axs = plt.subplots(4, 5, figsize=(20, 15), sharex=True, sharey=False)
# axs = axs.ravel()

# for j in np.arange(20):
#     transient = transients[j]
#     lc = lcs_detected[lcs_detected['object_id'] == transient]
#     for i in np.arange(len(VRO_bands)):
#         band = VRO_bands[i]
#         lc_band = lc[lc['passband_name'] == band]
#         axs[j].errorbar(lc_band['mjd'] - np.nanmin(lc['mjd']), lc_band['flux'], yerr=lc_band['flux_err'], fmt='o', mec='k', c=cols[i], label=band)
#     axs[j].set_ylabel("Flux")
#     axs[j].set_title(dftrain.loc[dftrain['object_id'] == transient, 'target_names'].values[0])
# axs[-1].legend()
# axs[0].set_xlim((-10, 200));
# for idx in np.arange(15, 20):
#     axs[idx].set_xlabel("Days since Discovery");
# fig.tight_layout(w_pad=0.1, h_pad=0.3)

In [9]:
columns = ['object_id', 'mjd', 'passband', 'flux', 'flux_err']
columns2 = ['mjd', 'passband', 'flux', 'flux_err']
# filter wavelengths in angstroms
wavelengths = {
    0: 3671.0,
    1: 4827.0,
    2: 6223.0,
    3: 7546.0,
    4: 8691.0,
    5: 9712.0
}
# print(lcs['passband'])
lcs['passband'] = lcs['passband'].map(wavelengths)
# lcs['passband']

In [10]:
# Prepare Train Data
x_data = []
y_data = []

max_length = 352

for id in dftrain['object_id']:
    if len(lcs[lcs['object_id'] == id]) == 0:
        continue
    x_data.append(lcs[lcs['object_id'] == id][columns].reset_index())
    #print(x_train[-1])
    y_data.append(dftrain[dftrain['object_id'] == id][['object_id', 'true_peakmjd']].values.tolist()[0])

    # Expand light curve data to max length with 0s
    while len(x_data[-1].index) < max_length: 
        x_data[-1].loc[len(x_data[-1].index)] = [len(x_data[-1].index), id, 0, 0, 0, 0]
    x_data[-1] = x_data[-1][columns]
    #print(x_train[-1])


In [11]:
# Normalize Data
for i in range(len(x_data)):
    # try:
    #     x_train[i] = x_train[i].values
    # except AttributeError:
    #     pass
    y_data[i][1] -= 59000
    y_data[i][1] /= 2000
    for j in range(len(x_data[i])):
        if x_data[i].loc[j]['mjd'] == 0:
            continue
        x_data[i].at[j, 'mjd'] -= 59000
        x_data[i].at[j, 'mjd'] /= 2000
        x_data[i].at[j, 'passband'] -= 3671
        x_data[i].at[j, 'passband'] /= 6041

In [12]:
x_train, x_test, y_train, y_test = train_test_split(
     x_data, y_data, test_size=0.2, random_state=42)

In [13]:
x_test, x_val, y_test, y_val = train_test_split(
    x_test, y_test, test_size=0.5, random_state=45)

In [14]:
# print(y_test)

In [15]:
# print(dftrain[dftrain['object_id'] == 50409]])

In [16]:

# print(x_train)
x_train_values = []
for i in range(len(x_train)):
    x_train_values.append([])
    for row in x_train[i].values:
        x_train_values[-1].append(row[1:])
y_train_values = [np.array(y[1]) for y in y_train]

x_test_values = []
for i in range(len(x_test)):
    x_test_values.append([])
    for row in x_test[i].values:
        x_test_values[-1].append(row[1:])
y_test_values = [np.array(y[1]) for y in y_test]

x_val_values = []
for i in range(len(x_val)):
    x_val_values.append([])
    for row in x_val[i].values:
        x_val_values[-1].append(row[1:])
y_val_values = [np.array(y[1]) for y in y_val]

# dftest = pd.read_csv('data/plasticc_test_metadata.csv')

# lcs_test = pd.read_csv('data/plasticc_test_lightcurves_01.csv')
# lcs_val = pd.read_csv('data/plasticc_test_lightcurves_02.csv')

# print(lcs, lcs_test, lcs_val)

# lcs_test['passband'] = lcs_test['passband'].map(wavelengths)
# lcs_val['passband'] = lcs_val['passband'].map(wavelengths)

# dftest = dftest[dftest['true_target'].isin(sn_models)]

# dftest['target_names'] = dftest['true_target'].map(model_nums)

# # lcs_test = pd.merge(lcs_test, dftest, on='object_id', how='inner')
# # lcs_val = pd.merge(lcs_val, dftest, on='object_id', how='inner')

# x_test, x_val, y_test, y_val = [], [], [], []

In [17]:
# for id in dftest['object_id']:
#     # print(id)
#     if len(lcs_test[lcs_test['object_id'] == id]) > 0:
#         # print("test: " + id, flush=True)
#         x_test.append(lcs_test[lcs_test['object_id'] == id][columns].reset_index())
#         y_test.append(dftest[dftest['object_id'] == id]['true_peakmjd'])
#         print(len(x_test[-1].index)
#         # while len(x_test[-1].index) < max_length:
#         #     x_test[-1].loc[len(x_test[-1].index)] = [len(x_test[-1].index), 0, 0, 0, 0]
#         # x_test[-1] = x_test[-1][columns]

# print(x_test)


In [18]:
# for id in dftest['object_id']:
#     if len(lcs_val[lcs_val['object_id'] == id]) > 0:
#         # print("val: " + id, flush=True)
#         x_val.append(lcs_val[lcs_val['object_id'] == id][columns].reset_index())
#         y_val.append(dftest[dftest['object_id'] == id]['true_peakmjd'])
#         while len(x_val[-1].index) < max_length:
#             x_val[-1].loc[len(x_val[-1].index)] = [len(x_val[-1].index), 0, 0, 0, 0]
#         x_val[-1] = x_val[-1][columns]

In [19]:
# print(x_train_values[:2])
# y_train_values[:2]

In [20]:
# Build Model

import tensorflow as tf
# import tensorflow_probability as tfp

tf.random.set_seed(1)

model = tf.keras.Sequential([
    tf.keras.layers.Input((352, 4)),
    tf.keras.layers.Masking(mask_value=np.array([0, 0, 0, 0])),
    tf.keras.layers.GRU(50, return_sequences=True, activation='tanh'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.GRU(50,activation='tanh'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(units=1, activation='sigmoid')])
    # tfp.layers.DistributionLambda(lambda t: tfp.distributions.Normal(loc=t[..., :1],
                           #scale=1e-3 + tf.math.softplus(0.05 * t[...,1:])))])


model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='mean_squared_error')
# print(model.summary())

In [21]:
print(len(x_train_values), len(y_train_values))

11132 11132


In [22]:
### Train Model
model.fit(np.array(x_train_values), np.array(y_train_values), epochs=50, validation_data=(np.array(x_val_values), np.array(y_val_values)))



Epoch 1/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 169ms/step - loss: 0.0152 - val_loss: 0.0038
Epoch 2/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 168ms/step - loss: 0.0041 - val_loss: 0.0030
Epoch 3/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 169ms/step - loss: 0.0031 - val_loss: 0.0026
Epoch 4/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 169ms/step - loss: 0.0026 - val_loss: 0.0019
Epoch 5/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 174ms/step - loss: 0.0019 - val_loss: 0.0014
Epoch 6/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 174ms/step - loss: 0.0013 - val_loss: 0.0011
Epoch 7/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 180ms/step - loss: 0.0012 - val_loss: 8.9616e-04
Epoch 8/50
[1m348/348[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 178ms/step - loss: 0.0010 - val_loss: 8.8175e-04
Epoch 9/

<keras.src.callbacks.history.History at 0x3e491fc50>

In [23]:
y_pred = model.predict(np.array(x_test_values))

[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 38ms/step


In [24]:
print(y_test, y_pred)

[[95564.0, 0.6931620000000003], [50409.0, 0.7065509999999995], [8079158.0, 0.5101365000000005], [105025507.0, 0.5964690000000009], [12362.0, 0.6021365000000005], [94059.0, 0.4721915000000008], [44943.0, 0.4494455000000016], [99050.0, 0.5197734999999993], [88877.0, 0.49753499999999984], [135054.0, 0.6221310000000012], [36685.0, 0.6106035000000011], [71828.0, 0.6036579999999995], [81798.0, 0.6612849999999999], [7044535.0, 0.8096464999999989], [88876.0, 0.7673420000000005], [39613.0, 0.8285275000000002], [17207.0, 0.46509000000000017], [93923383.0, 0.5809160000000011], [43465.0, 0.6326330000000017], [23245045.0, 0.30435150000000066], [9705.0, 0.7856955000000017], [5709606.0, 0.7562829999999995], [24091.0, 0.6680155000000013], [57363.0, 0.4512285000000011], [96552995.0, 0.2953669999999984], [130085491.0, 0.7985155000000014], [11403.0, 0.5389919999999984], [15870.0, 0.7730370000000003], [10989329.0, 0.4081270000000004], [8646.0, 0.4455099999999984], [11928.0, 0.42783000000000176], [2299.0, 

In [25]:
def display_object(object_id, lightcurve_data, metadata, predicted=0):
    lightcurve_data = lightcurve_data[lightcurve_data['object_id'] == object_id]
    passbands = []
    true_peaktime = metadata[metadata['object_id'] == object_id]['true_peakmjd']
    print(true_peaktime)

    for freq in range(6):
        passbands.append(lightcurve_data[lightcurve_data['passband'] == freq][['mjd', 'flux']])

    for p in passbands:
        plt.scatter(p['mjd'], p['flux'])
        print(true_peaktime)
        plt.scatter(true_peaktime, [0])
        if predicted > 0:
            plt.scatter([predicted], [0])
        

    # plt.xlabel("Modified Julien Date")

    # plt.ylabel("Flux")
    
    # plt.show()
    plt.show()


In [None]:
display_object(95564, lcs, dftrain, 0.6931620000000003*2000+59000)