In [None]:
import numpy as np
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import sklearn.metrics as skm
import os


In [None]:
# data_path = '/gpfswork/rech/prk/uzk69cg/psf_dataset2/'
data_path = '/Users/as274094/Documents/psf_dataset2/'

In [24]:
def calculate_success_rate(confusion_matrix):
    diagonal = np.trace(confusion_matrix)
    diagonal_neighbors = np.sum(np.diagonal(confusion_matrix, offset=1)) + np.sum(np.diagonal(confusion_matrix, offset=-1))
    total_classified = np.sum(confusion_matrix)
    
    success_rate = (diagonal + diagonal_neighbors) / total_classified
    return success_rate

class TrainingCompletionCallback(tf.keras.callbacks.Callback):
    def on_train_end(self, logs=None):
        epochs = len(self.model.history.history['loss'])
        final_loss = self.model.history.history['loss'][-1]
        final_val_loss = self.model.history.history['val_loss'][-1]
        final_acc = self.model.history.history['categorical_accuracy'][-1]
        final_val_acc = self.model.history.history['val_categorical_accuracy'][-1]

        print("Training completed. Number of epochs:", epochs, ", Final training loss:", final_loss, ", Final validation loss:", final_val_loss)
        print("Final accuracy:", final_acc, "Final validation accuracy:", final_val_acc)

completion_callback = TrainingCompletionCallback()

initializer = tf.keras.initializers.GlorotNormal(seed = 25)

star_class_labels = ['O5','B0','B5','A0','A5','F0','F5','G0','G5','K0','K5','M0','M5']

class softmax_model:
    def __init__(self, dataset_name, architecture):
        self.dataset_name = dataset_name
        dataset = np.load(data_path + dataset_name + ".npy", allow_pickle=True)[()]
        self.PCA_components = dataset['N_components']
        self.x_train = dataset['train_stars_pca']
        self.x_val = dataset['validation_stars_pca']
        self.x_test = dataset['test_stars_pca']
        self.y_val = dataset['validation_SEDs']
        self.y_test = dataset['test_SEDs']
        self.y_train_cat = tf.keras.utils.to_categorical(dataset['train_SEDs'],num_classes = 13)
        self.y_val_cat = tf.keras.utils.to_categorical(dataset['validation_SEDs'],num_classes = 13)
        self.y_test_cat = tf.keras.utils.to_categorical(dataset['test_SEDs'],num_classes = 13)
        self.learning = []
        self.model = architecture
        
    
    def train_model(self, learning_rate, training_epochs, patience_epochs):

        self.model.compile(
            loss = 'categorical_crossentropy',
            optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate),
            metrics = 'categorical_accuracy'
        )
        early_stopping = tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = patience_epochs, restore_best_weights=True)

        learn = self.model.fit(self.x_train, self.y_train_cat, epochs= training_epochs, verbose = 0,
                                        callbacks = [completion_callback,early_stopping], validation_data=(self.x_val,self.y_val_cat), shuffle=True) 

        self.learning.append(learn)

    def predict_test(self, verbose = True):
        y_test_pred = self.model.predict(self.x_test)
        class_predictions = np.argmax(y_test_pred, axis = 1)

        self.f1_test = skm.f1_score(self.y_test, class_predictions, average = None)
        self.f1_mean_test = np.mean(self.f1_test[:13])
        self.confusion_matrix_test = skm.confusion_matrix(self.y_test, class_predictions)
        self.success_rate_test = calculate_success_rate(self.confusion_matrix_test)

        if(verbose):
            print("Prediction results for the test set")
            print('Average F1 score:', self.f1_mean_test)
            print("\nConfusion matrix:")
            print(self.confusion_matrix_test)
            print('\nSuccess rate:', self.success_rate_test)


    def predict_val(self, verbose = True):
        y_val_pred = self.model.predict(self.x_val)
        class_predictions = np.argmax(y_val_pred, axis = 1)

        self.f1_val = skm.f1_score(self.y_val, class_predictions, average = None)
        self.f1_mean_val = np.mean(self.f1_val[:13])
        self.confusion_matrix_val = skm.confusion_matrix(self.y_val, class_predictions)
        self.success_rate_val = calculate_success_rate(self.confusion_matrix_val)

        if(verbose):
            print("Prediction results for the validation set")
            print('Average F1 score:', self.f1_mean_val)
            print("\nConfusion matrix:")
            print(self.confusion_matrix_val)
            print('\nSuccess rate:', self.success_rate_val)

    def save_model(self, N_model = 1):

        self.model.save(f"saved_models/{self.dataset_name}/my_model_{N_model}.h5")

    def load_model(self, N_model = 1):
        self.model = tf.keras.models.load_model(f"saved_models/{self.configuration['config_name']}/{self.dataset_name}/my_model_{N_model}.h5")
        

    def plot_training_evolution(self):
    # Plot the loss function and accuracy evolution

        loss_evolution = self.learning[-1].history["loss"]
        val_loss_evolution = self.learning[-1].history["val_loss"]
        acc_evolution = self.learning[-1].history['categorical_accuracy']
        val_acc_evolution = self.learning[-1].history['val_categorical_accuracy']

        plt.figure(figsize = (9,5))
        plt.subplot(121)
        plt.plot(loss_evolution,label = "Train set")
        plt.plot(val_loss_evolution,label = "Validation set")
        plt.xlabel("Epochs")
        plt.ylabel("Loss function value")
        plt.legend()
        plt.title("Loss function evolution")

        plt.subplot(122)
        plt.plot(acc_evolution)
        plt.plot(acc_evolution,label = "Train set")
        plt.plot(val_acc_evolution,label = "Validation set")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.title("Categorical accuracy evolution")

        print("Training loss:", loss_evolution[-1], ", Validation loss:", val_loss_evolution[-1])
        print("Training accuracy:", acc_evolution[-1], ", Validation accuracy:", val_acc_evolution[-1])

    def plot_cf_matrix(self):
        # Plot the confusion matrix

        plt.figure(figsize= (12,10))
        heatmap = plt.imshow(self.confusion_matrix_test[:13,:], cmap='Blues')
        plt.xticks(np.arange(13), star_class_labels)
        plt.yticks(np.arange(13), star_class_labels)
        plt.colorbar(heatmap)
        plt.xlabel("Estimated spectral type")
        plt.ylabel("True spectral type")
        plt.show()

    def plot_metrics(self):
        # Plot the metrics

        plt.figure(figsize = (9,5))
        plt.bar(np.arange(13), height = self.f1_test[:13], tick_label = star_class_labels ,label = "F1 score")
        plt.axhline(self.f1_mean_test, color='red', linestyle='--', label = 'F1 score average')
        plt.axhline(self.success_rate_test, color='purple', label = 'Success rate')
        plt.xlabel("Spectral class")
        plt.ylabel("Metric")
        plt.legend()
        plt.show()

