Advice from fed

In linear space, noise is a multiplicative factor. So I need to figure out a noise profile from each spectrum and scale it. Smooth the original spectrum, take the difference between the smoothed and the original spectrum, and then those residuals will be the noise profile (until we can talk to marc and get that code to extract noise profile more cleverly).

Can also do this additively in log space.

In [16]:
import sys
from os.path import isfile

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix

from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.optimizers import Nadam
from tensorflow.keras import callbacks
from tensorflow_addons.metrics import F1Score

# My packages
sys.path.insert(0, "/Users/admin/Code/SCS/scs/")
import data_degrading as dd
import data_preparation as dp
import data_augmentation as da
from prepare_datasets_for_training import extract
import data_plotting as dplt
import scs_config

sys.path.insert(0, "/Users/admin/Code/SCS/scs/models/")
import feed_forward
import transformer_encoder

In [2]:
rng = np.random.RandomState(1415)
overwrite = True

In [67]:
def gen_noise(spectrum, noise_scale, rng):
    filt = savgol_filter(spectrum, 10, 1, mode="mirror")
    res = spectrum - filt
    noise = res * noise_scale
    return noise
gen_noise = np.vectorize(gen_noise, signature="(n),(),()->(n)")

In [17]:
# Load original dataset

file_df_raw = "/Users/admin/Code/SCS/data/raw/sn_data.parquet"
df_raw = pd.read_parquet(file_df_raw)

In [17]:
# Data Preprocessing

## Inject Noise

data = dp.extract_dataframe(df_raw)
index, wvl, flux_columns, metadata_columns, df_fluxes, df_metadata, fluxes = data
fluxes_noise = fluxes + gen_noise(fluxes, noise_scale, rng)
df_raw[flux_columns] = fluxes_noise

## Degrade data

R = 100

file_df_R = f"/Users/admin/Code/SCS/data/R{R}/df_R.parquet"
file_df_C = f"/Users/admin/Code/SCS/data/R{R}/df_C.parquet"
if (not overwrite) and (isfile(file_df_R) and isfile(file_df_C)):
    df_R = pd.read_parquet(file_df_R)
    df_C = pd.read_parquet(file_df_C)
else:
    df_C, df_R = dd.degrade_dataframe(R, df_raw)
    df_R.to_parquet(file_df_R)
    df_C.to_parquet(file_df_C)

## Clean data

phase_range = (-20, 50)
ptp_range = (0.1, 100)
wvl_range = (4500, 7000)

file_df_RP = f"/Users/admin/Code/SCS/data/R{R}/df_RP.parquet"
file_df_CP = f"/Users/admin/Code/SCS/data/R{R}/df_CP.parquet"
if (not overwrite) and (isfile(file_df_RP) and isfile(file_df_CP)):
    df_RP = pd.read_parquet(file_df_RP)
    df_CP = pd.read_parquet(file_df_CP)
else:

    df_RP = dp.preproccess_dataframe(df_R, phase_range=phase_range, ptp_range=ptp_range, wvl_range=wvl_range)
    df_CP = dp.preproccess_dataframe(df_C, phase_range=phase_range, ptp_range=ptp_range, wvl_range=wvl_range)
    df_RP.to_parquet(file_df_RP)
    df_CP.to_parquet(file_df_CP)

## Train-Test split

train_frac = 0.50

file_df_RP_trn = f"/Users/admin/Code/SCS/data/R{R}/df_RP_trn.parquet"
file_df_CP_trn = f"/Users/admin/Code/SCS/data/R{R}/df_CP_trn.parquet"
file_df_RP_tst = f"/Users/admin/Code/SCS/data/R{R}/df_RP_tst.parquet"
file_df_CP_tst = f"/Users/admin/Code/SCS/data/R{R}/df_CP_tst.parquet"
if (not overwrite) and (isfile(file_df_RP_trn) and isfile(file_df_CP_trn) and isfile(file_df_RP_tst) and isfile(file_df_CP_tst)):
    df_RP_trn = pd.read_parquet(file_df_RP_trn)
    df_CP_trn = pd.read_parquet(file_df_CP_trn)
    df_RP_tst = pd.read_parquet(file_df_RP_tst)
    df_CP_tst = pd.read_parquet(file_df_CP_tst)
