In [None]:
import numpy as np
import tensorflow as tf
import time
from tensorflow.keras import layers
import sklearn.metrics as skm
import os
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
data_path = '/Users/as274094/Documents/psf_dataset2/'
# data_path = '/gpfswork/rech/prk/uzk69cg/psf_dataset2/'
#data_path = '/Users/as274094/GitHub/Refractored_star_classifier/tensorflow_version/'

In [None]:
# Hyperparameters
PCA_components = 24
model_learning_rate = 0.1
N_epochs = 100
N_committee = 48
patience_epochs = 10

In [None]:
def SEDlisttoC(SED_list):
    sed_array = np.array(SED_list)
    return sed_array*0.5 + 1.5

def CtoSEDarray(c_values, variance):
    sed_classes = ((c_values - 1.25) // 0.5).astype(int)
    sed_classes = np.where((c_values < 1.25) | (c_values > 7.75), 20, sed_classes)
    sed_classes = np.where((variance > 1.00), 20, sed_classes)
    return sed_classes

def calculate_success_rate(confusion_matrix):
    diagonal = np.trace(confusion_matrix)
    diagonal_neighbors = np.sum(np.diagonal(confusion_matrix, offset=1)) + np.sum(np.diagonal(confusion_matrix, offset=-1))
    total_classified = np.sum(confusion_matrix)
    
    success_rate = (diagonal + diagonal_neighbors) / total_classified
    return success_rate

In [None]:
class TrainingCompletionCallback(tf.keras.callbacks.Callback):
    def on_train_end(self, logs=None):
        epochs = len(self.model.history.history['loss'])
        final_loss = self.model.history.history['loss'][-1]
        final_val_loss = self.model.history.history['val_loss'][-1]

        print("Training completed. Number of epochs:", epochs, ", Final training loss:", final_loss, ", Final validation loss:", final_val_loss)

completion_callback = TrainingCompletionCallback()

In [None]:
class PCA_model:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        dataset = np.load(data_path + dataset_name + ".npy", allow_pickle=True)[()]
        self.x_train = dataset['train_stars_pca']
        self.x_val = dataset['validation_stars_pca']
        self.x_test = dataset['test_stars_pca']
        self.y_train = dataset['train_C']
        self.y_val = dataset['validation_C']
        self.y_test = dataset['test_C']
        self.SED_val = dataset['validation_SEDs']
        self.SED_test = dataset['test_SEDs']
        
    def create_model(self):
        initializer = tf.keras.initializers.GlorotNormal(seed = None)
        self.model = tf.keras.Sequential([
            layers.Dense(26, input_shape=[PCA_components], activation='sigmoid', kernel_initializer= initializer),
            layers.Dense(26, activation='sigmoid', kernel_initializer= initializer),
            layers.Dense(1, activation = 'linear', kernel_initializer= initializer)
        ])
    
    def train_model(self, learning_rate, training_epochs, patience_epochs):

        self.model.compile(
            loss = tf.keras.losses.MeanSquaredError(),
            optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
        )
        early_stopping = tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = patience_epochs, restore_best_weights=True)

        start_time = time.time() # Measure training time
        self.learning = self.model.fit(self.x_train, self.y_train, epochs= training_epochs, verbose = 0, callbacks = [completion_callback,early_stopping], validation_data=(self.x_val,self.y_val)) 

        # Calculate the training time
        end_time = time.time()
        self.training_time = end_time - start_time
        print("Total training time:", self.training_time, "seconds")

    def predict_test(self, verbose = True):
        C_pred = self.model.predict(self.x_test, verbose = 1).reshape(-1)
        SED_pred_test = CtoSEDarray(C_pred,np.zeros_like(C_pred))

        self.mse_test = np.mean((self.y_test - C_pred)**2)
        self.f1_test = skm.f1_score(self.SED_test, SED_pred_test, average = None)
        self.f1_mean_test = np.mean(self.f1_test[:13])
        self.confusion_matrix_test = skm.confusion_matrix(self.SED_test, SED_pred_test)
        self.success_rate_test = calculate_success_rate(self.confusion_matrix_test)

        if(verbose):
            print("Prediction results for the test set")
            print('MSE:', self.mse_test)
            print('\nF1 score for each class:', self.f1_test)
            print('Average F1 score:', self.f1_mean_test)
            print("\nConfusion matrix:")
            print(self.confusion_matrix_test)
            print('\nSuccess rate:', self.success_rate_test)


    def predict_val(self, verbose = True):
        C_pred = self.model.predict(self.x_val, verbose = 1).reshape(-1)
        SED_pred_val = CtoSEDarray(C_pred,np.zeros_like(C_pred))

        self.mse_val = np.mean((self.y_val - C_pred)**2) 
        self.f1_val = skm.f1_score(self.SED_val, SED_pred_val, average = None)
        self.f1_mean_val = np.mean(self.f1_val[:13])
        self.confusion_matrix_val = skm.confusion_matrix(self.SED_val, SED_pred_val)
        self.success_rate_val = calculate_success_rate(self.confusion_matrix_val)

        if(verbose):
            print("Prediction results for the validation set")
            print('MSE:', self.mse_val)
            print('\nF1 score for each class:', self.f1_val)
            print('Average F1 score:', self.f1_mean_val)
            print("\nConfusion matrix:")
            print(self.confusion_matrix_val)
            print('\nSuccess rate:', self.success_rate_val)

    def save_model(self):
        os.makedirs("single_models", exist_ok=True)
        self.model.save(f"single_models/my_model_{self.dataset_name}.h5")

    def plot_loss(self):
    # Plot the loss function evolution

        loss_evolution = self.learning.history["loss"]
        val_loss_evolution = self.learning.history["val_loss"]

        plt.figure(figsize = (9,5))
        plt.plot(loss_evolution,label = "Train set")
        plt.plot(val_loss_evolution,label = "Validation set")
        plt.xlabel("Epochs")
        plt.ylabel("Loss function value")
        plt.legend()
        plt.title("Loss function evolution")
        print("Total training time:", self.training_time, "seconds")
        print("Training loss:", loss_evolution[-1], ", Validation loss:", val_loss_evolution[-1])

    def plot_cf_matrix(self):
        # Plot the confusion matrix

        star_class_labels = ['O5','B0','B5','A0','A5','F0','F5','G0','G5','K0','K5','M0','M5']

        plt.figure(figsize= (12,10))
        heatmap = plt.imshow(self.confusion_matrix_test[:13,:], cmap='Blues')
        plt.xticks(np.arange(14), star_class_labels + ['???'])
        plt.yticks(np.arange(13), star_class_labels)
        plt.colorbar(heatmap)
        plt.xlabel("Estimated spectral type")
        plt.ylabel("True spectral type")
        plt.show()

