## 08/11/2023

Transformer taken from:
https://keras.io/examples/timeseries/timeseries_transformer_classification/

In [None]:
import numpy as np
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.backend import clear_session
from contextlib import redirect_stdout
from tensorflow.keras.callbacks import (
    CSVLogger,
    EarlyStopping,
    ModelCheckpoint,
    ReduceLROnPlateau,
)
import warnings
from sklearn.preprocessing import OneHotEncoder
import time
from astronet.t2.model import T2Model
import optuna
from optuna.samplers import TPESampler

In [None]:
#!pip install optuna

In [None]:
#try with my own data 
#only using the 100GP model so that we don't have to use padding, otherwise we would use 0pt2GP.
dir = '/Users/alexgagliano/Documents/Research/HostClassifier/transformer/ZTF_Data/'
X_train = np.load(dir + 'X_train_ZTF_Sim_FullSliced_Padded_100GP_hostPhotTrue_30Cut.npz')['arr_0']
X_test = np.load(dir + 'X_test_ZTF_Sim_FullSliced_Padded_100GP_hostPhotTrue_30Cut.npz')['arr_0']
y_train = np.load(dir + 'y_train_ZTF_Sim_FullSliced_Padded_100GP_hostPhotTrue_30Cut.npz')['arr_0']
y_test = np.load(dir + 'y_test_ZTF_Sim_FullSliced_Padded_100GP_hostPhotTrue_30Cut.npz')['arr_0']

params = {}
params['class_weight'] = {0:1, 1:1, 2:1} #not weighting the classes

#only weighting in time 
weights = np.ones(np.shape(X_train[:, :, 0]))
weights[(X_train[:, :, 0][:, -1] < 3)] = 10
weights[(X_train[:, :, 0][:, -1] > 3) & (X_train[:, :, 0][:, -1] < 15)] = 5

weights[y_train == 0] *= params['class_weight'][0]
weights[y_train == 1] *= params['class_weight'][1]
weights[y_train == 2] *= params['class_weight'][2]

#compress -- not doing in time anymore
weights = weights[:, 0]

#remove the fits with only 4 or fewer datapoints
#x_train_GP = [X_train[x, :, 0][5] != 0 for x in np.arange(len(X_train))]
#x_test_GP = [X_test[x, :, 0][5] != 0 for x in np.arange(len(X_test))]

#X_train = X_train[x_train_GP]
#y_train = y_train[x_train_GP]
#X_test = X_test[x_test_GP]
#y_test = y_test[x_test_GP]

#randomly grab 20% of train and test sets 
subset_frac = 0.5
idx_train = np.random.choice(np.arange(len(X_train)), size =int(subset_frac*len(X_train)), replace=False)
idx_test = np.random.choice(np.arange(len(X_test)), size =int(subset_frac*len(X_test)), replace=False)

X_train = X_train[idx_train]
y_train = y_train[idx_train]
X_test = X_test[idx_test]
y_test = y_test[idx_test]

In [None]:
#binarize labels 
y_train = OneHotEncoder(max_categories=3, sparse_output=False).fit_transform(y_train.reshape(-1, 1))
y_test = OneHotEncoder(max_categories=3, sparse_output=False).fit_transform(y_test.reshape(-1, 1))

num_classes = y_train.shape[1]

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

