In [2]:
%pip install scipy

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


# 0. Load Data

In [3]:
import numpy as np
import pickle
import glob
import os

folder = '/home/andrewkc/Projects/research-project/deap-dataset/data_preprocessed_python'

files = sorted(glob.glob(os.path.join(folder, 's*.dat')))

# Use lists first
all_data = []
all_labels = []

# Load and extract 'data' and 'labels' from each file
for file in files:
    with open(file, 'rb') as f:
        subject = pickle.load(f)
        all_data.append(subject['data'])     # shape: (40, 40, 8064)
        all_labels.append(subject['labels']) # shape: (40, 4)

In [4]:
# Convert to NumPy arrays                       32 x 40 = 1280
all_data = np.concatenate(all_data, axis=0)     # shape: (1280, 40, 8064)
all_labels = np.concatenate(all_labels, axis=0) # shape: (1280, 4)

# 32 personas vieron 40 videos cada una, de cada video se extrajo señales en 8064 instantes por cada uno de los 40 canales 
# (señales fisiológicas). Nosotros vamos a usar solo los primeros 32 canales de las señales como lo hace el autor del paper 
# (señales capturadas por los electrodos, es decir, señales EEG)

all_data = all_data[:, :32, :]
print(all_data.shape)

(1280, 32, 8064)


# 1. Features Extractors

In [5]:
# Hyperparameters
WINDOW_SIZE  = 128
NUM_SEGMENTS = 8064 / WINDOW_SIZE
print(NUM_SEGMENTS)

63.0


## 1.1. SpatialTemporalFExtractor

In [6]:
import numpy as np
from scipy.signal import butter, lfilter
from math import log, pi, e

class SpatialTemporalFExtractor:
    def __init__(self, sampling_rate=128, window_size_sec=1):
        self.fs = sampling_rate
        self.window_size = window_size_sec * self.fs  # muestras
        self.bands = {
            'theta': (4, 8), # low high Hz
            'alpha': (8, 12),
            'beta': (12, 30),
            'gamma': (30, 50)
        }

        # Mapa espacial Geneva (matriz 9x9 como en el paper)
        # El índice de cada canal EEG en la matriz (basado en la convención Geneva del paper)
        self.channel_map = {
            0:  (0, 3),   # Fp1
            1:  (1, 3),   # AF3
            2:  (2, 2),   # F3
            3:  (2, 0),   # F7
            4:  (3, 1),   # FC5
            5:  (3, 3),   # FC1
            6:  (4, 2),   # C3
            7:  (4, 0),   # T7
            8:  (5, 1),   # CP5
            9:  (5, 3),   # CP1
            10: (6, 2),   # P3
            11: (6, 0),   # P7
            12: (7, 3),   # PO3
            13: (8, 3),   # O1
            14: (8, 4),   # Oz
            15: (6, 4),   # Pz
            16: (0, 5),   # Fp2
            17: (1, 5),   # AF4
            18: (2, 4),   # Fz
            19: (2, 6),   # F4
            20: (2, 8),   # F8
            21: (3, 7),   # FC6
            22: (3, 5),   # FC2
            23: (4, 4),   # Cz
            24: (4, 6),   # C4
            25: (4, 8),   # T8
            26: (5, 7),   # CP6
            27: (5, 5),   # CP2
            28: (6, 6),   # P4
            29: (6, 8),   # P8
            30: (7, 5),   # PO4
            31: (8, 5),   # O2
        }

        self.left_hemisphere = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
        self.right_hemisphere = [16, 17, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31]

    def bandpass_filter(self, signal, low, high, order=4):
        nyq = 0.5 * self.fs
        b, a = butter(order, [low / nyq, high / nyq], btype='band')
        return lfilter(b, a, signal)

    def differential_entropy(self, signal):
        var = np.var(signal)
        return 0.5 * np.log(2 * pi * e * var + 1e-8)  # sumamos 1e-8 para evitar log(0)

    def extract_features(self, signal_3d):
        """
        signal_3d: shape (32, 8064)  → 32 canales, 8064 muestras (60 s)
        return: list of DE-matrices por ventana [num_windows x num_bandas x 9 x 9]
        """
        channels, total_samples = signal_3d.shape
        num_windows = total_samples // self.window_size

        features = []

        for w in range(num_windows):
            start = w * self.window_size
            end = start + self.window_size
            window_data = signal_3d[:, start:end]  # shape: (32, window_size)

            band_matrices = []

            for band_name, (low, high) in self.bands.items():
                band_matrix = np.zeros((9, 9))  # matriz espacial

                for ch in range(32):
                    filtered = self.bandpass_filter(window_data[ch], low, high)
                    de = self.differential_entropy(filtered)

                    if ch in self.channel_map:
                        x, y = self.channel_map[ch]
                        band_matrix[x, y] = de

                band_matrices.append(band_matrix)

            features.append(np.stack(band_matrices))  # shape: (4, 9, 9)

        return np.array(features)  # shape: (num_windows, 4, 9, 9)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/andrewkc/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/andrewkc/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/andrewkc/.local/lib/python3.10/si

