# Importing libraries

In [None]:
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import GRU, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.optimizers import Adam
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split

In [None]:
import importlib
spec = importlib.util.spec_from_file_location("preprocessing", "..\\utils\\preprocessing.py")
preprocessing = importlib.util.module_from_spec(spec)
spec.loader.exec_module(preprocessing)

spec = importlib.util.spec_from_file_location("fspliter", "..\\utils\\files_spliter.py")
fspliter = importlib.util.module_from_spec(spec)
spec.loader.exec_module(fspliter)

spec = importlib.util.spec_from_file_location("results", "..\\utils\\results.py")
results = importlib.util.module_from_spec(spec)
spec.loader.exec_module(results)

label_encoder = LabelEncoder()

# Preprocessing

In [None]:
def train_test_validation_split(data):

    # preprocessing and encoding
    data_processed = preprocessing.do_preprocessing(data)
    data_processed['state_encoded'] = label_encoder.fit_transform(data_processed['state'])

    # selection of features and scaling
    feature_columns = data_processed.columns[1:-2]
    bin_columns = [col for col in feature_columns if 'bin' in col]
    data_processed[feature_columns] = StandardScaler().fit_transform(data_processed[feature_columns])

    # split data into days
    day1 = fspliter.retrieve_day(data_processed, 1)
    day2 = fspliter.retrieve_day(data_processed, 2)
    day3 = fspliter.retrieve_day(data_processed, 3)
    day4 = fspliter.retrieve_day(data_processed, 4)

    # Concatenate day1 and day2 to form the train set
    train_data = pd.concat([day1, day2])

    # PCA on bin columns
    pca = PCA(n_components=30)
    X_train_pca = pca.fit_transform(train_data[bin_columns])
    X_val_pca = pca.transform(day4[bin_columns])
    X_test_pca = pca.transform(day3[bin_columns])

    # Concatenate PCA results with EEGv and EMGv
    X_train = pd.concat([train_data[['EEGv', 'EMGv']].reset_index(drop=True),
                         pd.DataFrame(X_train_pca)], axis=1)
    X_val = pd.concat([day4[['EEGv', 'EMGv']].reset_index(drop=True),
                       pd.DataFrame(X_val_pca)], axis=1)
    X_test = pd.concat([day3[['EEGv', 'EMGv']].reset_index(drop=True),
                        pd.DataFrame(X_test_pca)], axis=1)

    # Labels
    y_train = train_data['state_encoded']
    y_val = day4['state_encoded']
    y_test = day3['state_encoded']

    return X_train, X_val, X_test, y_train, y_val, y_test


In [None]:
def train_test_validation_split_on_day_3(data):

    day3 = fspliter.retrieve_day(data, 3)
    day3_without_first_6_hours = day3.iloc[5400:]

    # preprocessing and encoding
    data_processed = preprocessing.do_preprocessing(day3_without_first_6_hours)
    data_processed['state_encoded'] = label_encoder.fit_transform(data_processed['state'])

    # selection of features and scaling
    feature_columns = data_processed.columns[1:-2]
    bin_columns = [col for col in feature_columns if 'bin' in col]
    data_processed[feature_columns] = StandardScaler().fit_transform(data_processed[feature_columns])

    # PCA on bin columns
    pca = PCA(n_components=30)
    pca = pca.fit_transform(data_processed[bin_columns])

    # Concatenate PCA results with EEGv and EMGv
    df_pca = pd.concat([data_processed[['EEGv', 'EMGv']].reset_index(drop=True),
                         pd.DataFrame(pca)], axis=1)

    X_train, X_temp, y_train, y_temp = train_test_split(df_pca, data_processed['state_encoded'], test_size=0.3, shuffle = False, stratify = None)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, shuffle = False, stratify = None)

    return X_train, X_val, X_test, y_train, y_val, y_test


# Load data

In [None]:
mouse = fspliter.get_mice(0)
X_train, X_val, X_test, y_train, y_val, y_test = train_test_validation_split_on_day_3(mouse)

