"""

Created by Xin Zhang
### References
[1] Bashivan, et al. "Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016).

[2] https://github.com/numediart/EEGLearn-Pytorch/blob/master/Utils.py

[3] Bagchi S, Bathula D R. EEG-ConvTransformer for single-trial EEG-based visual stimulus classification[J]. Pattern Recognition, 2022, 129: 108757.`

Forked from https://github.com/numediart/EEGLearn-Pytorch/blob/master/Utils.py
"""

In [25]:
import scipy.io as sio
import numpy as np
import torch
from preprocessing.aep import azim_proj, gen_images
import einops

Here we generate fake EEG data to reduce time-cost. For re-train the model in citation, you can find the public dataset which download from https://purl.stanford.edu/bq914sc3730. other datasets are also supported as long as the following-like requirements are met, it's up to you.

Single raw EEG data should be the shape: [Time, Channels] and a corresponding label y. Each channel/electrode has a 3D coordinate [X, Y, Z]

In [87]:
num_samples = 64
time = 2048
channels = 32
xyz_electrode = torch.randn(size=(channels, 3))
eeg = torch.rand(size=(num_samples, time, channels)).cuda()
labels = torch.randint(size=(num_samples,), high=19, low=0).cuda()

Here We need employ FFT to extract power spectrum in each time-window
In citation[1], It said: "Fast Fourier Transform (FFT) is performed on the time series for each trial to estimate the power spectrum of the signal. Oscillatory cortical activity related to memory operations primarily exists in three frequency bands of theta (4-7Hz), alpha (8-13Hz), and beta (13-30Hz) (Bashivan et al., 2014; Jensen & Tesche, 2002)."

Besides, there is a different between citation [1] and [3], you can find the details in [3], which described as:
"However, contrary to the three frequency power bands from the earlier work, the AEP and interpolation are applied to the preprocessed signal to form a single channel mesh of G1 × G2 per time-frame."
Based on the short describe, I guess the author may divide the EEG signal into many fragments, and then do FFT at each fragment to get the 'time-frame'. If you have other opinion and welcome to commit your code to re-correct.

In [None]:
time_win = 128
sample_rate = 1000
eeg_ = einops.rearrange(eeg, 'n (f tw) c -> n f c tw ', n=num_samples, tw=time_win, c=channels)
print(eeg_.shape)
power = torch.abs(torch.fft.fft(eeg_, n=time_win, dim=-1, norm='forward'))
freqs = torch.fft.fftfreq(n=time_win, d=1/sample_rate)
theta_pass = torch.where((4 < freqs) & (freqs <= 7), True, False)
alpha_pass = torch.where((8 < freqs) & (freqs <= 13), True, False)
beta_pass = torch.where((13 < freqs) & (freqs <= 30), True, False)

theta = power[:, :, :, theta_pass]
alpha = power[:, :, :, alpha_pass]
beta = power[:, :, :, beta_pass]
print(theta.shape)

"Sum of squared absolute values within each of the three frequency bands was computed and used as separate measurement for each electrode."
Now we have three tensor with shape [num_samples, frame, channels]

In [None]:
theta = torch.norm(theta, p=2, dim=-1, keepdim=False)
alpha = torch.norm(alpha, p=2, dim=-1, keepdim=False)
beta = torch.norm(beta, p=2, dim=-1, keepdim=False)

Making channels features that can be found in their official code[2]
"Features are arranged in band and electrodes order (theta_1, theta_2..., theta_64, alpha_1, alpha_2, ..., beta_64)."
Then the features shape will be [num_samples, frame, channels*3]

In [None]:
features = torch.cat([theta, alpha, beta], dim=-1)

Remember We have the electrodes coordinate and label that defined before.
Now, we cast 3D coordinate(xyz) into 2D

In [None]:
locs_2d = []
for e in xyz_electrode:
    locs_2d.append(azim_proj(e))
print(np.shape(locs_2d))

Finally, it's time to generate images. We need to generate image for each sample and also each time-frame. So we can perform:

In [None]:
feats = einops.rearrange(features, 'n f c3 -> (n f) c3')
images = []
for i in range(feats.shape[1]):
    img = gen_images(locs=np.array(locs_2d),
                     features=feats,
                     n_gridpoints=32,
                     normalize=True)
    images.append(img)
images_time_win = torch.Tensor(np.array(images))
time_frame_3d = einops.rearrange(images, '(n f) color m m -> n f color m m')

[num_samples, frame, color, m, m]
consider to average the color dimension since the input of EEG-ConvTransformer is [batch, time-frame, 1, m, m]

In [None]:
time_frame_3d = torch.mean(time_frame_3d, dim=2, keepdim=True)
sio.savemat("sample_data/time_frames.mat",{"img":time_frame_3d})
sio.savemat("sample_data/labels.mat",{"lab":labels})