In [2]:
#334a4ba3ee7a3dc9ff8373e22d7cf2fd31e6198668a4ae16
#!pip install PyWavelets mne  pandas numpy matplotlib

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mne
import pywt
from torch.utils.data import DataLoader, Dataset
import torch
import pickle
import os

In [5]:
def file_to_DataDrame(path):
    """
    This function takes in a file path and returns a dataframe with the data and the target values
    format:
        Fc5	        Fc3	        Fc1	        ...	Oz	        O2	        Iz	        target
    0	-0.000046	-0.000041	-0.000032	...	0.000040	0.000108	0.000055	0
    1	-0.000054	-0.000048	-0.000034	...	0.000064	0.000114	0.000074	0
    ...
    """

    reader = mne.io.read_raw_edf(path, preload=True)
    annotations = reader.annotations  # get the values of the annotations
    codes = annotations.description  # get the codes from the annotations

    df = pd.DataFrame(
        reader.get_data().T,
        columns=[channel.replace(".", "") for channel in reader.ch_names],
    )  # transpose the data to get the right shape
    df = df[~(df == 0).all(axis=1)]  # remove rows with all zeros
    timeArray = np.array(
        [round(x, 10) for x in np.arange(0, len(df) / 160, 0.00625)]
    )  # create an array of time values

    codeArray = []
    counter = 0
    for timeVal in timeArray:
        if (
            timeVal in annotations.onset
        ):  # if the time value is in the onset array, add the corresponding code to the codeArray
            counter += 1
        code_of_target = int(
            codes[counter - 1].replace("T", "")
        )  # convert T0 to 0, T1 to 1, etc
        codeArray.append(code_of_target)

    df["target"] = np.array(codeArray).T
    return df


def save_to_pickle(data, file_path):
    with open(file_path, "wb") as f:
        pickle.dump(data, f)