AttributeError: _ARRAY_API not found

ImportError: numpy.core.multiarray failed to import

## 1.2. ConnectiveFExtractor

In [5]:
import numpy as np
from scipy.signal import butter, lfilter, hilbert
from math import log, pi, e


class ConnectiveFExtractor:
    def __init__(self, sampling_rate=128, window_size_sec=1, mode="pcc"):
        self.fs = sampling_rate
        self.window_size = window_size_sec * self.fs
        self.mode = mode.lower()
        
        # Definición de las bandas cerebrales
        self.bands = {
            'theta': (4, 8),
            'alpha': (8, 12),
            'beta': (12, 30),
            'gamma': (30, 50)
        }

        # Pares de canales simétricos (izquierdo, derecho) para la conectividad
        self.paired_channels = [
            (i, j) for i in range(32) for j in range(i+1, 32) 
        ]

    def bandpass_filter(self, signal, low, high, order=4):
        """Filtra la señal en un rango de frecuencias utilizando un filtro pasa-banda."""
        nyq = 0.5 * self.fs
        b, a = butter(order, [low / nyq, high / nyq], btype='band')
        return lfilter(b, a, signal)

    def phase_locking_value(self, signal1, signal2):
        """Calcula el Phase Locking Value (PLV) entre dos señales EEG."""
        analytic_signal_1 = hilbert(signal1)
        analytic_signal_2 = hilbert(signal2)
        phase_diff = np.angle(analytic_signal_1) - np.angle(analytic_signal_2)
        return np.abs(np.mean(np.exp(1j * phase_diff)))

    def pearson_correlation(self, signal1, signal2):
        """Calcula la correlación de Pearson entre dos señales EEG."""
        return np.corrcoef(signal1, signal2)[0, 1]

    def phase_lag_index(self, signal1, signal2):
        """Calcula el Phase Lag Index (PLI) entre dos señales EEG."""
        analytic_signal_1 = hilbert(signal1)
        analytic_signal_2 = hilbert(signal2)
        phase_diff = np.angle(analytic_signal_1) - np.angle(analytic_signal_2)
        return np.abs(np.mean(np.sign(np.sin(phase_diff))))

    def extract_pcc_features(self, signal_3d):
        """
        Extrae las características de conectividad para PCC: (num_windows, 4, 32, 32)
        """
        channels, total_samples = signal_3d.shape
        num_windows = total_samples // self.window_size

        all_features = []

        for w in range(num_windows):
            start = w * self.window_size
            end = start + self.window_size
            window_data = signal_3d[:, start:end]  

            pcc_matrix = np.zeros((32, 32))

            for band_name, (low, high) in self.bands.items():
                filtered_signals = []

                for i in range(32):
                    filtered_signal = self.bandpass_filter(window_data[i], low, high)
                    filtered_signals.append(filtered_signal)

                for i, j in self.paired_channels:
                    left_signal = filtered_signals[i]
                    right_signal = filtered_signals[j]

                    pcc_matrix[i, j] = self.pearson_correlation(left_signal, right_signal)

                    pcc_matrix[j, i] = pcc_matrix[i, j]

                pcc_matrix = (pcc_matrix + 1) / 2

                all_features.append(pcc_matrix)
        all_features_array = np.array(all_features)  
        return all_features_array  

    def extract_plv_features(self, signal_3d):
        """
        Extrae las características de conectividad para PLV: (num_windows, 4, 32, 32)
        """
        channels, total_samples = signal_3d.shape
        num_windows = total_samples // self.window_size

        all_features = []

        for w in range(num_windows):
            start = w * self.window_size
            end = start + self.window_size
            window_data = signal_3d[:, start:end]  

            plv_matrix = np.zeros((32, 32))

            for band_name, (low, high) in self.bands.items():
                filtered_signals = []

                for i in range(32):
                    filtered_signal = self.bandpass_filter(window_data[i], low, high)
                    filtered_signals.append(filtered_signal)
                for i, j in self.paired_channels:
                    left_signal = filtered_signals[i]
                    right_signal = filtered_signals[j]

                    plv_matrix[i, j] = self.phase_locking_value(left_signal, right_signal)

                    plv_matrix[j, i] = plv_matrix[i, j]

                plv_matrix = (plv_matrix + 1) / 2

                all_features.append(plv_matrix)

        
        all_features_array = np.array(all_features)  
        return all_features_array 

    def extract_pli_features(self, signal_3d):
        """
        Extrae las características de conectividad para PLI: (num_windows, 4, 32, 32)
        """
        channels, total_samples = signal_3d.shape
        num_windows = total_samples // self.window_size

        all_features = []

        for w in range(num_windows):
            start = w * self.window_size
            end = start + self.window_size
            window_data = signal_3d[:, start:end]  

            pli_matrix = np.zeros((32, 32))

            for band_name, (low, high) in self.bands.items():
                filtered_signals = []

            
                for i in range(32):
                    filtered_signal = self.bandpass_filter(window_data[i], low, high)
                    filtered_signals.append(filtered_signal)

            
                for i, j in self.paired_channels:
                    left_signal = filtered_signals[i]
                    right_signal = filtered_signals[j]

                    pli_matrix[i, j] = self.phase_lag_index(left_signal, right_signal)

                    pli_matrix[j, i] = pli_matrix[i, j]

          
                pli_matrix = (pli_matrix + 1) / 2

                all_features.append(pli_matrix)

      
        all_features_array = np.array(all_features)  
        return all_features_array  

    def extract_features(self, signal_3d):
        if self.mode == "pcc":
            return self.extract_pcc_features(signal_3d)
        elif self.mode == "plv":
            return self.extract_plv_features(signal_3d)
        elif self.mode == "pli":
            return self.extract_pli_features(signal_3d)
        else:
            raise ValueError(f"Modo de conectividad no soportado: {self.mode}")

