In [1]:
import os
import glob
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear, Conv1d, BatchNorm1d, Dropout
from torch_geometric.nn import ChebConv, global_mean_pool  # <--- UPGRADE: ChebConv
from torch_geometric.data import Data, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
TRAIN_FOLDER = "SMNI_CMI_TRAIN"
TEST_FOLDER = "SMNI_CMI_TEST"
BATCH_SIZE = 16
EPOCHS = 30
LEARNING_RATE = 0.0005  
CORRELATION_THRESHOLD = 0.6  

In [3]:
STANDARD_CHANNELS = [
    'FP1', 'FP2', 'F7', 'F8', 'AF1', 'AF2', 'FZ', 'F4', 'F3', 'FC6', 'FC5', 'FC2', 'FC1', 
    'T8', 'T7', 'CZ', 'C3', 'C4', 'CP5', 'CP6', 'CP1', 'CP2', 'P3', 'P4', 'PZ', 'P8', 'P7', 
    'PO2', 'PO1', 'O2', 'O1', 'X', 'AF7', 'AF8', 'F5', 'F6', 'FT7', 'FT8', 'FPZ', 'FC4', 'FC3', 
    'C6', 'C5', 'F2', 'F1', 'TP8', 'TP7', 'AFZ', 'CP3', 'CP4', 'P5', 'P6', 'C1', 'C2', 'PO7', 
    'PO8', 'FCZ', 'POZ', 'OZ', 'P2', 'P1', 'CPZ', 'nd', 'Y'
]

In [None]:
def compute_correlation_matrix(dataset_list):
    """
    Computes the average correlation matrix across a subset of training data
    to determine which electrodes usually talk to each other.
    """
    print("Computing Functional Connectivity (Correlation Graph)...")
    num_nodes = 64
    sum_corr = np.zeros((num_nodes, num_nodes))
    count = 0
    
    sample_size = min(len(dataset_list), 50)
    
    for i in range(sample_size):

        data_x = dataset_list[i].x.numpy()
        
        corr = np.abs(np.corrcoef(data_x))
        
        corr = np.nan_to_num(corr)
        sum_corr += corr
        count += 1
        
    avg_corr = sum_corr / count
    return avg_corr

In [None]:
def get_functional_edges(dataset_list, threshold=0.5):
    """ Creates edge_index based on correlation > threshold """
    avg_corr = compute_correlation_matrix(dataset_list)
    
    rows, cols = np.where(avg_corr > threshold)
    
    mask = rows != cols
    rows = rows[mask]
    cols = cols[mask]
    
    edge_index = torch.tensor([rows, cols], dtype=torch.long)
    print(f"Graph Created! Connectivity Density: {len(rows)/(64*64):.2%}")
    return edge_index

In [None]:
def load_data(folder_path):
    dataset = []
    if not os.path.exists(folder_path): return []
    csv_files = glob.glob(os.path.join(folder_path, "*.csv"))
    
    for file_path in csv_files:
        try:
            df = pd.read_csv(file_path)
            df.columns = df.columns.str.strip()
            if 'sensor position' not in df.columns and 'sensor pos' in df.columns:
                df.rename(columns={'sensor pos': 'sensor position'}, inplace=True)

            grouped = df.groupby('trial number')
            for trial_num, trial_data in grouped:
                pivot_df = trial_data.pivot_table(index='sensor position', columns='sample num', values='sensor value')
                pivot_df = pivot_df.reindex(STANDARD_CHANNELS).fillna(0)
                
                if pivot_df.shape != (64, 256): continue
                
                subject_id = trial_data['subject identifier'].iloc[0]
                y_label = 1 if subject_id == 'a' else 0
                
                x = torch.tensor(pivot_df.values, dtype=torch.float)
                y = torch.tensor([y_label], dtype=torch.long)
                
                data = Data(x=x, y=y) 
                dataset.append(data)
        except: pass
    return dataset