# Model

In [None]:
model = Sequential([
    Bidirectional(LSTM(100, return_sequences=True, input_shape=(20, X_train.shape[1]))),
    Dropout(0.3),
    Bidirectional(LSTM(100)),
    Dropout(0.3),
    Dense(50, activation='relu'),
    Dense(3, activation='softmax')  # 3 output units for 3 classes (w, n, r)
])

In [None]:
def create_bidirectional_sequences(data, n):
    sequences = []
    data_length = len(data)

    for i in range(n, data_length - n):
        seq = data[i - n: i + n + 1]
        sequences.append(seq)

    return np.array(sequences)

In [None]:
def create_sequences(data, sequence_length):
    sequences = []
    for i in range(len(data) - sequence_length):
        sequences.append(data[i:i + sequence_length])
    return np.array(sequences)


In [None]:
model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

X_train_array = X_train.values
X_val_array = X_val.values

#X_train_sequences = create_sequences(X_train.values, 10)
#X_val_sequences = create_sequences(X_val.values, 10)

X_train_sequences = create_bidirectional_sequences(X_train.values, 10)
X_val_sequences = create_bidirectional_sequences(X_val.values, 10)

y_train_adjusted = y_train[20:]
y_val_adjusted = y_val[20:]

class_weights = class_weight.compute_class_weight(class_weight = "balanced", classes= np.unique(y_train_adjusted), y= y_train_adjusted)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

history = model.fit(X_train_sequences, y_train_adjusted, epochs=2, batch_size=64, validation_data=(X_val_sequences, y_val_adjusted), verbose=1, class_weight=class_weights_dict)

model.summary()

# Saving model

In [None]:
#model.save('Saved_model/Mouse_3_state_classification.h5')

# Training and validation results

In [None]:
import matplotlib.pyplot as plt

# Extracting accuracy and loss from the history object
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)

# Plotting training and validation accuracy
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, acc, 'bo-', label='Training accuracy')
plt.plot(epochs, val_acc, 'ro-', label='Validation accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

# Plotting training and validation loss
plt.subplot(1, 2, 2)
plt.plot(epochs, loss, 'bo-', label='Training loss')
plt.plot(epochs, val_loss, 'ro-', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()


# Load model if needed

In [None]:
#model = load_model('Saved_model/Mouse_3_state_classification.h5')

# Testing model

In [None]:
#X_test_sequences = create_sequences(X_test.values, 10)
X_test_sequences = create_bidirectional_sequences(X_test.values, 10)

y_test_adjusted = y_test[20:]

X_test_pred = model.predict(X_test_sequences)

predicted_labels = X_test_pred.argmax(axis=1)

In [None]:
y_test_original = label_encoder.inverse_transform(y_test_adjusted)
y_pred_original = label_encoder.inverse_transform(predicted_labels)

results.scores(y_test_original, y_pred_original, ('n', 'r', 'w'))

## Testing model on other mice

In [None]:
def test_model_on_other_mice(model, mice):
    mouse = fspliter.get_mice(mice)
    _, _, X_testmouse, _, _, y_testmouse = train_test_validation_split_on_day_3(mouse)

    #X_test_sequences = create_sequences(X_testmouse.values, 5)
    X_test_sequences = create_bidirectional_sequences(X_testmouse.values, 10)

    y_test_adjusted = y_testmouse[20:]

    X_test_pred = model.predict(X_test_sequences)
    predicted_labels = X_test_pred.argmax(axis=1)
    y_test_original = label_encoder.inverse_transform(y_test_adjusted)
    y_pred_original = label_encoder.inverse_transform(predicted_labels)
    results.scores(y_test_original, y_pred_original, ('n', 'r', 'w'))

### Testing model on same strain

In [None]:
test_model_on_other_mice(model, 1)

### Testing model on other strain

In [None]:
test_model_on_other_mice(model, 4)