In [25]:
#334a4ba3ee7a3dc9ff8373e22d7cf2fd31e6198668a4ae16
#!pip install PyWavelets mne  pandas numpy matplotlib

In [26]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mne
import pywt
from torch.utils.data import DataLoader, Dataset
import torch
import pickle
import os
from datetime import datetime
import sqlite3

In [27]:
def file_to_DataDrame(path):
    """
    This function takes in a file path and returns a dataframe with the data and the target values
    format:
        Fc5	        Fc3	        Fc1	        ...	Oz	        O2	        Iz	        target
    0	-0.000046	-0.000041	-0.000032	...	0.000040	0.000108	0.000055	0
    1	-0.000054	-0.000048	-0.000034	...	0.000064	0.000114	0.000074	0
    ...
    """

    reader = mne.io.read_raw_edf(path, preload=True)
    annotations = reader.annotations  # get the values of the annotations
    codes = annotations.description  # get the codes from the annotations

    df = pd.DataFrame(
        reader.get_data().T,
        columns=[channel.replace(".", "") for channel in reader.ch_names],
    )  # transpose the data to get the right shape
    df = df[~(df == 0).all(axis=1)]  # remove rows with all zeros
    timeArray = np.array(
        [round(x, 10) for x in np.arange(0, len(df) / 160, 0.00625)]
    )  # create an array of time values

    codeArray = []
    counter = 0
    for timeVal in timeArray:
        if (
            timeVal in annotations.onset
        ):  # if the time value is in the onset array, add the corresponding code to the codeArray
            counter += 1
        code_of_target = int(
            codes[counter - 1].replace("T", "")
        )  # convert T0 to 0, T1 to 1, etc
        codeArray.append(code_of_target)

    df["target"] = np.array(codeArray).T
    return df


def save_to_pickle(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"File {filename} saved successfully.")



def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    print(f"File {filename} loaded successfully.")
    return data


In [28]:
def read_all_file_df(num_exp=[3, 4], num_people=2):
    """condct all files in one dataframe"""
    all_df = pd.DataFrame()
    for subject in range(1, num_people):
        for file in num_exp:
            fileName = f"files/S{subject:03d}/S{subject:03d}R{file:02d}.edf"
            df = file_to_DataDrame(fileName)
            all_df = pd.concat([all_df, df], axis=0)
    return all_df

In [29]:
df = read_all_file_df()

Extracting EDF parameters from /home/daniel/repos/Decoding_of_EEG/files/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/daniel/repos/Decoding_of_EEG/files/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...


In [30]:
import sqlite3
import pickle


def create_database(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(
        """
        CREATE TABLE IF NOT EXISTS wavelet_transforms (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            cwt_data BLOB,
            target INTEGER
        )
    """
    )
    conn.commit()
    conn.close()


def insert_cwt_data(db_path, cwt_data, targets):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cwt_data = cwt_data.transpose(2, 0, 1)
    
    cwt_data = cwt_data.reshape(cwt_data.shape[0], -1) # <-------- this is option 
    print(f"success save: {cwt_data.shape}")
    i = 0
    for single_cwt in cwt_data:

        cwt_blob = pickle.dumps(np.array(single_cwt, dtype=np.float32))
        cursor.execute(
            "INSERT INTO wavelet_transforms (cwt_data, target) VALUES (?, ?)",
            (cwt_blob, targets[i]),
        )
        i += 1
    conn.commit()
    conn.close()

In [31]:
def df_to_CWTfiles(
    df, num_of_rows=1000, wave="cgau4", frq=160, resolution=100, db_path="cwt_data.db"
):
    # Utworzenie bazy danych, jeśli nie istnieje
    create_database(db_path)

    for i in range(0, len(df), num_of_rows):
        if i + num_of_rows > len(df):
            break
        signals = df.iloc[i : i + num_of_rows].values
        list_cwt = []
        targets = ()
        if signals.shape == (num_of_rows, 65):
            signals = signals.transpose(1, 0)
        j = 0
        # print(len(signals))
        for signal in signals:
            j += 1
            if j == len(signals):
                targets = signal
                break
            signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal))
            time = np.linspace(0, len(signal) / frq, len(signal))
            widths = np.geomspace(1, 200, num=resolution)
            sampling_period = np.diff(time).mean()
            cwtmatr, freqs = pywt.cwt(
                signal, widths, wave, sampling_period=sampling_period
            )
            cwtmatr = np.abs(cwtmatr)
            list_cwt.append(cwtmatr)

        array_cwt = np.stack(list_cwt, axis=0)
        insert_cwt_data(db_path, array_cwt, targets)  # Zapis do bazy danych
        del array_cwt

In [32]:
df_to_CWTfiles(
    df, num_of_rows=1000, wave="cgau4", frq=160, resolution=10, db_path="cwt_data.db"
)

success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save: (1000, 640)
success save