In [None]:
class EEG_ChebNet(torch.nn.Module):
    def __init__(self, num_nodes=64, num_classes=2):
        super(EEG_ChebNet, self).__init__()
        
        self.conv1 = Conv1d(1, 16, kernel_size=10, stride=2)
        self.bn1 = BatchNorm1d(16)
        self.conv2 = Conv1d(16, 32, kernel_size=5, stride=2)
        self.bn2 = BatchNorm1d(32)
        
        self.flatten_size = 32 * 60 

        self.cheb1 = ChebConv(self.flatten_size, 128, K=3)
        self.cheb2 = ChebConv(128, 64, K=3)

        self.fc = Linear(64, num_classes)
        self.dropout = Dropout(p=0.5) 

    def forward(self, x, edge_index, batch):
        x = x.unsqueeze(1) 
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.dropout(x) 
        x = F.relu(self.bn2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.cheb1(x, edge_index))
        x = self.dropout(x)
        x = self.cheb2(x, edge_index)
        
        x = global_mean_pool(x, batch) 
        return self.fc(x)

In [8]:
if __name__ == "__main__":
    
    # Load raw data
    train_data_raw = load_data(TRAIN_FOLDER)
    test_data_raw = load_data(TEST_FOLDER)
    
    if not train_data_raw: 
        print("No data found."); exit()

    print("Generating Intelligent Graph Edges...")
    smart_edge_index = get_functional_edges(train_data_raw, threshold=CORRELATION_THRESHOLD)
    
    for d in train_data_raw: d.edge_index = smart_edge_index
    for d in test_data_raw: d.edge_index = smart_edge_index

    train_loader = DataLoader(train_data_raw, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data_raw, batch_size=BATCH_SIZE, shuffle=False)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EEG_ChebNet().to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) 
    criterion = torch.nn.CrossEntropyLoss()
    
    print("\n--- Starting Training (ChebNet) ---")
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        correct = 0; total = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (out.argmax(1) == batch.y).sum().item()
            total += batch.y.size(0)
            
        train_acc = correct/total
        
        # Validation
        model.eval()
        v_correct = 0; v_total = 0
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch)
                v_correct += (out.argmax(1) == batch.y).sum().item()
                v_total += batch.y.size(0)
        
        val_acc = v_correct/v_total
        print(f"Epoch {epoch+1:02d} | Loss: {total_loss/len(train_loader):.4f} | Train: {train_acc:.4f} | Test: {val_acc:.4f}")

Generating Intelligent Graph Edges...
Computing Functional Connectivity (Correlation Graph)...
Graph Created! Connectivity Density: 30.71%


  edge_index = torch.tensor([rows, cols], dtype=torch.long)
  train_loader = DataLoader(train_data_raw, batch_size=BATCH_SIZE, shuffle=True)
  test_loader = DataLoader(test_data_raw, batch_size=BATCH_SIZE, shuffle=False)



--- Starting Training (ChebNet) ---
Epoch 01 | Loss: 0.7546 | Train: 0.5791 | Test: 0.6729
Epoch 02 | Loss: 0.5929 | Train: 0.6859 | Test: 0.7146
Epoch 03 | Loss: 0.5481 | Train: 0.7286 | Test: 0.7167
Epoch 04 | Loss: 0.5288 | Train: 0.7244 | Test: 0.7417
Epoch 05 | Loss: 0.4750 | Train: 0.7799 | Test: 0.6875
Epoch 06 | Loss: 0.4963 | Train: 0.7286 | Test: 0.7229
Epoch 07 | Loss: 0.4084 | Train: 0.7991 | Test: 0.6917
Epoch 08 | Loss: 0.4189 | Train: 0.7991 | Test: 0.7167
Epoch 09 | Loss: 0.3633 | Train: 0.8397 | Test: 0.7250
Epoch 10 | Loss: 0.3808 | Train: 0.8162 | Test: 0.7125
Epoch 11 | Loss: 0.3478 | Train: 0.8419 | Test: 0.7125
Epoch 12 | Loss: 0.3264 | Train: 0.8568 | Test: 0.7104
Epoch 13 | Loss: 0.2714 | Train: 0.8846 | Test: 0.7375
Epoch 14 | Loss: 0.2519 | Train: 0.8803 | Test: 0.6750
Epoch 15 | Loss: 0.2422 | Train: 0.9103 | Test: 0.7333
Epoch 16 | Loss: 0.2536 | Train: 0.8953 | Test: 0.7167
Epoch 17 | Loss: 0.1746 | Train: 0.9316 | Test: 0.7146
Epoch 18 | Loss: 0.1355 | Tr