def objective(trial):
    # Clear clutter from previous Keras session graphs.
    clear_session()

    BATCH_SIZE = 64
    EPOCHS = 10
    
    print(type(X_train))

    outputfn = '100GP_Transformer_FullSimTraining_FullOpt'


    filters=trial.suggest_int("filters", 32, 64),
    ff_dim = trial.suggest_int("hidden_layer_size", 32, 256)
    embed_dim = trial.suggest_int("embed_dim", 32, 256, step=32) #should be divisible by the number of heads
    num_layers = trial.suggest_categorical("num_layers", [2, 4, 8])
    num_heads = trial.suggest_categorical("num_heads", [2, 4, 8])
    droprate = trial.suggest_float("droprate", 0.2, 0.8)
     
    #embed_dim = 64  # --> Embedding size for each token
    #num_heads = 4  # --> Number of attention heads
    #ff_dim = 128  # --> Hidden layer size in feed forward network inside transformer
    
    # --> Number of filters to use in ConvEmbedding block, should be equal to embed_dim
    num_filters = embed_dim
    
    #num_layers = 4  # --> N x repeated transformer blocks
    #droprate = 0.3  # --> Rate of neurons to drop
    passbands = 'XY'
    ts = int(time.time())
    outputdir = '/Users/alexgagliano/Documents/Research/HostClassifier/transformer/'
    textPath = outputdir + '/text/'
    #print(ff_dim)
    
    (
        _,
        timesteps,
        num_features,
    ) = X_train.shape  # X_train.shape[1:] == (TIMESTEPS, num_features)
    input_shape = (BATCH_SIZE, timesteps, num_features)
    
    model = T2Model(
        input_dim=input_shape,
        embed_dim=embed_dim,
        num_heads=num_heads,
        ff_dim=ff_dim,
        num_filters=num_filters,
        num_classes=num_classes,
        num_layers=num_layers,
        droprate=droprate,
    )

    opt = tf.keras.optimizers.legacy.Adam(learning_rate=1.e-3)
    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=["acc"])
    
    mc = ModelCheckpoint(outputdir+'/models/Model_%s_CheckpointWeights.sav'%outputfn, monitor='val_loss', mode='min', verbose=1, save_weights_only=True, save_best_only=True)
    es = EarlyStopping(min_delta=0.001, mode="min", monitor="val_loss", patience=20, restore_best_weights=True,verbose=1)
    #rlrop = ReduceLROnPlateau(cooldown=5, factor=0.1, mode="min", monitor="val_loss", patience=5, verbose=1)

    _ = model.fit(
        X_train,
        y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_test, y_test),
        callbacks=[mc, es], verbose=2)
    
    model.build_graph(input_shape)
    
    print(model.summary())

    # Evaluate the model accuracy on the validation set.
    score = model.evaluate(X_test, y_test, verbose=0)

    #returns the loss - switch to score[1] to return the accuracy instead
    return score[0]

In [None]:
if __name__ == "__main__":
    sampler = TPESampler(seed=325)  # Make the sampler behave in a deterministic way.
    study = optuna.create_study(direction="minimize", sampler=sampler)
    study.optimize(objective, n_trials=20)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

In [None]:
model.load_weights(outputdir+'/models/Model_%s_CheckpointWeights.sav'%outputfn)
loss, accuracy = model.evaluate(X_test, y_test)

In [None]:
print("test")

In [None]:
# evaluate the results as a function of phase 
y_test = y_test.argmax(axis=1)

X_test_first3 = X_test[np.nanmax(X_test[:, :, 0], axis=1) < 3, :, :]
y_test_first3 = y_test[np.nanmax(X_test[:, :, 0], axis=1) < 3]

X_test_mid15 = X_test[(np.nanmax(X_test[:, :, 0], axis=1) >3) & (np.nanmax(X_test[:, :, 0], axis=1) < 15), :, :]
y_test_mid15 = y_test[(np.nanmax(X_test[:, :, 0], axis=1) >3) & (np.nanmax(X_test[:, :, 0], axis=1) < 15)]


X_test_last15 = X_test[np.nanmax(X_test[:, :, 0], axis=1) > 15, :, :]
y_test_last15 = y_test[np.nanmax(X_test[:, :, 0], axis=1) > 15]

In [None]:
######################### calculating statistics ######################
import sklearn
from sklearn.model_selection import StratifiedKFold

for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    f1 = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        f1.append(sklearn.metrics.f1_score(y[test],  np.argmax(model.predict(X[test]), axis=1), average='macro'))
    nanmed = np.nanmedian(f1)
    nanstd = np.nanstd(f1)
    print("f1: %.2f +/- %.2f"%(nanmed, nanstd))

In [None]:
#balanced auroc score
for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    auroc = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        auroc.append(sklearn.metrics.roc_auc_score(y[test],  model.predict(X[test]), average='macro', multi_class='ovr'))
    print(np.nanmedian(auroc))
    print(np.nanstd(auroc))

In [None]:
#balanced precision:
for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    prec = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        prec.append(sklearn.metrics.precision_score(y[test],  np.argmax(model.predict(X[test]), axis=1), average='macro'))
    print(np.nanmedian(prec))
    print(np.nanstd(prec))

In [None]:
#recall
for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    rec = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        rec.append(sklearn.metrics.recall_score(y,  np.argmax(model.predict(X), axis=1), average='macro'))
    print(np.nanmedian(rec))
    print(np.nanstd(rec))

