## Import some libraries

In [1]:
import numpy as np                                      # for dealing with data
import pandas as pd
from scipy.signal import butter, sosfiltfilt, sosfreqz  # for filtering
import matplotlib.pyplot as plt                         # for plotting
import seaborn as sns
from sklearn.model_selection import cross_val_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix, roc_curve, auc
import os
from os import listdir
from os.path import isfile, join, isdir

# For IC-U-Net
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint

# For ASR
import mne
import asrpy
from asrpy import asr_calibrate, asr_process

## Setting

In [2]:
fs = 200.0     # 200 Hz sampling rate
lowcut = 1.0   
highcut = 20.0

epoch_s = -100      # epoch starting time relative to stmulus in miliseconds
epoch_e = 600    # epoch ending time relative to stmulus in miliseconds
bl_s = -100         # baseline starting time relative to stmulus in miliseconds
bl_e = 0       # baseline ending time relative to stmulus in miliseconds
epoch_len = int((abs(epoch_s) + abs(epoch_e)) * (fs / 1000))
print(epoch_len)

train_subj_num = 16
test_subj_num = 10
stimulus_per_subj = 340
trial_per_subj = 5

channels = [
    'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
    'FT7', 'FC3', 'FCz', 'FC4', 'FT8',
    'T7', 'C3', 'Cz', 'C4', 'T8',
    'TP7', 'CP3', 'CPz', 'CP4', 'TP8',
    'P7', 'P3', 'Pz', 'P4', 'P8',
    'O1', 'POz', 'O2'
]
print(len(channels))

140
30


## Bandpass Filter

In [3]:
# For butterworth band pass filter
def butter_bandpass_filter(raw_data, fs, lowcut = 1.0, highcut = 40.0, order = 5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = butter(order, [low, high], analog = False, btype = 'band', output = 'sos')
    filted_data = sosfiltfilt(sos, raw_data)
    return filted_data

## Load data

In [4]:
train_labels = pd.read_csv("./data/TrainLabels.csv") #! Training's labels（ground truth）, 5440

train_list_arr = np.array(sorted(listdir('./data/train')))
train_list_np = np.reshape(
    train_list_arr, (train_subj_num, trial_per_subj)) #! (16, 5)

test_list_arr = np.array(sorted(listdir('./data/test')))
test_list_np = np.reshape(
    test_list_arr, (test_subj_num, trial_per_subj)) #! (10, 5)

print(train_list_np.shape, test_list_np.shape)

train_data_list = np.empty(
    (0, stimulus_per_subj, len(channels), epoch_len), np.float64) #! (0, 340, 30, 140) or (0, 340, 30, 156)
test_data_list = np.empty(
    (0, stimulus_per_subj, len(channels), epoch_len), np.float64) #! (0, 340, 30, 140) or (0, 340, 30, 156)

print(train_data_list.shape, test_data_list.shape)

(16, 5) (10, 5)
(0, 340, 30, 140) (0, 340, 30, 140)


## IC-U-Net (need model/ and the .pth file)

In [5]:
import numpy as np
import csv
import time
import torch
import os
import shutil
from scipy.signal import decimate, resample_poly, firwin, lfilter

import torch
from torch import nn
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"]="0"

class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, ks=7):
        super().__init__()
        padding = int((ks - 1) / 2)

        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv1d(in_channels, middle_channels, kernel_size=ks, padding=padding)
        self.bn1 = nn.BatchNorm1d(middle_channels)
        self.conv2 = nn.Conv1d(middle_channels, out_channels, kernel_size=ks, padding=padding)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class NestedUNet3(nn.Module):
    def __init__(self, num_classes, input_channels=30, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool1d(2)
        # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up = nn.Upsample(scale_factor=2, mode='linear', align_corners=False)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])

        #if self.deep_supervision:
        self.final1 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
        self.final2 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
        self.final3 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
        #else:
            #self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
        #print("input:", input.shape)
        #print("x0_0: ", x0_0.shape)
        #print("pool: ", self.pool(x0_0).shape)
        #print("x1_0: ", x1_0.shape)
        #print("up:   ", self.up(x1_0).shape)
        #print("cat:  ", torch.cat([x0_0, self.up(x1_0)], 1).shape)
        #print("x0_1: ", x0_1.shape)

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        #if self.deep_supervision:
        output1 = self.final1(x0_1)
        output2 = self.final2(x0_2)
        output3 = self.final3(x0_3)
        return output1, output2, output3
        """
        else:
            output = self.final(x0_4)
            return output
        """

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels,kernel_size=7):
        super().__init__()
        padding = int((kernel_size - 1) / 2)

        self.double_conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm1d(out_channels),
            nn.Sigmoid(),
            #nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.BatchNorm1d(out_channels),
            #nn.ReLU(inplace=True)
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(2),
            DoubleConv(in_channels, out_channels,kernel_size)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, kernel_size, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            # self.up = F.interpolate()
            self.up = nn.Upsample(scale_factor=2, mode='linear', align_corners=False)
        else:
            self.up = nn.ConvTranspose1d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels, kernel_size)

    def forward(self, x1, x2):
        x = self.up(x1)
        # input is CHW
        #diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        #diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        #x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
        #                diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        #x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(OutConv, self).__init__()
        padding = int((kernel_size - 1) / 2)
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=True)

    def forward(self, x):
        return self.conv(x)