In [None]:
letters = ['A','B','C','D','E']

model_list = []
for letter in letters:
    model = PCA_model('PCA_dataset2' + letter)
    model.create_model()
    model.train_model(training_epochs= 12, learning_rate=model_learning_rate, patience_epochs=patience_epochs)
    model.predict_test()
    model.predict_val()
    model_list.append(model)



In [None]:
modelB = model_list[1]

# Plot the loss function evolution

loss_evolution = modelB.learning.history["loss"][10:]
val_loss_evolution = modelB.learning.history["val_loss"][10:]

plt.figure(figsize = (9,5))
plt.plot(loss_evolution,label = "Train set")
plt.plot(val_loss_evolution,label = "Validation set")
plt.xlabel("Epochs")
plt.ylabel("Loss function value")
plt.legend()
plt.title("Loss function evolution")
print("Total training time:", modelB.training_time, "seconds")
print("Training loss:", loss_evolution[-1], ", Validation loss:", val_loss_evolution[-1])

In [None]:
modelB.train_model(training_epochs= 100, learning_rate= 0.005, patience_epochs= 10)


In [None]:
# Plot the loss function evolution

model_list[0].plot_cf_matrix()

In [None]:
modelB.predict_val()
modelB.predict_test()

In [None]:
metrics = []
for model in model_list:
    metrics.append([model.success_rate_test, model.f1_mean_test, model.success_rate_val, model.f1_mean_val])
metrics = np.array(metrics)
print(metrics)

In [None]:
# Plot the metrics of the different datasets


fig, ax = plt.subplots(figsize = (18,10))

ax.scatter(letters, metrics[:,0],label = "Success rate test", color = 'red', marker= "x", s = 100)
ax.scatter(letters, metrics[:,2],label = "Success rate validation", color = 'blue', marker= "x", s = 100)
ax.scatter(letters, metrics[:,1] ,label = "F1 score test ", color = 'red', s = 100)
ax.scatter(letters, metrics[:,3] ,label = "F1 score validation ", color = 'blue', s = 100)

ax.set_xlabel("PCA dataset", fontsize = 20)
ax.set_ylabel("Metric", fontsize = 20)
ax.tick_params(axis='x', labelsize=15)
ax.tick_params(axis='y', labelsize=15)
ax.grid(True)
ax.legend(fontsize = 12)
ax.set_title('Performance of different PCA decomposition methods', fontsize = 25)
plt.show()