# NJU dataset Spatial MAPPing

### Test code

In [None]:
import sys
sys.path.insert(0, './learnable_spatial_mapping/')
from learnable_spatial_mapping.LSM import deepNetwork as lsm_network
from torchinfo import summary
import torch
import time

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# eeg = torch.randn(256, 32, 128*4, 8).to(device)
# numCategories = 2
# numConvolutinFilter = [8, 8, 12]
# winlen = 2
# numFCNeurons = [256, numCategories]
# imageSize = (12, 12)
# imageKernelSize = (13, 3, 3)

# model = lsm_network(
#     eeg,
#     numConvFilter=numConvolutinFilter,
#     winlen=winlen,
#     numFCNeurons=numFCNeurons,
#     numGroup=numCategories,
#     enableMaxPool=False,
#     imageSize=imageSize,
#     imageKernelSize=imageKernelSize,
#     device=device,
# )

# # print(summary(model, input_size=[(256, 32, 128*4, 8)], col_names=["input_size", "output_size", "num_params", "params_percent", "kernel_size"]))

# # """ ART Test """
# with torch.no_grad():
#     start_time = time.time()
#     for i in range(0, 5):
#         output = model(eeg)
#     end_time = time.time()
#     print(f"Training Time ={end_time-start_time}")


### Dataset 

In [None]:
import mne
import h5py
import numpy as np
import os
import scipy.io
import pandas as pd

import sys
sys.path.insert(0, './learnable_spatial_mapping/')
from learnable_spatial_mapping.LSM import deepNetwork as lsm_network
from torchinfo import summary
import torch
import time

# Variables
window_len_sec = 10
choice_direction = 90

# 使用範例
path = "./NJUNCA_preprocessed_arte_removed/"
expinfo_path = path + "expinfomat_csvs/"
mat_files = [f for f in os.listdir(path) if f.endswith('.mat')]
print(mat_files)    # 前兩個是info資料從第三個開始讀

subjects = 21
fs = 128  # sampling frequency
window_size = fs * window_len_sec  # 512
channels = 32  # assume full 32 channels
count = 0
side_dict = {
    "right": 0,
    "left": 1,
}

# Load data
all_segments = []
all_labels = []
all_trials = []
pre_trial_idx = 0

for subj in range(2, subjects):
    data_path = path + mat_files[subj]
    file = h5py.File(data_path, 'r')
    csv_path = expinfo_path + mat_files[subj].replace('.mat', '.csv')
    df = pd.read_csv(csv_path)['attended_lr']

    ref_data = file['data']
    ref_eeg = ref_data['eeg']
    ref_leftangle = ref_data['event']['leftWav']
    ref_rightangle = ref_data['event']['rightWav']
    ref_attenside = ref_data['event']['eeg']
    trials = len(ref_eeg[:])

    for trial in range(trials):
        try:
            left_angel_reg = file[ref_leftangle['value'][trial][0]]
            left_angel = file[left_angel_reg[0][0]][0][0]
            right_angel_reg = file[ref_rightangle['value'][trial][0]]
            right_angel = file[right_angel_reg[0][0]][0][0]

            if (left_angel*-1) != right_angel:
                # print(f"{mat_files[subj]}-{trial}: {left_angel}/{right_angel} not the same angle ")
                continue
            # if abs(right_angel) != 90:
            #     print(f"{mat_files[subj]}-{trial}: {left_angel}/{right_angel} not 90 degree")
            #     continue
            print(f"{mat_files[subj]}-{trial}: {left_angel}/{right_angel}")
            ref = ref_eeg[trial][0]
            eeg_data = np.array(file[ref][:]) # shape: 32, time
            trial_len = eeg_data.shape[1]
            
            # Segment into 4-second windows (non-overlapping)
            n_windows = trial_len // window_size
            
            for win in range(n_windows):
                segment = eeg_data[:, win * window_size : (win + 1) * window_size]  # shape: (channels, 512)
                count += 1
                all_segments.append(segment)
                all_labels.append(side_dict[df[trial]])

            # write trial info dict
            trial_info = {
                "subject_name": mat_files[subj].split('.')[0],
                "trial":trial,
                "attention_side": df[trial],
                "left_angel": left_angel,
                "right_angel": right_angel,
                "trial_start": pre_trial_idx,
                "trial_end": count,
            }
            all_trials.append(trial_info)
            pre_trial_idx = count 
        except:
            print(mat_files[subj], trial, 'load error')
    print(f"{mat_files[subj]} cumulated segments:{count}")



# Convert to numpy arrays
X = np.stack(all_segments)  # shape: (total_segments, channels, 512)
y = np.array(all_labels)    # shape: (total_segments,)

print("Data shape:", X.shape)
print("Labels shape:", y.shape)

group_ = [(2, 7), (8, 15), (16, 19), (21, 23), (25, 27)]  # 5 groups
group_index_list = [[] for _ in range(len(group_))]      # list of 5 empty lists

cum_trials = 0 
for idx, info_dict in enumerate(all_trials):
    # print(f"{idx}-{info_dict['subject_name']}/{info_dict['trial']} - cumulated trials rate:{info_dict['trial_end']/count}")
    if choice_direction != 0:
        if abs(info_dict['left_angel']) != choice_direction:
            continue

    subject_int = int(info_dict['subject_name'].replace('S', ''))

    # 分配到對應的 group
    for group_id, (low, high) in enumerate(group_):
        if low <= subject_int <= high:
            trial_range = range(info_dict['trial_start'], info_dict['trial_end'])
            # print(info_dict['trial_start'], info_dict['trial_end'])
            cum_trials += len(trial_range)
            group_index_list[group_id].extend(trial_range)
            break