class UNet1(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet1, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64, kernel_size=7)
        self.down1 = Down(64, 128, kernel_size=7)
        self.down2 = Down(128, 256,kernel_size=5)
        self.down3 = Down(256, 512,kernel_size=3)
        self.up1 = Up(512, 256, kernel_size=3)
        self.up2 = Up(256, 128, kernel_size=3)
        self.up3 = Up(128, 64, kernel_size=3)
        self.outc = OutConv(64, n_classes,kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        logits = self.outc(x)
        return logits

def resample(signal, fs):
    # downsample the signal to a sample rate of 256 Hz
    if fs>256:
        fs_down = 256 # Desired sample rate
        q = int(fs / fs_down) # Downsampling factor
        signal_new = []
        for ch in signal:
            x_down = decimate(ch, q)
            signal_new.append(x_down)

    # upsample the signal to a sample rate of 256 Hz
    elif fs<256:
        fs_up = 256  # Desired sample rate
        p = int(fs_up / fs)  # Upsampling factor 
        signal_new = []
        for ch in signal:
            x_up = resample_poly(ch, p, 1)
            signal_new.append(x_up)

    else:
        signal_new = signal

    signal_new = np.array(signal_new).astype(np.float64)

    return signal_new

def FIR_filter(signal, lowcut, highcut):
    fs = 256.0
    # Number of FIR filter taps
    numtaps = 1000
    # Use firwin to create a bandpass FIR filter
    fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs)
    # Apply the filter to signal:
    filtered_signal  = lfilter(fir_coeff, 1.0, signal)
    
    return filtered_signal


def read_train_data(file_name):
    with open(file_name, 'r', newline='') as f:
        lines = csv.reader(f)
        data = []
        for line in lines:
            data.append(line)

    data = np.array(data).astype(np.float64)
    return data


def cut_data(raw_data):
    raw_data = np.array(raw_data).astype(np.float64)
    total = int(len(raw_data[0]) / 1024)
    if total == 0:
        total = 1
    for i in range(total):
        if total == 1:
            table = raw_data
        else:
            table = raw_data[:, i * 1024:(i + 1) * 1024]
        filename = './temp2/' + str(i) + '.csv'
        with open(filename, 'w', newline='') as csvfile:
            print('Writing file {}'.format(filename))
            writer = csv.writer(csvfile)
            writer.writerows(table)
    return total


