# EEG Explorer

## 1. Trabalhando com dados brutos

In [1]:
import ipywidgets as widgets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import scipy
import os

from src.file import File
from src.data import Formatter

### 1.1. Importando arquivos do bucket

In [2]:
raw = File.get_files_from(resource="raw")

df = pd.read_csv(f"{File.get_path_by(resource='raw')}/{raw[0]}", delimiter="\t")
df.head()

Unnamed: 0.1,Unnamed: 0,Index,Fp1,Fp2,C3,C4,P7,P8,O1,O2,...,other.6,other.7,other.8,other.9,other.10,other.11,other.12,Timestamp,other.13,Timestamp (Formatted)
0,0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,...,28.0,32.0,1.0,128.0,0.0,0.0,0.0,1670348000.0,0.0,2022-12-06 14:32:35.467166
1,1,1.0,-27263.506327,-28970.989856,10233.690426,10234.493994,10222.87166,-65185.202457,-24651.519068,-28402.612648,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1670348000.0,0.0,2022-12-06 14:32:35.483128
2,2,2.0,-110571.24692,-114105.008761,58725.746235,58736.930882,58726.703818,-1307.959726,49300.675563,48016.423156,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1670348000.0,0.0,2022-12-06 14:32:35.487188
3,3,3.0,-154142.625668,-155786.204539,70149.767671,70163.745795,70138.803896,-51118.9193,62646.308276,61035.890705,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1670348000.0,0.0,2022-12-06 14:32:35.490187
4,4,4.0,-110415.579124,-111002.382939,46598.342596,46609.843158,46588.287136,-87282.959423,14976.516493,6464.93445,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1670348000.0,0.0,2022-12-06 14:32:35.493189


### 1.2. Renomeando arquivos brutos

In [3]:
File.rename_raw_files()

### 1.3. Formatando os sinais

In [4]:
renamed = File.get_files_from(resource="renamed")
dataframes = [pd.read_csv(f"{File.get_path_by(resource='renamed')}/{file}", delimiter="\t") for file in renamed]

#### 1.3.1. Normalizando timestamps

In [5]:
# df["Timestamp"] = df["Timestamp"] - df["Timestamp"].min()

dataframes = Formatter.normalize_timestamps_for(dataframes)

#### 1.3.2. Removendo colunas indesejadas

In [6]:
# df.drop de cada coluna marcada como "other", "Unnamed" ou "Timestamp (Formatted)"

dataframes = Formatter.remove_other_columns_for(dataframes)

##### DataFrame pós formatação

In [7]:
dataframes[0].head()

Unnamed: 0,Index,Fp1,Fp2,C3,C4,P7,P8,O1,O2,Timestamp
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,-147837.835293,-128334.954779,-145905.169356,-146224.866357,-115510.104955,-137346.932274,79150.410491,-116625.501707,0.017
2,2.0,-141958.745355,-134553.120679,-149981.881676,-150155.487675,-118665.210148,-146172.943255,62684.394441,-121244.355588,0.019945
3,3.0,-86579.571555,-99937.108152,-100662.779887,-100315.165557,-112596.666824,-106055.987603,-25638.121443,-111502.660692,0.022997
4,4.0,-86704.718972,-89918.341031,-90199.056589,-89976.924953,-107100.439918,-92472.944256,-5784.452651,-105001.4323,0.027965


In [8]:
File.write_dataframes_in(path=File.get_path_by(resource="formatted"),
                         dataframes=dataframes,
                         filenames=renamed)

### 1.4. Visualização dos sinais

In [9]:
# Dependências
filenames = File.get_files_from(resource="formatted")
channels = ["Fp1", "Fp2", "C3", "C4", "P7", "P8", "O1", "O2"]

# Widgets
window_widget = widgets.IntSlider(min=1,
                                  max=1000,
                                  step=10,
                                  value=50,
                                  description="Janela")
file_widget = widgets.Combobox(options=filenames, 
                               value=filenames[0],
                               description="Arquivo")
channel_widget = widgets.Dropdown(options=channels,
                                  value=channels[0],
                                  description="Canal")