def load_from_pickle(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data

In [6]:
def read_all_file_df(num_exp=[3, 4], num_people=2):
    """condct all files in one dataframe"""
    all_df = pd.DataFrame()
    for subject in range(1, num_people):
        for file in num_exp:
            fileName = f"files/S{subject:03d}/S{subject:03d}R{file:02d}.edf"
            df = file_to_DataDrame(fileName)
            all_df = pd.concat([all_df, df], axis=0)
    return all_df

In [7]:
df = read_all_file_df()

Extracting EDF parameters from /home/daniel/repos/Decoding_of_EEG/files/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/daniel/repos/Decoding_of_EEG/files/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...


In [8]:
len(df)

39840

In [45]:
def df_to_CWTfiles(df, num_of_rows):
    """i"""
    for i in range(0, len(df), num_of_rows):
        signals = df.iloc[i : i + num_of_rows].values
        all_cwt= np.zeros((num_of_rows,100,65))
        if signals.shape == (1000,64):
            signals=signals.transpose(1,0)
        for signal in signals:
            signal = (signal - signal.mean()) / signal.std()
            time = np.linspace(0, len(signal) / 160, len(signal))
            widths = np.geomspace(1, 200, num=100)  # range of scales
            sampling_period = np.diff(time).mean()  # 0.006251562890722681
            cwtmatr, _ = pywt.cwt(signal, widths, "cgau4", sampling_period=sampling_period)
            cwtmatr= np.abs(cwtmatr[:-1,:-1])
            cwtmatr = np.abs(cwtmatr)
            print(cwtmatr.shape)
        
        print(int(i/num_of_rows))
        save_to_pickle(cwtmatr,f"cwt_data{i}")

        del cwtmatr

In [46]:
num_of_rows = 1000
df_to_CWTfiles(df,1000)

0
1


KeyboardInterrupt: 

In [None]:
zer=np.zeros((4,100))

In [None]:
len(zer)

4

In [None]:
for i in range(0, len(df), num_of_rows):
    chunks = df.iloc[i : i + num_of_rows].values
    print(len(chunks[:,1]))
    chanks_to_CWTchanks(chunks)

1000
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'num

KeyboardInterrupt: 

In [None]:
num_of_rows = 1000
chunks = [df[i : i + num_of_rows] for i in range(0, df.shape[0], num_of_rows)]
chunks.value
chanks_to_CWTchanks(chunks.value)

AttributeError: 'list' object has no attribute 'value'

In [None]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

In [None]:
dataset = EEGDataset(df.iloc[:, :-1].values, df.iloc[:, -1].values)

In [None]:
dataset.__getitem__(0)

In [None]:
signal = df.iloc[:1000, 0]  # First column

signals = df.iloc[:1000, :-1]

In [None]:
# conwert ot np array
signals = signals.values
signal = signal.values

In [None]:
signal2 = signals[:,0]
print(signal2.shape)
print(signal2.dtype)
print(type(signal2))
print(signal.shape)
print(signal.dtype)
print(type(signal))
is_same = np.array_equal(signal2, signal)
print(is_same)

In [None]:
# normalization signal
signal = (signal - signal.mean()) / signal.std()

In [None]:
time = np.linspace(0, len(signal) / 160, len(signal))
widths = np.geomspace(1, 200, num=100)  # range of scales
sampling_period = np.diff(time).mean()  # 0.006251562890722681
print(signal.shape)
cwtmatr, freqs = pywt.cwt(signal, widths, "cgau4", sampling_period=sampling_period)
# cwtmatr= np.abs(cwtmatr[:-1,:-1])
cwtmatr = np.abs(cwtmatr)
# cwtmatr= torch.tensor(cwtmatr)

In [None]:
signal.shape

In [None]:
plt.figure(figsize=(20, 3))
print(cwtmatr.shape)
plt.pcolormesh(time, freqs, cwtmatr)
maxval = np.max(freqs)
plt.yscale("log")
plt.ylabel("Frequency [Hz]")
plt.xlabel("Time [s]")
plt.colorbar()
plt.show()

In [None]:
#if shpae in == 100x64 do not transpose

if signals.shape == (1000,64):
    signals=signals.transpose(1,0)
for signal in signals:
    signal = (signal - signal.mean()) / signal.std()
    time = np.linspace(0, len(signal) / 160, len(signal))
    widths = np.geomspace(1, 200, num=100)  # range of scales
    sampling_period = np.diff(time).mean()  # 0.006251562890722681
    cwtmatr, freqs = pywt.cwt(signal, widths, "cgau4", sampling_period=sampling_period)
    # cwtmatr= np.abs(cwtmatr[:-1,:-1])
    cwtmatr = np.abs(cwtmatr)
    # cwtmatr= torch.tensor(cwtmatr)
    plt.figure(figsize=(20, 3))
    plt.pcolormesh(time, freqs, cwtmatr)
    maxval = np.max(freqs)
    plt.yscale("log")
    plt.ylabel("Frequency [Hz]")
    plt.xlabel("Time [s]")
    plt.colorbar()
    plt.show()
    


In [None]:
signal2 = (signal2 - signal2.mean()) / signal2.std()
time = np.linspace(0, len(signal2) / 160, len(signal2))
widths = np.geomspace(1, 200, num=100)  # range of scales
sampling_period = np.diff(time).mean()  # 0.006251562890722681
print(signal.shape)
cwtmatr, freqs = pywt.cwt(signal2, widths, "cgau4", sampling_period=sampling_period)
# cwtmatr= np.abs(cwtmatr[:-1,:-1])
cwtmatr = np.abs(cwtmatr)
# cwtmatr= torch.tensor(cwtmatr)

In [None]:
plt.figure(figsize=(20, 3))
plt.pcolormesh(time, freqs, cwtmatr)
maxval = np.max(freqs)
plt.yscale("log")
plt.ylabel("Frequency [Hz]")
plt.xlabel("Time [s]")
plt.colorbar()
plt.show()