In [59]:
import pandas as pd
import numpy as np
import os
import shutil
import matplotlib.pyplot as plt
from scipy import interpolate, signal
import warnings
import mne
import random
import torch
warnings.filterwarnings('ignore')

# Raw data preprocessing
We segmented the original tsv file to obtain its valid information, and intercepted the data of the corresponding time period in which it viewed the videos for segmentation to obtain the relevant eye movement data about the 10 videos viewed in disordered order for each subject.

In [60]:
def from_time_2_stamp(time):
    time = time.split(':')
    assert len(time) == 3
    stamp    = 0
    time_h   = 60*60*1000*int(time[0])
    time_m   = 60*1000*int(time[1])
    time_s   = time[2].split('.')[0]
    time_s   = 1000*int(time_s)
    time_ms  = time[2].split('.')[1][:3]
    time_ms  = int(time_ms)
    stamp    = time_h + time_m + time_s + time_ms
    return stamp

In [61]:
def read_trigger(eye_data, trigger_data):

    start_time = eye_data['LocalTimeStamp'].values[0] 
    base       = from_time_2_stamp(start_time)

    trigger = []
    for i in range(trigger_data.shape[0]):
        cur_trigger = []
        cur_trigger.append(int(trigger_data.iloc[i, 0]))
        stamp = from_time_2_stamp(trigger_data.iloc[i, 1].split(' ')[1])
        # cur_trigger.append((stamp - base)*1000)
        cur_trigger.append((stamp - base))
        trigger.append(cur_trigger)
    return trigger

In [62]:
def find_stamp(trigger, start, end):
    start_idx = []
    end_idx   = []
    for i in range(len(trigger)):
        if trigger[i][0] == start:
            start_idx.append(trigger[i][1])
        if trigger[i][0] == end:
            end_idx.append(trigger[i][1])
    return start_idx, end_idx


In [63]:
def find_index_for_a_trigger(time_col, trigger):
    left = 0
    right = len(time_col) - 1
    while right - left > 1:
        mid = int(left + (right - left) / 2)
        if time_col[mid] == trigger:
            return mid
        elif time_col[mid] < trigger:
            left = mid
        else:
            right = mid
    if trigger <= (time_col[left] + time_col[right]) / 2:
        return left
    else:
        return right

In [64]:
def EYE_preprocess(COLUMN_USED, path = './raw_data', clip_path = './raw_data_clip'):
    """
    Main function of eye movement file preprocessing
    """
    if os.path.exists(clip_path):
        shutil.rmtree(clip_path)
    os.mkdir(clip_path)
    id_list = os.listdir(path)
    id_list.sort()
    for id in id_list:
        print("Begin to process person {}.".format(id))
        id_path = os.path.join(path, id)
        clip_path_one=os.path.join(clip_path,id)
        os.mkdir(clip_path_one)
        eye_file, trigger_file, info_file = None, None, None
        for file in os.listdir(id_path):
            if file.split('.')[1] == 'tsv':
                eye_file = os.path.join(id_path, file)
            if file.split('.')[1] == 'csv' and file.split('_')[-2] == 'trigger':
                trigger_file = os.path.join(id_path, file)
            if file.split('.')[1] == 'csv' and file.split('_')[-2] == 'save':
                info_file = os.path.join(id_path, file)
        eye_data = pd.read_csv(eye_file, sep='\t', low_memory = False)
        eye_data = eye_data[COLUMN_USED]
        trigger_data = pd.read_csv(trigger_file)
        trigger = read_trigger(eye_data, trigger_data)
        stamp = find_stamp(trigger, 1, 2)
        for i in range(len(stamp[0])):
            start_ind = find_index_for_a_trigger(eye_data['RecordingTimestamp'], stamp[0][i])
            end_ind = find_index_for_a_trigger(eye_data['RecordingTimestamp'], stamp[1][i])
            # clip_data = eye_data[stamp[0][i]:stamp[1][i]]
            clip_data = eye_data[start_ind:end_ind]
            save_path = os.path.join(clip_path_one, id[2] + '_' + id + '_' + str(i) + '.csv')
            clip_data=clip_data[['GazePointLeftX (ADCSpx)','GazePointRightX (ADCSpx)', 'GazeEventType']]
            clip_data.to_csv(save_path)

# Label Processing
The next step is to determine whether the eye movement signal is focused or not based on the degree of focusing, i.e., the positional difference between the left and right eyes. We retain the *GazeEventType* attribute of the eye movement data and compute the positional difference between the left and right eyes, and use a combination of the standard deviation and the mean to make a judgment.

In addition, considering the large noise level and the existence of segments that did not successfully record valid eye movement data, we firstly interpolate them and finally smooth them to modify the intervals that are too short (usually caused by noise) to obtain a smoother focused and unfocused interval. 