## 1.3. IntesityDifferenceFExtractor

In [6]:
import numpy as np
from scipy.signal import butter, lfilter
from math import log, pi, e

class IntensityDifferenceFExtractor:
    def __init__(self, sampling_rate=128, window_size_sec=1):
        self.fs = sampling_rate
        self.window_size = window_size_sec * sampling_rate

        # Bandas cerebrales
        self.bands = {
            'theta': (4, 8),
            'alpha': (8, 12),
            'beta': (12, 30),
            'gamma': (30, 50)
        }

        # Pares de canales simétricos (izquierdo, derecho)
        self.paired_channels = [
            (0, 16),   # Fp1 - Fp2
            (1, 17),   # AF3 - AF4
            (2, 19),   # F3 - F4
            (3, 20),   # F7 - F8
            (4, 21),   # FC5 - FC6
            (5, 22),   # FC1 - FC2
            (6, 24),   # C3 - C4
            (7, 25),   # T7 - T8
            (8, 26),   # CP5 - CP6
            (9, 27),   # CP1 - CP2
            (10, 28),  # P3 - P4
            (11, 29),  # P7 - P8
            (12, 30),  # PO3 - PO4
            (13, 31),  # O1 - O2
        ]

    def bandpass_filter(self, signal, low, high, order=4):
        nyq = 0.5 * self.fs
        b, a = butter(order, [low / nyq, high / nyq], btype='band')
        return lfilter(b, a, signal)

    def extract_features(self, signal_3d):
        """
        signal_3d: (32, 8064)  → EEG de una muestra
        Return: (num_windows, 4, 14, 14) → correlación de diferencias por banda
        """
        channels, total_samples = signal_3d.shape
        num_windows = total_samples // self.window_size

        all_features = []

        for w in range(num_windows):
            start = w * self.window_size
            end = start + self.window_size
            window_data = signal_3d[:, start:end]

            band_corrs = []

            for band_name, (low, high) in self.bands.items():
                diffs = []

                for left, right in self.paired_channels:
                    left_filtered = self.bandpass_filter(window_data[left], low, high)
                    right_filtered = self.bandpass_filter(window_data[right], low, high)

                    diff_signal = left_filtered - right_filtered
                    diffs.append(diff_signal)

                diffs_matrix = np.stack(diffs) 
                corr_matrix = np.corrcoef(diffs_matrix)  
                corr_matrix = (corr_matrix + 1) / 2    

                band_corrs.append(corr_matrix)

            # apilar 4 bandas (4, 14, 14)
            all_features.append(np.stack(band_corrs))

        return np.array(all_features)  # (num_windows, 4, 14, 14)


