# Dataset

This notebook is to showcase what this dataset looks like. This notebook contains how to interact with the dataset as well as some of the helper functions this repo uses.

## Base

### Imports

In [None]:
from matplotlib import pyplot as plt
from matplotlib import rcParams
from scipy import signal
import pandas as pd
import numpy as np
import requests
import torch
import os

### Dataset and Figure Settings

This is a couple of settings for the dataset and the figures.

In [None]:
# This is the name we'll set for the file
fname = 'motor_imagery.npz'
# This is where we download the dataset from
url = "https://osf.io/ksqv8/download"

# Check if the file is already there
if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

# Figure settings
rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] = 15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

### Helper Functions

In [None]:
class process():
    def __init__():
        """
        This class is used to preprocess the data.
        The input is a dict with a key 'V' containing the voltage data.
        """
        return

    def preprocess(data):
        V = data['V'].astype('float32')
        b, a = signal.butter(3, [50], btype='high', fs=1000)
        V = signal.filtfilt(b, a, V, 0)
        V = np.abs(V)**2
        b, a = signal.butter(3, [10], btype='low', fs=1000)
        V = signal.filtfilt(b, a, V, 0)
        V = V/V.mean(0)
        return V

class plots():
    def __init__():
        """
        For plotting the dataset
        """
        return

    def singlechannel1(data, channel, trange, title=''):
        plt.figure(figsize=(20,10))
        if title:
            plt.suptitle(title, fontsize=20)
        plt.plot(trange, data[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def singlechannel2(data, data2, channel, trange, title=''):
        plt.figure(figsize=(20,10))
        if title:
            plt.suptitle(title, fontsize=20)
        plt.plot(trange, data[:,channel])
        plt.plot(trange, data2[:,channel])
        plt.title('ch%d'%channel)
        plt.xticks([0, 1000, 2000])
        plt.ylim([0, 4])

    def all_channels1(data, trange, title=''):
        plt.figure(figsize=(20,10))
        if title:
            plt.suptitle(title, fontsize=20)
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

    def all_channels2(data, data2, trange, title=''):
        plt.figure(figsize=(20,10))
        if title:
            plt.suptitle(title, fontsize=20)
        for j in range(46):
            ax = plt.subplot(5,10,j+1)
            plt.plot(trange, data[:,j])
            plt.plot(trange, data2[:,j])
            plt.title('ch%d'%j)
            plt.xticks([0, 1000, 2000])
            plt.ylim([0, 4])

## Basics of the Dataset

The dataset this repo uses is an ECoG dataset that contains imaginary and real movement of a couple of different stimuli. For this project, we primarily focus on imagining moving the tongue and imagining moving the hand.

In [None]:
alldat = np.load(fname, allow_pickle=True)['dat']

# select just one of the recordings here. 11 is nice because it has some neurons in vis ctx.
dat1 = alldat[0][0]
dat2 = alldat[0][1]

print(dat1.keys())
print(dat2.keys())

* `dat['V']`: continuous voltage data (time by channels)
* `dat['srate']`: acquisition rate (1000 Hz). All stimulus times are in units of this.  
* `dat['t_on']`: time of stimulus onset in data samples
* `dat['t_off']`: time of stimulus offset, always 400 samples after `t_on`
* `dat['stim_id`]: identity of stimulus (11 = tongue, 12 = hand), real or imaginary stimulus
* `dat['scale_uv']`: scale factor to multiply the data values to get to microvolts (uV). 
* `dat['locs`]`: 3D electrode positions on the brain surface

In [None]:
from nilearn import plotting
from nimare import utils

plt.figure(figsize=(8, 8))
locs = dat1['locs']
view = plotting.view_markers(utils.tal2mni(locs),
                             marker_labels=['%d'%k for k in np.arange(locs.shape[0])],
                             marker_color='purple',
                             marker_size=5)
view

In [None]:
# quick way to get broadband power in time-varying windows
from scipy import signal

# pick subject 0 and experiment 0 (real movements)
dat1 = alldat[0][0]

# V is the voltage data
V = dat1['V'].astype('float32')

# high-pass filter above 50 Hz
b, a = signal.butter(3, [50], btype='high', fs=1000)
V = signal.filtfilt(b, a, V, 0)

# compute smooth envelope of this signal = approx power
V = np.abs(V)**2
b, a = signal.butter(3, [10], btype='low', fs=1000)
V = signal.filtfilt(b, a, V, 0)

# normalize each channel so its mean power is 1
V = V/V.mean(0)

In [None]:
# average the broadband power across all tongue and hand trials
nt, nchan = V.shape
nstim = len(dat1['t_on'])

trange = np.arange(0, 2000)
ts = dat1['t_on'][:, np.newaxis] + trange
V_epochs = np.reshape(V[ts, :], (nstim, 2000, nchan))

V_tongue = (V_epochs[dat1['stim_id'] == 11]).mean(0)
V_hand = (V_epochs[dat1['stim_id'] == 12]).mean(0)

In [None]:
# let's find the electrodes that distinguish tongue from hand movements
# note the behaviors happen some time after the visual cue

plt.figure(figsize=(20, 10))
for j in range(46):
  ax = plt.subplot(5, 10, j+1)
  #plt.plot(trange, V_tongue[:, j])
  plt.plot(trange, V_hand[:, j])
  plt.title('ch%d'%j)
  plt.xticks([0, 1000, 2000])
  plt.ylim([0, 4])
plt.show()

In [None]:
# let's look at all the trials for electrode 20 that has a good response to hand movements
# we will sort trials by stimulus id
plt.subplot(1, 3, 1)
isort = np.argsort(dat1['stim_id'])
plt.imshow(V_epochs[isort, :, 20].astype('float32'),
           aspect='auto',
           vmax=7, vmin=0,
           cmap='magma')
plt.colorbar()
plt.show()

In [None]:
# Electrode 42 seems to respond to tongue movements
isort = np.argsort(dat1['stim_id'])
plt.subplot(1, 3, 1)
plt.imshow(V_epochs[isort, :, 42].astype('float32'),
           aspect='auto',
           vmax=7, vmin=0,
           cmap='magma')
plt.colorbar()
plt.show()

## Setting up the Dataset

In [None]:
# Options are 'test' and 'real'. Test is smaller version of the dataset, real is the full dataset.
mode = 'test'

# 12 is the hand, 11 is the tongue
desired_stim = 12

In [None]:
# Load the data
DataLoad = np.load(fname, allow_pickle=True)['dat']
# Print the data
type(DataLoad), len(DataLoad), DataLoad.shape, DataLoad[0][0].keys()

In [None]:
# This is where we'll keep all the processed data
realV = {}
imagV = {}

# This is where we'll hold all the metadata
realMeta = {}
imagineMeta = {}

In [None]:
# These are the keys we'll be filtering for in the dataset
desiredKeys = ['t_off', 'stim_id', 't_on', 'V', 'scale_uv', 'locs', 'srate']

In [None]:
if mode == 'test':
    length_of_data = 1
elif mode == 'real':
    length_of_data = len(DataLoad)

In [None]:
trange = np.arange(0, 2000)

In [None]:
for i in range(length_of_data):
    print(f"Sample rate of participant (real) {i}: {DataLoad[i][0]['srate']}")
    print(f"Sample rate of participant (imagine) {i}: {DataLoad[i][1]['srate']}")
    
    x = process.preprocess(DataLoad[i][0])
    nt, nchan = x.shape
    nstim = len(DataLoad[i][0]['t_on'])
    ts = DataLoad[i][0]['t_on'][:, np.newaxis] + trange
    V_epochs = np.reshape(x[ts, :], (nstim, 2000, nchan))
    V_epochs = V_epochs[DataLoad[i][0]['stim_id']==desired_stim]
    print(V_epochs.shape)
    realV[i] = V_epochs
    realMeta[i] = {key: DataLoad[i][0][key] for key in desiredKeys}

    y = process.preprocess(DataLoad[i][1])
    nt, nchan = y.shape
    nstim = len(DataLoad[i][1]['t_on'])
    ts = DataLoad[i][1]['t_on'][:, np.newaxis] + trange
    V_epochs = np.reshape(y[ts, :], (nstim, 2000, nchan))
    V_epochs = V_epochs[DataLoad[i][1]['stim_id']==desired_stim]
    print(V_epochs.shape)
    imagV[i] = V_epochs
    imagineMeta[i] = {key: DataLoad[i][1][key] for key in desiredKeys}

## Exploring the Dataset

realV contains the preprocessed and properly filtered data for the real movement trials

- realV

realSet contains the metadata for the real movement trials

- realMeta

imagineV contains the preprocessed and properly filtered data for the imagined movement trials

- imagV

imagineSet contains the metadata for the imagined movement trials

- imagineMeta

In [None]:
plt.plot(trange, realV[0][0, :])
plt.show()
plt.plot(trange, realV[0][29, :])
plt.show()

In [None]:
for i in range(length_of_data):
    plots.all_channels1(realV[i][0], trange, f'Real Movement: Participant {i}')

In [None]:
for i in range(length_of_data):
    plots.all_channels1(imagV[i][0], trange, f'Imagined Movement: Participant {i}')

In [None]:
for i in range(length_of_data):
    plots.all_channels1(realV[i][0] - imagV[i][0], trange, f'Real Movement - Imaged Movement: Participant {i}')

In [None]:
plots.all_channels2(realV[0][0], imagV[0][0], trange)

In [None]:
plots.singlechannel1(realV[0][0], 0, trange)

## Making a Pytorch Dataset

In [None]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [None]:
EEG = EEGDataset(realV[0])
plt.plot(trange, EEG.__getitem__(1))
plt.show()