def glue_data(file_name, total):
    gluedata = 0
    for i in range(total):
        file_name1 = file_name + 'output{}.csv'.format(str(i))
        with open(file_name1, 'r', newline='') as f:
            lines = csv.reader(f)
            raw_data = []
            for line in lines:
                raw_data.append(line)
        raw_data = np.array(raw_data).astype(np.float64)
        if i == 0:
            gluedata = raw_data
        else:
            smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2
            gluedata[:, -1] = smooth
            raw_data[:, 1] = smooth
            gluedata = np.append(gluedata, raw_data, axis=1)
    return gluedata


def save_data(data, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(data)

def dataDelete(path):
    try:
        shutil.rmtree(path)
    except OSError as e:
        print(e)
    else:
        pass


def decode_data(data, std_num, mode=5):

    if mode == "ICUNet":
        model = UNet1(n_channels=30, n_classes=30)
        resumeLoc = './model/ICUNet/modelsave' + '/BEST_checkpoint.pth.tar'

    elif mode == "UNetpp":
        model = NestedUNet3(num_classes=30)
        resumeLoc = './model/UNetpp/modelsave' + '/checkpoint.pth.tar'

    checkpoint = torch.load(resumeLoc, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'],False)
    model.eval()
    with torch.no_grad():
        # run the mdoel
        data = data[np.newaxis, :, :]
        data = torch.Tensor(data)
        if mode == "UNetpp" or mode == "UNetpp_block" or mode == "Trans" or mode == "Trans_block":
            decode1, decode2, decode = model(data)
        else:
            decode = model(data)
        if int(std_num) != 0:
            decode = decode * std_num
    decode = np.array(decode.cpu()).astype(np.float64)
    return decode

def preprocessing(signal, samplerate):
    # establish temp folder
    if not os.path.exists("./temp2/"):
        os.makedirs("./temp2/", exist_ok=True)
    else :
        dataDelete("./temp2/")
        os.makedirs("./temp2/", exist_ok=True)

    # resample
    signal = resample(signal, samplerate)
    # FIR_filter
    signal = FIR_filter(signal, 1, 40) #! In original code highcut = 50
    # cutting data
    total_file_num = cut_data(signal)

    return total_file_num


# model = tf.keras.models.load_model('./denoise_model/')
def reconstruct(model_name, total):
    # -------------------decode_data---------------------------
    second1 = time.time()
    for i in range(total):
        file_name = './temp2/{}.csv'.format(str(i))
        data_noise = read_train_data(file_name)

        std = np.std(data_noise)
        avg = np.average(data_noise)

        data_noise = (data_noise-avg)/std

        # UNet
        d_data = decode_data(data_noise, std, model_name)
        d_data = d_data[0]

        outputname = "./temp2/output{}.csv".format(str(i))
        save_data(d_data, outputname)

    # --------------------glue_data----------------------------
    data_after_denoised = glue_data("./temp2/", total)
    # -------------------delete_data---------------------------
    dataDelete("./temp2/")
    second2 = time.time()

    print("Reconstruct has been success ", second2 - second1, "sec(s)")
    return data_after_denoised

## Set up denoising type

In [19]:
denoise_type = ""
# denoise_type = "ASR"
# denoise_type = "filter"
# denoise_type = "ICUNet"
# denoise_type = "ICA"

## Epoch Generation

In [20]:
def generate_epoch(file_path, channels, fs, eeg_filter, stimulus_times=None, baseline=True,  epoch_s=-100, epoch_e=600, bl_s=-100, bl_e=0):
    # read data and data selection
    train_data = pd.read_csv(file_path)
    
    # Use Pandas to read EEG data and convert the time column (assumed to be in seconds) to milliseconds.
    # Extract the specified EEG channel data and convert it to NumPy format.
    train_data['Time'] *= 1000
    
    raw_eeg = train_data[channels].values.T.astype(np.float64)
    print('EEG shape after channel selection', raw_eeg.shape)
    
    # Get the index of the stimulus
    # According to the data in the FeedBackEvent column equal to 1, determine the index of the event occurrence
    train_data['index'] = train_data.index.values
    if stimulus_times is None:
        mark_indices = train_data[train_data['FeedBackEvent'] == 1].index.to_numpy()
    else:
        mark_indices = np.round(np.asarray(
            stimulus_times).flatten() * fs).astype(int)
    
    # Determine the length of the epoch
    # b_s and b_e: baseline correction window start and end index.
    # epoch_len: the total length of each epoch (in sample points).
    # e_s and e_e: relative start and end index of the epoch.
    b_s = int((abs(epoch_s) + bl_s) * (fs / 1000)) # 0
    b_e = int((abs(epoch_s) + bl_e) * (fs / 1000)) # 100
    epoch_len = int((abs(epoch_s) + abs(epoch_e)) * (fs / 1000)) # 140 or 136
    print('epoch_len: ', epoch_len)
    
    # Because each time is 5ms, the length of idx needs to be converted to 700 / 5 = 140
    e_s = int((epoch_s * (fs / 1000))) # -20
    e_e = int((epoch_e * (fs / 1000))) # 120
    
    if denoise_type == "ICUNet":
        print('Denoising By ICUNet...')
    elif denoise_type == "ASR":
        print('Denoising By ASR...')
    elif denoise_type == "filter":
        print('Filtering only...')
    else:
        print('No any denoising method...')
    
    # Step 1: Apply denoising to the entire signal first
    denoised_data = {}
    
    # Apply different denoising methods to the entire signal
    if denoise_type == "filter":
        # Apply filter to entire signal for each channel
        for channel in channels:
            rawEEG = train_data[channel].values.astype(np.float64)
            denoised_data[channel] = eeg_filter(rawEEG, fs, 1, 50)
    elif denoise_type == "ICUNet":
        # For ICUNet, we need to process the entire multi-channel signal
        # Create a matrix with all channels for the entire signal
        full_signal_matrix = np.zeros((len(channels), len(train_data)), dtype=np.float64)
        for j, channel in enumerate(channels):
            full_signal_matrix[j, :] = train_data[channel].values.astype(np.float64)
        
        # Apply ICUNet denoising to the entire signal
        total_file_num = preprocessing(full_signal_matrix, fs)
        denoised_full_signal = reconstruct("ICUNet", total_file_num)
        
        # Store denoised signal back to dictionary
        for j, channel in enumerate(channels):
            denoised_data[channel] = denoised_full_signal[j, :]
    elif denoise_type == "ASR":
        # For ASR, process the entire multi-channel signal
        full_signal_matrix = np.zeros((len(channels), len(train_data)), dtype=np.float64)
        for j, channel in enumerate(channels):
            full_signal_matrix[j, :] = train_data[channel].values.astype(np.float64)
        
        # Apply ASR calibration and processing to the entire signal
        M, T = asr_calibrate(full_signal_matrix, fs, cutoff=5)
        denoised_full_signal = asr_process(full_signal_matrix, fs, M, T)
        
        # Store denoised signal back to dictionary
        for j, channel in enumerate(channels):
            denoised_data[channel] = denoised_full_signal[j, :]
    else:
        # No denoising, use original data
        for channel in channels:
            denoised_data[channel] = train_data[channel].values.astype(np.float64)
    
    # Step 2: Extract epochs from the denoised signal
    # Determine final epoch dimensions based on denoising method
    if denoise_type == "ICUNet":
        final_epoch_len = 136  # ICUNet reduces the epoch length
        final_epoch = np.zeros((len(mark_indices), len(channels), final_epoch_len), dtype=np.float64)
    else:
        final_epoch_len = 140  # Standard epoch length
        final_epoch = np.zeros((len(mark_indices), len(channels), final_epoch_len), dtype=np.float64)
    
    # Iterate over each event
    for i, mark_idx in enumerate(mark_indices):
        # Extract the epoch for each channel from denoised signal
        for j, channel in enumerate(channels):
            # Extract the data segment of the event from denoised signal
            if denoise_type == "ICUNet":
                # For ICUNet, adjust indices due to potential length change
                adjusted_e_s = max(0, mark_idx + e_s)
                adjusted_e_e = min(len(denoised_data[channel]), mark_idx + e_e)
                epoch = denoised_data[channel][adjusted_e_s:adjusted_e_e]
                
                # Adjust epoch length for ICUNet (might be shorter)
                if len(epoch) > final_epoch_len:
                    epoch = epoch[:final_epoch_len]
                elif len(epoch) < final_epoch_len:
                    # Pad with zeros if needed
                    padded_epoch = np.zeros(final_epoch_len)
                    padded_epoch[:len(epoch)] = epoch
                    epoch = padded_epoch
            else:
                epoch = denoised_data[channel][mark_idx + e_s: mark_idx + e_e]
            
            # If the length is not enough, skip
            if len(epoch) != final_epoch_len:
                print(f'Epoch length not match for event {i}, channel {channel}, expected {final_epoch_len}, got {len(epoch)}, skip')
                continue
            
            # Baseline correction (if enabled)
            if baseline:
                baseline_mean = np.mean(epoch[b_s:b_e])
                epoch -= baseline_mean
            
            # Save the data to the matrix
            final_epoch[i, j, :] = epoch
            
    return final_epoch

## Preparing data

In [21]:
# Last shape: (16, 340, 30, 140) or (16, 340, 30, 136)
train_data_list = np.empty((0, stimulus_per_subj, len(channels), epoch_len), np.float64)
train_data_list_ICUNet = np.empty((0, stimulus_per_subj, len(channels), 136), np.float64)

# Preparing training data
if not isfile("./data/train_data.npy"):
    for training_participant_id in range(train_subj_num): # 0-15
        subject_dir_list = train_list_np[training_participant_id]
        subject_epoch = np.empty((0, len(channels), epoch_len), np.float64) # (0, 30, 140)
        subject_epoch_ICUNet = np.empty((0, len(channels), 136), np.float64) # (0, 30, 136)

        for trial_id in range(trial_per_subj): # 0-4
            subject_dir = subject_dir_list[trial_id]
            print('Subject directory: ', subject_dir)
            
            data = generate_epoch('./data/train/'+subject_dir, channels, fs,
                    butter_bandpass_filter, epoch_s = epoch_s, epoch_e = epoch_e, bl_s = bl_s, bl_e = bl_e)
            
            print('Epoched data shape: ', data.shape)
            if denoise_type == "ICUNet":
                subject_epoch_ICUNet = np.vstack((subject_epoch_ICUNet, data)) # Last shape: (340, 30, 136)
            else:
                subject_epoch = np.vstack((subject_epoch, data)) # Last shape: (340, 30, 140)

        if denoise_type == "ICUNet":
            subject_epoch_ICUNet = np.expand_dims(subject_epoch_ICUNet, axis=0)
        else:
            subject_epoch = np.expand_dims(subject_epoch, axis=0)
        
        if denoise_type == "ICUNet":
            print('Epoched subject data shape: ' + str(subject_epoch_ICUNet.shape))
        else:
            print('Epoched subject data shape: ' + str(subject_epoch.shape))
        
        # (16, 340, 30, 140) all subjects' data
        if denoise_type == "ICUNet":
            train_data_list_ICUNet = np.vstack((train_data_list_ICUNet, subject_epoch_ICUNet))
        else:
            train_data_list = np.vstack((train_data_list, subject_epoch))

    # Store the data after denoising
    if denoise_type == "ICUNet":
        print('Train data list denoised: ', train_data_list_ICUNet.shape)
        np.save('./data/train_data.npy', train_data_list_ICUNet)
    else:
        print('Train data list: ', train_data_list.shape)
        np.save('./data/train_data.npy', train_data_list)


Subject directory:  Data_S02_Sess01.csv
EEG shape after channel selection (30, 132001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S02_Sess02.csv
EEG shape after channel selection (30, 128001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S02_Sess03.csv
EEG shape after channel selection (30, 127001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S02_Sess04.csv
EEG shape after channel selection (30, 128001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S02_Sess05.csv
EEG shape after channel selection (30, 196001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (100, 30, 140)
Epoched subject data shape: (1, 340, 30, 140)
Subject directory:  Data_S06_Sess01.csv
EEG shape after channel selection (30, 132001)
epoch_len:  140
No any denoising method..

In [22]:
# Preparing testing data

# Last shape: (16, 340, 30, 140) or (16, 340, 30, 136)
test_data_list = np.empty((0, stimulus_per_subj, len(channels), epoch_len), np.float64)
test_data_list_ICUNet = np.empty((0, stimulus_per_subj, len(channels), 136), np.float64)

if not isfile("./data/test_data.npy"):
    for testing_participant_id in range(test_subj_num): # 0-15
        subject_dir_list = test_list_np[testing_participant_id]
        subject_epoch = np.empty((0, len(channels), epoch_len), np.float64) # (0, 30, 140)
        subject_epoch_denoised = np.empty((0, len(channels), 136), np.float64) # (0, 30, 136)

        for trial_id in range(trial_per_subj): # 0-4
            subject_dir = subject_dir_list[trial_id]
            print('Subject directory: ', subject_dir)
            
            data = generate_epoch('./data/test/'+subject_dir, channels, fs,
                butter_bandpass_filter, baseline=True, epoch_s = epoch_s, epoch_e = epoch_e, bl_s = bl_s, bl_e = bl_e)
            
            print('Epoched data shape: ', data.shape)
            if denoise_type == "ICUNet":
                subject_epoch_denoised = np.vstack((subject_epoch_denoised, data)) # Last shape: (340, 30, 136)
            else:
                subject_epoch = np.vstack((subject_epoch, data)) # Last shape: (340, 30, 140)

        if denoise_type == "ICUNet":
            subject_epoch_denoised = np.expand_dims(subject_epoch_denoised, axis=0)
        else:
            subject_epoch = np.expand_dims(subject_epoch, axis=0)
        
        if denoise_type == "ICUNet":
            print('Epoched subject data shape: ' + str(subject_epoch_denoised.shape))
        else:
            print('Epoched subject data shape: ' + str(subject_epoch.shape))
        
        if denoise_type == "ICUNet":
            test_data_list_ICUNet = np.vstack((test_data_list_ICUNet, subject_epoch_denoised))
        else:
            test_data_list = np.vstack((test_data_list, subject_epoch)) # Last shape: (16, 340, 30, 140)

    if denoise_type == "ICUNet":
        print('test data list denoised: ', test_data_list_ICUNet.shape)
        np.save('./data/test_data.npy', test_data_list_ICUNet)
    else:
        print('test data list: ', test_data_list.shape)
        np.save('./data/test_data.npy', test_data_list)

Subject directory:  Data_S01_Sess01.csv
EEG shape after channel selection (30, 127401)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S01_Sess02.csv
EEG shape after channel selection (30, 120801)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S01_Sess03.csv
EEG shape after channel selection (30, 120801)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S01_Sess04.csv
EEG shape after channel selection (30, 123001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (60, 30, 140)
Subject directory:  Data_S01_Sess05.csv
EEG shape after channel selection (30, 194001)
epoch_len:  140
No any denoising method...
Epoched data shape:  (100, 30, 140)
Epoched subject data shape: (1, 340, 30, 140)
Subject directory:  Data_S03_Sess01.csv
EEG shape after channel selection (30, 138001)
epoch_len:  140
No any denoising method..