# N. Main

In [7]:
extractor = SpatialTemporalFExtractor()

# Ejemplo con una muestra (solo EEG)
sample_idx = 0
eeg_data = all_data[sample_idx, :, :]  # (32, 8064)

print(eeg_data.shape)

features = extractor.extract_features(eeg_data)

print("Shape:", features.shape)  # (num_windows, 4, 9, 9)

# show DE in first window

# banda theta
print(features[0][0])

# banda alpha
print(features[0][1])

# banda betha
print(features[0][2])

# banda gamma
print(features[0][3])

(32, 8064)


Shape: (63, 4, 9, 9)
[[0.         0.         0.         1.51340617 0.         1.74132303
  0.         0.         0.        ]
 [0.         0.         0.         1.78640185 0.         1.48885533
  0.         0.         0.        ]
 [1.9574312  0.         2.007377   0.         1.75232527 0.
  2.11976898 0.         1.39411543]
 [0.         1.52375448 0.         1.24102146 0.         1.71527093
  0.         1.03676285 0.        ]
 [1.75245835 0.         1.86812605 0.         1.66335345 0.
  1.39536882 0.         1.69289967]
 [0.         1.61410412 0.         1.2750583  0.         1.48725367
  0.         1.77284199 0.        ]
 [1.80639345 0.         1.79867897 0.         1.33371762 0.
  2.26326665 0.         2.31115073]
 [0.         0.         0.         1.32517214 0.         2.41096827
  0.         0.         0.        ]
 [0.         0.         0.         1.72037233 1.84274597 2.24762634
  0.         0.         0.        ]]