# Função de plotagem
def plot_signal(file, channel, window):
    delta = window
    df = pd.read_csv(f"{File.get_path_by(resource='formatted')}/{file}", delimiter=",")
    timestamps = df["Timestamp"].to_numpy()
    channel_data = df[channel].to_numpy()
    
    moving_average = np.convolve(channel_data, np.ones(512)/512, mode="valid")
    first_derivative = 100*(moving_average[:-delta] - moving_average[delta:])/delta
    fig = px.line(x=timestamps[:len(first_derivative)],
                  y=[channel_data[:len(first_derivative)],
                     moving_average[:len(first_derivative)],
                     first_derivative])
    fig.update_yaxes(range=[moving_average[10]-5E3, moving_average[10]+5E3])
    fig.show()

# Display dos widgets
out = widgets.interactive_output(plot_signal, 
                                 {"file": file_widget, 
                                  "channel": channel_widget, 
                                  "window": window_widget})

widgets.VBox([widgets.HBox([file_widget, channel_widget, window_widget]), out])

VBox(children=(HBox(children=(Combobox(value='react_5.csv', description='Arquivo', options=('react_5.csv', 'ch…

## 2. Filtrando trechos indesejados dos sinais

In [10]:
import ipywidgets as widgets
import os

from src.file import File
from src.data import Truncate
from src.data import TruncateIntervals
from src.data import Plotter

### 2.1. Seleção de intervalos com os trechos indesejados para cada sinal

#### 2.1.1. Selecionando trechos de maneira interativa

In [11]:
# Dependências
output = widgets.Output()
trunc_intervals = TruncateIntervals(truncate_intervals_path=File.get_path_by(resource="truncation_intervals"))
plotter = Plotter(files_path=File.get_path_by(resource="formatted"),
                  output=output)
filenames = File.get_files_from(resource="formatted")
channels = ["Fp1", "Fp2", "C3", "C4", "P7", "P8", "O1", "O2"]

file_widget_changed = False
channel_widget_changed = False

# Widgets
file_w = widgets.Combobox(options=filenames,
                          value=filenames[0],
                          description="Arquivo",
                          layout=widgets.Layout(width="300px"))
channel_w = widgets.Dropdown(options=channels,
                             value = channels[0],
                             description="Canal",
                             layout=widgets.Layout(width = "300px"))
refresh_btn_w = widgets.Button(description="Refresh",
                               layout=widgets.Layout(width = "70px"))
refresh_btn_w.layout.margin = "0px 0px 0px 50px"

signal_controller_hboxw = widgets.HBox([file_w, channel_w, refresh_btn_w])

label_w = widgets.Label(value="Intervalo de corte")
bottom_limit_w = widgets.FloatText(layout=widgets.Layout(width="100px"))
top_limit_w = widgets.FloatText(layout = widgets.Layout(width="100px"))
add_interval_btn_w = widgets.Button(description="Add",
                                    layout=widgets.Layout(width="70px"))
pop_interval_btn_w = widgets.Button(description="Pop",
                                    layout=widgets.Layout(width="70px"))
save_interval_btn_w = widgets.Button(description="Save",
                                     layout=widgets.Layout(width="70px"))

trunc_intervals_hboxw = widgets.HBox([label_w,
                                      bottom_limit_w,
                                      top_limit_w,
                                      add_interval_btn_w,
                                      pop_interval_btn_w,
                                      save_interval_btn_w],
                                     layout=widgets.Layout(justify_content='center'))

# Definição de eventos para os widgets
def file_widget_change_handler(change):
    global file_widget_changed
    if change["type"] == "change" and change["name"] == "value":
        file_widget_changed = True

def channel_widget_change_handler(change):
    global channel_widget_changed
    if change["type"] == "change" and change["name"] == "value":
        channel_widget_changed = True

def refresh_button_clicked(btn):
    global file_widget_changed
    global channel_widget_changed
    if file_widget_changed:
        trunc_intervals.save_current_file_intervals()
        trunc_intervals.load_file_intervals(file_w.value)
        plotter.load_signal(file_w.value, channel_w.value)
        plotter.plot_signal(trunc_intervals.get_channel_intervals(channel_w.value))
        file_widget_changed = False
        channel_widget_changed = False
    elif channel_widget_changed:
        plotter.change_current_fig(channel_w.value)
        plotter.plot_signal(trunc_intervals.get_channel_intervals(channel_w.value))
        channel_widget_changed = False

def add_intervals_clicked(btn):
    start = bottom_limit_w.value
    end = top_limit_w.value
    trunc_intervals.add_interval_by_channel(channel_w.value, start, end)
    plotter.plot_signal(trunc_intervals.get_channel_intervals(channel_w.value))

def pop_intervals_clicked(btn):
    trunc_intervals.pop_interval_by_channel(channel_w.value)
    plotter.plot_signal(trunc_intervals.get_channel_intervals(channel_w.value))

def save_intervals_clicked(btn):
    trunc_intervals.save_current_file_intervals()

file_w.observe(file_widget_change_handler)
channel_w.observe(channel_widget_change_handler)
refresh_btn_w.on_click(refresh_button_clicked)
add_interval_btn_w.on_click(add_intervals_clicked)
pop_interval_btn_w.on_click(pop_intervals_clicked)
save_interval_btn_w.on_click(save_intervals_clicked)

# Display dos widgets
trunc_intervals.load_file_intervals(file_w.value)
plotter.load_signal(file_w.value, channel_w.value)
plotter.plot_signal(trunc_intervals.get_channel_intervals(channel_w.value))

display(signal_controller_hboxw, output, trunc_intervals_hboxw)

HBox(children=(Combobox(value='react_5.csv', description='Arquivo', layout=Layout(width='300px'), options=('re…

Output()

HBox(children=(Label(value='Intervalo de corte'), FloatText(value=0.0, layout=Layout(width='100px')), FloatTex…

### 2.2. Truncamento dos sinais originais a partir dos intervalos selecionados

In [12]:
intervals_origin_dir = \
    "pre_calculated_truncation_intervals" \
    if input("Deseja utilizar intervalos pré calculados? [y/n] ") == "y" else "truncation_intervals"

Deseja utilizar intervalos pré calculados? [y/n]  y


In [13]:
print(f"'{intervals_origin_dir}' selecionado!")

'pre_calculated_truncation_intervals' selecionado!


In [14]:
truncation_intervals = File.get_files_from(resource=intervals_origin_dir)
filenames = [file.replace('json', 'csv') for file in truncation_intervals]

trc = Truncate(files_path=File.get_path_by(resource="formatted"),
               trunc_intervals_path=File.get_path_by(resource=intervals_origin_dir))

truncated_dfs = []
for file in filenames:
    trc.setup_by_filename(file)
    df = trc.truncate()
    truncated_dfs.append(df)

File.write_dataframes_in(path=File.get_path_by(resource="truncated"),
                         dataframes=truncated_dfs,
                         filenames=filenames)

## 3. Preparando arquivos para a fase de treino da rede neural

In [15]:
from src.file import File
from src.data import Continuous

### 3.1. Separando arquivos originais em N arquivos com sinal contínuo

In [16]:
intervals_origin_dir = \
    "pre_calculated_truncation_intervals" \
    if input("Deseja utilizar intervalos pré calculados? [y/n] ") == "y" else "truncation_intervals"

Deseja utilizar intervalos pré calculados? [y/n]  y


In [17]:
print(f"'{intervals_origin_dir}' selecionado!")

'pre_calculated_truncation_intervals' selecionado!


In [18]:
cnts = Continuous(input_data_path=File.get_path_by(resource="truncated"),
                  output_data_path=File.get_path_by(resource="continuous"),
                                                   truncate_intervals_path=File.get_path_by(resource=intervals_origin_dir))
truncated = File.get_files_from(resource="truncated")

for file in truncated:
    cnts.process_file(file)

### 3.2. Adicionando arquivos não truncados junto aos fragmentos contínuos

In [19]:
File.add_not_fragmented_files_to_continuous()

### 3.3. Separando arquivos em treino e teste

In [20]:
File.generate_train_test_files()

## 4. Preparando o input da rede e multiplicando a quantidade de dados

In [8]:
from src.file import File
from src.data import Windowing
import datetime
import os

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.model_selection import train_test_split

### 4.1. Conferindo a disponibilidade de uma GPU

In [4]:
gpus = tf.config.list_physical_devices('GPU')

if gpus:
    print(f"GPUs disponíveis: {len(gpus)}")
    for gpu in gpus:
        print(f"  - {gpu}")
else:
    print("Nenhuma GPU encontrada. O TensorFlow está usando a CPU.")

GPUs disponíveis: 1
  - PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


### 4.2. Janelamento dos dados

In [5]:
windowing = Windowing(input_data_path=File.get_path_by(resource="pre_training", subdirs="cafe/train"), window_channels=["C3"], window_size=255, normalize=True)
x_all, y_all = windowing.process()

windowing.set_input_data_path(input_data_path=File.get_path_by(resource="pre_training", subdirs="cafe/test"))

x_test, y_test = windowing.process()

### 4.2. Separando dados de treino e validação

In [6]:
x_train, x_validation, y_train, y_validation = train_test_split(x_all, y_all, test_size=0.3, random_state=1024)

### 4.3. Criando a rede neural

In [9]:
early_stopping = EarlyStopping(monitor='val_loss', patience=3)
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

model = keras.Sequential()
model.add(layers.Conv1D(128, 50, strides=1, padding='valid', activation='relu', input_shape=(255,1)))
model.add(layers.MaxPooling1D(pool_size=2))
model.add(layers.Flatten())
model.add(layers.Dense(50, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, validation_data=(x_validation, y_validation), 
          epochs=10, batch_size=512, verbose=1, shuffle=True, 
          callbacks=[early_stopping, tensorboard_callback])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
I0000 00:00:1736473585.211234   90410 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5529 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060 Ti, pci bus id: 0000:2b:00.0, compute capability: 8.9


Epoch 1/10


I0000 00:00:1736473590.276561   91121 service.cc:148] XLA service 0x7f38b4008000 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736473590.276863   91121 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 4060 Ti, Compute Capability 8.9
2025-01-09 22:46:30.306076: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1736473590.424119   91121 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-01-09 22:46:31.046878: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 32.57GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


[1m  6/237[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2s[0m 11ms/step - accuracy: 0.5498 - loss: 0.7749

I0000 00:00:1736473591.444639   91121 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m236/237[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - accuracy: 0.5947 - loss: 0.6741 

2025-01-09 22:46:33.778264: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.55GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step - accuracy: 0.5949 - loss: 0.6740 - val_accuracy: 0.6550 - val_loss: 0.6384
Epoch 2/10
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 10ms/step - accuracy: 0.6581 - loss: 0.6316 - val_accuracy: 0.6922 - val_loss: 0.5891
Epoch 3/10
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 10ms/step - accuracy: 0.6945 - loss: 0.5898 - val_accuracy: 0.7054 - val_loss: 0.5808
Epoch 4/10
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 10ms/step - accuracy: 0.7077 - loss: 0.5769 - val_accuracy: 0.7078 - val_loss: 0.5820
Epoch 5/10
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 10ms/step - accuracy: 0.7213 - loss: 0.5598 - val_accuracy: 0.7287 - val_loss: 0.5512
Epoch 6/10
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 9ms/step - accuracy: 0.7244 - loss: 0.5534 - val_accuracy: 0.7220 - val_loss: 0.5421
Epoch 7/10
[1m237/237[0m [32m━━

<keras.src.callbacks.history.History at 0x7f39bf311450>

In [10]:
loss, accuracy = model.evaluate(x_test, y_test)

print(f"Loss: {loss}")
print(f"Accuracy: {accuracy}")

[1m1760/1760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.5732 - loss: 0.7159
Loss: 0.7161311507225037
Accuracy: 0.5713626146316528