We end up using the *Focus* attribute to record whether it's focused or not. To make subsequent EEG cutting and alignment easier, we chose to splice the 10-video slices to obtain a total data file about each subject. 

In addition, since the images are played in disorganized order, they need to be sorted chronologically according to the timestamp at the end of the merge to ensure that the eye-tracking labels are saved strictly in chronological order.

In [65]:
def Denoise(data, interval = 10):
    sequence = data['Focus'].tolist()
    i = 0
    while i < len(sequence):
        if sequence[i] == 0:
            count = 0
            j = i
            while j < len(sequence) and sequence[j] == 0:
                count += 1
                j += 1
            if count <= interval:
                for k in range(i, j):
                    sequence[k] = 1
            i = j
        else:
            i += 1
    data['Focus'] = sequence
    return data

In [66]:
def Focus(data, interval = 10):
    for i in range(len(data['GazePointLeftX (ADCSpx)'])):
        if (pd.isnull(data['GazePointLeftX (ADCSpx)'][i]) and not pd.isnull(data['GazePointRightX (ADCSpx)'][i])) or (not pd.isnull(data['GazePointLeftX (ADCSpx)'][i]) and pd.isnull(data['GazePointRightX (ADCSpx)'][i])):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                data['GazePointLeftX (ADCSpx)'][i] = np.nan
                data['GazePointRightX (ADCSpx)'][i] = np.nan
    data['GazePointLeftX (ADCSpx)'] = data['GazePointLeftX (ADCSpx)'].interpolate(method = 'cubic')
    data['GazePointRightX (ADCSpx)'] = data['GazePointRightX (ADCSpx)'].interpolate(method = 'cubic')
    data['Gaze difference'] = data['GazePointLeftX (ADCSpx)']-data['GazePointRightX (ADCSpx)']
    mean = data['Gaze difference'].mean()
    std = data['Gaze difference'].std()
    data['Focus'] = 1
    for i in data.index.tolist():
        data['Focus'][i] = 1 if data['Gaze difference'][i] <= (mean + std) else 0
    data = data.drop(columns = ['GazePointLeftX (ADCSpx)', 'GazePointRightX (ADCSpx)'])
    return Denoise(data = data, interval = interval)

In [67]:
def Display_focus(data, name):
    fig, ax1 = plt.subplots(figsize = (5,4))
    ax1.set_xlabel('Timestamp')
    ax1.set_ylabel('Gaze difference', color='b')
    ax1.plot(data.index.tolist(),data['Gaze difference'], color = 'b', ls = '-', lw = 1)
    # ax1.set_xticks(data.index.tolist())
    ax1.tick_params(axis = 'y', labelcolor = 'b')
    ax2 = ax1.twinx()  
    ax2.set_ylabel('Focus', color='r') 
    ax2.plot(data.index.tolist(),data['Focus'], color = 'r', ls = ':', lw = 0.5)
    ax2.tick_params(axis = 'y', labelcolor = 'r')
    fig.tight_layout() 
    plt.title(name)
    plt.show()

In [68]:
def Label_process(pathbase, savebase, display = False, interval = 10):
    """
    Main function of label processing
    """
    if os.path.exists(savebase):
        shutil.rmtree(savebase)
    os.mkdir(savebase)
    person = os.listdir(pathbase)
    person.sort()
    for j in person:
        print("Begin to process person {}.".format(j))
        perpath = os.path.join(pathbase,j)
        eye_list = os.listdir(perpath)
        eye_list.sort()
        persave = os.path.join(savebase,j)
        os.mkdir(persave)
        for i in eye_list:
            path = os.path.join(perpath,i)
            data = pd.read_csv(path)
            if data['GazePointLeftX (ADCSpx)'].isnull().all() and data['GazePointRightX (ADCSpx)'].isnull().all():
                print("{} has no valid data!".format(i))
                continue
            data = Focus(data = data, interval = interval)
            if display:
                Display_focus(data = data, name = i)
            data = data.drop(columns = 'Gaze difference')
            savepath = os.path.join(persave,i)
            data.columns = ['Timestamp', 'GazeEventType', 'Focus']
            for k in data.index.tolist():
                data['GazeEventType'][k] = 1 if data['GazeEventType'][k] == 'Fixation' else 0
            data.to_csv(savepath, index = False)
            print("Finished processing {}!".format(i))

# Alignment of EEG raw data and eye movement label
In the following, we preprocessed the EEG signals. Considering that the sampling frequencies of EEG and eye movement are different, we first need to downsample the EEG signal so that it can match the eye movement label perfectly. According to the event label, the data of the corresponding time period of viewing the video is taken out and downsampled, at this time, the data is strictly in chronological order, and then, it is merged with the eye movement labels to become a complete preliminary dataset.

