In [None]:
import copy
import mne
import math
import numpy as np
import pandas as pd
import pickle
import random


from sklearn import preprocessing
from os.path import exists
from sklearn.preprocessing import MinMaxScaler
from sklearn.feature_selection import SelectKBest, f_classif


In [None]:
Sampling_Frequency=250
#Length_of_trail=25
k=400
healthy=14
schizo=14
file_name=['h','s']
Band_Pass_Low_Range=0
Band_Pass_High_Range=50
iir_params = {'order': 2, 'ftype': 'butter'}
recording_time=25

In [None]:
# Random Subjects shuffling
indice=list(range(1,healthy+1))
random.shuffle(indice)
indice=[1, 12, 5, 11, 13, 10, 2, 9, 6, 8, 4, 3, 7, 14]
print(indice)

In [None]:
def select_channel(raw,selected_channel=['Fp1','Fp2']):
    channel_indices=[]
    for channel in selected_channel:
        try:
            index=raw.ch_names.index(channel)
            channel_indices.append(index)
        except ValueError:
            print('Channel{} is not present or again check the names of channel'.format())
    new_raw=raw.get_data(picks=channel_indices)
    return new_raw
            


In [None]:
def sliding_window_augmentation(data, window_size, stride):
    data=np.array(data)
    n_channels,n_timepoints,  = data.shape
    augmented_data = []
    for start in range(0, n_timepoints - window_size + 1, stride):
        end = start + window_size
        segment = data[:,start:end]
        augmented_data.append(segment)
    return np.array(augmented_data)


In [None]:


def select_features(new_raw, label, samp_freq=Sampling_Frequency, recording_time=25, augment_ratio=2, k=k):
    scaler = MinMaxScaler()
    trail_length = samp_freq * recording_time   # Calculating Length of each Trail
    stride = trail_length // augment_ratio
    trail_length_shape = np.shape(new_raw)[1]   # Determining the time of each channel
    total_chunks = trail_length_shape // trail_length    # Final How many chunks can be created 
    new_raw = new_raw[:, 0:total_chunks * trail_length]     # Changing Shape of Data
    augmented_eeg_data = sliding_window_augmentation(new_raw, trail_length, stride)
    
    feature = []
    for chunk in augmented_eeg_data:
        segg = []
        for i in range(chunk.shape[0]):
            fourier = np.abs(np.fft.fft(chunk[i, :]))
            # Select the top k features based on their scores
            selected_features = SelectKBest(score_func=f_classif, k=k)
            selected_features.fit_transform(fourier.reshape(1, -1), np.array([label]))  # Assuming label is binary
            selected_indices = selected_features.get_support(indices=True)
            selected_fourier = fourier[selected_indices]
            segg.append(selected_fourier)
        final_concate = np.concatenate(segg, axis=0)
        final_concate = final_concate.reshape(-1, 1)
        final_concate = scaler.fit_transform(final_concate)
        final_concate = final_concate.ravel()
        feature.append((final_concate, label))
    return feature


In [None]:
# Combining all and preprocessing the whole data
def preprocess(subject_name,indices,labels,number_of_subjects=14,train_percent=0.8,Band_Pass_Low_Range=0,Band_Pass_High_Range=50,trail_times=25,augmented_ratio=2):
    limit=math.ceil(train_percent*number_of_subjects)
    k=1
    train_data=[]
    test_data=[]
    for subject_number in indices:
        if exists('./Dataset/{}{}.edf'.format(subject_name,subject_number)):
            file='./Dataset/{}{}.edf'.format(subject_name,subject_number)
            raw=mne.io.read_raw_edf(file,verbose=0,preload=True,eog=None,exclude=(),stim_channel='auto')
            raw=raw.pick_types(eeg=True)
            raw=raw.filter(Band_Pass_Low_Range,Band_Pass_High_Range,iir_params=iir_params,method='iir')
            new_raw=select_channel(raw)
            #new_raw=normalize_eeg_data(new_raw)
            chunks=select_features(new_raw,label=labels,recording_time=trail_times,augment_ratio=augmented_ratio)
            if k<=limit:
                train_data+=chunks
            else:
                test_data+=chunks
            k=k+1
    return train_data,test_data



In [None]:
def batches_division(pos_class_train,neg_class_train,batch_size=32,train_set=True):
    if train_set==True:
        num_batches=max(len(pos_class_train),len(neg_class_train))//batch_size
    else:
        num_batches=min(len(pos_class_train),len(neg_class_train))//batch_size

    batch_size=batch_size//2
    samples_select=batch_size//2
    selected_batch = set()
    combined_batches = []
    pos_batches = [pos_class_train[i * batch_size: (i + 1) * batch_size] for i in range(num_batches)]
    neg_batches = [neg_class_train[i * batch_size: (i + 1) * batch_size] for i in range(num_batches)]
    if len(pos_class_train)>len(neg_class_train):
        pos_batches_filled = copy.deepcopy(neg_batches)
    else:
        pos_batches_filled = copy.deepcopy(pos_batches) 

    while any(len(batch) < batch_size for batch in pos_batches_filled):
        for i in range(len(pos_batches_filled)):
            if len(pos_batches_filled[i]) == 0 or len(pos_batches_filled[i]) < batch_size:
                num_of_samples_to_find = batch_size - len(pos_batches_filled[i])
                non_empty_batches_indices = [j for j in range(len(pos_batches_filled)) if
                                            len(pos_batches_filled[j]) > 0 and len(pos_batches_filled[j]) == batch_size]
                random_non_empty_batch_index = np.random.choice(non_empty_batches_indices)
                while random_non_empty_batch_index in selected_batch:
                    random_non_empty_batch_index = np.random.choice(non_empty_batches_indices)
                selected_batch.add(random_non_empty_batch_index)
                random_non_empty_batch = pos_batches_filled[random_non_empty_batch_index]
                num_samples_to_pick = min(samples_select, num_of_samples_to_find)
                random_samples_indices = np.random.choice(range(len(random_non_empty_batch)), num_samples_to_pick,
                                                        replace=False)
                for index in random_samples_indices:
                    pos_batches_filled[i].append(random_non_empty_batch[index])
        pos_batches=pos_batches_filled
    for pos_batch, neg_batch in zip(pos_batches, neg_batches):
        combined_batch = []
        for pos_chunk, neg_chunk in zip(pos_batch, neg_batch):
            combined_batch.append(pos_chunk)
            combined_batch.append(neg_chunk)
        random.shuffle(combined_batch)
        combined_batches.append(combined_batch)
    random.shuffle(combined_batches)


    return combined_batches,pos_batches,neg_batches


In [None]:
pos_class_train,pos_class_test=preprocess('h',indices=indice,labels=0,trail_times=recording_time,augmented_ratio=3,train_percent=0.8)
neg_class_train,neg_class_test=preprocess('s',indices=indice,labels=1,trail_times=recording_time,augmented_ratio=3,train_percent=0.8)
training_data,_,_=batches_division(pos_class_train,neg_class_train,batch_size=32)
testing_data=random.shuffle(pos_class_test+neg_class_test)

In [None]:
# Using double backslashes
with open('./samples/training_augmented.pkl', 'wb') as f:
    pickle.dump(training_data, f)

with open('./samples/testing_augmented.pkl', 'wb') as f:
    pickle.dump(testing_data, f)