print(f"Total trials: {cum_trials}")
cum_ratio = 0
for idx, g_l in enumerate(group_index_list):
    trial_ratio = len(g_l)/cum_trials
    cum_ratio += trial_ratio
    print(f"Group {idx+1}:(S{group_[idx][0]:02d}~S{group_[idx][1]:02d}) | total trials:{len(g_l)}, trials ratio:{trial_ratio*100:.1f}%, cumulated ratio:{cum_ratio*100:.1f}%")
    # print(len(g_l), trial_ratio, cum_ratio)


### Run

In [None]:
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score

batch_size = 64
num_epochs = 100

group_acc_list = []
pred_list = []
true_list = []

for valid_group in range(len(group_index_list)):
    print(f"--------------------------------------------------------------------------------------------------------------\n\n")
    print(f"\nGroup:{valid_group}\n")
    print(f"--------------------------------------------------------------------------------------------------------------\n\n")

    test_idx = group_index_list[valid_group]
    train_idx = [i for g, indices in enumerate(group_index_list) if g != valid_group for i in indices]

    X_train = X[train_idx].copy()
    X_test = X[test_idx].copy()
    train_y = y[train_idx]
    test_y = y[test_idx]

    print(f"{X_train.shape} - {X_test.shape} - {train_y.shape} - {test_y.shape}")
    train_dataset = EEGDataset(X_train, train_y)
    test_dataset = EEGDataset(X_test, test_y)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    val_acc, val_pred, val_true = run_train(train_loader, val_loader, winlen=window_len_sec, batch_size=batch_size, num_epochs=num_epochs)

    print(f"Group {valid_group} Accuracy: {val_acc:.4f}")
    group_acc_list.append(val_acc)
    pred_list.extend(val_pred)
    true_list.extend(val_true)

# === Summary ===
group_acc_array = np.array(group_acc_list)
std_acc = group_acc_array.std()
acc = accuracy_score(true_list, pred_list)

print("\n=== Cross-Validation Summary ===")
print(f"Accuracy: {acc:.3f}")
print(f"Group Std  Accuracy: {std_acc:.3f}")

### Run Train

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch
from torch.utils.data import Dataset
import numpy as np
from model.utils import *

class EEGDataset(Dataset):
    def __init__(self, X, y, apply_filter=True):
        self.X = X  # shape: (N, 32, 512)
        self.y = y  # shape: (N,)
        self.apply_filter = apply_filter
        if self.apply_filter:
            self.filter = filterBank()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        eeg = self.X[idx]  # shape: (32, 512)
        label = self.y[idx]

        if self.apply_filter:
            eeg_filtered = self.filter.__step__(eeg)  # shape: (4, 32, 512)
            eeg_filtered = np.transpose(eeg_filtered, (1, 2, 0))  # to (32, 512, 4)
        else:
            eeg_filtered = eeg  # shape: (32, 512)

        return torch.from_numpy(eeg_filtered).float(), torch.tensor(label).long()
    
def run_train(train_loader, val_loader, winlen=2, batch_size=64, num_epochs=100, learning_rate = 1e-4, early_stop_patience=30):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # === Hyperparameters ===

    l2_factor = 1e-5
    lr_drop_factor = 0.5
    patience = 10
    cooldown = 0
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    eeg = torch.randn(256, 32, 128*winlen, 4).to(device)
    numCategories = 2
    numConvolutinFilter = [8, 8, 12]
    numFCNeurons = [256, numCategories]
    imageSize = (12, 12)
    imageKernelSize = (13, 3, 3)

    model = lsm_network(
        eeg,
        numConvFilter=numConvolutinFilter,
        winlen=winlen,
        numFCNeurons=numFCNeurons,
        numGroup=numCategories,
        enableMaxPool=False,
        imageSize=imageSize,
        imageKernelSize=imageKernelSize,
        device=device,
    )

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_factor)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=lr_drop_factor, patience=patience, cooldown=cooldown,
        threshold=0.015, threshold_mode='abs', verbose=True
    )

    best_val_acc = 0.0
    best_val_pred = []
    best_val_true = []
    epochs_since_improvement = 0  # <<< 用于early stopping

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, pred = out.max(1)
            correct += (pred == y).sum().item()
            total += y.size(0)

        train_acc = correct / total
        avg_loss = running_loss / len(train_loader)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}")

        # === Validation ===
        model.eval()
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                _, pred = out.max(1)
                val_correct += (pred == y).sum().item()
                val_total += y.size(0)

                val_preds.append(pred.cpu())
                val_labels.append(y.cpu())

        val_acc = val_correct / val_total
        print(f"[Epoch {epoch+1}] Val Acc: {val_acc:.4f}")
        scheduler.step(val_acc)

        # Save best model and predictions
        if val_acc > best_val_acc + 1e-4:  # small epsilon to avoid float rounding issues
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pt")
            print(f"Best model saved at epoch {epoch+1}")
            best_val_pred = torch.cat(val_preds).numpy()
            best_val_true = torch.cat(val_labels).numpy()
            epochs_since_improvement = 0  # <<< reset counter
        else:
            epochs_since_improvement += 1  # <<< increment counter
            print(f"No improvement for {epochs_since_improvement} epoch(s)")

        # === Early stopping condition ===
        if epochs_since_improvement >= early_stop_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break


    # === End of Training ===
    return best_val_acc, best_val_pred, best_val_true