It should be noted that, since each sample contains two label attributes, in order to be more accurate, we set that only when the attributes of *GazeEventType* and *Focus* are both 1, we will determine that it is a positive sample, otherwise, it is a negative sample, and we use new labels to indicate the positive and negative attributes. After such an operation it is still possible that the interval is too short, i.e., there is noise, so we perform smoothing again.

This is followed by the separation of positive and negative sample intervals. We separated the positive and negative samples into two separate data sets based on the labels. However, considering that if the intervals are too short it may cause problems in subsequent feature extraction, which is usually due to the fact that too short intervals lead to the inability to obtain enough frequency bands when they are subjected to the short-time Fourier transform, we screened them again, and obtained more robust data after eliminating the shorter intervals.

In [69]:
def Down_sample(perpath, eye_path):
    for name in os.listdir(perpath):
        if name.split('.')[1] == 'cnt':
            if name[-9:-4] == 'curry':
                cntpath_curry = os.path.join(perpath, name)
            else:
                cntpath = os.path.join(perpath, name)
    cntindex = mne.io.read_raw_cnt(cntpath)
    cntdata = mne.io.read_raw_cnt(cntpath_curry)
    raw_data, _ = cntdata.get_data(return_times = True)
    events, dict = mne.events_from_annotations(cntindex)
    start, end = events[events[:,2] == dict['1']][:,0], events[events[:,2] == dict['2']][:,0]
    videos = []
    eye_list = os.listdir(eye_path)
    for i in range(start.shape[0]):
        video = raw_data[:,start[i]:end[i]].tolist()
        eye_file = pd.read_csv(os.path.join(eye_path, eye_list[i]))
        for k in range(len(video)):
            video[k] = signal.resample(video[k], eye_file.shape[0]).tolist()
        video = np.array(video)
        videos.append(video.T)
    return videos

In [70]:
def Aligning(videos, labels):
    datasets = videos
    datasets1 = datasets[:,0:32]
    datasets2 = datasets[:,33:42]
    datasets3 = datasets[:,43:(datasets.shape[1]-3)]
    datasets = np.concatenate((datasets1, datasets2, datasets3), axis = 1)
    final_labels = []
    for i in range(labels.shape[0]):
        if labels[i][0] == 1 and labels[i][1] == 1:
            final_labels.append(1)
        else:
            final_labels.append(0)
    i = 0
    interval = 10
    sequence = final_labels
    while i < len(sequence):
        if sequence[i] == 0:
            count = 0
            j = i
            while j < len(sequence) and sequence[j] == 0:
                count += 1
                j += 1
            if count <= interval:
                for k in range(i, j):
                    sequence[k] = 1
            i = j
        else:
            i += 1
    final_labels = np.array(sequence)
    final_labels = np.expand_dims(final_labels, axis = 1)
    datasets = np.concatenate((datasets, final_labels), axis = 1)
    return datasets

In [71]:
def Segmentation(dataset):
    pos_data, neg_data = [], []
    start = end = 0
    for i in range(dataset.shape[0]-1):
        if dataset[i][-1] != dataset[i+1][-1] or (i+1) % 5001 == 0:
            end = i
            if end - start + 1 >= 50: 
                if dataset[i][-1] == 1:
                    pos_data.append(dataset[start:end+1,:])
                else:
                    neg_data.append(dataset[start:end+1,:])
                start = i+1
    return pos_data, neg_data

# Feature extraction
In the following, we need to find the relevant features of the EEG signals, so we perform a full-length short-time Fourier transform for all intervals, which is extracted as 5 samples of feature 67 channels, reshape them into a column, and subsequently add labels and merge the set of positive and negative samples to get the final usable dataset about each subject.

In [72]:
def get_average_psd(energy_graph, freq_bands, sample_freq, stft_n=256):
    start_index = int(np.floor(freq_bands[0] / sample_freq * stft_n))
    end_index = int(np.floor(freq_bands[1] / sample_freq * stft_n))
    ave_psd = np.mean(energy_graph[:, start_index - 1:end_index] ** 2, axis=1)
    return ave_psd

In [73]:
def extract_psd_feature(raw_data, freq_bands, stft_n=256):
    n_channels, n_samples = raw_data.shape
    psd_feature = np.zeros((len(freq_bands), n_channels))
    start_index, end_index = 0, n_samples
    window_data = raw_data[:, start_index:end_index]
    hdata = window_data * signal.hann(n_samples)
    fft_data = np.fft.fft(hdata, n=stft_n)
    energy_graph = np.abs(fft_data[:, 0: int(stft_n / 2)])
    for band_index, band in enumerate(freq_bands):
        band_ave_psd = get_average_psd(energy_graph, band, n_samples, stft_n)
        psd_feature[band_index, :] = band_ave_psd
    return psd_feature.T

