In [None]:
print("Importing...")
from pprint import pprint
import json
import sys
import os
from shutil import rmtree
from copy import deepcopy
from os.path import join
from importlib import reload

from os.path import isdir
from os.path import join
from os import mkdir
from shutil import rmtree

import numpy as np
import pandas as pd
from scipy import stats
from matplotlib import pyplot as plt
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import f1_score

from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from tensorflow import keras
from keras import callbacks
from keras import losses
from keras import metrics
from keras import optimizers
from keras import layers
from keras import regularizers
import tensorflow_addons as tfa

sys.path.insert(0, os.path.join(os.getcwd(), "../scs"))
import scs_config as scsc
import data_loading as dl
import data_degrading as dd
import data_preparation as dp
import data_augmentation as da
import data_plotting as dplt
import learn
import lr_schedules
import hp_sets
import prepare_dataset

In [None]:
def run_SCS(dir_model, R, hp, file_raw_data=None, resume=False):
    dir_model_backup = join(dir_model, "backup")
    dir_model_data = join(dir_model, "data")
    dir_model_results = join(dir_model, "results")

    file_model = join(dir_model, "model.hdf5")
    file_model_history = join(dir_model, "history.log")
    file_model_results = join(dir_model_results, "results.json")
    file_model_hp = join(dir_model, "hp.json")
    file_model_curves = join(dir_model_results, "curves.pdf")
    
    if isfile(file_model_results):
        os.remove(file_model_results)
        
    if isfile(file_model_hp):
        os.remove(file_model_hp)
    
    file_df_trn = join(dir_model_data, "df_trn.parquet")
    file_df_tst = join(dir_model_data, "df_tst.parquet")
    if not resume:
        df_trn, df_tst = prepare_R_data(
            R,
            file_raw_data,
            phase_range=hp["phase_range"],
            ptp_range=hp["ptp_range"],
            wvl_range=hp["wvl_range"],
            train_frac=hp["train_frac"],
            noise_scale=hp["noise_scale"],
            spike_scale=hp["spike_scale"],
            max_spikes=hp["max_spikes"],
            random_state=hp["random_state"],
        )
        df_trn.to_parquet(file_df_trn)
        df_tst.to_parquet(file_df_tst)

    df_trn = load_sn_data(file_df_trn)
    df_tst = load_sn_data(file_df_tst)

    # TODO: Add a function call here to generate some summary statistics
    # and/or plots based on df_trn and df_tst.

    Xtrn, Ytrn, num_trn, num_wvl, num_classes = extract(df_trn)
    Xtst, Ytst, num_tst, num_wvl, num_classes = extract(df_tst)
    if hp["add_dim"]:
        Xtrn = add_dim(Xtrn, swap=hp["swap"])
        Xtst = add_dim(Xtst, swap=hp["swap"])

    write_json(hp, file_model_hp)

    input_shape = Xtrn.shape[1:]
    model = get_model(input_shape, num_classes, hp)
    model.summary()

    compile_model(model, num_classes, hp["lr0"])
    lr_schedule = get_lr_schedule(hp)
    callbacks = get_callbacks(dir_model, lr_schedule)

    history = train(
        model,
        Xtrn,
        Ytrn,
        Xtst,
        Ytst,
        hp["epochs"],
        hp["batch_size"],
        callbacks,
    )

    results = evaluate(model, Xtrn, Ytrn, Xtst, Ytst, verbose=0)
    write_json(results, file_model_results)
    
    log = pd.read_csv(file_model_history)
    fig = dplt.plot_loss(log)
    fig.savefig(file_model_curves)
    fig.clf()


In [None]:
R = 100
data_dir_original = "/home/2649/repos/SCS/data/"
dir_models = "/lustre/lrspec/users/2649/models/transformer_testing"

# Construct the directories if they don't exist or delete them and recreate
# them if they do and `restart_fit` is `True`.
dir_model = join(dir_models, f"{R}_dev")
dir_backup = join(dir_model, "backup")
dir_model_data = join(dir_model, "data")
if isdir(dir_model):
    rmtree(dir_model)
mkdir(dir_model)
mkdir(dir_backup)
mkdir(dir_model_data)

file_trn = join(dir_model_data, f"sn_data_trn.RPA.parquet")
file_tst = join(dir_model_data, f"sn_data_tst.RP.parquet")

hp = deepcopy(scsc.default_hyper_parameters)

hp["train_frac"] = 0.80
hp["noise_scale"] = 0.15848931924611134
hp["spike_scale"] = 1.045639552591273
hp["max_spikes"] = 3

# Prepare the dataset from the original dataset dataframe `sn_data_file`.
sn_data_file = join(data_dir_original, "sn_data.parquet")
prepare_dataset.prepare_dataset(
    R,
    sn_data_file,
    dir_model_data,
    dir_model_data,
    dir_model_data,
    dir_model_data,
    hp["phase_range"],
    hp["ptp_range"],
    hp["wvl_range"],
    hp["train_frac"],
    hp["noise_scale"],
    hp["spike_scale"],
    hp["max_spikes"],
    random_state=hp["random_state"],
)

# Load the dataset.
df_trn = dl.load_sn_data(file_trn)
df_tst = dl.load_sn_data(file_tst)
dataset, num_wvl, num_classes = learn.prepare_datasets_for_training(
    df_trn, df_tst
)
Xtrn, Ytrn, Xtst, Ytst = dataset

In [None]:
Xtrn.shape, Ytrn.shape

In [None]:
Xtst.shape, Ytst.shape

In [None]:
reload(learn)
model = learn.devmodel(
    num_wvls=Xtrn.shape[1],
    num_classes=num_classes,
    num_transformer_blocks=1,
    num_heads=4,
    key_dim=4,
    kr_l2=0,
    br_l2=0,
    ar_l2=0,
    dropout_attention=0,
    dropout_projection=0,
    filters=512,
    num_feed_forward_layers=3,
    feed_forward_layer_size=1024,
    dropout_feed_forward=0,
    initial_projection=100,
)

In [None]:
model.summary()