In [None]:
#balanced accuracy:
for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    bacc = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        bacc.append(sklearn.metrics.balanced_accuracy_score(y[test],  np.argmax(model.predict(X[test]), axis=1)))
    print(np.nanmedian(bacc))
    print(np.nanstd(bacc))

In [None]:
for X, y in [(X_test_first3, y_test_first3), (X_test_mid15, y_test_mid15), (X_test_last15, y_test_last15)]:
    acc = []
    cv = StratifiedKFold(n_splits=5)
    for train,test in cv.split(X, y):
        acc.append(sklearn.metrics.accuracy_score(y[test],  np.argmax(model.predict(X[test]), axis=1)))
    print(np.nanmedian(acc))
    print(np.nanstd(acc))

In [None]:
def plot_PR_wCV(model, params, ax, X, y, encoding_dict, fnstr='', plotpath='./', save=True):
    """Short summary.

    Parameters
    ----------
    model : type
        Description of parameter `model`.
    ax : type
        Description of parameter `ax`.
    X : type
        Description of parameter `X`.
    y : type
        Description of parameter `y`.
    encoding_dict : type
        Description of parameter `encoding_dict`.
    fnstr : type
        Description of parameter `fnstr`.
    save : type
        Description of parameter `save`.

    Returns
    -------
    type
        Description of returned object.

    """
    nsplit = 5
    cv = StratifiedKFold(n_splits=nsplit)
    classes = np.unique(y)
    colors = sns.color_palette('Dark2', params['Nclass'])
    mean_r = np.linspace(0, 1, 100)
    accuracy_tot = 0
    nclass = len(classes)
    for j in range(nclass):
        ps = []
        allAcc = []
        aucs = []
        all_confMatrices = []
        for train, test in cv.split(X, y):
            Xtrain_resampled = X[train]
            ytrain_resampled = y[train]

            probas_ = model.predict(X[test])#[0]
            predictions = model.predict(X[test])#[0]
            predictDF = pd.DataFrame(data=predictions, columns=classes)
            predictions = predictDF.idxmax(axis=1)
            precision, recall, thresholds = precision_recall_curve(y[test], probas_[:, j], pos_label=classes[j])
            ps.append(interp(mean_r, recall[::-1], precision[::-1]))
            pr_auc = auc(recall, precision)
            aucs.append(pr_auc)
            tempAccuracy =  np.sum(predictions == y[test])/len(y[test])*100
            allAcc.append(tempAccuracy)
            matr = confusion_matrix(y[test], predictions, normalize='true')
            all_confMatrices.append(matr)
            accuracy_tot += tempAccuracy
        mean_p = np.mean(ps, axis=0)
        mean_auc = auc(mean_r, mean_p)
        std_auc = np.std(aucs)
        accuracy = accuracy_tot / (nsplit*len(classes))
        if std_auc < 0.01:
            ax.plot(mean_r, mean_p, color=colors[j],
                     label='%s (%0.2f $\pm$ <0.01)' % (encoding_dict[j].replace("SN", ""), mean_auc),
                     lw=2, alpha=.8)
        else:
            ax.plot(mean_r, mean_p, color=colors[j],
                     label='%s (%0.2f $\pm$ %0.2f)' % (encoding_dict[j].replace("SN", ""), mean_auc, std_auc),
                     lw=2, alpha=.8)
        std_p = np.std(ps, axis=0)
        ps_upper = np.minimum(mean_p + std_p, 1)
        ps_lower = np.maximum(mean_p - std_p, 0)
        ax.fill_between(mean_r, ps_lower, ps_upper, color=colors[j], alpha=.3)

    ax.set_xlabel("Precision");
    ax.set_ylabel("Recall");
    ax.legend()
    #ax.legend(loc=4)
    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])
    if save:
        plt.savefig(plotpath + "/Combined_MeanPR_Curve_%s.png"% fnstr, dpi=150)
    return accuracy, allAcc


