In [None]:
import numpy as np
from sklearn.model_selection import train_test_split, train_test_split
import pandas as pd

import utils.load_data as ld
import utils.load_SAM40_data as ld_SAM40
import utils.features as f
import utils.classifiers as clfs
import utils.variables as v
import utils.metrics as m

In [None]:
def load_and_shape_data(data_type, label_type, feature, kfold, new_ica = False):
    #Load data
    if kfold:
        train_data, test_data, train_labels, test_labels = ld.load_kfold_data(data_type, label_type, epoched = False, binary = True)
    else:
        train_data, test_data, val_data, train_labels, test_labels, val_labels = ld.load_data(data_type, label_type, epoched = True, binary = True)
        return train_data, test_data, val_data, train_labels, test_labels, val_labels
    
    print('\n---- Balanced dataset? ----')
    print(f'Section of non-stressed in train set: {np.sum(train_labels == 0)/len(train_labels)}')
    print(f'Section of non-stressed in test set: {np.sum(test_labels == 0)/len(test_labels)}')


    if feature:
        #Reshape labels to fit (n_recordings*n_channels, 1)
        train_labels = np.repeat(train_labels, repeats = v.NUM_CHANNELS, axis = 0).reshape((train_data.shape[0]*v.NUM_CHANNELS,1))
        train_labels = train_labels.ravel()

        test_labels = np.repeat(test_labels,repeats = v.NUM_CHANNELS, axis = 0).reshape((test_data.shape[0]*v.NUM_CHANNELS,1))
        test_labels = test_labels.ravel()
        
        #Extract features
        #time_series_features, fractal_features, entropy_features, hjorth_features, freq_band_features, kymatio_wave_scattering
        train_data = f.time_series_features(train_data, new_ica)
        test_data = f.time_series_features(test_data, new_ica)

        return train_data, test_data, train_labels, test_labels
    else:
        #Reshape data
        train_data = np.reshape(train_data, (train_data.shape[0]*train_data.shape[1], train_data.shape[2]))
        train_labels = np.repeat(train_labels, repeats = 8, axis = 1).reshape(-1,1)
        train_labels = train_labels.ravel()

        test_data = np.reshape(test_data, (test_data.shape[0]*test_data.shape[1],test_data.shape[2]))
        test_labels = np.repeat(test_labels, repeats = 8, axis = 1).reshape(-1,1)
        test_labels = test_labels.ravel()
        return train_data, test_data, train_labels, test_labels



In [None]:
def load_and_shape_psd_data(label_type):
    train_data, test_data, train_labels, test_labels = ld.load_psd_data(label_type, binary = True)
    train_data = np.reshape(train_data, (train_data.shape[0]*train_data.shape[1], train_data.shape[2]))
    train_labels = np.repeat(train_labels, repeats = 8, axis = 1).reshape(-1,1)
    train_labels = train_labels.ravel()

    test_data = np.reshape(test_data, (test_data.shape[0]*test_data.shape[1],test_data.shape[2]))
    test_labels = np.repeat(test_labels, repeats = 8, axis = 1).reshape(-1,1)
    test_labels = test_labels.ravel()
    
    return train_data, test_data, train_labels, test_labels

In [None]:
def load_and_shape_SAM40_data():
    ##Load the SAM40 dataset to be used as test data/label
    selected_channels_names = ['Fp2', 'F4', 'FC6', 'T8', 'Oz', 'O1', 'C3', 'FT9']

    dataset_SAM40_ = ld_SAM40.load_dataset('raw', 'Arithmetic')

    #dataset_SAM40 = ld_SAM40.convert_to_epochs(dataset_SAM40_, v_SAM40.NUM_CHANNELS, v_SAM40.SFREQ)

    channels = ld_SAM40.load_channels()
    selected_chan_index = [channels.index(elem) for elem in selected_channels_names]
    print(selected_chan_index)
    selected_channels_dataset = np.array([dataset_SAM40_[:,i,:] for i in selected_chan_index])
    selected_channels_dataset = np.reshape(selected_channels_dataset, (120,8,3200))
    print(selected_channels_dataset.shape)

    test_data_SAM40 = f.time_series_features(selected_channels_dataset, False, SAM40=True)
    labels = ld_SAM40.load_labels()
    label = pd.concat([labels['t1_math'], labels['t2_math'],
                    labels['t3_math']]).to_numpy()
    
    #Change labels from T/F to 1/0
    for i in range(len(label)):
        if label[i]:
            label[i] = 1
        else:
            label[i] = 0

    print(label)
    print(label.shape)
    label = np.repeat(label, test_data_SAM40.shape[0]//label.shape[0])

    print(label.shape)
    return test_data_SAM40, label

In [None]:
data_type = 'raw'
label_type = 'stai'
feature = True
kfold = True


train_data, test_data, train_labels, test_labels = load_and_shape_data(data_type, label_type, feature, kfold)

In [None]:
SAM40_data, SAM40_labels = load_and_shape_SAM40_data()

clfs.svm_classification_SAM40(train_data, test_data, SAM40_data, train_labels, test_labels, SAM40_labels)

In [None]:
print(f'{data_type} data with {label_type} labels')
clfs.knn_classification(train_data, test_data, train_labels, test_labels)

In [None]:
print(f'{data_type} data with {label_type} labels')
clfs.svm_classification(train_data, test_data, train_labels, test_labels)