In [25]:
architecture = tf.keras.Sequential([
    layers.Dense(26, input_shape=[24], activation='sigmoid', kernel_initializer= initializer),
    layers.Dense(26, activation='sigmoid', kernel_initializer= initializer),
    layers.Dense(13, activation = 'softmax', kernel_initializer= initializer)
])

model = softmax_model('PCA_dataset2B24', architecture)

Training completed. Number of epochs: 66 , Final training loss: 0.8153238892555237 , Final validation loss: 0.84336256980896
Final accuracy: 0.6318125128746033 Final validation accuracy: 0.6244000196456909
Prediction results for the validation set
Average F1 score: 0.6321016442215686

Confusion matrix:
[[1089  289  112   15    5    0    0    0    0    0    0    0    0]
 [ 781  522  198   24    3    0    0    0    0    0    0    0    0]
 [  28   66  605  659  126    9    1    0    0    0    0    0    0]
 [   3   24  227  856  396   30    2    0    0    0    0    0    0]
 [   0    2   14  188 1070  297   16    0    3    0    0    0    0]
 [   0    0    1    4  154 1030  278   16   18    1    0    0    0]
 [   0    0    1    3   16  425  887  143   97    6    2    0    0]
 [   0    0    0    0    2   33  271  388  685  186    3    0    0]
 [   0    0    0    0    0    8   75  202  807  411   13    0    0]
 [   0    0    0    0    0    0   19   36  415 1050   44    0    0]
 [   0    0    0

In [None]:
model.train_model(training_epochs= 100, learning_rate= 0.05, patience_epochs= 10)
model.predict_val()
model.predict_test()
model.plot_training_evolution()

In [None]:
architecture = tf.keras.Sequential([
    layers.Dense(100, input_shape=[30], activation='relu', kernel_initializer= initializer),
    layers.Dense(100, activation='relu', kernel_initializer= initializer),
    layers.Dense(100, activation='relu', kernel_initializer= initializer),
    layers.Dense(13, activation = 'softmax', kernel_initializer= initializer)
])
model2 = softmax_model('PCA_dataset2B30', architecture)

In [None]:
model2.train_model(training_epochs= 100, learning_rate= 0.001, patience_epochs= 10)
model2.predict_val()
model2.predict_test()
model2.plot_training_evolution()

In [None]:
architecture = tf.keras.Sequential([
    layers.Dense(50, input_shape=[30], activation = 'relu',kernel_regularizer = tf.keras.regularizers.l2(1e-3)),
    layers.BatchNormalization(),
    layers.Dropout(0.1),
    layers.Dense(50, activation = 'relu',kernel_regularizer = tf.keras.regularizers.l2(1e-3)),
    layers.BatchNormalization(),
    layers.Dropout(0.1),
    layers.Dense(50, activation = 'relu',kernel_regularizer = tf.keras.regularizers.l2(1e-3)),
    layers.BatchNormalization(),
    layers.Dropout(0.1),
    layers.Dense(13, activation = 'softmax', kernel_initializer= initializer)
])
model3 = softmax_model('PCA_dataset2B30', architecture)

In [None]:
model3.train_model(training_epochs= 100, learning_rate= 0.001, patience_epochs= 10)
model3.predict_val()
model3.predict_test()
model3.plot_training_evolution()

In [None]:
# Load a trained model

model_path = '/Users/as274094/GitHub/Refractored_star_classifier/tensorflow_version/single_model/my_model.h5'
model = tf.keras.models.load_model(model_path)


In [None]:
#Save the model

os.makedirs("single_model", exist_ok=True)
model.save(f"single_model/my_model.h5")

# Plotting

In [None]:
# Plot the loss function and accuracy evolution

loss_evolution = learning.history["loss"]
val_loss_evolution = learning.history["val_loss"]
acc_evolution = learning.history['categorical_accuracy']
val_acc_evolution = learning.history['val_categorical_accuracy']

plt.figure(figsize = (9,5))
plt.subplot(121)
plt.plot(loss_evolution,label = "Train set")
plt.plot(val_loss_evolution,label = "Validation set")
plt.xlabel("Epochs")
plt.ylabel("Loss function value")
plt.legend()
plt.title("Loss function evolution")

plt.subplot(122)
plt.plot(acc_evolution)
plt.plot(acc_evolution,label = "Train set")
plt.plot(val_acc_evolution,label = "Validation set")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Categorical accuracy evolution")

print("Total training time:", training_time, "seconds")
print("Training loss:", loss_evolution[-1], ", Validation loss:", val_loss_evolution[-1])