def plot_ROC_wCV(model, params, ax, X, y, encoding_dict, fnstr='', plotpath='./', save=True):
    """Short summary.

    Parameters
    ----------
    model : type
        Description of parameter `model`.
    ax : type
        Description of parameter `ax`.
    X : type
        Description of parameter `X`.
    y : type
        Description of parameter `y`.
    encoding_dict : type
        Description of parameter `encoding_dict`.
    fnstr : type
        Description of parameter `fnstr`.
    save : type
        Description of parameter `save`.

    Returns
    -------
    type
        Description of returned object.

    """
    nsplit = 5
    cv = StratifiedKFold(n_splits=nsplit)
    classes = np.unique(y)
    colors = sns.color_palette('Dark2', params['Nclass'])
    mean_fpr = np.linspace(0, 1, 100)
    accuracy_tot = 0
    nclass = len(classes)
    for j in range(nclass):
        wrong = []
        allRight = []
        tprs = []
        allAcc = []
        aucs = []
        all_confMatrices = []
        for train, test in cv.split(X, y):
            Xtrain_resampled = X[train]
            ytrain_resampled = y[train]

            probas_ = model.predict(X[test])#[0]
            predictions = model.predict(X[test])
            predictDF = pd.DataFrame(data=predictions, columns=classes)
            predictions = predictDF.idxmax(axis=1)
            fpr, tpr, thresholds = roc_curve(y[test], probas_[:, j], pos_label=classes[j])
            tprs.append(interp(mean_fpr, fpr, tpr))
            tprs[-1][0] = 0.0
            roc_auc = auc(fpr, tpr)
            aucs.append(roc_auc)
            tempAccuracy =  np.sum(predictions == y[test])/len(y[test])*100
            allAcc.append(tempAccuracy)
            matr = confusion_matrix(y[test], predictions, normalize='true')
            all_confMatrices.append(matr)
            accuracy_tot += tempAccuracy
        mean_tpr = np.mean(tprs, axis=0)
        mean_tpr[-1] = 1.0
        mean_auc = auc(mean_fpr, mean_tpr)
        std_auc = np.std(aucs)
        accuracy = accuracy_tot / (nsplit*len(classes))
        if std_auc < 0.01:
            ax.plot(mean_fpr, mean_tpr, color=colors[j],
                     label='%s (%0.2f $\pm$ <0.01)' % (encoding_dict[j].replace("SN", ""), mean_auc),
                     lw=2, alpha=.8)
        else:
            ax.plot(mean_fpr, mean_tpr, color=colors[j],
                     label='%s (%0.2f $\pm$ %0.2f)' % (encoding_dict[j].replace("SN", ""), mean_auc, std_auc),
                     lw=2, alpha=.8)
        std_tpr = np.std(tprs, axis=0)
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color=colors[j], alpha=.3)

    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='k',alpha=.8)
    ax.set_xlabel("False Positive Rate");
    ax.set_ylabel("True Positive Rate");
    ax.legend()
    #ax.legend(loc=4)
    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])
    if save:
        plt.savefig(plotpath + "/Combined_MeanROC_Curve_%s.png"% fnstr,dpi=150)
    return accuracy, allAcc

In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
import pandas as pd
from sklearn.metrics import precision_recall_curve, roc_curve, auc, confusion_matrix
from numpy import interp
sns.set_context("talk")

outdir = '/Users/alexgagliano/Documents/Research/HostClassifier/transformer/'
params = {'Nclass':3}
encoding_dict = {0:'SN II', 1:'SN Ia', 2:'SN Ibc'}

######### ROC curves ################
fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_ROC_wCV(model, params, fig.gca(), X_test_first3, y_test_first3, encoding_dict, fnstr=outputfn + "_First3", plotpath=outdir+'/plots/', save=True)

fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_ROC_wCV(model, params, fig.gca(), X_test_mid15, y_test_mid15, encoding_dict, fnstr=outputfn + "_Mid15", plotpath=outdir+'/plots/', save=True)

fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_ROC_wCV(model, params, fig.gca(), X_test_last15, y_test_last15, encoding_dict, fnstr=outputfn + "_Last15", plotpath=outdir+'/plots/', save=True)

######### Precision-recall curves ################
fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_PR_wCV(model, params, fig.gca(), X_test_first3, y_test_first3, encoding_dict, fnstr=outputfn + "_First3", plotpath=outdir+'/plots/', save=True)

fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_PR_wCV(model, params, fig.gca(), X_test_mid15, y_test_mid15, encoding_dict, fnstr=outputfn + "_Mid15", plotpath=outdir+'/plots/', save=True)

fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
plot_PR_wCV(model, params, fig.gca(), X_test_last15, y_test_last15, encoding_dict, fnstr=outputfn + "_Last15", plotpath=outdir+'/plots/', save=True)