[[0.         0.         0.         2.15085133 0.         2.0962779

In [8]:
extractor = IntensityDifferenceFExtractor()

# Ejemplo con una muestra (solo EEG)
sample_idx = 0
eeg_data = all_data[sample_idx, :, :]  # (32, 8064)

print(eeg_data.shape)

features = extractor.extract_features(eeg_data)

print("Shape:", features.shape)  # (num_windows, 4, 14, 14)

# show DE in first window

# banda theta
print(features[0][0].shape)

# banda alpha
print(features[0][1].shape)

# banda betha
print(features[0][2].shape)

# banda gamma
print(features[0][3].shape)
print(features[0][3])

(32, 8064)


Shape: (63, 4, 14, 14)
(14, 14)
(14, 14)
(14, 14)
(14, 14)
[[1.         0.53974052 0.69650649 0.69100437 0.72656924 0.80917866
  0.43602078 0.37708162 0.43415412 0.44100911 0.45575801 0.5787175
  0.32836195 0.41294431]
 [0.53974052 1.         0.71782178 0.79487765 0.65903764 0.34934772
  0.87437901 0.82664365 0.69314371 0.5658603  0.65748663 0.50007457
  0.78340562 0.73901959]
 [0.69650649 0.71782178 1.         0.81797664 0.86837801 0.70639784
  0.75794229 0.71484652 0.68901612 0.58314572 0.64605916 0.64074801
  0.57255465 0.59934008]
 [0.69100437 0.79487765 0.81797664 1.         0.81747881 0.5187412
  0.77612854 0.68154705 0.73878888 0.62081371 0.68199789 0.66646171
  0.62004959 0.65589218]
 [0.72656924 0.65903764 0.86837801 0.81747881 1.         0.72193048
  0.72085835 0.66018264 0.72243089 0.64788169 0.73752117 0.68720193
  0.53866561 0.60681688]
 [0.80917866 0.34934772 0.70639784 0.5187412  0.72193048 1.
  0.3988134  0.35180157 0.44329125 0.52803047 0.48381466 0.58923133
  0.323610

In [9]:

extractor = ConnectiveFExtractor(sampling_rate=128, window_size_sec=1)

sample_idx = 0
eeg_data = all_data[sample_idx, :, :]  # (32, 8064)

print("Shape de los datos EEG:", eeg_data.shape)

pcc_features = extractor.extract_pcc_features(eeg_data)
plv_features = extractor.extract_plv_features(eeg_data)
pli_features = extractor.extract_pli_features(eeg_data)


print("Shape de PCC:", pcc_features.shape)  
print("Shape de PLV:", plv_features.shape)  
print("Shape de PLI:", pli_features.shape)  

# Banda theta 
print("PCC para banda theta - ventana 0:")
print(pcc_features[0][0])  

print("PLV para banda theta - ventana 0:")
print(plv_features[0][0])  

print("PLI para banda theta - ventana 0:")
print(pli_features[0][0]) 

# Banda alpha
print("PCC para banda alpha - ventana 0:")
print(pcc_features[0][1])  

print("PLV para banda alpha - ventana 0:")
print(plv_features[0][1])  

print("PLI para banda alpha - ventana 0:")
print(pli_features[0][1]) 

# Banda beta
print("PCC para banda beta - ventana 0:")
print(pcc_features[0][2])  

print("PLV para banda beta - ventana 0:")
print(plv_features[0][2]) 

print("PLI para banda beta - ventana 0:")
print(pli_features[0][2])  

# Banda gamma
print("PCC para banda gamma - ventana 0:")
print(pcc_features[0][3])  

print("PLV para banda gamma - ventana 0:")
print(plv_features[0][3]) 
print("PLI para banda gamma - ventana 0:")
print(pli_features[0][3])  


Shape de los datos EEG: (32, 8064)


Shape de PCC: (252, 32, 32)
Shape de PLV: (252, 32, 32)
Shape de PLI: (252, 32, 32)
PCC para banda theta - ventana 0:
[ 0.5         0.94663652  0.7868775   0.87057236  0.72111498  0.84497555
 -0.16707376 -0.5798889  -0.70692235 -0.63505269 -0.68653438 -0.66155896
 -0.86482348 -0.79775783 -0.70768684 -0.82051316  0.88104914  0.89525078
  0.80568917  0.61153187  0.67428556  0.56988231  0.40353229  0.30537097
  0.19367364  0.44654505 -0.22254823 -0.00548786 -0.50654389 -0.21451576
 -0.75301968 -0.71276776]
PLV para banda theta - ventana 0:
[0.5        0.82302664 0.64459835 0.93497477 0.57013153 0.60847176
 0.35823363 0.47590236 0.63013992 0.6693908  0.5487395  0.52880367
 0.74092256 0.71941595 0.63400171 0.71174746 0.83342087 0.86698178
 0.56506875 0.32768939 0.5638101  0.39590007 0.05146652 0.25989769
 0.57624765 0.45299948 0.58538415 0.41340587 0.1765038  0.15504596
 0.30872853 0.43755348]
PLI para banda theta - ventana 0:
[0.5      0.421875 0.46875  0.0625   0.4375   0.140625 0.125    

## Modelo


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class STCCNN(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.5):
        super(STCCNN, self).__init__()
        
        self.spatial_conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
        self.spatial_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.spatial_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.connectivity_conv1 = nn.Conv2d(4, 32, kernel_size=7, stride=1, padding=3)
        self.connectivity_conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.connectivity_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        # Intensity difference input shape: (batch_size, 4, 14, 14)
        self.intensity_conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
        self.intensity_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.intensity_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.spatial_flatten_size = 9 * 9 * 128
        self.connectivity_flatten_size = 8 * 8 * 128
        self.intensity_flatten_size = 14 * 14 * 128

        total_features = self.spatial_flatten_size + self.connectivity_flatten_size + self.intensity_flatten_size
        self.fc1 = nn.Linear(total_features, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.dropout = nn.Dropout(dropout_rate)

        self.bn_spatial1 = nn.BatchNorm2d(32)
        self.bn_spatial2 = nn.BatchNorm2d(64)
        self.bn_spatial3 = nn.BatchNorm2d(128)

        self.bn_connectivity1 = nn.BatchNorm2d(32)
        self.bn_connectivity2 = nn.BatchNorm2d(64)
        self.bn_connectivity3 = nn.BatchNorm2d(128)

        self.bn_intensity1 = nn.BatchNorm2d(32)
        self.bn_intensity2 = nn.BatchNorm2d(64)
        self.bn_intensity3 = nn.BatchNorm2d(128)

    def spatial_branch(self, x):
        x = self.spatial_conv1(x)
        x = self.bn_spatial1(x)
        x = F.selu(x)

        x = self.spatial_conv2(x)
        x = self.bn_spatial2(x)
        x = F.selu(x)

        x = self.spatial_conv3(x)
        x = self.bn_spatial3(x)
        x = F.selu(x)

        return x.view(x.size(0), -1)

    def connectivity_branch(self, x):
        x = self.connectivity_conv1(x)
        x = self.bn_connectivity1(x)
        x = F.selu(x)

        x = self.connectivity_conv2(x)
        x = self.bn_connectivity2(x)
        x = F.selu(x)
        x = self.maxpool(x)

        x = self.connectivity_conv3(x)
        x = self.bn_connectivity3(x)
        x = F.selu(x)
        x = self.maxpool(x)

        return x.view(x.size(0), -1)

    def intensity_branch(self, x):
        x = self.intensity_conv1(x)
        x = self.bn_intensity1(x)
        x = F.selu(x)

        x = self.intensity_conv2(x)
        x = self.bn_intensity2(x)
        x = F.selu(x)

        x = self.intensity_conv3(x)
        x = self.bn_intensity3(x)
        x = F.selu(x)

        return x.view(x.size(0), -1)

    def forward(self, spatial_input, connectivity_input, intensity_input):
        spatial_features = F.normalize(self.spatial_branch(spatial_input), p=2, dim=1)
        connectivity_features = F.normalize(self.connectivity_branch(connectivity_input), p=2, dim=1)
        intensity_features = F.normalize(self.intensity_branch(intensity_input), p=2, dim=1)

        fused = torch.cat([spatial_features, connectivity_features, intensity_features], dim=1)
        x = F.relu(self.fc1(fused))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)

    def predict_proba(self, spatial_input, connectivity_input, intensity_input):
        return F.softmax(self.forward(spatial_input, connectivity_input, intensity_input), dim=1)

    def fit(self, dataloader, optimizer, loss_fn, epochs=10, device="cuda"):
        self.to(device)
        self.train()

        for epoch in range(epochs):
            total_loss = 0.0
            for spatial, connectivity, intensity, labels in dataloader:
                spatial = spatial.to(device)
                connectivity = connectivity.to(device)
                intensity = intensity.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                output = self.forward(spatial, connectivity, intensity)
                loss = loss_fn(output, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class STCCNN(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.5):
        super(STCCNN, self).__init__()
        
        # Rama de características espaciales-temporales (3D-DE features)
        # Input: 9x9x4
        self.spatial_conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
        self.spatial_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.spatial_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Rama de características de conectividad
        # Input: 32x32x4
        self.connectivity_conv1 = nn.Conv2d(4, 32, kernel_size=7, stride=1, padding=3)
        self.connectivity_conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
        self.connectivity_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Rama de características de diferencia de intensidad
        # Input: 9x9x4 (similar a spatial pero con diferentes características)
        self.intensity_conv1 = nn.Conv2d(4, 32, kernel_size=5, stride=1, padding=2)
        self.intensity_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.intensity_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        
        # Max pooling layers
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Calcular tamaños después del flatten
        self.spatial_flatten_size = 9 * 9 * 128      # 10368
        self.connectivity_flatten_size = 8 * 8 * 128 # 8192 (después de 2 maxpools)
        self.intensity_flatten_size = 9 * 9 * 128    # 10368
        
        # Capas totalmente conectadas después de la fusión
        total_features = self.spatial_flatten_size + self.connectivity_flatten_size + self.intensity_flatten_size
        self.fc1 = nn.Linear(total_features, 256)  # Aumentamos el tamaño para manejar más características
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
        # Dropout layers
        self.dropout = nn.Dropout(dropout_rate)
        
        # Normalización por lotes para cada rama
        self.bn_spatial1 = nn.BatchNorm2d(32)
        self.bn_spatial2 = nn.BatchNorm2d(64)
        self.bn_spatial3 = nn.BatchNorm2d(128)
        
        self.bn_connectivity1 = nn.BatchNorm2d(32)
        self.bn_connectivity2 = nn.BatchNorm2d(64)
        self.bn_connectivity3 = nn.BatchNorm2d(128)
        
        self.bn_intensity1 = nn.BatchNorm2d(32)
        self.bn_intensity2 = nn.BatchNorm2d(64)
        self.bn_intensity3 = nn.BatchNorm2d(128)
    
    def spatial_branch(self, x):
        """
        Rama para características espaciales-temporales (3D-DE features)
        Input shape: (batch_size, 4, 9, 9)
        """
        # Primera capa convolucional
        x = self.spatial_conv1(x)
        x = self.bn_spatial1(x)
        x = F.selu(x)
        
        # Segunda capa convolucional
        x = self.spatial_conv2(x)
        x = self.bn_spatial2(x)
        x = F.selu(x)
        
        # Tercera capa convolucional
        x = self.spatial_conv3(x)
        x = self.bn_spatial3(x)
        x = F.selu(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        return x
    
    def connectivity_branch(self, x):
        """
        Rama para características de conectividad
        Input shape: (batch_size, 4, 32, 32)
        """
        # Primera capa convolucional (7x7)
        x = self.connectivity_conv1(x)
        x = self.bn_connectivity1(x)
        x = F.selu(x)
        
        # Segunda capa convolucional (5x5) + MaxPooling
        x = self.connectivity_conv2(x)
        x = self.bn_connectivity2(x)
        x = F.selu(x)
        x = self.maxpool(x)  # 32x32 -> 16x16
        
        # Tercera capa convolucional (3x3) + MaxPooling
        x = self.connectivity_conv3(x)
        x = self.bn_connectivity3(x)
        x = F.selu(x)
        x = self.maxpool(x)  # 16x16 -> 8x8
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        return x
    
    def intensity_branch(self, x):
        """
        Rama para características de diferencia de intensidad
        Input shape: (batch_size, 4, 9, 9)
        """
        # Primera capa convolucional
        x = self.intensity_conv1(x)
        x = self.bn_intensity1(x)
        x = F.selu(x)
        
        # Segunda capa convolucional
        x = self.intensity_conv2(x)
        x = self.bn_intensity2(x)
        x = F.selu(x)
        
        # Tercera capa convolucional
        x = self.intensity_conv3(x)
        x = self.bn_intensity3(x)
        x = F.selu(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        return x
    
    def forward(self, spatial_input, connectivity_input, intensity_input):
        """
        Forward pass del modelo STC-CNN mejorado
        
        Args:
            spatial_input: Tensor de características espaciales-temporales (batch_size, 4, 9, 9)
            connectivity_input: Tensor de características de conectividad (batch_size, 4, 32, 32)
            intensity_input: Tensor de características de diferencia de intensidad (batch_size, 4, 9, 9)
        
        Returns:
            output: Logits de clasificación (batch_size, num_classes)
        """
        # Procesar cada rama por separado
        spatial_features = self.spatial_branch(spatial_input)
        connectivity_features = self.connectivity_branch(connectivity_input)
        intensity_features = self.intensity_branch(intensity_input)
        
        # Normalización de características antes de la fusión
        spatial_features = F.normalize(spatial_features, p=2, dim=1)
        connectivity_features = F.normalize(connectivity_features, p=2, dim=1)
        intensity_features = F.normalize(intensity_features, p=2, dim=1)
        
        # Fusión: concatenar las características de las tres ramas
        fused_features = torch.cat([spatial_features, connectivity_features, intensity_features], dim=1)
        
        # Primera capa completamente conectada
        x = self.fc1(fused_features)
        x = F.relu(x)
        x = self.dropout(x)
        
        # Segunda capa completamente conectada
        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # Tercera capa completamente conectada (salida)
        x = self.fc3(x)
        
        return x
    
    def predict_proba(self, spatial_input, connectivity_input, intensity_input):
        """
        Obtener probabilidades de predicción usando softmax
        """
        logits = self.forward(spatial_input, connectivity_input, intensity_input)
        return F.softmax(logits, dim=1)

In [13]:
pip install tqdm

Note: you may need to restart the kernel to use updated packages.


In [16]:
import os
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import time

class EEGFeatureDataset(Dataset):
    def __init__(self, signals, labels, spatial_extractor, intensity_extractor, connectivity_extractor, cache_path='eeg_features.pt'):
        self.cache_path = cache_path

        if os.path.exists(cache_path):
            print(f"📦 Cargando características preprocesadas desde: {cache_path}")
            cached = torch.load(cache_path)
            self.spatial_features = cached['spatial']
            self.intensity_features = cached['intensity']
            self.connectivity_features = cached['connectivity']
            self.targets = cached['labels']
            return

        # Si no existe el archivo cache, se hace el procesamiento
        print("⚙️ Iniciando extracción de características...")
        start_time = time.time()

        self.spatial_features = []
        self.intensity_features = []
        self.connectivity_features = []
        self.targets = []

        for i in tqdm(range(len(signals)), desc="Extrayendo EEG..."):
            x = signals[i]

            spatial = spatial_extractor.extract_features(x)
            intensity = intensity_extractor.extract_features(x)
            connectivity = connectivity_extractor.extract_features(x)

            label = int(labels[i][0] > 5)  # binarización de valence

            for j in range(spatial.shape[0]):
                self.spatial_features.append(torch.tensor(spatial[j], dtype=torch.float32))
                self.intensity_features.append(torch.tensor(intensity[j], dtype=torch.float32))
                self.connectivity_features.append(torch.tensor(connectivity[j], dtype=torch.float32))
                self.targets.append(torch.tensor(label, dtype=torch.long))

        # Guardar en disco
        print(f"💾 Guardando características extraídas en {cache_path}")
        torch.save({
            'spatial': self.spatial_features,
            'intensity': self.intensity_features,
            'connectivity': self.connectivity_features,
            'labels': self.targets
        }, cache_path)

        elapsed = time.time() - start_time
        print(f"✅ Extracción completada en {elapsed:.2f} segundos.")

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return (
            self.spatial_features[idx],
            self.connectivity_features[idx],
            self.intensity_features[idx],
            self.targets[idx]
        )


In [15]:
# Inicializar extractores
spatial_extractor = SpatialTemporalFExtractor()
intensity_extractor = IntensityDifferenceFExtractor()
connectivity_extractor = ConnectiveFExtractor()  # Suponemos que ya está implementado
connectivity_extractor.mode = "pcc"
# Crear dataset
dataset = EEGFeatureDataset(all_data, all_labels, spatial_extractor, intensity_extractor, connectivity_extractor)

# DataLoader
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

⚙️ Iniciando extracción de características...


Extrayendo EEG...:   2%|▏         | 27/1280 [03:13<2:29:10,  7.14s/it]

In [None]:
# Instanciar el modelo
model = STCCNN(num_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Definir loss y optimizador
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Entrenar
model.fit(train_loader, optimizer, criterion, device, epochs=10)