In [74]:
def Extracton(data, freq_band, label, stft_n=256):
    for i in range(len(data)):
        data[i] = extract_psd_feature(data[i].T[0:data[i].shape[1]-1,:], freq_bands = freq_band, stft_n = stft_n)
        data[i] = data[i].reshape(1, data[i].size).squeeze().tolist()
        if label == 1:
            data[i].append(1)
        else:
            data[i].append(0)
    return np.array(data)

In [75]:
def Feature_extracton(path, savebase, data_path, freq_bands):
    """
    Main function of alignment and feature extraction
    """
    if os.path.exists(data_path):
        shutil.rmtree(data_path)
    os.mkdir(data_path)
    person = []
    for name in os.listdir(path):
        person.append(name)
    person.sort()
    for j in person:
        print("Begin to process person {}.".format(j))
        perpath = os.path.join(path,j)
        eye_path = os.path.join(savebase,j)
        os.mkdir(os.path.join(data_path, j))
        videos = Down_sample(perpath = perpath, eye_path = eye_path)
        id_list = os.listdir(eye_path)
        indices = []
        for i in range(len(id_list)):
            index = []
            tmp = pd.read_csv(os.path.join(eye_path, id_list[i]))
            index.append(tmp['Timestamp'][0])
            index.append(os.path.join(eye_path, id_list[i]))
            indices.append(index)
        indices = sorted(indices, key=(lambda x:x[0]))
        for i in range(len(indices)):
            label = pd.read_csv(indices[i][1])[['GazeEventType', 'Focus']]
            label = np.array(label)
            datasets = Aligning(videos = videos[i], labels = label)
            pos_data, neg_data = Segmentation(dataset = datasets)
            pos_data = Extracton(data = pos_data, freq_band = freq_bands, label = 1)
            neg_data = Extracton(data = neg_data, freq_band = freq_bands, label = 0)
            data = np.concatenate((pos_data, neg_data), axis = 0)
            data = data[~np.isnan(data).any(axis = 1)]
            np.random.shuffle(data)
            print(data.shape)
            np.save(os.path.join(data_path, j , str(i) + '.npy'), data)

# Project configurations
Here we define the various necessary paths, whether or not data preprocessing is required for a particular session, and the parameters that need to be read corresponding to certain data.

In [76]:
COLUMN_USED = ['RecordingTimestamp', 'LocalTimeStamp',
               'GazePointLeftX (ADCSpx)', 'GazePointRightX (ADCSpx)', 'GazeEventType']
preprocess, label, feature_extracton = False, False, False
path = './raw_data'
clip_path = './raw_data_clip'
savebase = './raw_data_focus'
data_path = './dataset'
freq_bands = [(1,4),(4,8),(8,14),(14,31),(31,49)]
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


# Preprocessing Operations
Here we select the module corresponding to the execution of data preprocessing by the configuration above.

In [77]:
if preprocess:
    EYE_preprocess(COLUMN_USED=COLUMN_USED,path = path, clip_path = clip_path)
if label:
    Label_process(pathbase = clip_path, savebase = savebase, display = False, interval = 10)
if feature_extracton:
    Feature_extracton(path = path, savebase = savebase, data_path = data_path, freq_bands = freq_bands)

Begin to process person huangsiye_20210529_1.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
(49, 311)
(57, 311)
(61, 311)
(39, 311)
(110, 311)
(101, 311)
(19, 311)
(43, 311)
(73, 311)
(59, 311)
Begin to process person huangsiye_20210531_2.
Used Annotations descriptions: ['1', '2', '3']
(72, 311)
(67, 311)
(70, 311)
(137, 311)
(62, 311)
(30, 311)
(48, 311)
(40, 311)
(106, 311)
(26, 311)
Begin to process person huangsiye_20210604_3.
Used Annotations descriptions: ['1', '2', '3']
(31, 311)
(48, 311)
(84, 311)
(34, 311)
(13, 311)
(80, 311)
(38, 311)
(52, 311)
(101, 311)
(52, 311)
Begin to process person liangjie_20210424_2.
Used Annotations descriptions: ['1', '2', '3', '6']
(139, 311)
(153, 311)
(128, 311)
(238, 311)
(98, 311)
(54, 311)
(68, 311)
(67, 311)
(150, 311)
(88, 311)
Begin to process person liuzhiwei_20210608_1.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
(86, 311)
(88, 311)
(115, 311)
(42, 311)
(158, 311)
(163, 311)
(34, 311)
(66, 311)
(171, 31