# Playground to implement W&B as well as start hyperparameter-tuning

## Setup

In [5]:
import re
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.utils import class_weight
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import wandb
from wandb.keras import WandbCallback

### import preprocessed df's

In [None]:
cce_df = pd.read_csv('testdata/cce.csv')
# scce_df = pd.read_csv('testdata/ssce.csv')
# multilable_df = pd.read_csv('testdata/multilable.csv') ## WIP

### SCCE / CCE Model

In [None]:
# Split the data into train, validation, and test sets
X_2 = np.array(cce_df['sequence_vector'].tolist())
y_2 = np.array(cce_df['precursor_charge'])
max_len = max(cce_df.loc[:, 'sequence_vector'].apply(len))  # Find the maximum length


# Create an instance of StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

# Perform the split
train_val_indices, test_indices = next(sss.split(X_2, y_2))
X_2_train_val, X_2_test = X_2[train_val_indices], X_2[test_indices]
y_2_train_val, y_2_test = y_2[train_val_indices], y_2[test_indices]

# Create another instance of StratifiedShuffleSplit for train-validation split
sss_train_val = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

# Perform the train-validation split
train_indices, val_indices = next(sss_train_val.split(X_2_train_val, y_2_train_val))
X_2_train, X_2_val = X_2_train_val[train_indices], X_2_train_val[val_indices]
y_2_train, y_2_val = y_2_train_val[train_indices], y_2_train_val[val_indices]

num_classes = 8  # Number of precursor charge classes (1 to 7, plus an extra class for 'None' charge)
y_2_train_encoded = tf.keras.utils.to_categorical(y_2_train, num_classes)
y_2_val_encoded = tf.keras.utils.to_categorical(y_2_val, num_classes)
y_2_test_encoded = tf.keras.utils.to_categorical(y_2_train, num_classes)

# Define model
model_cce = tf.keras.models.Sequential([
    tf.keras.layers.Embedding(input_dim=max_len, output_dim=20, input_length=X_2.shape[1]),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model_cce.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

checkpoint_callback = ModelCheckpoint('precursor_charge_prediction_model_v1/cce_wo7_allSequences.h5', monitor='val_accuracy', save_best_only=True, mode='max')

# Define early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

In [None]:
# Train the model
history_cce = model_cce.fit(X_2_train, y_2_train_encoded, epochs=10, batch_size=32, validation_data=(X_2_val, y_2_val_encoded), callbacks=[checkpoint_callback, early_stopping]) #, wandb_callback])

In [None]:
# Access the loss, validation loss, and accuracy from the history object
loss = history_cce.history['loss']
val_loss = history_cce.history['val_loss']
accuracy = history_cce.history['accuracy']
val_accuracy = history_cce.history['val_accuracy']

# Plot the loss, validation loss, and accuracy curves
epochs = range(1, len(loss) + 1)

# Create subplots
fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Plot loss and validation loss
ax1.plot(epochs, loss, 'b', label='Training Loss')
ax1.plot(epochs, val_loss, 'r', label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot accuracy and validation accuracy
ax2.plot(epochs, accuracy, 'b', label='Training Accuracy')
ax2.plot(epochs, val_accuracy, 'r', label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

# Adjust spacing between subplots
plt.tight_layout()

# Show the plots
plt.show()

## Check in with Franzi's group for reporting

### Multilable Model
#### WIP in precursor_charge_predictor

## Model Testing
### check if models only predict charge 2 or also other charges. Due to 'overrepresentation' the best bet for the model could be to only output charge state 2

## Hyperparameter Tuning