else:
    df_RP_trn, df_RP_tst = dp.split_data(df_RP, train_frac, rng)
    df_CP_trn, df_CP_tst = dp.split_data(df_CP, train_frac, rng)
    df_RP_trn.to_parquet(file_df_RP_trn)
    df_CP_trn.to_parquet(file_df_CP_trn)
    df_RP_tst.to_parquet(file_df_RP_tst)
    df_CP_tst.to_parquet(file_df_CP_tst)

## Augment training set
# noise_scale = 0.25
spike_scale = 3
max_spikes = 5

file_df_RPA_trn = f"/Users/admin/Code/SCS/data/R{R}/df_RPA_trn.parquet"
file_df_CPA_trn = f"/Users/admin/Code/SCS/data/R{R}/df_CPA_trn.parquet"
df_RPA_trn = da.augment(df_RP_trn, rng, wvl_range=wvl_range, noise_scale=noise_scale, spike_scale=spike_scale, max_spikes=max_spikes)
df_CPA_trn = da.augment(df_CP_trn, rng, wvl_range=wvl_range, noise_scale=noise_scale, spike_scale=spike_scale, max_spikes=max_spikes)
df_RPA_trn.to_parquet(file_df_RPA_trn)
df_CPA_trn.to_parquet(file_df_CPA_trn)

## Ready dataset for ML

df_trn = df_RPA_trn
df_tst = df_RP_tst
Xtrn, Ytrn, num_trn, num_wvl, num_classes = extract(df_trn)
Xtst, Ytst, num_tst, num_wvl, num_classes = extract(df_tst)

# Machine Learning

## Initialize model

input_shape = Xtrn.shape[1:]
units = [1024, 1024, 1024]
model = feed_forward.model(input_shape, num_classes, units, activation="relu", dropout=0.1)
model.summary()

## Metrics, loss, and optimizer

lr0 = 1e-5

loss = CategoricalCrossentropy()
acc = CategoricalAccuracy(name="ca")
f1 = F1Score(num_classes=num_classes, average="macro", name="f1")
opt = Nadam(learning_rate=lr0)
model.compile(loss=loss, optimizer=opt, metrics=[acc, f1])

## Callbacks

early = callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=10,
    verbose=2,
    mode="min",
    restore_best_weights=True,
)

file_log = f"/Users/admin/Code/SCS/data/R{R}/history.log"
logger = callbacks.CSVLogger(file_log, append=False)

cbs = [early, logger]

## Fit model to training set

epochs = 10_000
batch_size = 32
verbose = 2

history = model.fit(
    Xtrn,
    Ytrn,
    validation_data=(Xtst, Ytst),
    epochs=epochs,
    batch_size=batch_size,
    verbose=verbose,
    callbacks=cbs,
)

## Predict on testing set

loss_trn, ca_trn, f1_trn = model.evaluate(x=Xtrn, y=Ytrn, verbose=0)
loss_tst, ca_tst, f1_tst = model.evaluate(x=Xtst, y=Ytst, verbose=0)

print(f"{f1_tst:.4f}")
print(f"{f1_trn:.4f}")
print(f"{ca_tst:.4f}")
print(f"{ca_trn:.4f}")
print(f"{loss_tst:.4f}")
print(f"{loss_tst:.4f}")

In [17]:
# Analysis

log = pd.read_csv(file_log)

## Loss curves

fig = dplt.plot_loss(log, scale=6)
fig.show()

## Confusion Matrix

Ptrn = model.predict(Xtrn)
Ptst = model.predict(Xtst)

Ptrn_flat = np.argmax(Ptrn, axis=1)
Ptst_flat = np.argmax(Ptst, axis=1)

Ytrn_flat = np.argmax(Ytrn, axis=1)
Ytst_flat = np.argmax(Ytst, axis=1)

SNtypes_int = np.unique(Ytrn_flat)
SNtypes_str = [scs_config.SN_Stypes_int_to_str[sn] for sn in SNtypes_int]

CMtrn = confusion_matrix(Ytrn_flat, Ptrn_flat)
CMtst = confusion_matrix(Ytst_flat, Ptst_flat)

dplt.plot_cm(CMtst, SNtypes_str, R, normalize=True)
# dplt.plot_cm(CMtrn, SNtypes_str, R, normalize=True)