<a href="https://colab.research.google.com/github/Justin-Hwang/EEG-AD-FTD-Detection/blob/main/Multiclass_Ablation_Studies.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 그중에 2025 Lab Research 폴더 안을 확인
!ls "/content/drive/MyDrive/2025_Lab_Research"

'AD vs FTD vs CN Binary Classification'   eeg_holdout-4.db
'Colab Files'				  eeg_holdout-5.db
'Data Preparation.gdoc'			  eeg_holdout-6.db
 eeg_dataset.py				  eeg_holdout-7.db
 EEGformer_model_training.ipynb		  eeg_holdout-8.db
 eegformer_optuna_cv_3.db		  eeg_holdout-9.db
 eegformer_optuna_cv_4.db		  eeg_holdout.db
 eegformer_optuna_cv_5.db		  eeg_holdout_fixed_1.db
 eeg_grid_search-10.db			  eeg_optuna_trial_1.db
 eeg_grid_search-11.db			  eeg_optuna_trial_2.db
 eeg_grid_search-12.db			  eeg_optuna_trial_3.db
 eeg_grid_search-13.db			 'EEG Transformer Architecture.gdoc'
 eeg_grid_search-14.db			 'Lab Info'
 eeg_grid_search-15.db			 'Lab Research Paper Review'
 eeg_grid_search-16.db			 'Meeting Note.gdoc'
 eeg_grid_search-17.db			  model-data
 eeg_grid_search-18.db			  model-data.zip
 eeg_grid_search-2.db			  model_optimized_2.py
 eeg_grid_search-3.db			  model_optimized_3.py
 eeg_grid_search-4.db			  model_optimized_4.py
 eeg_grid_search-5.db			  model_optimized_5.py
 eeg_grid_sear

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/2025_Lab_Research')

In [None]:
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Running on", DEVICE)  # → “cuda” 가 뜨면 GPU 정상

Running on cuda


In [None]:
import wandb
wandb.login()  # 첫 실행 시 API 키 입력

[34m[1mwandb[0m: Currently logged in as: [33mjh8032[0m ([33mjh8032-new-york-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!pip install optuna
!pip install wandb
!pip install mne



### Search the best hyperparameter using Hold-out set
- Use Block = 1
- Use Head = [2,3]
- Overfitting occurred

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4, 1e-3]
WD_CHOICES          = [1e-3, 5e-4, 1e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-2",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-2",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-2.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json
Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 03:30:35,554] A new study created in RDB with name: eeg_holdout_grid_search-2



===== Trial 0 =====
 lr=1.00e-03, wd=1.00e-03, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0711 acc=0.4210 | val_loss=1.0745 acc=0.4317 | time=246.8s
Epoch 002 | train_loss=1.0677 acc=0.4311 | val_loss=1.0742 acc=0.4317 | time=16.8s
Epoch 003 | train_loss=1.0686 acc=0.4311 | val_loss=1.0774 acc=0.4317 | time=17.0s
Epoch 004 | train_loss=1.0699 acc=0.4311 | val_loss=1.0864 acc=0.4317 | time=16.9s
Epoch 005 | train_loss=1.0675 acc=0.4311 | val_loss=1.0744 acc=0.4317 | time=16.8s
Epoch 006 | train_loss=1.0673 acc=0.4311 | val_loss=1.0762 acc=0.4317 | time=16.9s
Epoch 007 | train_loss=1.0675 acc=0.4311 | val_loss=1.0745 acc=0.4317 | time=16.9s
Epoch 008 | train_loss=1.0673 acc=0.4311 | val_loss=1.0763 acc=0.4317 | time=16.9s
Epoch 009 | train_loss=1.0673 acc=0.4311 | val_loss=1.0756 acc=0.4317 | time=16.7s
Epoch 010 | train_loss=1.0654 acc=0.4311 | val_loss=1.0720 acc=0.4317 | time=17.0s
Epoch 011 | train_loss=1.0257 acc=0.5045 | val_loss=1.0469 acc=0.4984 | time=16.7s
Epoch 012 | 

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▁▁▁▁▁▁▁▁▂▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇███
train_loss,███████████▇▆▆▆▆▆▅▅▅▄▄▄▄▃▃▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▃▅▅▆▆▆▆▇▆▇▇▇▅▇▇▇████
validation_loss,▅▅▅▅▅▅▅▅▅▅▅▄▃▃▃▃▂▂▂▁▁▁▃▃▂▁▄▂█▇

0,1
epoch,30.0
train_accuracy,0.87495
train_loss,0.31952
validation_accuracy,0.67857
validation_loss,1.17028


[I 2025-05-02 03:43:12,928] Trial 0 finished with value: 0.7875745651267824 and parameters: {'lr': 0.001, 'weight_decay': 0.001, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 0 with value: 0.7875745651267824.



===== Trial 1 =====
 lr=1.00e-03, wd=1.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0779 acc=0.4062 | val_loss=1.0814 acc=0.4317 | time=18.8s
Epoch 002 | train_loss=1.0707 acc=0.4311 | val_loss=1.0787 acc=0.4317 | time=18.4s
Epoch 003 | train_loss=1.0680 acc=0.4311 | val_loss=1.0807 acc=0.4317 | time=18.6s
Epoch 004 | train_loss=1.0707 acc=0.4311 | val_loss=1.0746 acc=0.4317 | time=18.5s
Epoch 005 | train_loss=1.0688 acc=0.4311 | val_loss=1.0752 acc=0.4317 | time=18.5s
Epoch 006 | train_loss=1.0689 acc=0.4311 | val_loss=1.0742 acc=0.4317 | time=18.7s
Epoch 007 | train_loss=1.0691 acc=0.4311 | val_loss=1.0754 acc=0.4317 | time=18.5s
Epoch 008 | train_loss=1.0678 acc=0.4311 | val_loss=1.0750 acc=0.4317 | time=18.6s
Epoch 009 | train_loss=1.0676 acc=0.4311 | val_loss=1.0741 acc=0.4317 | time=18.4s
Epoch 010 | train_loss=1.0614 acc=0.4447 | val_loss=1.0183 acc=0.5606 | time=18.7s
Epoch 011 | train_loss=0.9945 acc=0.5569 | val_loss=0.9825 acc=0.5637 | time=18.5s
Epoch 012 | t

[W 2025-05-02 03:55:14,332] Trial 1 failed with parameters: {'lr': 0.001, 'weight_decay': 0.0001, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-8-712a74a241c1>", line 137, in objective_holdout
    loss.backward()
  File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

KeyboardInterrupt: 

### Revised the EEGformer model
- Additional Dropout in CNNDecoder
- Dropout Rate = 0.3

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4, 1e-3]
WD_CHOICES          = [1e-3, 5e-4, 1e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-3",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-3",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-3.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 04:07:56,477] A new study created in RDB with name: eeg_holdout_grid_search-3



===== Trial 0 =====
 lr=1.00e-03, wd=1.00e-03, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0777 acc=0.4276 | val_loss=1.0760 acc=0.4317 | time=16.8s
Epoch 002 | train_loss=1.0727 acc=0.4252 | val_loss=1.0768 acc=0.4317 | time=16.8s
Epoch 003 | train_loss=1.0718 acc=0.4299 | val_loss=1.0752 acc=0.4317 | time=16.8s
Epoch 004 | train_loss=1.0675 acc=0.4283 | val_loss=1.0759 acc=0.4317 | time=16.9s
Epoch 005 | train_loss=1.0698 acc=0.4136 | val_loss=1.0764 acc=0.4317 | time=16.8s
Epoch 006 | train_loss=1.0680 acc=0.4307 | val_loss=1.0738 acc=0.4317 | time=17.1s
Epoch 007 | train_loss=1.0673 acc=0.4283 | val_loss=1.0761 acc=0.4317 | time=16.7s
Epoch 008 | train_loss=1.0689 acc=0.4295 | val_loss=1.0740 acc=0.4317 | time=16.7s
Epoch 009 | train_loss=1.0636 acc=0.4217 | val_loss=1.0790 acc=0.4317 | time=17.1s
Epoch 010 | train_loss=1.0056 acc=0.5379 | val_loss=0.9730 acc=0.5543 | time=16.8s
Epoch 011 | train_loss=0.9493 acc=0.5814 | val_loss=0.9982 acc=0.5326 | time=16.8s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▆▆▆▆▇▇▇██
train_accuracy,▁▁▁▁▁▁▁▁▁▄▄▅▅▅▆▆▆▇▇▇▇███
train_loss,█████████▇▆▆▅▅▄▃▃▃▃▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▅▄▅▅▇▇▆▆▇▇▇▇█▇▇
validation_loss,█████████▄▅▃▂▁▁▂▂▄▃▇▁▅▅▆

0,1
epoch,24.0
train_accuracy,0.75107
train_loss,0.62836
validation_accuracy,0.61025
validation_loss,1.02721


[I 2025-05-02 04:14:46,615] Trial 0 finished with value: 0.9058446770622617 and parameters: {'lr': 0.001, 'weight_decay': 0.001, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 0 with value: 0.9058446770622617.



===== Trial 1 =====
 lr=1.00e-03, wd=1.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0718 acc=0.4299 | val_loss=1.0743 acc=0.4317 | time=18.6s
Epoch 002 | train_loss=1.0691 acc=0.4330 | val_loss=1.0817 acc=0.4317 | time=18.4s
Epoch 003 | train_loss=1.0666 acc=0.4307 | val_loss=1.0743 acc=0.4317 | time=18.5s
Epoch 004 | train_loss=1.0663 acc=0.4311 | val_loss=1.0765 acc=0.4317 | time=18.5s
Epoch 005 | train_loss=1.0686 acc=0.4311 | val_loss=1.0746 acc=0.4317 | time=18.5s
Epoch 006 | train_loss=1.0671 acc=0.4311 | val_loss=1.0758 acc=0.4317 | time=18.4s
Epoch 007 | train_loss=1.0684 acc=0.4311 | val_loss=1.0768 acc=0.4317 | time=18.6s
Epoch 008 | train_loss=1.0673 acc=0.4311 | val_loss=1.0744 acc=0.4317 | time=18.5s
Epoch 009 | train_loss=1.0661 acc=0.4311 | val_loss=1.0746 acc=0.4317 | time=18.6s
Epoch 010 | train_loss=1.0662 acc=0.4311 | val_loss=1.0755 acc=0.4317 | time=18.5s
Epoch 011 | train_loss=1.0689 acc=0.4311 | val_loss=1.0743 acc=0.4317 | time=18.6s
★ Early stopp

0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
train_accuracy,▁█▃▄▄▄▄▄▄▄▄
train_loss,█▅▂▁▄▂▄▂▁▁▄
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▁█▁▃▁▂▃▁▁▂▁

0,1
epoch,11.0
train_accuracy,0.43107
train_loss,1.06887
validation_accuracy,0.43168
validation_loss,1.0743


[I 2025-05-02 04:18:12,383] Trial 1 finished with value: 1.074283242225647 and parameters: {'lr': 0.001, 'weight_decay': 0.0001, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.9058446770622617.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0721 acc=0.4202 | val_loss=1.0784 acc=0.4317 | time=18.5s
Epoch 002 | train_loss=1.0684 acc=0.4237 | val_loss=1.0759 acc=0.4317 | time=18.4s
Epoch 003 | train_loss=1.0679 acc=0.4307 | val_loss=1.0744 acc=0.4317 | time=18.5s
Epoch 004 | train_loss=1.0677 acc=0.4303 | val_loss=1.0745 acc=0.4317 | time=18.5s
Epoch 005 | train_loss=1.0698 acc=0.4307 | val_loss=1.0742 acc=0.4317 | time=18.4s
Epoch 006 | train_loss=1.0670 acc=0.4315 | val_loss=1.0744 acc=0.4317 | time=18.5s
Epoch 007 | train_loss=1.0683 acc=0.4311 | val_loss=1.0742 acc=0.4317 | time=18.6s
Epoch 008 | train_loss=1.0684 acc=0.4311 | val_loss=1.0733 acc=0.4317 | time=18.4s
Epoch 009 | train_loss=1.0643 acc=0.4326 | val_loss=1.0717 acc=0.4317 | time=18.7s
Epoch 010 | train_loss=1.0662 acc=0.4303 | val_loss=1.0735 acc=0.4317 | time=18.5s
Epoch 011 | train_loss=1.0234 acc=0.4835 | val_loss=1.0725 acc=0.5512 | time=18.6s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_accuracy,▁▁▁▁▁▁▁▁▁▁▃▅▅▅▆▆▆▆▆▇▇▇▇▇██
train_loss,██████████▇▆▆▅▅▅▄▄▄▄▃▃▂▂▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▆▅▇▇▇███▇██▇▇▇▇▆
validation_loss,███████████▄▃▁▂▁▁▂▁▁▃▄▇▄▇▆

0,1
epoch,26.0
train_accuracy,0.71495
train_loss,0.6408
validation_accuracy,0.56366
validation_loss,1.03248


[I 2025-05-02 04:26:18,665] Trial 2 finished with value: 0.9298576144945054 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.9058446770622617.



===== Trial 3 =====
 lr=1.00e-03, wd=5.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0763 acc=0.4089 | val_loss=1.0811 acc=0.4317 | time=18.5s
Epoch 002 | train_loss=1.0717 acc=0.4237 | val_loss=1.0799 acc=0.4317 | time=18.4s
Epoch 003 | train_loss=1.0723 acc=0.4303 | val_loss=1.0741 acc=0.4317 | time=18.6s
Epoch 004 | train_loss=1.0614 acc=0.4501 | val_loss=1.0737 acc=0.4317 | time=18.5s
Epoch 005 | train_loss=0.9815 acc=0.5635 | val_loss=0.9694 acc=0.5699 | time=18.8s
Epoch 006 | train_loss=0.9453 acc=0.5864 | val_loss=0.9613 acc=0.5481 | time=18.7s
Epoch 007 | train_loss=0.9344 acc=0.5880 | val_loss=0.9573 acc=0.5745 | time=19.2s
Epoch 008 | train_loss=0.9264 acc=0.5876 | val_loss=0.9573 acc=0.5823 | time=18.8s
Epoch 009 | train_loss=0.9008 acc=0.6101 | val_loss=0.9391 acc=0.5683 | time=18.9s
Epoch 010 | train_loss=0.9062 acc=0.6016 | val_loss=0.9219 acc=0.5839 | time=18.8s
Epoch 011 | train_loss=0.8640 acc=0.6245 | val_loss=0.9408 acc=0.6040 | time=18.9s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_accuracy,▁▁▂▂▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_loss,████▆▆▆▅▅▅▄▄▄▃▃▃▂▂▁▁
validation_accuracy,▁▁▁▁▇▆▇▇▇▇█▇▇▇▇▇█▇█▇
validation_loss,▆▆▆▆▃▂▂▂▂▁▂▁▃▃▁▃▃▆▃█

0,1
epoch,20.0
train_accuracy,0.70019
train_loss,0.6682
validation_accuracy,0.57919
validation_loss,1.12586


[I 2025-05-02 04:32:36,865] Trial 3 finished with value: 0.9219413456462678 and parameters: {'lr': 0.001, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.9058446770622617.



===== Trial 4 =====
 lr=5.00e-04, wd=1.00e-04, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0698 acc=0.4303 | val_loss=1.0751 acc=0.4317 | time=17.1s
Epoch 002 | train_loss=1.0701 acc=0.4307 | val_loss=1.0748 acc=0.4317 | time=17.4s
Epoch 003 | train_loss=1.0715 acc=0.4311 | val_loss=1.0756 acc=0.4317 | time=17.2s
Epoch 004 | train_loss=1.0709 acc=0.4318 | val_loss=1.0752 acc=0.4317 | time=17.3s
Epoch 005 | train_loss=1.0664 acc=0.4311 | val_loss=1.0785 acc=0.4317 | time=17.3s
Epoch 006 | train_loss=1.0692 acc=0.4311 | val_loss=1.0751 acc=0.4317 | time=17.1s
Epoch 007 | train_loss=1.0686 acc=0.4311 | val_loss=1.0743 acc=0.4317 | time=17.2s
Epoch 008 | train_loss=1.0689 acc=0.4311 | val_loss=1.0750 acc=0.4317 | time=16.9s
Epoch 009 | train_loss=1.0677 acc=0.4311 | val_loss=1.0745 acc=0.4317 | time=17.2s
Epoch 010 | train_loss=1.0680 acc=0.4311 | val_loss=1.0744 acc=0.4317 | time=17.0s
Epoch 011 | train_loss=1.0669 acc=0.4311 | val_loss=1.0755 acc=0.4317 | time=17.3s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇████
train_loss,████████████▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▆▅▆▅▆▇▇▇███▇▇█▇█▇▇███▇█
validation_loss,███████████▇▅▅▃▃▃▂▁▁▂▁▂▂▁▂▃▁▃▂▂▄▆▅▆

0,1
epoch,35.0
train_accuracy,0.82563
train_loss,0.45747
validation_accuracy,0.62888
validation_loss,1.02684


[I 2025-05-02 04:42:41,069] Trial 4 finished with value: 0.8606603259132022 and parameters: {'lr': 0.0005, 'weight_decay': 0.0001, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 4 with value: 0.8606603259132022.



===== Trial 5 =====
 lr=1.00e-03, wd=1.00e-04, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0829 acc=0.4016 | val_loss=1.0754 acc=0.4317 | time=17.1s
Epoch 002 | train_loss=1.0687 acc=0.4214 | val_loss=1.0755 acc=0.4317 | time=17.5s
Epoch 003 | train_loss=1.0727 acc=0.4190 | val_loss=1.0746 acc=0.4317 | time=17.2s
Epoch 004 | train_loss=1.0703 acc=0.4318 | val_loss=1.0745 acc=0.4317 | time=16.9s


[W 2025-05-02 04:44:01,082] Trial 5 failed with parameters: {'lr': 0.001, 'weight_decay': 0.0001, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-8-e29a196a779c>", line 140, in objective_holdout
    tloss    += loss.item()
                ^^^^^^^^^^^
KeyboardInterrupt
[W 2025-05-02 04:44:01,085] Trial 5 failed with value None.


KeyboardInterrupt: 

### Revised the model due to overfitting
- CNNDecoder dropout rate = 0.5
- Increase the Weight Decay

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4, 1e-3]
WD_CHOICES          = [5e-3, 5e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [1, 2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-4",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-4",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-4.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 05:02:56,608] A new study created in RDB with name: eeg_holdout_grid_search-4



===== Trial 0 =====
 lr=1.00e-03, wd=5.00e-03, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0764 acc=0.4101 | val_loss=1.0752 acc=0.4317 | time=16.6s
Epoch 002 | train_loss=1.0723 acc=0.4179 | val_loss=1.0758 acc=0.4317 | time=16.8s
Epoch 003 | train_loss=1.0691 acc=0.4287 | val_loss=1.0744 acc=0.4317 | time=16.8s
Epoch 004 | train_loss=1.0667 acc=0.4322 | val_loss=1.0742 acc=0.4317 | time=16.5s
Epoch 005 | train_loss=1.0709 acc=0.4322 | val_loss=1.0745 acc=0.4317 | time=16.7s
Epoch 006 | train_loss=1.0680 acc=0.4311 | val_loss=1.0747 acc=0.4317 | time=16.9s
Epoch 007 | train_loss=1.0678 acc=0.4307 | val_loss=1.0746 acc=0.4317 | time=16.8s
Epoch 008 | train_loss=1.0664 acc=0.4303 | val_loss=1.0743 acc=0.4317 | time=16.5s
Epoch 009 | train_loss=1.0674 acc=0.4311 | val_loss=1.0750 acc=0.4317 | time=16.7s
Epoch 010 | train_loss=1.0671 acc=0.4311 | val_loss=1.0716 acc=0.4317 | time=16.4s
Epoch 011 | train_loss=1.0626 acc=0.4322 | val_loss=1.0571 acc=0.4317 | time=16.7s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▁▁▁▁▁▁▁▁▁▂▂▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████
train_loss,███████████▇▇▇▆▅▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▂▂▂▅▆▆▆▆▆▇▇▆▆█▇▇▇██▇▇██▇██████
validation_loss,██████████▇▇▇▆▅▄▄▃▄▃▃▂▃▂▂▂▃▁▂▂▁▁▁▂▂▁▁▂▁▂

0,1
epoch,45.0
train_accuracy,0.79883
train_loss,0.48234
validation_accuracy,0.72205
validation_loss,0.68842


[I 2025-05-02 05:15:34,806] Trial 0 finished with value: 0.6386472043536958 and parameters: {'lr': 0.001, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 0 with value: 0.6386472043536958.



===== Trial 1 =====
 lr=1.00e-03, wd=5.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0796 acc=0.4132 | val_loss=1.0757 acc=0.4317 | time=18.5s
Epoch 002 | train_loss=1.0757 acc=0.4140 | val_loss=1.0742 acc=0.4317 | time=18.6s
Epoch 003 | train_loss=1.0726 acc=0.4194 | val_loss=1.0750 acc=0.4317 | time=18.5s
Epoch 004 | train_loss=1.0690 acc=0.4210 | val_loss=1.0739 acc=0.4317 | time=18.9s
Epoch 005 | train_loss=1.0596 acc=0.4501 | val_loss=1.0207 acc=0.5668 | time=18.7s
Epoch 006 | train_loss=1.0027 acc=0.5468 | val_loss=1.0058 acc=0.5621 | time=18.7s
Epoch 007 | train_loss=0.9713 acc=0.5728 | val_loss=0.9684 acc=0.5264 | time=18.6s
Epoch 008 | train_loss=0.9407 acc=0.5856 | val_loss=0.9816 acc=0.5792 | time=18.6s
Epoch 009 | train_loss=0.9147 acc=0.6031 | val_loss=0.9445 acc=0.5870 | time=18.7s
Epoch 010 | train_loss=0.8854 acc=0.6136 | val_loss=0.9190 acc=0.5839 | time=18.7s
Epoch 011 | train_loss=0.8837 acc=0.6113 | val_loss=0.9402 acc=0.5839 | time=18.7s
Epoch 012 | t

[W 2025-05-02 05:19:21,532] Trial 1 failed with parameters: {'lr': 0.001, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "<ipython-input-8-b80845e0b137>", line 140, in objective_holdout
    tloss    += loss.item()
                ^^^^^^^^^^^
KeyboardInterrupt
[W 2025-05-02 05:19:21,535] Trial 1 failed with value None.


KeyboardInterrupt: 

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4, 1e-3]
WD_CHOICES          = [5e-3, 5e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [1, 2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-4",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-4",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-4.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


### Stronger Weight Decay
- Block = 1, Head = 1


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [1e-3]
WD_CHOICES          = [1e-2]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [1]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-5",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-5",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-5.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 05:23:41,610] A new study created in RDB with name: eeg_holdout_grid_search-5



===== Trial 0 =====
 lr=1.00e-03, wd=1.00e-02, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0783 acc=0.4070 | val_loss=1.0755 acc=0.4317 | time=16.7s
Epoch 002 | train_loss=1.0727 acc=0.4190 | val_loss=1.0761 acc=0.4317 | time=16.6s
Epoch 003 | train_loss=1.0711 acc=0.4249 | val_loss=1.0742 acc=0.4317 | time=16.5s
Epoch 004 | train_loss=1.0682 acc=0.4338 | val_loss=1.0744 acc=0.4317 | time=16.5s
Epoch 005 | train_loss=1.0714 acc=0.4307 | val_loss=1.0764 acc=0.4317 | time=16.6s
Epoch 006 | train_loss=1.0672 acc=0.4307 | val_loss=1.0705 acc=0.4317 | time=16.7s
Epoch 007 | train_loss=1.0533 acc=0.4598 | val_loss=1.0750 acc=0.4317 | time=17.0s
Epoch 008 | train_loss=1.0687 acc=0.4280 | val_loss=1.0671 acc=0.4317 | time=16.6s
Epoch 009 | train_loss=1.0299 acc=0.4955 | val_loss=1.0729 acc=0.3416 | time=16.4s
Epoch 010 | train_loss=1.0026 acc=0.5472 | val_loss=1.0091 acc=0.5435 | time=16.6s
Epoch 011 | train_loss=0.9708 acc=0.5693 | val_loss=0.9890 acc=0.5512 | time=16.8s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▁▁▁▁▂▁▂▃▃▃▄▄▄▄▄▅▄▅▅▅▆▆▆▅▆▆▇▇███
train_loss,█████████▇▇▇▇▇▆▆▆▆▆▅▅▄▄▃▄▅▄▄▃▂▂▁▁
validation_accuracy,▃▃▃▃▃▃▃▃▁▆▆▅▆▆▆▇▆▇▇▇▆▇██▇▆▇█▇█▇▇▇
validation_loss,▃▃▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▂▃▂▁▁▄▃▃▂▄▅▆▇█

0,1
epoch,33.0
train_accuracy,0.89553
train_loss,0.27233
validation_accuracy,0.61801
validation_loss,1.52626


[I 2025-05-02 05:32:55,611] Trial 0 finished with value: 0.851824280761537 and parameters: {'lr': 0.001, 'weight_decay': 0.01, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 0 with value: 0.851824280761537.



===== Best Trial =====
best_val_loss       = 0.851824
best_train_loss     = 0.582561
best_train_accuracy = 0.7417
best_val_accuracy   = 0.6537
best params:
  lr: 0.001
  weight_decay: 0.01
  num_blocks: 1
  num_heads: 1
  num_segments: 5


### Even though setting stronger weight decay, the model overfitting
- This is because I didn't change the learning rate according to the weight decay. The learning rate should be addressed to decreased as the weight decay set high

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [1e-2]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [1]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-6",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-6",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-6.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


[I 2025-05-02 05:43:13,744] A new study created in RDB with name: eeg_holdout_grid_search-6



===== Trial 0 =====
 lr=5.00e-04, wd=1.00e-02, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0805 acc=0.3907 | val_loss=1.0744 acc=0.4317 | time=16.8s
Epoch 002 | train_loss=1.0749 acc=0.4117 | val_loss=1.0745 acc=0.4317 | time=16.5s
Epoch 003 | train_loss=1.0699 acc=0.4252 | val_loss=1.0763 acc=0.4317 | time=16.9s
Epoch 004 | train_loss=1.0696 acc=0.4210 | val_loss=1.0765 acc=0.4317 | time=16.8s
Epoch 005 | train_loss=1.0736 acc=0.4307 | val_loss=1.0746 acc=0.4317 | time=16.6s
Epoch 006 | train_loss=1.0723 acc=0.4183 | val_loss=1.0760 acc=0.4317 | time=16.7s
Epoch 007 | train_loss=1.0714 acc=0.4315 | val_loss=1.0738 acc=0.4317 | time=16.6s
Epoch 008 | train_loss=1.0642 acc=0.4384 | val_loss=1.0558 acc=0.4317 | time=16.7s
Epoch 009 | train_loss=1.0025 acc=0.5483 | val_loss=0.9978 acc=0.5062 | time=16.7s
Epoch 010 | train_loss=0.9708 acc=0.5771 | val_loss=0.9783 acc=0.5233 | time=16.4s
Epoch 011 | train_loss=0.9411 acc=0.5899 | val_loss=0.9799 acc=0.5730 | time=16.8s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▁▁▂▁▂▂▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇████
train_loss,████████▇▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▃▃▅▅▅▅▅▆▇▇▆▇▇████▇█▇████
validation_loss,▇▇▇▇▇▇▇▇▆▅▅▅▅▄▄▄▃▃▅▃▁▁▄▃▃▂▄▅▅▆▇█

0,1
epoch,32.0
train_accuracy,0.92427
train_loss,0.22318
validation_accuracy,0.71118
validation_loss,1.12488


[I 2025-05-02 05:52:09,596] Trial 0 finished with value: 0.7523488317217145 and parameters: {'lr': 0.0005, 'weight_decay': 0.01, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 0 with value: 0.7523488317217145.



===== Best Trial =====
best_val_loss       = 0.752349
best_train_loss     = 0.569233
best_train_accuracy = 0.7717
best_val_accuracy   = 0.6988
best params:
  lr: 0.0005
  weight_decay: 0.01
  num_blocks: 1
  num_heads: 1
  num_segments: 5


### Add drop out in transformer model
- After Depthwise Conv(ODCM)
- After TransformerBlock Attention output (before residual connection)
- After Token Embedding (RTM/STM/TTM)
- After CNNDecoder Conv layers


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-3, 5e-4, 5e-5]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [1]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-7",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-7",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-7.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


[I 2025-05-02 06:28:51,249] A new study created in RDB with name: eeg_holdout_grid_search-7



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-05, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0794 acc=0.4113 | val_loss=1.0761 acc=0.4317 | time=17.1s
Epoch 002 | train_loss=1.0762 acc=0.4070 | val_loss=1.0751 acc=0.4317 | time=17.0s
Epoch 003 | train_loss=1.0725 acc=0.4171 | val_loss=1.0743 acc=0.4317 | time=16.7s
Epoch 004 | train_loss=1.0737 acc=0.4272 | val_loss=1.0743 acc=0.4317 | time=16.8s
Epoch 005 | train_loss=1.0670 acc=0.4338 | val_loss=1.0746 acc=0.4317 | time=16.7s
Epoch 006 | train_loss=1.0716 acc=0.4322 | val_loss=1.0743 acc=0.4317 | time=17.0s
Epoch 007 | train_loss=1.0687 acc=0.4303 | val_loss=1.0755 acc=0.4317 | time=16.5s
Epoch 008 | train_loss=1.0709 acc=0.4315 | val_loss=1.0745 acc=0.4317 | time=16.9s
Epoch 009 | train_loss=1.0701 acc=0.4315 | val_loss=1.0768 acc=0.4317 | time=16.6s
Epoch 010 | train_loss=1.0700 acc=0.4307 | val_loss=1.0745 acc=0.4317 | time=16.6s
Epoch 011 | train_loss=1.0698 acc=0.4307 | val_loss=1.0742 acc=0.4317 | time=16.7s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_accuracy,▂▁▄▆██▇▇▇▇▇▇▇▇▇▇▆▇█▇▇█▇▇▇▇
train_loss,█▆▄▅▂▄▃▄▃▃▃▁▂▄▃▂▄▂▃▁▃▃▂▂▁▂
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▇▅▃▃▄▃▅▄█▄▃▃▄▄▃▁▃▅▃█▃▄▃▃▄▃

0,1
epoch,26.0
train_accuracy,0.43107
train_loss,1.06758
validation_accuracy,0.43168
validation_loss,1.07404


[I 2025-05-02 06:36:12,736] Trial 0 finished with value: 1.0731083296594166 and parameters: {'lr': 0.0005, 'weight_decay': 5e-05, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 0 with value: 1.0731083296594166.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-04, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0815 acc=0.3860 | val_loss=1.0746 acc=0.4317 | time=17.2s
Epoch 002 | train_loss=1.0773 acc=0.3992 | val_loss=1.0749 acc=0.4317 | time=17.3s
Epoch 003 | train_loss=1.0770 acc=0.4000 | val_loss=1.0753 acc=0.4317 | time=16.8s
Epoch 004 | train_loss=1.0728 acc=0.4171 | val_loss=1.0754 acc=0.4317 | time=17.4s
Epoch 005 | train_loss=1.0720 acc=0.4163 | val_loss=1.0745 acc=0.4317 | time=17.7s
Epoch 006 | train_loss=1.0694 acc=0.4268 | val_loss=1.0744 acc=0.4317 | time=17.5s
Epoch 007 | train_loss=1.0726 acc=0.4171 | val_loss=1.0744 acc=0.4317 | time=17.3s
Epoch 008 | train_loss=1.0718 acc=0.4264 | val_loss=1.0747 acc=0.4317 | time=17.1s
Epoch 009 | train_loss=1.0654 acc=0.4497 | val_loss=1.0609 acc=0.4596 | time=17.3s
Epoch 010 | train_loss=1.0103 acc=0.5414 | val_loss=0.9821 acc=0.5714 | time=17.1s
Epoch 011 | train_loss=0.9592 acc=0.5790 | val_loss=0.9560 acc=0.5652 | time=17.3s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_accuracy,▁▁▁▂▂▂▂▂▂▄▅▅▅▅▆▆▆▆▆▇▇▇▇██
train_loss,█████████▇▆▆▆▅▅▄▄▃▃▃▃▂▂▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▂▆▆▆▇▇▇▇▇▇██▆█▇█▇
validation_loss,▅▅▅▅▅▅▅▅▄▃▂▂▂▁▁▂▁▂▃▄▂▆▅▇█

0,1
epoch,25.0
train_accuracy,0.75534
train_loss,0.57294
validation_accuracy,0.60714
validation_loss,1.20127


[I 2025-05-02 06:43:21,365] Trial 1 finished with value: 0.9222010232153393 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 1 with value: 0.9222010232153393.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=1, segs=5
Epoch 001 | train_loss=1.0883 acc=0.3938 | val_loss=1.0766 acc=0.4317 | time=16.6s
Epoch 002 | train_loss=1.0721 acc=0.4369 | val_loss=1.0758 acc=0.4317 | time=16.8s
Epoch 003 | train_loss=1.0739 acc=0.4159 | val_loss=1.0742 acc=0.4317 | time=16.8s
Epoch 004 | train_loss=1.0728 acc=0.4229 | val_loss=1.0748 acc=0.4317 | time=16.6s
Epoch 005 | train_loss=1.0689 acc=0.4295 | val_loss=1.0751 acc=0.4317 | time=16.7s
Epoch 006 | train_loss=1.0692 acc=0.4322 | val_loss=1.0774 acc=0.4317 | time=16.9s
Epoch 007 | train_loss=1.0713 acc=0.4307 | val_loss=1.0742 acc=0.4317 | time=16.9s
Epoch 008 | train_loss=1.0713 acc=0.4311 | val_loss=1.0744 acc=0.4317 | time=16.8s
Epoch 009 | train_loss=1.0664 acc=0.4326 | val_loss=1.0754 acc=0.4317 | time=17.0s
Epoch 010 | train_loss=1.0663 acc=0.4311 | val_loss=1.0743 acc=0.4317 | time=17.3s
Epoch 011 | train_loss=1.0685 acc=0.4318 | val_loss=1.0742 acc=0.4317 | time=17.1s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▂▁▁▂▂▂▂▂▂▂▂▂▄▄▄▅▅▅▅▆▅▆▆▆▆▇▇▇▇████
train_loss,█████████████▇▇▆▆▆▅▅▄▅▄▄▄▃▃▃▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▆▆▆▇▆▇▇▇▇▇███▇████
validation_loss,████████████▆▅▅▄▃▃▂▂▃▁▁▁▂▂▃▂▃▄▇▄▆▇

0,1
epoch,34.0
train_accuracy,0.85165
train_loss,0.3828
validation_accuracy,0.68168
validation_loss,1.05871


[I 2025-05-02 06:52:57,979] Trial 2 finished with value: 0.8170908462433588 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 1, 'num_segments': 5}. Best is trial 2 with value: 0.8170908462433588.



===== Best Trial =====
best_val_loss       = 0.817091
best_train_loss     = 0.669932
best_train_accuracy = 0.7247
best_val_accuracy   = 0.6553
best params:
  lr: 0.0005
  weight_decay: 0.005
  num_blocks: 1
  num_heads: 1
  num_segments: 5


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3, 4]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-8",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-8",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-8.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


[I 2025-05-02 07:10:43,774] A new study created in RDB with name: eeg_holdout_grid_search-8



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-04, blocks=1, heads=4, segs=5
Epoch 001 | train_loss=1.0736 acc=0.4175 | val_loss=1.0744 acc=0.4317 | time=21.9s
Epoch 002 | train_loss=1.0740 acc=0.4047 | val_loss=1.0769 acc=0.4317 | time=21.7s
Epoch 003 | train_loss=1.0730 acc=0.4175 | val_loss=1.0748 acc=0.4317 | time=21.6s
Epoch 004 | train_loss=1.0718 acc=0.4291 | val_loss=1.0749 acc=0.4317 | time=21.6s
Epoch 005 | train_loss=1.0680 acc=0.4260 | val_loss=1.0744 acc=0.4317 | time=21.7s
Epoch 006 | train_loss=1.0693 acc=0.4311 | val_loss=1.0745 acc=0.4317 | time=21.4s
Epoch 007 | train_loss=1.0675 acc=0.4280 | val_loss=1.0751 acc=0.4317 | time=21.8s
Epoch 008 | train_loss=1.0679 acc=0.4318 | val_loss=1.0743 acc=0.4317 | time=21.7s
Epoch 009 | train_loss=1.0686 acc=0.4291 | val_loss=1.0748 acc=0.4317 | time=21.7s
Epoch 010 | train_loss=1.0699 acc=0.4315 | val_loss=1.0742 acc=0.4317 | time=21.9s
Epoch 011 | train_loss=1.0688 acc=0.4303 | val_loss=1.0748 acc=0.4317 | time=21.6s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇██
train_accuracy,▄▁▄▇▆█▇█▇██████████████
train_loss,██▇▆▃▄▂▃▃▅▄▄▄▂▂▃▃▄▁▅▄▃▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▂█▃▃▁▂▃▁▂▁▃▁▁▃▂▂▁▁▂▁▁▁▁

0,1
epoch,23.0
train_accuracy,0.43107
train_loss,1.06586
validation_accuracy,0.43168
validation_loss,1.07439


[I 2025-05-02 07:19:04,836] Trial 0 finished with value: 1.074195731253851 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 4, 'num_segments': 5}. Best is trial 0 with value: 1.074195731253851.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-04, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0770 acc=0.4027 | val_loss=1.0778 acc=0.4317 | time=18.5s
Epoch 002 | train_loss=1.0740 acc=0.4175 | val_loss=1.0748 acc=0.4317 | time=18.8s
Epoch 003 | train_loss=1.0700 acc=0.4264 | val_loss=1.0751 acc=0.4317 | time=18.7s
Epoch 004 | train_loss=1.0710 acc=0.4241 | val_loss=1.0743 acc=0.4317 | time=18.8s
Epoch 005 | train_loss=1.0688 acc=0.4276 | val_loss=1.0743 acc=0.4317 | time=18.5s
Epoch 006 | train_loss=1.0714 acc=0.4256 | val_loss=1.0742 acc=0.4317 | time=18.5s
Epoch 007 | train_loss=1.0697 acc=0.4315 | val_loss=1.0742 acc=0.4317 | time=18.5s
Epoch 008 | train_loss=1.0687 acc=0.4318 | val_loss=1.0753 acc=0.4317 | time=18.5s
Epoch 009 | train_loss=1.0681 acc=0.4303 | val_loss=1.0748 acc=0.4317 | time=18.7s
Epoch 010 | train_loss=1.0681 acc=0.4303 | val_loss=1.0744 acc=0.4317 | time=18.6s
Epoch 011 | train_loss=1.0683 acc=0.4287 | val_loss=1.0745 acc=0.4317 | time=18.7s
Epoch 012 | t

0,1
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
train_accuracy,▁▅▇▆▇▇████▇█████
train_loss,█▆▃▄▃▄▃▂▂▂▂▃▃▁▃▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▂▃▁▁▁▁▃▂▁▂▁▁▂▁▁

0,1
epoch,16.0
train_accuracy,0.43107
train_loss,1.06698
validation_accuracy,0.43168
validation_loss,1.07439


[I 2025-05-02 07:24:04,792] Trial 1 finished with value: 1.074192376363845 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 1 with value: 1.074192376363845.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-04, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0759 acc=0.4198 | val_loss=1.0793 acc=0.4317 | time=17.2s
Epoch 002 | train_loss=1.0741 acc=0.4276 | val_loss=1.0761 acc=0.4317 | time=17.3s
Epoch 003 | train_loss=1.0706 acc=0.4299 | val_loss=1.0745 acc=0.4317 | time=17.0s
Epoch 004 | train_loss=1.0700 acc=0.4315 | val_loss=1.0754 acc=0.4317 | time=17.1s
Epoch 005 | train_loss=1.0688 acc=0.4280 | val_loss=1.0751 acc=0.4317 | time=17.1s
Epoch 006 | train_loss=1.0681 acc=0.4268 | val_loss=1.0752 acc=0.4317 | time=16.9s
Epoch 007 | train_loss=1.0683 acc=0.4315 | val_loss=1.0754 acc=0.4317 | time=17.1s
Epoch 008 | train_loss=1.0698 acc=0.4307 | val_loss=1.0744 acc=0.4317 | time=17.3s
Epoch 009 | train_loss=1.0683 acc=0.4311 | val_loss=1.0741 acc=0.4317 | time=16.9s
Epoch 010 | train_loss=1.0654 acc=0.4353 | val_loss=1.0624 acc=0.4317 | time=17.2s
Epoch 011 | train_loss=1.0033 acc=0.5565 | val_loss=0.9699 acc=0.5668 | time=17.2s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
train_accuracy,▁▁▁▁▁▁▁▁▁▁▄▄▅▅▅▅▆▆▇▇▇▇▇████
train_loss,██████████▇▆▆▆▅▅▄▄▃▃▃▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▆▅▇▆▇▇▇████████▇█
validation_loss,█████████▇▅▄▃▃▂▂▁▃▃▁▁▂▃▂▂▂▄

0,1
epoch,27.0
train_accuracy,0.74796
train_loss,0.59003
validation_accuracy,0.61025
validation_loss,0.94048


[I 2025-05-02 07:31:48,919] Trial 2 finished with value: 0.8549597462018331 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 2 with value: 0.8549597462018331.



===== Best Trial =====
best_val_loss       = 0.854960
best_train_loss     = 0.804907
best_train_accuracy = 0.6517
best_val_accuracy   = 0.6025
best params:
  lr: 0.0005
  weight_decay: 0.0005
  num_blocks: 1
  num_heads: 2
  num_segments: 5


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from models import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-3]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3, 4]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-9",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-9",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-9.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


[I 2025-05-02 07:46:37,065] A new study created in RDB with name: eeg_holdout_grid_search-9



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=4, segs=5
Epoch 001 | train_loss=1.0747 acc=0.4252 | val_loss=1.0825 acc=0.4317 | time=21.8s
Epoch 002 | train_loss=1.0718 acc=0.4283 | val_loss=1.0744 acc=0.4317 | time=21.7s
Epoch 003 | train_loss=1.0690 acc=0.4353 | val_loss=1.0752 acc=0.4317 | time=21.8s
Epoch 004 | train_loss=1.0689 acc=0.4350 | val_loss=1.0756 acc=0.4317 | time=21.7s
Epoch 005 | train_loss=1.0711 acc=0.4338 | val_loss=1.0742 acc=0.4317 | time=21.7s
Epoch 006 | train_loss=1.0685 acc=0.4280 | val_loss=1.0744 acc=0.4317 | time=21.6s
Epoch 007 | train_loss=1.0698 acc=0.4299 | val_loss=1.0746 acc=0.4317 | time=21.7s
Epoch 008 | train_loss=1.0661 acc=0.4315 | val_loss=1.0746 acc=0.4317 | time=21.7s
Epoch 009 | train_loss=1.0704 acc=0.4315 | val_loss=1.0743 acc=0.4317 | time=21.6s
Epoch 010 | train_loss=1.0673 acc=0.4311 | val_loss=1.0748 acc=0.4317 | time=21.7s
Epoch 011 | train_loss=1.0680 acc=0.4307 | val_loss=1.0742 acc=0.4317 | time=22.0s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▃▃▄▄▄▄▅▅▅▆▆▆▆▆▇▇███
train_loss,█████████████████▇▇▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▃▂▂▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▄▄▅▄▅▅▅▆▆▆▆▆▇▇▇█▇▇██▇█
validation_loss,▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▄▄▄▄▄▃▃▃▃▃▂▃▃▅▃▂▁▃▁▅▂█▂▄▇

0,1
epoch,63.0
train_accuracy,0.91495
train_loss,0.23775
validation_accuracy,0.71739
validation_loss,1.09627


[I 2025-05-02 08:09:26,020] Trial 0 finished with value: 0.7734349171320597 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 4, 'num_segments': 5}. Best is trial 0 with value: 0.7734349171320597.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.0745 acc=0.4249 | val_loss=1.0752 acc=0.4317 | time=18.8s
Epoch 002 | train_loss=1.0752 acc=0.4214 | val_loss=1.0750 acc=0.4317 | time=18.5s
Epoch 003 | train_loss=1.0721 acc=0.4283 | val_loss=1.0749 acc=0.4317 | time=18.7s
Epoch 004 | train_loss=1.0711 acc=0.4249 | val_loss=1.0751 acc=0.4317 | time=18.6s
Epoch 005 | train_loss=1.0708 acc=0.4311 | val_loss=1.0750 acc=0.4317 | time=18.8s
Epoch 006 | train_loss=1.0687 acc=0.4330 | val_loss=1.0757 acc=0.4317 | time=18.6s
Epoch 007 | train_loss=1.0699 acc=0.4268 | val_loss=1.0743 acc=0.4317 | time=18.7s
Epoch 008 | train_loss=1.0691 acc=0.4256 | val_loss=1.0781 acc=0.4317 | time=18.6s
Epoch 009 | train_loss=1.0703 acc=0.4303 | val_loss=1.0740 acc=0.4317 | time=18.6s
Epoch 010 | train_loss=1.0674 acc=0.4334 | val_loss=1.0746 acc=0.4317 | time=18.8s
Epoch 011 | train_loss=1.0630 acc=0.4365 | val_loss=1.0561 acc=0.5543 | time=18.7s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▁▁▁▁▁▁▁▁▁▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇████
train_loss,███████████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▅▅▄▅▆▆▆▆▅▆▇▆▇▇██▇▇█▇█▇▇▇▇██▇
validation_loss,▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▂▂▁▂▁▄▆▃▄▃▄▅▆▆█

0,1
epoch,38.0
train_accuracy,0.93864
train_loss,0.16386
validation_accuracy,0.64907
validation_loss,1.80009


[I 2025-05-02 08:21:18,935] Trial 1 finished with value: 0.7810127351965223 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.7734349171320597.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.0776 acc=0.3895 | val_loss=1.0771 acc=0.4317 | time=17.3s
Epoch 002 | train_loss=1.0716 acc=0.4260 | val_loss=1.0753 acc=0.4317 | time=17.4s
Epoch 003 | train_loss=1.0715 acc=0.4241 | val_loss=1.0757 acc=0.4317 | time=17.3s
Epoch 004 | train_loss=1.0728 acc=0.4190 | val_loss=1.0754 acc=0.4317 | time=17.1s
Epoch 005 | train_loss=1.0718 acc=0.4225 | val_loss=1.0742 acc=0.4317 | time=17.0s
Epoch 006 | train_loss=1.0711 acc=0.4233 | val_loss=1.0742 acc=0.4317 | time=17.5s
Epoch 007 | train_loss=1.0675 acc=0.4307 | val_loss=1.0739 acc=0.4317 | time=17.2s
Epoch 008 | train_loss=1.0704 acc=0.4225 | val_loss=1.0780 acc=0.4317 | time=17.1s
Epoch 009 | train_loss=1.0705 acc=0.4276 | val_loss=1.0719 acc=0.4317 | time=17.2s
Epoch 010 | train_loss=1.0512 acc=0.4505 | val_loss=1.0504 acc=0.4984 | time=17.1s
Epoch 011 | train_loss=0.9871 acc=0.5534 | val_loss=0.9778 acc=0.5683 | time=17.0s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train_accuracy,▁▂▂▁▂▂▂▂▂▂▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇████
train_loss,██████████▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▃▅▅▅▆▅▆▆▆▆▆▆▇▇▇▇▇█▇██
validation_loss,█████████▇▄▄▄▃▂▃▂▂▁▁▂▁▄▂▂▁▆▄▃▄

0,1
epoch,30.0
train_accuracy,0.79728
train_loss,0.44936
validation_accuracy,0.67547
validation_loss,0.96939


[I 2025-05-02 08:30:01,256] Trial 2 finished with value: 0.8837438935325259 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 0 with value: 0.7734349171320597.



===== Best Trial =====
best_val_loss       = 0.773435
best_train_loss     = 0.505497
best_train_accuracy = 0.7852
best_val_accuracy   = 0.6941
best params:
  lr: 0.0005
  weight_decay: 0.005
  num_blocks: 1
  num_heads: 4
  num_segments: 5


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from model_optimized import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-2, 5e-3]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-10",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-10",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-10.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


[I 2025-05-02 09:27:06,085] A new study created in RDB with name: eeg_holdout_grid_search-10



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.1809 acc=0.3662 | val_loss=1.0806 acc=0.4317 | time=19.8s
Epoch 002 | train_loss=1.1256 acc=0.3895 | val_loss=1.0773 acc=0.4317 | time=19.4s
Epoch 003 | train_loss=1.1110 acc=0.3880 | val_loss=1.0810 acc=0.4317 | time=19.7s
Epoch 004 | train_loss=1.0968 acc=0.4062 | val_loss=1.0821 acc=0.4317 | time=19.5s
Epoch 005 | train_loss=1.0939 acc=0.3973 | val_loss=1.0792 acc=0.4317 | time=19.7s
Epoch 006 | train_loss=1.0886 acc=0.4058 | val_loss=1.0796 acc=0.4317 | time=19.6s
Epoch 007 | train_loss=1.0853 acc=0.4066 | val_loss=1.0749 acc=0.4317 | time=19.6s
Epoch 008 | train_loss=1.0801 acc=0.4093 | val_loss=1.0729 acc=0.4317 | time=19.6s
Epoch 009 | train_loss=1.0767 acc=0.4078 | val_loss=1.0753 acc=0.4317 | time=19.5s
Epoch 010 | train_loss=1.0774 acc=0.4082 | val_loss=1.0632 acc=0.4317 | time=19.5s
Epoch 011 | train_loss=1.0674 acc=0.4299 | val_loss=1.0396 acc=0.5450 | time=19.6s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▁▂▂▂▂▂▂▂▂▃▅▅▅▆▅▆▆▆▇▆▇▆▇▇▇▇█▇███
train_loss,█▇▇▇▇▇▇▇▇▇▇▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▁▂▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▄▂▃▅▆▆▇▇▇▇██▇▇▆▇▇▇▆▆▆▇▆
validation_loss,██████████▇▇▆▄▃▃▃▃▂▂▂▂▁▃▅▂▃▃▇▃▅▄▃

0,1
epoch,33.0
train_accuracy,0.7134
train_loss,0.6239
validation_accuracy,0.64596
validation_loss,0.78822


[I 2025-05-02 09:37:55,864] Trial 0 finished with value: 0.672920666989826 and parameters: {'lr': 0.0005, 'weight_decay': 0.05, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.672920666989826.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.1699 acc=0.3674 | val_loss=1.0817 acc=0.4317 | time=19.7s
Epoch 002 | train_loss=1.1243 acc=0.3860 | val_loss=1.0750 acc=0.4317 | time=19.7s
Epoch 003 | train_loss=1.1024 acc=0.3868 | val_loss=1.0765 acc=0.4317 | time=19.7s
Epoch 004 | train_loss=1.0980 acc=0.3872 | val_loss=1.0717 acc=0.4317 | time=19.8s
Epoch 005 | train_loss=1.0787 acc=0.4179 | val_loss=1.0380 acc=0.5217 | time=19.7s
Epoch 006 | train_loss=1.0567 acc=0.4489 | val_loss=0.9920 acc=0.5870 | time=19.8s
Epoch 007 | train_loss=1.0064 acc=0.5126 | val_loss=0.9131 acc=0.6025 | time=19.6s
Epoch 008 | train_loss=0.9598 acc=0.5546 | val_loss=0.8720 acc=0.6382 | time=19.6s
Epoch 009 | train_loss=0.9216 acc=0.5732 | val_loss=0.8418 acc=0.6708 | time=19.6s
Epoch 010 | train_loss=0.9104 acc=0.5930 | val_loss=0.8247 acc=0.6661 | time=19.7s
Epoch 011 | train_loss=0.8814 acc=0.5860 | val_loss=0.7986 acc=0.6957 | time=19.5s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▁▁▂▂▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████
train_loss,█▇▇▇▇▇▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▁▁▁▁▃▅▅▆▇▆▇███▆▇▇▇▇▇▆▆▆▇▆▇▇▇▇▆▇▇▇█▇▇▇
validation_loss,████▇▆▅▄▄▄▃▃▂▂▃▂▃▂▂▂▃▃▂▁▂▂▁▂▂▃▄▂▃▂▂▃▃

0,1
epoch,37.0
train_accuracy,0.74214
train_loss,0.5626
validation_accuracy,0.70807
validation_loss,0.76616


[I 2025-05-02 09:50:02,248] Trial 1 finished with value: 0.6642108644757952 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 1 with value: 0.6642108644757952.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.1348 acc=0.3786 | val_loss=1.0839 acc=0.4317 | time=17.5s
Epoch 002 | train_loss=1.1228 acc=0.4039 | val_loss=1.0790 acc=0.4317 | time=17.6s
Epoch 003 | train_loss=1.1083 acc=0.3996 | val_loss=1.0784 acc=0.4317 | time=18.0s
Epoch 004 | train_loss=1.1000 acc=0.3918 | val_loss=1.0705 acc=0.5233 | time=17.7s
Epoch 005 | train_loss=1.0829 acc=0.4190 | val_loss=1.0476 acc=0.4876 | time=17.2s
Epoch 006 | train_loss=1.0358 acc=0.4812 | val_loss=0.9707 acc=0.5807 | time=17.3s
Epoch 007 | train_loss=1.0025 acc=0.5200 | val_loss=0.8998 acc=0.6180 | time=17.3s
Epoch 008 | train_loss=0.9618 acc=0.5612 | val_loss=0.9030 acc=0.6149 | time=17.2s
Epoch 009 | train_loss=0.9150 acc=0.5670 | val_loss=0.8771 acc=0.6227 | time=17.4s
Epoch 010 | train_loss=0.9103 acc=0.5709 | val_loss=0.8627 acc=0.6102 | time=17.5s
Epoch 011 | train_loss=0.8859 acc=0.5880 | val_loss=0.8849 acc=0.5761 | time=17.3s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_accuracy,▁▂▁▁▂▃▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
train_loss,███▇▇▇▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁
validation_accuracy,▁▁▁▃▂▅▆▆▆▆▅▇▆▆▇█▆▇▇▆▇▄▇▇▆▆
validation_loss,████▇▆▅▅▄▄▄▃▃▃▂▁▃▂▂▇▃▆▅▅▅▆

0,1
epoch,26.0
train_accuracy,0.70058
train_loss,0.65527
validation_accuracy,0.64441
validation_loss,0.97048


[I 2025-05-02 09:57:36,608] Trial 2 finished with value: 0.7074131766955057 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 1 with value: 0.6642108644757952.



===== Trial 3 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.1706 acc=0.3825 | val_loss=1.0724 acc=0.4317 | time=17.3s
Epoch 002 | train_loss=1.1569 acc=0.3581 | val_loss=1.0840 acc=0.4767 | time=17.5s
Epoch 003 | train_loss=1.1227 acc=0.3977 | val_loss=1.0793 acc=0.4317 | time=17.3s
Epoch 004 | train_loss=1.1076 acc=0.3872 | val_loss=1.0732 acc=0.4317 | time=17.4s
Epoch 005 | train_loss=1.0906 acc=0.3977 | val_loss=1.0726 acc=0.4317 | time=17.3s
Epoch 006 | train_loss=1.0905 acc=0.3984 | val_loss=1.0611 acc=0.4581 | time=17.5s
Epoch 007 | train_loss=1.0593 acc=0.4505 | val_loss=1.0015 acc=0.5870 | time=17.0s
Epoch 008 | train_loss=0.9978 acc=0.5227 | val_loss=0.9774 acc=0.6009 | time=17.7s
Epoch 009 | train_loss=0.9568 acc=0.5550 | val_loss=0.8785 acc=0.6289 | time=17.4s
Epoch 010 | train_loss=0.9202 acc=0.5685 | val_loss=0.8944 acc=0.6382 | time=17.2s
Epoch 011 | train_loss=0.9090 acc=0.5724 | val_loss=0.8257 acc=0.6366 | time=17.1s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▂▂▂▂▃▄▅▅▅▅▅▅▅▆▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█▇██
train_loss,██▇▇▇▇▇▆▆▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁
validation_accuracy,▁▂▁▁▁▂▅▅▆▆▆▆▄▆▇▇▇▅▇▇█▇▆█▇██▇▇▄█▆▆▆▅▇▇
validation_loss,▅▅▅▅▅▅▄▄▃▃▂▂▄▃▂▁▁▃▁▂▁▁▄▂▂▃▁▃▅▄▅▅▇▅█▅▅

0,1
epoch,37.0
train_accuracy,0.73398
train_loss,0.5715
validation_accuracy,0.66615
validation_loss,1.01655


[I 2025-05-02 10:08:24,202] Trial 3 finished with value: 0.7243047186306545 and parameters: {'lr': 0.0005, 'weight_decay': 0.05, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 1 with value: 0.6642108644757952.



===== Best Trial =====
best_val_loss       = 0.664211
best_train_loss     = 0.674144
best_train_accuracy = 0.6862
best_val_accuracy   = 0.7003
best params:
  lr: 0.0005
  weight_decay: 0.005
  num_blocks: 1
  num_heads: 3
  num_segments: 5


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from model_optimized_2 import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-2, 5e-3]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-11",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-11",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-11.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 10:54:18,347] Using an existing study with name 'eeg_holdout_grid_search-11' instead of creating a new one.
[W 2025-05-02 10:54:18,405] `GridSampler` is re-evaluating a configuration because the grid has been exhausted. This may happen due to a timing issue during distributed optimization or when re-running optimizations on already finished studies.



===== Trial 4 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.1692 acc=0.3297 | val_loss=1.0784 acc=0.4317 | time=17.3s
Epoch 002 | train_loss=1.1136 acc=0.3748 | val_loss=1.0748 acc=0.4317 | time=17.2s
Epoch 003 | train_loss=1.0828 acc=0.4202 | val_loss=1.0749 acc=0.4317 | time=17.0s
Epoch 004 | train_loss=1.0910 acc=0.4016 | val_loss=1.0742 acc=0.4317 | time=17.4s
Epoch 005 | train_loss=1.0854 acc=0.4155 | val_loss=1.0745 acc=0.4317 | time=17.6s
Epoch 006 | train_loss=1.0799 acc=0.4120 | val_loss=1.0742 acc=0.4317 | time=17.0s
Epoch 007 | train_loss=1.0787 acc=0.4117 | val_loss=1.0740 acc=0.4317 | time=17.1s
Epoch 008 | train_loss=1.0791 acc=0.4163 | val_loss=1.0740 acc=0.4317 | time=17.1s
Epoch 009 | train_loss=1.0709 acc=0.4186 | val_loss=1.0744 acc=0.4317 | time=17.0s
Epoch 010 | train_loss=1.0924 acc=0.3973 | val_loss=1.0781 acc=0.4317 | time=17.0s
Epoch 011 | train_loss=1.0766 acc=0.4183 | val_loss=1.0779 acc=0.4317 | time=16.8s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train_accuracy,▁▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▄▅▆▆▆▇▇▇▇▇▇▇▇▇██████
train_loss,█▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▆▆▆▆▆▆▅▅▅▄▄▃▃▃▃▃▂▃▂▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▆▆▇▇▇▇██▆▄▇▆█▆▇▇▇█▇▇
validation_loss,███████████████████▇▇▆▅▅▄▄▆▄▂▁▅▃▂▁▃▂▂▃▃▂

0,1
epoch,57.0
train_accuracy,0.63456
train_loss,0.70932
validation_accuracy,0.64441
validation_loss,0.82578


[I 2025-05-02 11:10:38,217] Trial 4 finished with value: 0.7751180188996452 and parameters: {'lr': 0.0005, 'weight_decay': 0.05, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 4 with value: 0.7751180188996452.



===== Best Trial =====
best_val_loss       = 0.775118
best_train_loss     = 0.795321
best_train_accuracy = 0.6093
best_val_accuracy   = 0.6739
best params:
  lr: 0.0005
  weight_decay: 0.05
  num_blocks: 1
  num_heads: 2
  num_segments: 5


In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from model_optimized_2 import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-2, 5e-3]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [1]
NUM_HEAD_CHOICES    = [2, 3, 4]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-16",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    torch.cuda.empty_cache()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-16",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-16.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json
Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 11:56:30,533] A new study created in RDB with name: eeg_holdout_grid_search-16



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=4, segs=5
Epoch 001 | train_loss=1.1536 acc=0.3783 | val_loss=1.0772 acc=0.4317 | time=242.5s
Epoch 002 | train_loss=1.1305 acc=0.3829 | val_loss=1.0801 acc=0.4317 | time=21.3s
Epoch 003 | train_loss=1.1068 acc=0.3899 | val_loss=1.0782 acc=0.4317 | time=21.3s
Epoch 004 | train_loss=1.0975 acc=0.4004 | val_loss=1.0767 acc=0.4317 | time=21.1s
Epoch 005 | train_loss=1.0883 acc=0.4082 | val_loss=1.0784 acc=0.4317 | time=21.1s
Epoch 006 | train_loss=1.0875 acc=0.4093 | val_loss=1.0786 acc=0.4317 | time=21.3s
Epoch 007 | train_loss=1.0885 acc=0.4132 | val_loss=1.0779 acc=0.4317 | time=21.1s
Epoch 008 | train_loss=1.0757 acc=0.4167 | val_loss=1.0757 acc=0.4317 | time=21.2s
Epoch 009 | train_loss=1.0830 acc=0.4132 | val_loss=1.0702 acc=0.4317 | time=21.1s
Epoch 010 | train_loss=1.0785 acc=0.4237 | val_loss=1.0759 acc=0.4317 | time=21.1s
Epoch 011 | train_loss=1.0736 acc=0.4210 | val_loss=1.0739 acc=0.4317 | time=21.2s
Epoch 012 | 

0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train_accuracy,▁▂▃▄▅▅▆▆▆▇▇▆▇▇▇█▇▇█
train_loss,█▆▄▃▃▃▃▂▂▂▁▂▁▁▂▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▆█▇▆▇▇▆▅▁▅▄▃▄▄▄▃▄▃▂

0,1
epoch,19.0
train_accuracy,0.42913
train_loss,1.0678
validation_accuracy,0.43168
validation_loss,1.07131


[I 2025-05-02 12:07:15,745] Trial 0 finished with value: 1.0702224600882757 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 4, 'num_segments': 5}. Best is trial 0 with value: 1.0702224600882757.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.1664 acc=0.3903 | val_loss=1.0771 acc=0.4317 | time=18.2s
Epoch 002 | train_loss=1.1262 acc=0.3895 | val_loss=1.0758 acc=0.4317 | time=18.1s
Epoch 003 | train_loss=1.1338 acc=0.3860 | val_loss=1.0729 acc=0.4317 | time=18.0s
Epoch 004 | train_loss=1.1194 acc=0.3837 | val_loss=1.0957 acc=0.4317 | time=18.3s
Epoch 005 | train_loss=1.1117 acc=0.3942 | val_loss=1.0761 acc=0.4317 | time=18.0s
Epoch 006 | train_loss=1.0970 acc=0.4136 | val_loss=1.0777 acc=0.4317 | time=18.1s
Epoch 007 | train_loss=1.0934 acc=0.3950 | val_loss=1.0851 acc=0.4317 | time=18.0s
Epoch 008 | train_loss=1.0803 acc=0.4245 | val_loss=1.0821 acc=0.4317 | time=18.1s
Epoch 009 | train_loss=1.0826 acc=0.4120 | val_loss=1.0854 acc=0.4317 | time=18.3s
Epoch 010 | train_loss=1.0828 acc=0.4070 | val_loss=1.0827 acc=0.4317 | time=18.1s
Epoch 011 | train_loss=1.0770 acc=0.4206 | val_loss=1.0773 acc=0.4317 | time=18.2s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_accuracy,▁▁▁▁▂▂▂▂▂▃▃▄▄▄▅▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇█████
train_loss,█▇▇▇▇▇▇▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▄▅▅▆▆▆▅▅▆▇▇▇▇▅▇▆▇▇▆▇▆▆▇▆▇▇▆██▇█
validation_loss,█████████▇▆▆▅▆▆▅▅▅▃▃▃▂▂▄▂▂▂▁▂▂▂▁▂▁▁▂▁▁▂▁

0,1
epoch,50.0
train_accuracy,0.65126
train_loss,0.71699
validation_accuracy,0.68012
validation_loss,0.7604


[I 2025-05-02 12:22:25,554] Trial 1 finished with value: 0.7499674530256362 and parameters: {'lr': 0.0005, 'weight_decay': 0.05, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 1 with value: 0.7499674530256362.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.2016 acc=0.3250 | val_loss=1.0832 acc=0.3416 | time=17.0s
Epoch 002 | train_loss=1.1391 acc=0.3740 | val_loss=1.0831 acc=0.4317 | time=17.1s
Epoch 003 | train_loss=1.1150 acc=0.3903 | val_loss=1.0805 acc=0.3416 | time=17.0s
Epoch 004 | train_loss=1.1150 acc=0.3841 | val_loss=1.0790 acc=0.4317 | time=16.7s
Epoch 005 | train_loss=1.1010 acc=0.3992 | val_loss=1.0783 acc=0.4317 | time=16.9s
Epoch 006 | train_loss=1.0900 acc=0.3992 | val_loss=1.0753 acc=0.4317 | time=16.9s
Epoch 007 | train_loss=1.0879 acc=0.4171 | val_loss=1.0770 acc=0.4317 | time=17.3s
Epoch 008 | train_loss=1.0843 acc=0.4058 | val_loss=1.0749 acc=0.4317 | time=17.0s
Epoch 009 | train_loss=1.0745 acc=0.4295 | val_loss=1.0738 acc=0.4317 | time=17.2s
Epoch 010 | train_loss=1.0767 acc=0.4128 | val_loss=1.0728 acc=0.4317 | time=17.1s
Epoch 011 | train_loss=1.0813 acc=0.4144 | val_loss=1.0702 acc=0.4317 | time=17.0s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
train_accuracy,▁▂▂▂▃▃▃▃▃▃▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,█▇▇▇▇▆▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▁▃▁▃▃▃▃▃▃▃▄▅▇▇▇█▇█▇███▆█▇█▇▇█▇████▇▆▇██▇
validation_loss,███████████▆▆▅▄▄▄▂▃▂▂▂▄▁▃▁▃▂▁▂▂▁▁▁▂▃▄▁▁▃

0,1
epoch,46.0
train_accuracy,0.64272
train_loss,0.70944
validation_accuracy,0.56832
validation_loss,0.90963


[I 2025-05-02 12:35:30,751] Trial 2 finished with value: 0.7560889039720807 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 1 with value: 0.7499674530256362.



===== Trial 3 =====
 lr=5.00e-04, wd=5.00e-03, blocks=1, heads=3, segs=5
Epoch 001 | train_loss=1.1793 acc=0.3600 | val_loss=1.0827 acc=0.3416 | time=18.1s
Epoch 002 | train_loss=1.1415 acc=0.3767 | val_loss=1.0948 acc=0.3416 | time=18.1s
Epoch 003 | train_loss=1.1208 acc=0.3751 | val_loss=1.0772 acc=0.3416 | time=18.1s
Epoch 004 | train_loss=1.1144 acc=0.3837 | val_loss=1.0887 acc=0.3416 | time=18.2s
Epoch 005 | train_loss=1.1074 acc=0.3938 | val_loss=1.0837 acc=0.3416 | time=18.2s
Epoch 006 | train_loss=1.0895 acc=0.4148 | val_loss=1.0789 acc=0.4317 | time=18.3s
Epoch 007 | train_loss=1.0934 acc=0.4117 | val_loss=1.0722 acc=0.4332 | time=18.2s
Epoch 008 | train_loss=1.0885 acc=0.4070 | val_loss=1.0745 acc=0.4317 | time=18.0s
Epoch 009 | train_loss=1.0720 acc=0.4194 | val_loss=1.0560 acc=0.5202 | time=18.3s
Epoch 010 | train_loss=1.0823 acc=0.4221 | val_loss=1.0732 acc=0.4053 | time=18.1s
Epoch 011 | train_loss=1.0627 acc=0.4482 | val_loss=1.0232 acc=0.5621 | time=18.3s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▁▂▂▂▂▂▂▃▄▅▅▅▅▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇█████
train_loss,██▇▇▇▇▇▇▇▆▆▅▅▅▅▅▄▄▄▄▄▃▄▃▃▂▃▂▂▂▂▂▂▂▂▁▂▁▁▁
validation_accuracy,▁▁▁▁▁▃▃▃▂▅▆▆▆▇▇▇▆▆▇▆█▇██▇▇█▇▇▇▆██▆▄▇▅▇▇▇
validation_loss,▅▆▅▅▅▅▅▅▅▅▅▄▄▃▃▃▃▄▂▃▂▂▁▂▂▂▁▃▂▂▄▁▂▅▆▂█▃▃▃

0,1
epoch,45.0
train_accuracy,0.68621
train_loss,0.64485
validation_accuracy,0.64286
validation_loss,0.86116


[I 2025-05-02 12:49:13,239] Trial 3 finished with value: 0.7206419763110933 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 1, 'num_heads': 3, 'num_segments': 5}. Best is trial 3 with value: 0.7206419763110933.



===== Trial 4 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=2, segs=5
Epoch 001 | train_loss=1.1965 acc=0.3169 | val_loss=1.0846 acc=0.4317 | time=17.4s
Epoch 002 | train_loss=1.1646 acc=0.3196 | val_loss=1.0996 acc=0.2236 | time=17.7s
Epoch 003 | train_loss=1.1237 acc=0.3565 | val_loss=1.0868 acc=0.4317 | time=17.7s
Epoch 004 | train_loss=1.1024 acc=0.3833 | val_loss=1.0875 acc=0.4317 | time=17.4s
Epoch 005 | train_loss=1.0874 acc=0.4074 | val_loss=1.0760 acc=0.4317 | time=17.4s
Epoch 006 | train_loss=1.0795 acc=0.4186 | val_loss=1.0225 acc=0.5730 | time=17.6s
Epoch 007 | train_loss=1.0320 acc=0.4885 | val_loss=1.0213 acc=0.6025 | time=17.4s
Epoch 008 | train_loss=1.0016 acc=0.5231 | val_loss=0.9806 acc=0.6227 | time=17.4s
Epoch 009 | train_loss=0.9736 acc=0.5503 | val_loss=0.9631 acc=0.6196 | time=17.1s
Epoch 010 | train_loss=0.9499 acc=0.5631 | val_loss=0.8848 acc=0.6211 | time=17.6s
Epoch 011 | train_loss=0.9267 acc=0.5845 | val_loss=0.8956 acc=0.6273 | time=17.6s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_accuracy,▁▁▂▂▃▃▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇████
train_loss,██▇▇▇▇▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▄▁▄▄▄▆▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇██▇████▇█▇▇▇▇▇▇
validation_loss,█████▇▇▆▆▄▄▄▄▄▄▄▂▃▂▂▃▂▂▃▃▂▂▁▃▁▂▃▆▄▄▃▅▆▄▃

0,1
epoch,40.0
train_accuracy,0.68155
train_loss,0.65972
validation_accuracy,0.64441
validation_loss,0.83544


[I 2025-05-02 13:00:58,342] Trial 4 finished with value: 0.7121514422552926 and parameters: {'lr': 0.0005, 'weight_decay': 0.05, 'num_blocks': 1, 'num_heads': 2, 'num_segments': 5}. Best is trial 4 with value: 0.7121514422552926.



===== Trial 5 =====
 lr=5.00e-04, wd=5.00e-02, blocks=1, heads=4, segs=5
Epoch 001 | train_loss=1.1755 acc=0.3425 | val_loss=1.0770 acc=0.4317 | time=21.2s
Epoch 002 | train_loss=1.1183 acc=0.3918 | val_loss=1.0792 acc=0.4317 | time=21.3s
Epoch 003 | train_loss=1.1112 acc=0.3833 | val_loss=1.0755 acc=0.4317 | time=21.6s
Epoch 004 | train_loss=1.1088 acc=0.3779 | val_loss=1.0753 acc=0.4317 | time=21.2s
Epoch 005 | train_loss=1.0924 acc=0.3988 | val_loss=1.0726 acc=0.4317 | time=21.2s
Epoch 006 | train_loss=1.0923 acc=0.4113 | val_loss=1.0724 acc=0.4317 | time=21.4s
Epoch 007 | train_loss=1.0778 acc=0.4047 | val_loss=1.0761 acc=0.4317 | time=21.3s
Epoch 008 | train_loss=1.0834 acc=0.3911 | val_loss=1.0779 acc=0.4317 | time=21.3s


0,1
epoch,▁▂▃▄▅▆▇█
train_accuracy,▁▆▅▅▇█▇▆
train_loss,█▄▃▃▂▂▁▁
validation_accuracy,▁▁▁▁▁▁▁▁
validation_loss,▆█▄▄▁▁▅▇

0,1
epoch,8.0
train_accuracy,0.39107
train_loss,1.08343
validation_accuracy,0.43168
validation_loss,1.07786


[I 2025-05-02 13:04:11,959] Trial 5 pruned. 


▸ Trial 5 pruned at epoch 9

===== Best Trial =====
best_val_loss       = 0.712151
best_train_loss     = 0.745850
best_train_accuracy = 0.6233
best_val_accuracy   = 0.7034
best params:
  lr: 0.0005
  weight_decay: 0.05
  num_blocks: 1
  num_heads: 2
  num_segments: 5


In [None]:
torch.cuda.empty_cache()

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from model_optimized_3 import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-3, 5e-4]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [2]
NUM_HEAD_CHOICES    = [1, 2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 10
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-17",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    torch.cuda.empty_cache()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-17",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-17.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-02 13:28:55,853] A new study created in RDB with name: eeg_holdout_grid_search-17



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-04, blocks=2, heads=3, segs=5
Epoch 001 | train_loss=1.1312 acc=0.3802 | val_loss=1.0756 acc=0.4317 | time=29.7s
Epoch 002 | train_loss=1.1196 acc=0.3802 | val_loss=1.0803 acc=0.4317 | time=29.5s
Epoch 003 | train_loss=1.1044 acc=0.3946 | val_loss=1.0756 acc=0.4317 | time=29.2s
Epoch 004 | train_loss=1.0924 acc=0.3977 | val_loss=1.0836 acc=0.4317 | time=29.2s
Epoch 005 | train_loss=1.0939 acc=0.3860 | val_loss=1.0761 acc=0.4317 | time=29.5s
Epoch 006 | train_loss=1.0760 acc=0.4155 | val_loss=1.0739 acc=0.4317 | time=29.2s
Epoch 007 | train_loss=1.0944 acc=0.4082 | val_loss=1.0833 acc=0.4317 | time=29.3s
Epoch 008 | train_loss=1.0808 acc=0.4035 | val_loss=1.0827 acc=0.4317 | time=29.4s
Epoch 009 | train_loss=1.0862 acc=0.4101 | val_loss=1.0833 acc=0.4317 | time=29.2s
Epoch 010 | train_loss=1.0798 acc=0.4249 | val_loss=1.0863 acc=0.4317 | time=29.1s
Epoch 011 | train_loss=1.0787 acc=0.4074 | val_loss=1.0821 acc=0.4317 | time=29.2s
Epoch 012 | t

0,1
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
train_accuracy,▁▁▃▃▂▆▅▄▅▇▅█▇▇▇▇
train_loss,█▇▅▄▄▂▄▂▃▂▂▁▁▂▂▂
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▁▂▁▃▁▁▂▂▂▃▂█▂▃▂▁

0,1
epoch,16.0
train_accuracy,0.4268
train_loss,1.07471
validation_accuracy,0.43168
validation_loss,1.07606


[I 2025-05-02 13:36:48,896] Trial 0 finished with value: 1.0738752626237416 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 2, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 1.0738752626237416.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-03, blocks=2, heads=2, segs=5
Epoch 001 | train_loss=1.1864 acc=0.3511 | val_loss=1.0804 acc=0.4317 | time=23.7s
Epoch 002 | train_loss=1.1254 acc=0.3860 | val_loss=1.0826 acc=0.4317 | time=23.8s
Epoch 003 | train_loss=1.1194 acc=0.3950 | val_loss=1.0746 acc=0.4317 | time=23.9s
Epoch 004 | train_loss=1.1071 acc=0.4000 | val_loss=1.0723 acc=0.4317 | time=24.0s
Epoch 005 | train_loss=1.0927 acc=0.4109 | val_loss=1.0715 acc=0.4519 | time=23.9s
Epoch 006 | train_loss=1.0487 acc=0.4792 | val_loss=0.9982 acc=0.6025 | time=23.9s
Epoch 007 | train_loss=1.0306 acc=0.5056 | val_loss=0.9829 acc=0.5994 | time=23.8s
Epoch 008 | train_loss=0.9799 acc=0.5460 | val_loss=0.9550 acc=0.6087 | time=23.8s
Epoch 009 | train_loss=0.9331 acc=0.5837 | val_loss=0.9110 acc=0.6320 | time=23.8s
Epoch 010 | train_loss=0.9065 acc=0.5860 | val_loss=0.8887 acc=0.6335 | time=23.7s
Epoch 011 | train_loss=0.8735 acc=0.6019 | val_loss=0.8471 acc=0.6522 | time=23.8s
Epoch 012 | t

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
train_accuracy,▁▂▂▂▂▃▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
train_loss,█▇▇▇▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁▁
validation_accuracy,▁▁▁▁▂▅▅▅▆▆▇▃█▆▇▆█▆▇▇▇▇▇▇▇▆▆
validation_loss,█████▆▆▆▅▄▃▄▃▃▃▃▁▄▃▂▃▄▂▂▂▂▂

0,1
epoch,27.0
train_accuracy,0.74408
train_loss,0.5678
validation_accuracy,0.64286
validation_loss,0.78289


[I 2025-05-02 13:47:34,427] Trial 1 finished with value: 0.7162879818961734 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 2, 'num_heads': 2, 'num_segments': 5}. Best is trial 1 with value: 0.7162879818961734.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-04, blocks=2, heads=1, segs=5
Epoch 001 | train_loss=1.1393 acc=0.3588 | val_loss=1.0758 acc=0.4317 | time=18.1s
Epoch 002 | train_loss=1.1282 acc=0.3767 | val_loss=1.0789 acc=0.4317 | time=18.3s
Epoch 003 | train_loss=1.1026 acc=0.4019 | val_loss=1.0744 acc=0.4317 | time=18.3s
Epoch 004 | train_loss=1.1037 acc=0.3771 | val_loss=1.0746 acc=0.4317 | time=18.2s
Epoch 005 | train_loss=1.0851 acc=0.4117 | val_loss=1.0744 acc=0.4317 | time=18.1s
Epoch 006 | train_loss=1.0782 acc=0.4019 | val_loss=1.0839 acc=0.3416 | time=18.1s
Epoch 007 | train_loss=1.0734 acc=0.4427 | val_loss=1.0739 acc=0.3385 | time=18.3s
Epoch 008 | train_loss=1.0245 acc=0.5091 | val_loss=0.9524 acc=0.6009 | time=18.1s
Epoch 009 | train_loss=0.9687 acc=0.5515 | val_loss=0.9500 acc=0.6056 | time=18.2s
Epoch 010 | train_loss=0.9320 acc=0.5829 | val_loss=0.9923 acc=0.5186 | time=18.2s
Epoch 011 | train_loss=0.9243 acc=0.5860 | val_loss=0.8931 acc=0.6413 | time=18.0s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▂▁▂▂▂▃▄▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇███
train_loss,████▇▇▇▇▆▆▆▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁
validation_accuracy,▃▃▃▃▃▁▁▅▆▄▆▆▄▇▆▆▇▇▇▇▇▇█▇▆▆█▆▆▇▇▇▇
validation_loss,███████▆▆▇▅▄▆▃▄▃▃▃▃▃▂▂▁▂▄▅▁▄▃▂▂▂▂

0,1
epoch,33.0
train_accuracy,0.76699
train_loss,0.53228
validation_accuracy,0.70652
validation_loss,0.70533


[I 2025-05-02 13:57:37,244] Trial 2 finished with value: 0.6300574328218188 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 2, 'num_heads': 1, 'num_segments': 5}. Best is trial 2 with value: 0.6300574328218188.



===== Trial 3 =====
 lr=5.00e-04, wd=5.00e-04, blocks=2, heads=2, segs=5
Epoch 001 | train_loss=1.1651 acc=0.3417 | val_loss=1.0988 acc=0.4317 | time=23.9s
Epoch 002 | train_loss=1.1195 acc=0.3724 | val_loss=1.0813 acc=0.4317 | time=23.9s
Epoch 003 | train_loss=1.1053 acc=0.3926 | val_loss=1.0727 acc=0.4317 | time=23.9s
Epoch 004 | train_loss=1.0687 acc=0.4252 | val_loss=1.0897 acc=0.4317 | time=23.8s
Epoch 005 | train_loss=1.0437 acc=0.4800 | val_loss=0.9976 acc=0.5963 | time=23.8s
Epoch 006 | train_loss=0.9677 acc=0.5592 | val_loss=0.9756 acc=0.5730 | time=23.8s
Epoch 007 | train_loss=0.9268 acc=0.5829 | val_loss=0.8813 acc=0.6366 | time=23.9s
Epoch 008 | train_loss=0.9037 acc=0.5981 | val_loss=0.9349 acc=0.5730 | time=23.9s
Epoch 009 | train_loss=0.8949 acc=0.6008 | val_loss=0.8561 acc=0.6289 | time=23.8s
Epoch 010 | train_loss=0.8759 acc=0.6202 | val_loss=0.8534 acc=0.6087 | time=23.9s
Epoch 011 | train_loss=0.8398 acc=0.6229 | val_loss=0.8263 acc=0.6211 | time=23.8s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇██
train_accuracy,▁▂▂▂▃▅▅▅▅▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇███
train_loss,██▇▇▇▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▅▄▆▄▆▅▅▇▇▅▇█▆▇█▇▇▆▇▆▇▇▆▇▇
validation_loss,████▆▆▄▅▄▄▄▃▃▄▂▁▂▂▁▃▁▃▂▂▂▂▄▃▄

0,1
epoch,29.0
train_accuracy,0.76544
train_loss,0.52494
validation_accuracy,0.67391
validation_loss,0.82371


[I 2025-05-02 14:09:10,844] Trial 3 finished with value: 0.6644889655567351 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 2, 'num_heads': 2, 'num_segments': 5}. Best is trial 2 with value: 0.6300574328218188.



===== Trial 4 =====
 lr=5.00e-04, wd=5.00e-03, blocks=2, heads=1, segs=5
Epoch 001 | train_loss=1.1563 acc=0.3864 | val_loss=1.0747 acc=0.4317 | time=18.2s
Epoch 002 | train_loss=1.1275 acc=0.3868 | val_loss=1.0750 acc=0.4317 | time=18.2s
Epoch 003 | train_loss=1.0930 acc=0.4035 | val_loss=1.0744 acc=0.4317 | time=18.1s
Epoch 004 | train_loss=1.1034 acc=0.3915 | val_loss=1.0863 acc=0.4317 | time=18.3s
Epoch 005 | train_loss=1.0922 acc=0.4066 | val_loss=1.0807 acc=0.4317 | time=18.1s
Epoch 006 | train_loss=1.0965 acc=0.4078 | val_loss=1.0798 acc=0.4317 | time=18.3s
Epoch 007 | train_loss=1.0952 acc=0.4074 | val_loss=1.0850 acc=0.4317 | time=18.2s
Epoch 008 | train_loss=1.0870 acc=0.4039 | val_loss=1.0876 acc=0.4317 | time=18.0s
Epoch 009 | train_loss=1.0877 acc=0.4078 | val_loss=1.0774 acc=0.4317 | time=18.1s
Epoch 010 | train_loss=1.0931 acc=0.3829 | val_loss=1.0773 acc=0.4317 | time=18.0s
Epoch 011 | train_loss=1.0815 acc=0.4276 | val_loss=1.0759 acc=0.4317 | time=18.1s
Epoch 012 | t

0,1
epoch,▁▂▂▃▃▄▅▅▆▆▇▇█
train_accuracy,▂▂▄▂▅▅▅▄▅▁█▅▆
train_loss,█▅▂▃▂▃▂▂▂▂▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,▁▁▁▇▄▄▇█▃▃▂▁▁

0,1
epoch,13.0
train_accuracy,0.41437
train_loss,1.08113
validation_accuracy,0.43168
validation_loss,1.07513


[I 2025-05-02 14:13:09,071] Trial 4 finished with value: 1.0744036038716633 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 2, 'num_heads': 1, 'num_segments': 5}. Best is trial 2 with value: 0.6300574328218188.



===== Trial 5 =====
 lr=5.00e-04, wd=5.00e-03, blocks=2, heads=3, segs=5
Epoch 001 | train_loss=1.1470 acc=0.3720 | val_loss=1.0757 acc=0.4317 | time=29.3s
Epoch 002 | train_loss=1.1211 acc=0.3899 | val_loss=1.0831 acc=0.4596 | time=29.3s
Epoch 003 | train_loss=1.1192 acc=0.3852 | val_loss=1.0760 acc=0.4317 | time=29.3s
Epoch 004 | train_loss=1.0912 acc=0.3946 | val_loss=1.0752 acc=0.5326 | time=29.3s
Epoch 005 | train_loss=1.0688 acc=0.4326 | val_loss=1.0450 acc=0.5186 | time=29.2s
Epoch 006 | train_loss=1.0316 acc=0.4862 | val_loss=1.0473 acc=0.5171 | time=29.3s
Epoch 007 | train_loss=0.9751 acc=0.5282 | val_loss=0.9226 acc=0.6242 | time=29.5s
Epoch 008 | train_loss=0.9431 acc=0.5635 | val_loss=0.8909 acc=0.6273 | time=29.2s
Epoch 009 | train_loss=0.8889 acc=0.5934 | val_loss=0.8421 acc=0.6304 | time=29.2s
Epoch 010 | train_loss=0.8760 acc=0.6124 | val_loss=0.7831 acc=0.6599 | time=29.4s
Epoch 011 | train_loss=0.8420 acc=0.6163 | val_loss=0.8462 acc=0.6599 | time=29.2s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
train_accuracy,▁▁▁▁▂▃▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█████
train_loss,███▇▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▁▁▁▁
validation_accuracy,▁▂▁▃▃▃▅▆▆▆▆▆▆▆▇▇█▇▄▇█▇▅█▇█▆█▇▇▇
validation_loss,████▇▇▆▅▄▃▄▃▃▄▃▂▂▂▄▃▁▂▃▁▂▂▃▂▂▄▃

0,1
epoch,31.0
train_accuracy,0.78485
train_loss,0.4959
validation_accuracy,0.68478
validation_loss,0.77909


[I 2025-05-02 14:28:20,094] Trial 5 finished with value: 0.6330114475318364 and parameters: {'lr': 0.0005, 'weight_decay': 0.005, 'num_blocks': 2, 'num_heads': 3, 'num_segments': 5}. Best is trial 2 with value: 0.6300574328218188.



===== Best Trial =====
best_val_loss       = 0.630057
best_train_loss     = 0.649820
best_train_accuracy = 0.7087
best_val_accuracy   = 0.7376
best params:
  lr: 0.0005
  weight_decay: 0.0005
  num_blocks: 2
  num_heads: 1
  num_segments: 5


### Blocks = 2 -> 3 으로 표현력 늘려보기
- Random Seed 바꿔보기
- 추후에 여기서 잘 나온 거 기반으로 5 Fold Cross Validation 돌려보기

In [None]:
import os
import json
import time
import gc
import multiprocessing

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import optuna
import wandb

from eeg_dataset import EEGDataset
from model_optimized_3 import EEGformer

# ─── Grid Search 후보값 ────────────────────────────────────────
LR_CHOICES          = [5e-4]
WD_CHOICES          = [5e-3, 5e-4, 5e-6]
NUM_FILTERS         = 120
NUM_BLOCK_CHOICES   = [3]
NUM_HEAD_CHOICES    = [1, 2, 3]
SEGMENT_CHOICES     = [5]

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS  = 100
PATIENCE    = 15
BATCH_SIZE  = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler 하이퍼파라미터 ─────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA     = 0.5


def objective_holdout(trial):
    # ─── 1) Grid용 sampling ───────────────────────────────────
    lr           = trial.suggest_categorical("lr", LR_CHOICES)
    weight_decay = trial.suggest_categorical("weight_decay", WD_CHOICES)
    num_blocks   = trial.suggest_categorical("num_blocks", NUM_BLOCK_CHOICES)
    num_heads    = trial.suggest_categorical("num_heads", NUM_HEAD_CHOICES)
    num_segments = trial.suggest_categorical("num_segments", SEGMENT_CHOICES)

    # ─── 2) 데이터 로드 ────────────────────────────────────────
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds    = EEGDataset(DATA_DIR, train_meta)
    labels     = [d["label"] for d in train_meta]
    n_samples  = len(full_ds)
    input_length = full_ds[0][0].shape[-1]

    # ─── 3) Hold-out split ─────────────────────────────────────
    train_idx, val_idx = train_test_split(
        list(range(n_samples)),
        test_size=0.2,
        stratify=labels,
        random_state=42
    )
    train_loader = DataLoader(
        Subset(full_ds, train_idx),
        batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS
    )
    val_loader = DataLoader(
        Subset(full_ds, val_idx),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
    )

    # ─── 4) W&B init ───────────────────────────────────────────
    wandb.init(
        project="eeg-holdout-grid-search-18",
        config={
            "lr": lr,
            "weight_decay": weight_decay,
            "num_blocks": num_blocks,
            "num_heads": num_heads,
            "num_segments": num_segments
        }
    )

    print(f"\n===== Trial {trial.number} =====")
    print(
        f" lr={lr:.2e}, wd={weight_decay:.2e}, "
        f"blocks={num_blocks}, heads={num_heads}, segs={num_segments}"
    )

    # ─── 5) Model / optimizer / loss ──────────────────────────
    model = EEGformer(
        in_channels  = 19,
        input_length = input_length,
        kernel_size  = 10,
        num_filters  = NUM_FILTERS,
        num_heads    = num_heads,
        num_blocks   = num_blocks,
        num_segments = num_segments,
        num_classes  = 3
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    # ─── 6) Scheduler ──────────────────────────────────────────
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=FIXED_GAMMA,
        patience=FIXED_STEP_SIZE,
        min_lr=1e-6
    )

    # ─── 7) Training loop w/ Early Stopping & Pruning ─────────
    best_val_loss     = float("inf")
    epochs_no_improve = 0

    best_train_loss = best_train_acc = best_val_acc = None

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()

        # — train —
        model.train()
        tloss = tcorrect = ttotal = 0
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(X)
            loss   = criterion(logits, y)
            loss.backward()
            optimizer.step()

            tloss    += loss.item()
            tcorrect += (logits.argmax(1) == y).sum().item()
            ttotal   += y.size(0)
        train_loss = tloss / len(train_loader)
        train_acc  = tcorrect / ttotal

        # — validate —
        model.eval()
        vloss = vcorrect = vtotal = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y    = X.to(DEVICE), y.to(DEVICE)
                logits  = model(X)
                loss    = criterion(logits, y)
                vloss   += loss.item()
                vcorrect+= (logits.argmax(1) == y).sum().item()
                vtotal  += y.size(0)
        val_loss = vloss / len(val_loader)
        val_acc  = vcorrect / vtotal
        elapsed  = time.time() - t0

        # — report & pruning check —
        trial.report(val_loss, epoch)
        if trial.should_prune():
            wandb.finish()
            print(f"▸ Trial {trial.number} pruned at epoch {epoch}")
            raise optuna.TrialPruned()

        # — print & log —
        print(
            f"Epoch {epoch:03d} | "
            f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
            f"time={elapsed:.1f}s"
        )
        wandb.log({
            "epoch":               epoch,
            "train_loss":          train_loss,
            "train_accuracy":      train_acc,
            "validation_loss":     val_loss,
            "validation_accuracy": val_acc,
        }, step=epoch)

        scheduler.step(val_loss)

        # — early stopping logic & save best metrics —
        if val_loss < best_val_loss:
            best_val_loss     = val_loss
            epochs_no_improve = 0
            best_train_loss   = train_loss
            best_train_acc    = train_acc
            best_val_acc      = val_acc
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= PATIENCE:
                print(f"★ Early stopping at epoch {epoch}")
                break

    # store best-epoch metrics
    trial.set_user_attr("best_train_loss", best_train_loss)
    trial.set_user_attr("best_train_acc",  best_train_acc)
    trial.set_user_attr("best_val_acc",    best_val_acc)

    wandb.finish()
    torch.cuda.empty_cache()
    gc.collect()
    return best_val_loss


if __name__ == "__main__":
    multiprocessing.freeze_support()

    # ─── GridSampler용 파라미터 그리드 ─────────────────────────
    param_grid = {
        "lr":            LR_CHOICES,
        "weight_decay":  WD_CHOICES,
        "num_blocks":    NUM_BLOCK_CHOICES,
        "num_heads":     NUM_HEAD_CHOICES,
        "num_segments":  SEGMENT_CHOICES,
    }

    sampler = optuna.samplers.GridSampler(param_grid)
    study = optuna.create_study(
        direction="minimize",
        sampler=sampler,
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
        study_name="eeg_holdout_grid_search-18",
        storage="sqlite:////content/drive/MyDrive/2025_Lab_Research/eeg_grid_search-18.db",
        load_if_exists=True
    )
    study.optimize(objective_holdout)  # grid 크기만큼 자동 실행

    # ─── 결과 출력 ─────────────────────────────────────────────
    best = study.best_trial
    print("\n===== Best Trial =====")
    print(f"best_val_loss       = {best.value:.6f}")
    print(f"best_train_loss     = {best.user_attrs['best_train_loss']:.6f}")
    print(f"best_train_accuracy = {best.user_attrs['best_train_acc']:.4f}")
    print(f"best_val_accuracy   = {best.user_attrs['best_val_acc']:.4f}")
    print("best params:")
    for k, v in best.params.items():
        print(f"  {k}: {v}")


Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json
Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory


[I 2025-05-03 03:41:53,234] A new study created in RDB with name: eeg_holdout_grid_search-18



===== Trial 0 =====
 lr=5.00e-04, wd=5.00e-04, blocks=3, heads=3, segs=5
Epoch 001 | train_loss=1.1592 acc=0.3616 | val_loss=1.0927 acc=0.4317 | time=183.0s
Epoch 002 | train_loss=1.1299 acc=0.3751 | val_loss=1.0919 acc=0.4317 | time=40.4s
Epoch 003 | train_loss=1.1114 acc=0.3817 | val_loss=1.0768 acc=0.4317 | time=40.6s
Epoch 004 | train_loss=1.0910 acc=0.4062 | val_loss=1.0754 acc=0.4317 | time=40.4s
Epoch 005 | train_loss=1.0962 acc=0.3977 | val_loss=1.0743 acc=0.4317 | time=40.4s
Epoch 006 | train_loss=1.0975 acc=0.4004 | val_loss=1.0736 acc=0.4317 | time=40.3s
Epoch 007 | train_loss=1.0898 acc=0.4012 | val_loss=1.0775 acc=0.4317 | time=40.4s
Epoch 008 | train_loss=1.0764 acc=0.4066 | val_loss=1.0763 acc=0.4317 | time=40.6s
Epoch 009 | train_loss=1.0783 acc=0.4198 | val_loss=1.0746 acc=0.4317 | time=40.3s
Epoch 010 | train_loss=1.0646 acc=0.4272 | val_loss=1.0789 acc=0.4317 | time=40.3s
Epoch 011 | train_loss=1.0739 acc=0.4167 | val_loss=1.0636 acc=0.4317 | time=40.3s
Epoch 012 | 

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_accuracy,▁▁▁▂▂▂▂▂▂▂▂▂▃▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████
train_loss,███▇▇▇▇▇▇▇▇▇▇▆▅▅▅▅▅▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▅▆▆▄▆▅▆▆▇▄▆▅▆▇▆▇█▇█▇▆▇▇▇▆▆▆▇▅
validation_loss,██████████▇▆▆▄▅▄▄▃▄▂▅▃▄▂▂▅▂▂▂▁▃▄▅▂▄▆▇▇▄█

0,1
epoch,43.0
train_accuracy,0.79612
train_loss,0.48118
validation_accuracy,0.59317
validation_loss,1.07291


[I 2025-05-03 04:13:30,024] Trial 0 finished with value: 0.7071966386976696 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 3, 'num_heads': 3, 'num_segments': 5}. Best is trial 0 with value: 0.7071966386976696.



===== Trial 1 =====
 lr=5.00e-04, wd=5.00e-06, blocks=3, heads=1, segs=5
Epoch 001 | train_loss=1.1435 acc=0.3577 | val_loss=1.0758 acc=0.4317 | time=23.6s
Epoch 002 | train_loss=1.1207 acc=0.3802 | val_loss=1.0770 acc=0.4317 | time=23.8s
Epoch 003 | train_loss=1.1043 acc=0.3965 | val_loss=1.0782 acc=0.4317 | time=23.8s
Epoch 004 | train_loss=1.0955 acc=0.3814 | val_loss=1.0771 acc=0.4317 | time=23.8s
Epoch 005 | train_loss=1.0822 acc=0.4194 | val_loss=1.0687 acc=0.5124 | time=23.8s
Epoch 006 | train_loss=1.1063 acc=0.3810 | val_loss=1.0939 acc=0.4317 | time=23.7s
Epoch 007 | train_loss=1.1042 acc=0.3891 | val_loss=1.0661 acc=0.4317 | time=23.7s
Epoch 008 | train_loss=1.0684 acc=0.4315 | val_loss=1.0526 acc=0.5916 | time=23.6s
Epoch 009 | train_loss=1.0186 acc=0.5029 | val_loss=1.0184 acc=0.5668 | time=23.7s
Epoch 010 | train_loss=1.0062 acc=0.5118 | val_loss=1.0571 acc=0.3727 | time=23.7s
Epoch 011 | train_loss=0.9728 acc=0.5530 | val_loss=1.0631 acc=0.3478 | time=23.9s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train_accuracy,▁▁▂▁▂▁▁▂▃▃▄▄▄▅▅▅▅▅▆▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇█▇███
train_loss,███▇▇██▇▇▇▆▆▆▅▅▅▅▅▄▅▄▄▃▄▃▃▃▃▃▃▂▃▂▂▂▂▁▁▁▁
validation_accuracy,▃▃▃▃▄▃▃▆▅▁▁▆▇▇▇██▇▆▇█▇█▄███▇█▇█▆▇▇█▆▆▇▅▅
validation_loss,███████▇▇▇▇▅▅▄▄▄▃▄▄▄▁▂▁▄▁▁▂▂▂▁▁▃▁▂▂▅▄▂▅▆

0,1
epoch,42.0
train_accuracy,0.79845
train_loss,0.49634
validation_accuracy,0.57609
validation_loss,0.96411


[I 2025-05-03 04:30:09,750] Trial 1 finished with value: 0.6982603782699222 and parameters: {'lr': 0.0005, 'weight_decay': 5e-06, 'num_blocks': 3, 'num_heads': 1, 'num_segments': 5}. Best is trial 1 with value: 0.6982603782699222.



===== Trial 2 =====
 lr=5.00e-04, wd=5.00e-04, blocks=3, heads=1, segs=5
Epoch 001 | train_loss=1.1314 acc=0.3584 | val_loss=1.0756 acc=0.4317 | time=23.9s
Epoch 002 | train_loss=1.1067 acc=0.3810 | val_loss=1.0760 acc=0.4317 | time=23.7s
Epoch 003 | train_loss=1.0949 acc=0.3946 | val_loss=1.0763 acc=0.4317 | time=23.6s
Epoch 004 | train_loss=1.0905 acc=0.3942 | val_loss=1.0749 acc=0.4317 | time=23.8s
Epoch 005 | train_loss=1.0794 acc=0.4190 | val_loss=1.0689 acc=0.4317 | time=23.7s
Epoch 006 | train_loss=1.0673 acc=0.4342 | val_loss=1.0983 acc=0.3416 | time=23.6s
Epoch 007 | train_loss=1.0226 acc=0.4983 | val_loss=0.9701 acc=0.6056 | time=23.7s
Epoch 008 | train_loss=0.9678 acc=0.5464 | val_loss=0.9184 acc=0.5901 | time=24.0s
Epoch 009 | train_loss=0.9223 acc=0.5783 | val_loss=0.9658 acc=0.5124 | time=23.7s
Epoch 010 | train_loss=0.9115 acc=0.5814 | val_loss=0.9145 acc=0.6273 | time=23.6s
Epoch 011 | train_loss=0.8918 acc=0.5996 | val_loss=0.8426 acc=0.6398 | time=23.7s
Epoch 012 | t

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▁▁▂▂▂▂▃▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇█▇█████
train_loss,████▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▃▃▃▃▃▁▆▆▄▆▇▅▆▇▇▇██▇▇▇▇█▇▇▇▇▇█▇▆▆▅▇▆▇
validation_loss,▆▆▆▆▆▇▅▄▅▄▃▄▃▃▂▃▂▁▂▂▁▁▁▁▃▂▄▃▁▃▆▆█▄▅▃

0,1
epoch,36.0
train_accuracy,0.78563
train_loss,0.47894
validation_accuracy,0.65062
validation_loss,0.85122


[I 2025-05-03 04:44:27,409] Trial 2 finished with value: 0.7222818647112165 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 3, 'num_heads': 1, 'num_segments': 5}. Best is trial 1 with value: 0.6982603782699222.



===== Trial 3 =====
 lr=5.00e-04, wd=5.00e-04, blocks=3, heads=2, segs=5
Epoch 001 | train_loss=1.1429 acc=0.3588 | val_loss=1.0759 acc=0.4317 | time=32.4s
Epoch 002 | train_loss=1.1154 acc=0.3907 | val_loss=1.0803 acc=0.4317 | time=32.2s
Epoch 003 | train_loss=1.1014 acc=0.3907 | val_loss=1.0754 acc=0.4317 | time=32.2s
Epoch 004 | train_loss=1.0963 acc=0.3833 | val_loss=1.0744 acc=0.4317 | time=32.3s
Epoch 005 | train_loss=1.0938 acc=0.4000 | val_loss=1.0750 acc=0.4317 | time=32.3s
Epoch 006 | train_loss=1.0822 acc=0.4194 | val_loss=1.0758 acc=0.4317 | time=32.3s
Epoch 007 | train_loss=1.0751 acc=0.4120 | val_loss=1.0732 acc=0.4317 | time=32.1s
Epoch 008 | train_loss=1.0720 acc=0.4249 | val_loss=1.0656 acc=0.4317 | time=32.4s
Epoch 009 | train_loss=1.0596 acc=0.4528 | val_loss=1.0664 acc=0.4441 | time=32.2s
Epoch 010 | train_loss=1.0010 acc=0.5247 | val_loss=0.9649 acc=0.6009 | time=32.4s
Epoch 011 | train_loss=0.9427 acc=0.5740 | val_loss=0.8823 acc=0.6211 | time=32.3s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train_accuracy,▁▁▁▁▂▂▃▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███████████
train_loss,█████▇▇▇▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▅▅▆▇▇▇▇▇▇▅▄▆▅▆▇▆▅▇███▆▆▇▇██▇███▇
validation_loss,███████▅▅▄▃▃▃▃▃▄▅▃▃▂▃▃▂▂▃▂▁▂▁▁▃▄▂▂▁▄▃▂▂▃

0,1
epoch,54.0
train_accuracy,0.82913
train_loss,0.38215
validation_accuracy,0.7236
validation_loss,0.74834


[I 2025-05-03 05:13:36,413] Trial 3 finished with value: 0.6183182043688638 and parameters: {'lr': 0.0005, 'weight_decay': 0.0005, 'num_blocks': 3, 'num_heads': 2, 'num_segments': 5}. Best is trial 3 with value: 0.6183182043688638.



===== Trial 4 =====
 lr=5.00e-04, wd=5.00e-06, blocks=3, heads=3, segs=5
Epoch 001 | train_loss=1.1454 acc=0.3522 | val_loss=1.0839 acc=0.4317 | time=40.5s
Epoch 002 | train_loss=1.1230 acc=0.3887 | val_loss=1.0779 acc=0.4317 | time=40.6s
Epoch 003 | train_loss=1.1002 acc=0.3930 | val_loss=1.0815 acc=0.4317 | time=40.4s
Epoch 004 | train_loss=1.0935 acc=0.3922 | val_loss=1.0751 acc=0.4317 | time=40.3s
Epoch 005 | train_loss=1.0865 acc=0.4109 | val_loss=1.0790 acc=0.4317 | time=40.2s
Epoch 006 | train_loss=1.1033 acc=0.3771 | val_loss=1.0743 acc=0.4317 | time=40.3s
Epoch 007 | train_loss=1.0907 acc=0.3981 | val_loss=1.0752 acc=0.4317 | time=40.6s
Epoch 008 | train_loss=1.0781 acc=0.4252 | val_loss=1.0762 acc=0.4317 | time=40.4s
Epoch 009 | train_loss=1.0822 acc=0.4012 | val_loss=1.0753 acc=0.4317 | time=40.3s
Epoch 010 | train_loss=1.0776 acc=0.4128 | val_loss=1.0754 acc=0.4317 | time=40.3s
Epoch 011 | train_loss=1.0686 acc=0.4326 | val_loss=1.0734 acc=0.4317 | time=40.5s
Epoch 012 | t

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_accuracy,▁▂▂▂▂▁▂▂▂▂▂▂▂▃▄▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇█████
train_loss,██▇▇▇▇▇▇▇▇▇▇▇▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▆▆▇▅▇▇█▆█▇▇█▇▇█████▇▆▇▇▆▆
validation_loss,█████████████▇▆▅▅▃▄▃▂▂▄▁▂▂▁▃▂▂▁▃▂▂▂▆▅▄▄▅

0,1
epoch,44.0
train_accuracy,0.72311
train_loss,0.58053
validation_accuracy,0.63043
validation_loss,0.92208


[I 2025-05-03 05:43:19,069] Trial 4 finished with value: 0.6824441452821096 and parameters: {'lr': 0.0005, 'weight_decay': 5e-06, 'num_blocks': 3, 'num_heads': 3, 'num_segments': 5}. Best is trial 3 with value: 0.6183182043688638.



===== Trial 5 =====
 lr=5.00e-04, wd=5.00e-03, blocks=3, heads=3, segs=5
Epoch 001 | train_loss=1.1323 acc=0.3697 | val_loss=1.0850 acc=0.3416 | time=40.5s
Epoch 002 | train_loss=1.1073 acc=0.3961 | val_loss=1.0774 acc=0.4161 | time=40.3s
Epoch 003 | train_loss=1.0941 acc=0.3992 | val_loss=1.0921 acc=0.3416 | time=40.3s
Epoch 004 | train_loss=1.0960 acc=0.3845 | val_loss=1.0745 acc=0.4317 | time=40.4s


0,1
epoch,▁▃▆█
train_accuracy,▁▇█▅
train_loss,█▃▁▁
validation_accuracy,▁▇▁█
validation_loss,▅▂█▁

0,1
epoch,4.0
train_accuracy,0.38447
train_loss,1.09599
validation_accuracy,0.43168
validation_loss,1.07447


[I 2025-05-03 05:46:42,962] Trial 5 pruned. 


▸ Trial 5 pruned at epoch 5



===== Trial 6 =====
 lr=5.00e-04, wd=5.00e-03, blocks=3, heads=2, segs=5
Epoch 001 | train_loss=1.1348 acc=0.3612 | val_loss=1.0933 acc=0.4317 | time=32.3s
Epoch 002 | train_loss=1.1161 acc=0.3728 | val_loss=1.0793 acc=0.4317 | time=32.2s
Epoch 003 | train_loss=1.1098 acc=0.3926 | val_loss=1.0765 acc=0.4317 | time=32.2s
Epoch 004 | train_loss=1.0956 acc=0.4058 | val_loss=1.0754 acc=0.4317 | time=32.3s


0,1
epoch,▁▃▆█
train_accuracy,▁▃▆█
train_loss,█▅▄▁
validation_accuracy,▁▁▁▁
validation_loss,█▃▁▁

0,1
epoch,4.0
train_accuracy,0.40583
train_loss,1.09557
validation_accuracy,0.43168
validation_loss,1.07542


[I 2025-05-03 05:49:26,015] Trial 6 pruned. 


▸ Trial 6 pruned at epoch 5



===== Trial 7 =====
 lr=5.00e-04, wd=5.00e-03, blocks=3, heads=1, segs=5
Epoch 001 | train_loss=1.1422 acc=0.3724 | val_loss=1.0795 acc=0.4317 | time=23.9s
Epoch 002 | train_loss=1.1130 acc=0.3903 | val_loss=1.0757 acc=0.4317 | time=23.7s
Epoch 003 | train_loss=1.1079 acc=0.3876 | val_loss=1.0774 acc=0.4317 | time=23.7s
Epoch 004 | train_loss=1.0945 acc=0.3814 | val_loss=1.0768 acc=0.4317 | time=23.6s


0,1
epoch,▁▃▆█
train_accuracy,▁█▇▄
train_loss,█▄▃▁
validation_accuracy,▁▁▁▁
validation_loss,█▁▄▃

0,1
epoch,4.0
train_accuracy,0.38136
train_loss,1.09445
validation_accuracy,0.43168
validation_loss,1.07683


[I 2025-05-03 05:51:26,440] Trial 7 pruned. 


▸ Trial 7 pruned at epoch 5



===== Trial 8 =====
 lr=5.00e-04, wd=5.00e-06, blocks=3, heads=2, segs=5
Epoch 001 | train_loss=1.1411 acc=0.3487 | val_loss=1.1201 acc=0.3416 | time=32.5s
Epoch 002 | train_loss=1.1171 acc=0.3619 | val_loss=1.0757 acc=0.4317 | time=32.3s
Epoch 003 | train_loss=1.1008 acc=0.3895 | val_loss=1.0791 acc=0.3416 | time=32.5s
Epoch 004 | train_loss=1.0933 acc=0.3845 | val_loss=1.0762 acc=0.4317 | time=32.4s


0,1
epoch,▁▃▆█
train_accuracy,▁▃█▇
train_loss,█▄▂▁
validation_accuracy,▁█▁█
validation_loss,█▁▂▁

0,1
epoch,4.0
train_accuracy,0.38447
train_loss,1.0933
validation_accuracy,0.43168
validation_loss,1.0762


[I 2025-05-03 05:54:10,560] Trial 8 pruned. 


▸ Trial 8 pruned at epoch 5

===== Best Trial =====
best_val_loss       = 0.618318
best_train_loss     = 0.479521
best_train_accuracy = 0.7833
best_val_accuracy   = 0.7500
best params:
  lr: 0.0005
  weight_decay: 0.0005
  num_blocks: 3
  num_heads: 2
  num_segments: 5


### K fold Cross Validation Test using best hyperparameter
- Block=1, Head=3
- ReduceLRPlateu --> Not stable

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_3 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-2
NUM_FILTERS = 120
NUM_BLOCKS = 1
NUM_HEADS = 3
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.5

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-2",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WEIGHT_DECAY
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=FIXED_GAMMA,
            patience=FIXED_STEP_SIZE,
            min_lr=1e-6
        )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step(val_loss)

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-2",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1964 acc=0.3340 | val_loss=1.0918 acc=0.3587 | time=18.6s
Epoch 002 | train_loss=1.1114 acc=0.3643 | val_loss=1.1037 acc=0.4006 | time=18.2s
Epoch 003 | train_loss=1.1302 acc=0.3662 | val_loss=1.0936 acc=0.4022 | time=18.0s
Epoch 004 | train_loss=1.1028 acc=0.3887 | val_loss=1.0953 acc=0.4286 | time=18.1s
Epoch 005 | train_loss=1.0821 acc=0.4074 | val_loss=1.0885 acc=0.5078 | time=18.3s
Epoch 006 | train_loss=1.0702 acc=0.4400 | val_loss=1.6323 acc=0.3587 | time=18.1s
Epoch 007 | train_loss=1.0474 acc=0.4695 | val_loss=1.0324 acc=0.5233 | time=18.1s
Epoch 008 | train_loss=0.9879 acc=0.5480 | val_loss=0.9448 acc=0.6040 | time=18.2s
Epoch 009 | train_loss=0.9583 acc=0.5705 | val_loss=0.9453 acc=0.6320 | time=18.1s
Epoch 010 | train_loss=0.9459 acc=0.5693 | val_loss=0.9224 acc=0.6227 | time=18.2s
Epoch 011 | train_loss=0.9050 acc=0.5957 | val_loss=0.8497 acc=0.6398 | time=18.0s
Epoch 012 | train_loss=0.8867 acc=0.6054 | val_loss=0.9269 acc=0.5885 | time=18.2s
Epoc

0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇██
train_accuracy,▁▁▂▄▄▅▅▅▅▆▅▆▆▇▇▇███▇████████████████████
train_loss,██▇▇▆▅▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▂▂▁▆▅▇█▇▇▇▇▇▇█▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
validation_loss,▄▄█▃▃▂▂▂▁▁▂▁▁▁▃▂▃▄▃▃▄▃▃▃▃▄▄▄▄▃▃▄▄▃▃▃▄▄▄▄

0,1
best_train_accuracy,0.69049
best_train_loss,0.67951
best_val_accuracy,0.71118
best_val_loss,0.71002
epoch,100.0
train_accuracy,0.83806
train_loss,0.36556
validation_accuracy,0.67702
validation_loss,1.05496



===== Fold 2 =====


Epoch 001 | train_loss=1.1512 acc=0.3751 | val_loss=1.0792 acc=0.4255 | time=18.1s
Epoch 002 | train_loss=1.1054 acc=0.3903 | val_loss=1.0804 acc=0.4255 | time=18.0s
Epoch 003 | train_loss=1.0996 acc=0.3957 | val_loss=1.0763 acc=0.4255 | time=18.2s
Epoch 004 | train_loss=1.0974 acc=0.3996 | val_loss=1.0657 acc=0.5031 | time=18.1s
Epoch 005 | train_loss=1.0610 acc=0.4520 | val_loss=0.9951 acc=0.5621 | time=18.2s
Epoch 006 | train_loss=0.9868 acc=0.5363 | val_loss=0.9157 acc=0.6351 | time=18.1s
Epoch 007 | train_loss=0.9582 acc=0.5608 | val_loss=0.9044 acc=0.6537 | time=18.0s
Epoch 008 | train_loss=0.9286 acc=0.5783 | val_loss=0.8829 acc=0.6242 | time=18.2s
Epoch 009 | train_loss=0.8926 acc=0.5849 | val_loss=0.8285 acc=0.6786 | time=18.1s
Epoch 010 | train_loss=0.8613 acc=0.5981 | val_loss=0.8492 acc=0.6925 | time=18.3s
Epoch 011 | train_loss=0.8509 acc=0.6058 | val_loss=0.8419 acc=0.6335 | time=18.0s
Epoch 012 | train_loss=0.8210 acc=0.6249 | val_loss=0.7974 acc=0.6910 | time=18.3s
Epoc

KeyboardInterrupt: 

### Learning Rate Scheduler 바꾸기
- ReduceLRPlateu -> CosineAnnealingWarmupRestarts

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_3 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-2
NUM_FILTERS = 120
NUM_BLOCKS = 1
NUM_HEADS = 3
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.5

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-3",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WEIGHT_DECAY
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,
            T_mult=2,
            eta_min=1e-6
            )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step()

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-3",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1428 acc=0.3771 | val_loss=1.0937 acc=0.3587 | time=18.6s
Epoch 002 | train_loss=1.1296 acc=0.3755 | val_loss=1.0822 acc=0.4783 | time=18.3s
Epoch 003 | train_loss=1.0983 acc=0.3930 | val_loss=1.0740 acc=0.4161 | time=18.2s
Epoch 004 | train_loss=1.0355 acc=0.4901 | val_loss=0.9692 acc=0.5901 | time=18.1s
Epoch 005 | train_loss=0.9776 acc=0.5518 | val_loss=0.9922 acc=0.6227 | time=18.2s
Epoch 006 | train_loss=0.9668 acc=0.5569 | val_loss=0.9329 acc=0.5730 | time=18.2s
Epoch 007 | train_loss=0.9282 acc=0.5876 | val_loss=0.9263 acc=0.6289 | time=18.4s
Epoch 008 | train_loss=0.9201 acc=0.5852 | val_loss=0.8918 acc=0.6382 | time=18.0s
Epoch 009 | train_loss=0.9010 acc=0.6050 | val_loss=0.8902 acc=0.6382 | time=18.1s
Epoch 010 | train_loss=0.9084 acc=0.6039 | val_loss=0.8965 acc=0.6320 | time=18.1s
Epoch 011 | train_loss=0.9142 acc=0.5872 | val_loss=0.9102 acc=0.6553 | time=18.1s
Epoch 012 | train_loss=0.8876 acc=0.6070 | val_loss=0.8223 acc=0.6460 | time=18.4s
Epoc

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇████
train_accuracy,▁▃▄▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇████████▇█▇████████
train_loss,█▇▇▆▆▆▆▅▅▄▄▃▄▃▄▄▄▃▃▃▃▃▂▂▂▁▁▁▁▁▁▁▂▂▂▂▂▁▁▁
validation_accuracy,▁▃▂▅▆▆▇▆▇████▆██▅▇▆▇▇█▇▆▇▇▆▇▇▆▆▆▇▇▆▅▆▇▆▆
validation_loss,▃▃▃▂▂▁▁▁▁▂▂▁▂▂▂▂▁▃▂▃▂▃▃▄▄▄▄▄▄▄▄▅▄▅▅▆▅█▆▆

0,1
best_train_accuracy,0.78252
best_train_loss,0.53374
best_val_accuracy,0.6941
best_val_loss,0.68011
epoch,100.0
train_accuracy,0.83961
train_loss,0.34271
validation_accuracy,0.59317
validation_loss,1.59282



===== Fold 2 =====


Epoch 001 | train_loss=1.1421 acc=0.3460 | val_loss=1.0796 acc=0.4255 | time=18.2s
Epoch 002 | train_loss=1.0977 acc=0.3926 | val_loss=1.0727 acc=0.4255 | time=18.1s
Epoch 003 | train_loss=1.0634 acc=0.4447 | val_loss=1.0168 acc=0.4891 | time=18.3s
Epoch 004 | train_loss=1.0039 acc=0.5076 | val_loss=0.9399 acc=0.6242 | time=18.1s
Epoch 005 | train_loss=0.9778 acc=0.5348 | val_loss=0.9044 acc=0.6553 | time=18.1s
Epoch 006 | train_loss=0.9542 acc=0.5569 | val_loss=0.8612 acc=0.6755 | time=18.2s
Epoch 007 | train_loss=0.9248 acc=0.5860 | val_loss=0.8604 acc=0.6848 | time=18.0s
Epoch 008 | train_loss=0.8801 acc=0.6101 | val_loss=0.8552 acc=0.7003 | time=18.3s
Epoch 009 | train_loss=0.8657 acc=0.6245 | val_loss=0.8288 acc=0.6988 | time=18.2s
Epoch 010 | train_loss=0.8589 acc=0.6280 | val_loss=0.8281 acc=0.7081 | time=18.1s
Epoch 011 | train_loss=0.8787 acc=0.6066 | val_loss=0.8907 acc=0.5807 | time=18.1s
Epoch 012 | train_loss=0.8753 acc=0.6085 | val_loss=0.7857 acc=0.6801 | time=18.0s
Epoc

KeyboardInterrupt: 

In [None]:
Epoch 021 | train_loss=0.6985 acc=0.6928 | val_loss=0.6842 acc=0.7469 | time=18.1s

### ReducePlateuLR Step Size -> 10,  gamma factor-> 0.8
### Model Revision
- CNNDecoder CNN Layer N1 = 16, N2= 32 revision
- CNNDecoder dropout rate increase to 0.5
- Transformer Block Encoder  dropout rate = 0.2
- Weight Decay to be strong 5e-2 -> 7e-2
- LR gamma factor 0.5 -> 0.7 revision
- Number of segments = 5 -> 3: prevent overfitting and focus on Global Context

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_4 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-4
NUM_FILTERS = 120
NUM_BLOCKS = 1
NUM_HEADS = 3
NUM_SEGMENTS = 3

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.7

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-4",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WEIGHT_DECAY
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=FIXED_GAMMA,
            patience=FIXED_STEP_SIZE,
            min_lr=1e-6
        )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step(val_loss)

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-4",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()


Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json
Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1635 acc=0.3173 | val_loss=1.0962 acc=0.3587 | time=192.9s
Epoch 002 | train_loss=1.1223 acc=0.3577 | val_loss=1.0910 acc=0.4006 | time=19.0s
Epoch 003 | train_loss=1.0954 acc=0.3957 | val_loss=1.0888 acc=0.4006 | time=18.7s
Epoch 004 | train_loss=1.0854 acc=0.4221 | val_loss=1.0899 acc=0.4006 | time=19.0s
Epoch 005 | train_loss=1.0866 acc=0.4058 | val_loss=1.0899 acc=0.4006 | time=19.0s
Epoch 006 | train_loss=1.0859 acc=0.4066 | val_loss=1.0885 acc=0.4006 | time=18.5s
Epoch 007 | train_loss=1.0766 acc=0.4194 | val_loss=1.0895 acc=0.4006 | time=19.0s
Epoch 008 | train_loss=1.0763 acc=0.4276 | val_loss=1.0882 acc=0.4006 | time=18.8s
Epoch 009 | train_loss=1.0770 acc=0.4175 | val_loss=1.0856 acc=0.4006 | time=19.3s
Epoch 010 | train_loss=1.0738 acc=0.4287 | val_loss=1.0855 acc=0.4006 | time=19.1s
Epoch 011 | train_loss=1.0485 acc=0.4532 | val_loss=1.0444 acc=0.5994 | time=19.3s
Epoch 012 | train_loss=0.9995 acc=0.5243 | val_loss=0.9694 acc=0.5745 | time=19.2s
Epo

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
train_accuracy,▁▁▁▂▃▄▅▅▆▆▆▆▆▆▇▇▇▇▇██▇██████████████████
train_loss,█████▆▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▁▁▆▇█▇█▇█▆▇▇▇█▇▇▇█▇█▇▇▇▇██▇██▇▇▇█▇▇███▇
validation_loss,▆▆▆▆▆▆▆▅▄▃▂▁▂▁▁▃▅▂▆▃▄▅▆▅▅▆▆▆▆▆▅▇▇█▆▇▇▇▆▇

0,1
best_train_accuracy,0.69282
best_train_loss,0.69853
best_val_accuracy,0.74689
best_val_loss,0.68421
epoch,100.0
train_accuracy,0.84078
train_loss,0.35132
validation_accuracy,0.6677
validation_loss,1.13687



===== Fold 2 =====


Epoch 001 | train_loss=1.1601 acc=0.3748 | val_loss=1.0952 acc=0.3665 | time=18.6s
Epoch 002 | train_loss=1.1454 acc=0.3697 | val_loss=1.0762 acc=0.4270 | time=18.5s
Epoch 003 | train_loss=1.1187 acc=0.3856 | val_loss=1.0760 acc=0.4255 | time=18.4s
Epoch 004 | train_loss=1.1237 acc=0.3779 | val_loss=1.0739 acc=0.5326 | time=18.8s
Epoch 005 | train_loss=1.1111 acc=0.3860 | val_loss=1.0736 acc=0.4752 | time=18.4s
Epoch 006 | train_loss=1.0955 acc=0.3915 | val_loss=1.0696 acc=0.4270 | time=18.9s
Epoch 007 | train_loss=1.0820 acc=0.4183 | val_loss=1.0134 acc=0.5730 | time=18.3s
Epoch 008 | train_loss=1.0317 acc=0.4905 | val_loss=0.9829 acc=0.6165 | time=18.6s
Epoch 009 | train_loss=1.0027 acc=0.5433 | val_loss=1.0021 acc=0.6320 | time=18.5s
Epoch 010 | train_loss=0.9560 acc=0.5736 | val_loss=0.9903 acc=0.6056 | time=18.3s
Epoch 011 | train_loss=0.9299 acc=0.5759 | val_loss=0.8320 acc=0.6444 | time=18.8s
Epoch 012 | train_loss=0.8908 acc=0.5981 | val_loss=0.8974 acc=0.6211 | time=18.3s
Epoc

KeyboardInterrupt: 

### Change the block = 1 -> 2, head = 3 -> 2

- After changing the segment = 3, it turns out that the model shows too high validation loss and accuracy is stuck around 66%
- Block = 2, Head = 2
- Remaining the dropout rate and layer num as before
- Turn back to segment = 5

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_4 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-4
NUM_FILTERS = 120
NUM_BLOCKS = 2
NUM_HEADS = 2
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.7

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-5",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WEIGHT_DECAY
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=FIXED_GAMMA,
            patience=FIXED_STEP_SIZE,
            min_lr=1e-6
        )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step(val_loss)

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-5",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1641 acc=0.3674 | val_loss=1.0901 acc=0.4006 | time=25.0s
Epoch 002 | train_loss=1.1204 acc=0.3895 | val_loss=1.0927 acc=0.4006 | time=24.2s
Epoch 003 | train_loss=1.1167 acc=0.3876 | val_loss=1.1304 acc=0.4006 | time=24.3s
Epoch 004 | train_loss=1.0982 acc=0.3988 | val_loss=1.0918 acc=0.4006 | time=24.2s
Epoch 005 | train_loss=1.0962 acc=0.3973 | val_loss=1.0883 acc=0.4006 | time=24.1s
Epoch 006 | train_loss=1.0847 acc=0.4268 | val_loss=1.0964 acc=0.4006 | time=24.0s
Epoch 007 | train_loss=1.0878 acc=0.4113 | val_loss=1.0909 acc=0.4006 | time=24.1s
Epoch 008 | train_loss=1.0616 acc=0.4505 | val_loss=1.0474 acc=0.5311 | time=24.2s
Epoch 009 | train_loss=1.0163 acc=0.5130 | val_loss=1.0542 acc=0.5590 | time=24.3s
Epoch 010 | train_loss=0.9332 acc=0.5825 | val_loss=0.8620 acc=0.6304 | time=24.1s
Epoch 011 | train_loss=0.8714 acc=0.5996 | val_loss=0.8327 acc=0.6522 | time=24.3s
Epoch 012 | train_loss=0.8411 acc=0.6202 | val_loss=0.8908 acc=0.6351 | time=24.1s
Epoc

KeyboardInterrupt: 

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_4 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-2
NUM_FILTERS = 120
NUM_BLOCKS = 2
NUM_HEADS = 2
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.7

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-6",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WEIGHT_DECAY
        )
        criterion = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=FIXED_GAMMA,
            patience=FIXED_STEP_SIZE,
            min_lr=1e-6
        )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step(val_loss)

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-6",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()



===== Fold 1 =====


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇███
train_accuracy,▁▃▄▄▅▅▅▆▆▆▆▇▆▆▇▇▇▇▇▇████████████████████
train_loss,██▇▇▇▆▅▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▄▆▆██▇▇█▇█▇▇▇▇▇▇▇▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆
validation_loss,▅▅▂▃▁▂▁▁▃▁▂▂▃▃▄▃▄▄▄▅▄▆▆▇▆▆▅▆▆▇▆▇▆▆▆▇▆▇▇█

0,1
epoch,88.0
train_accuracy,0.88583
train_loss,0.25898
validation_accuracy,0.66925
validation_loss,1.40101


Epoch 001 | train_loss=1.1807 acc=0.3557 | val_loss=1.1243 acc=0.2391 | time=23.8s
Epoch 002 | train_loss=1.1356 acc=0.3755 | val_loss=1.0979 acc=0.4006 | time=23.8s
Epoch 003 | train_loss=1.1200 acc=0.4027 | val_loss=1.0907 acc=0.4006 | time=24.1s
Epoch 004 | train_loss=1.1065 acc=0.3969 | val_loss=1.0937 acc=0.4006 | time=24.0s
Epoch 005 | train_loss=1.1034 acc=0.3981 | val_loss=1.0901 acc=0.4006 | time=23.8s
Epoch 006 | train_loss=1.0948 acc=0.3918 | val_loss=1.0897 acc=0.4006 | time=23.8s
Epoch 007 | train_loss=1.0820 acc=0.4210 | val_loss=1.0957 acc=0.4006 | time=23.8s
Epoch 008 | train_loss=1.0715 acc=0.4190 | val_loss=1.0869 acc=0.4006 | time=23.8s
Epoch 009 | train_loss=1.0759 acc=0.4295 | val_loss=1.0733 acc=0.4006 | time=23.9s
Epoch 010 | train_loss=1.0440 acc=0.4664 | val_loss=1.0426 acc=0.5559 | time=23.8s
Epoch 011 | train_loss=0.9854 acc=0.5472 | val_loss=0.9198 acc=0.6258 | time=24.1s
Epoch 012 | train_loss=0.9262 acc=0.5841 | val_loss=0.8927 acc=0.6180 | time=23.9s
Epoc

0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇████
train_accuracy,▁▁▂▂▄▄▅▅▅▆▆▆▇▇▇▇████████████████████████
train_loss,█████▆▅▄▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▁▁▁▄▆▆▇▇▇▆██▇▆▇▆▇█▅▆▇▆▆▇▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆
validation_loss,▄▄▄▄▄▃▃▁▂▃▃▂▁▃▆▃▇▆▇▇▆▇▇▆▇▇▇▇████▇▇▇▇▇█▇▇

0,1
best_train_accuracy,0.70019
best_train_loss,0.67081
best_val_accuracy,0.74224
best_val_loss,0.68714
epoch,100.0
train_accuracy,0.86058
train_loss,0.31192
validation_accuracy,0.66149
validation_loss,1.4391



===== Fold 2 =====


Epoch 001 | train_loss=1.1596 acc=0.3817 | val_loss=1.0851 acc=0.4255 | time=24.0s


KeyboardInterrupt: 

#### Num_Filters = 120 -> 60
#### N1=16 ->8, N2=32->16
- Due to Overfitting
- Decrease the N1 = 16 -> 8, N2 = 32 -> 16

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_5 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 5e-4
WEIGHT_DECAY = 5e-2
NUM_FILTERS = 60
NUM_BLOCKS = 2
NUM_HEADS = 2
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Scheduler parameters ─────────────────────────────────────
FIXED_STEP_SIZE = 5
FIXED_GAMMA = 0.7

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        # wandb run 시작
        wandb.init(
            project="eeg-5fold-cv-7",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            }
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 / 옵티마이저 / 손실함수
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        base_lr = LR
        base_wd = WEIGHT_DECAY
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=base_lr,
            weight_decay=base_wd
        )
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=FIXED_GAMMA,
            patience=FIXED_STEP_SIZE,
            min_lr=1e-6
        )

        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)
            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            # — 터미널 출력 —
            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # — wandb 로깅 —
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
            }, step=epoch)

            # Scheduler
            scheduler.step(val_loss)

            # Weight Decay Scheduling
            current_lr = optimizer.param_groups[0]['lr']
            new_wd = base_wd * (current_lr / LR)
            for g in optimizer.param_groups:
                g['weight_decay'] = new_wd

            # best 갱신
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_val_acc    = val_acc
                best_train_loss = train_loss
                best_train_acc  = train_acc

        # Fold 종료 시 summary 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        # cleanup
        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 로깅 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    # 별도 W&B run 으로 평균 지표 기록
    wandb.init(
        project="eeg-5fold-cv-7",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        }
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

    # ─── 5-Fold CV Summary 터미널 출력 ────────────────────────────
    print("\n===== 5-Fold CV Summary =====")
    for i, res in enumerate(fold_results, 1):
        print(
            f" Fold {i:2d}: "
            f"train_loss = {res['train_loss']:.4f}, "
            f"train_acc = {res['train_acc']:.4f}, "
            f"val_loss = {res['val_loss']:.4f}, "
            f"val_acc = {res['val_acc']:.4f}"
        )
    print(
        f"\n Average: "
        f"train_loss = {avg['train_loss']:.4f}, "
        f"train_acc = {avg['train_acc']:.4f}, "
        f"val_loss = {avg['val_loss']:.4f}, "
        f"val_acc = {avg['val_acc']:.4f}"
    )


if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1397 acc=0.3767 | val_loss=1.0947 acc=0.4006 | time=16.5s
Epoch 002 | train_loss=1.1166 acc=0.3817 | val_loss=1.0917 acc=0.3587 | time=16.4s
Epoch 003 | train_loss=1.1068 acc=0.3969 | val_loss=1.0911 acc=0.4006 | time=16.8s
Epoch 004 | train_loss=1.0894 acc=0.4054 | val_loss=1.0915 acc=0.4006 | time=16.5s
Epoch 005 | train_loss=1.0901 acc=0.4050 | val_loss=1.0934 acc=0.4006 | time=16.4s
Epoch 006 | train_loss=1.0871 acc=0.4198 | val_loss=1.0925 acc=0.4006 | time=16.7s
Epoch 007 | train_loss=1.0848 acc=0.4132 | val_loss=1.0887 acc=0.4006 | time=16.6s
Epoch 008 | train_loss=1.0714 acc=0.4353 | val_loss=1.0761 acc=0.4006 | time=16.4s
Epoch 009 | train_loss=1.0620 acc=0.4781 | val_loss=1.0360 acc=0.5854 | time=16.5s
Epoch 010 | train_loss=1.0379 acc=0.5002 | val_loss=0.9679 acc=0.6149 | time=16.9s
Epoch 011 | train_loss=0.9817 acc=0.5569 | val_loss=0.9548 acc=0.6444 | time=16.3s
Epoch 012 | train_loss=0.9680 acc=0.5709 | val_loss=0.9775 acc=0.6413 | time=16.5s
Epoc

KeyboardInterrupt: 

#### Block = 2, Head = 2
- Early Stopping
- OneCycleLR
- Prevent overfitting: 5e-2 -> 1e-1


In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from wandb import Settings
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_5 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 1e-3
WEIGHT_DECAY = 1e-1
NUM_FILTERS = 60
NUM_BLOCKS = 2
NUM_HEADS = 2
NUM_SEGMENTS = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS = 100
BATCH_SIZE = 32
PATIENCE = 10                   # ← MOD: EarlyStopping patience
NUM_WORKERS = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds = EEGDataset(DATA_DIR, train_meta)
    labels = [d["label"] for d in train_meta]
    input_length = full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        wandb.init(
            project="eeg-5fold-cv-12",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS
            },
            settings=Settings(init_timeout=120)
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 생성
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        # ← MOD: base_lr/base_wd 저장
        base_lr = LR
        base_wd = WEIGHT_DECAY

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=base_lr,
            weight_decay=base_wd
        )

        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        total_steps = MAX_EPOCHS * len(train_loader)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=base_lr,
            total_steps=total_steps,
            pct_start=0.2,
            anneal_strategy='cos',
            div_factor=10,
            final_div_factor=100
        )

        # ← MOD: EarlyStopping 변수 초기화
        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0
        no_improve = 0

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()

            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                # ← MOD: OneCycleLR 배치 단위 step
                scheduler.step()

                # ← MOD: lr 변화에 맞춰 weight_decay 조정
                cur_lr = optimizer.param_groups[0]['lr']
                new_wd = base_wd * (cur_lr / base_lr)
                for g in optimizer.param_groups:
                    g['weight_decay'] = new_wd

                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)

            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
                "lr":                  optimizer.param_groups[0]['lr']
            }, step=epoch)

            # ← MOD: EarlyStopping 로직
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_train_loss = train_loss
                best_train_acc  = train_acc
                best_val_acc    = val_acc
                no_improve      = 0
            else:
                no_improve += 1
                if no_improve >= PATIENCE:
                    print(f"Early stopping at epoch {epoch}")
                    break

        # Fold 결과 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 기록 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    wandb.init(
        project="eeg-5fold-cv-12",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        },
        settings=Settings(init_timeout=120)
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1429 acc=0.3817 | val_loss=1.0908 acc=0.4006 | time=16.5s
Epoch 002 | train_loss=1.1409 acc=0.3635 | val_loss=1.0923 acc=0.4006 | time=16.5s
Epoch 003 | train_loss=1.1410 acc=0.3631 | val_loss=1.0905 acc=0.4006 | time=16.5s
Epoch 004 | train_loss=1.1272 acc=0.3650 | val_loss=1.0909 acc=0.4006 | time=16.7s
Epoch 005 | train_loss=1.1080 acc=0.3988 | val_loss=1.0937 acc=0.4006 | time=16.5s
Epoch 006 | train_loss=1.1047 acc=0.3926 | val_loss=1.0899 acc=0.4006 | time=16.4s
Epoch 007 | train_loss=1.0957 acc=0.3969 | val_loss=1.0729 acc=0.4488 | time=16.5s
Epoch 008 | train_loss=1.0476 acc=0.4885 | val_loss=0.9947 acc=0.5947 | time=16.5s
Epoch 009 | train_loss=1.0207 acc=0.5293 | val_loss=0.9561 acc=0.6366 | time=16.2s
Epoch 010 | train_loss=0.9783 acc=0.5798 | val_loss=0.9371 acc=0.6102 | time=16.4s
Epoch 011 | train_loss=0.9537 acc=0.6008 | val_loss=0.8983 acc=0.6460 | time=16.5s
Epoch 012 | train_loss=0.9410 acc=0.5872 | val_loss=0.9097 acc=0.6149 | time=16.7s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇▇▇
train_accuracy,▁▁▁▁▂▂▂▃▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████
train_loss,████▇▇▇▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▂▅▆▆▆▆▆▆▇▅▆█▇▇▇▇▇▅▅█▇▇▇▇▇▇▇▅▇▇
validation_loss,███████▆▅▄▄▄▃▃▂▅▃▂▃▃▂▂▃▆▆▁▂▂▂▂▃▃▅▆▃▄

0,1
best_train_accuracy,0.70097
best_train_loss,0.78334
best_val_accuracy,0.71584
best_val_loss,0.78803
epoch,36.0
lr,0.0009
train_accuracy,0.7701
train_loss,0.69876
validation_accuracy,0.64907
validation_loss,0.92513



===== Fold 2 =====


Epoch 001 | train_loss=1.1798 acc=0.3841 | val_loss=1.0839 acc=0.4255 | time=16.4s
Epoch 002 | train_loss=1.1702 acc=0.3709 | val_loss=1.0856 acc=0.4255 | time=16.5s
Epoch 003 | train_loss=1.1351 acc=0.3829 | val_loss=1.0856 acc=0.4255 | time=16.6s


[34m[1mwandb[0m: 500 encountered ({"error":"context deadline exceeded"}), retrying request
[34m[1mwandb[0m: 500 encountered ({"error":"context deadline exceeded"}), retrying request


Epoch 004 | train_loss=1.1425 acc=0.3596 | val_loss=1.0885 acc=0.4255 | time=16.7s
Epoch 005 | train_loss=1.1245 acc=0.3794 | val_loss=1.0833 acc=0.4255 | time=16.6s
Epoch 006 | train_loss=1.1170 acc=0.3837 | val_loss=1.0816 acc=0.4255 | time=16.4s
Epoch 007 | train_loss=1.1127 acc=0.3841 | val_loss=1.0812 acc=0.4255 | time=16.7s
Epoch 008 | train_loss=1.0991 acc=0.3907 | val_loss=1.0837 acc=0.4255 | time=16.6s
Epoch 009 | train_loss=1.0931 acc=0.4085 | val_loss=1.0789 acc=0.4255 | time=16.4s
Epoch 010 | train_loss=1.0941 acc=0.4058 | val_loss=1.0746 acc=0.4255 | time=16.6s
Epoch 011 | train_loss=1.0832 acc=0.4179 | val_loss=1.0619 acc=0.4845 | time=16.6s
Epoch 012 | train_loss=1.0546 acc=0.4854 | val_loss=0.9922 acc=0.5947 | time=16.4s
Epoch 013 | train_loss=1.0125 acc=0.5383 | val_loss=0.9416 acc=0.6506 | time=16.3s
Epoch 014 | train_loss=0.9695 acc=0.5849 | val_loss=0.8911 acc=0.6661 | time=16.6s
Epoch 015 | train_loss=0.9534 acc=0.5849 | val_loss=0.9271 acc=0.6366 | time=16.5s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇▇▇
train_accuracy,▁▁▁▁▁▁▁▂▂▂▂▃▄▅▅▅▅▆▆▆▆▆▆▇▇▆▇▇▇▇▇▇█▇██
train_loss,██▇▇▇▇▇▇▇▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▂▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▂▅▆▆▆▆▆▇▆▅▇▇▆▇▅▇▇█▇▆▆▇█▅▇█
validation_loss,██████████▇▆▄▃▄▄▃▂▂▄▃▂▂▁▅▁▃▁▂▄▃▁▂▃▂▁

0,1
best_train_accuracy,0.6567
best_train_loss,0.83622
best_val_accuracy,0.70807
best_val_loss,0.79969
epoch,36.0
lr,0.0009
train_accuracy,0.73786
train_loss,0.73996
validation_accuracy,0.73447
validation_loss,0.80906



===== Fold 3 =====


Epoch 001 | train_loss=1.1651 acc=0.3511 | val_loss=1.0836 acc=0.4410 | time=16.5s
Epoch 002 | train_loss=1.1422 acc=0.3654 | val_loss=1.0819 acc=0.4410 | time=16.4s
Epoch 003 | train_loss=1.1404 acc=0.3596 | val_loss=1.0818 acc=0.4410 | time=16.4s
Epoch 004 | train_loss=1.1220 acc=0.3833 | val_loss=1.0823 acc=0.4410 | time=16.3s
Epoch 005 | train_loss=1.1190 acc=0.3724 | val_loss=1.0826 acc=0.4410 | time=16.5s
Epoch 006 | train_loss=1.1043 acc=0.4101 | val_loss=1.0820 acc=0.4410 | time=16.5s
Epoch 007 | train_loss=1.1056 acc=0.3957 | val_loss=1.0793 acc=0.4472 | time=16.6s
Epoch 008 | train_loss=1.0922 acc=0.3988 | val_loss=1.0585 acc=0.5140 | time=16.3s
Epoch 009 | train_loss=1.0558 acc=0.4676 | val_loss=1.0001 acc=0.6180 | time=16.8s
Epoch 010 | train_loss=1.0307 acc=0.5336 | val_loss=0.9371 acc=0.6382 | time=17.4s
Epoch 011 | train_loss=0.9688 acc=0.5751 | val_loss=0.9394 acc=0.6242 | time=17.0s
Epoch 012 | train_loss=0.9511 acc=0.5946 | val_loss=0.9211 acc=0.6475 | time=16.8s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇▇▇▇▇
train_accuracy,▁▁▁▂▁▂▂▂▃▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████
train_loss,███▇▇▇▇▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▃▅▆▅▆▄▆▇▇▇▇▇██▆█▅▇█▅█▄▇█▆▇▇▇▆▆▆
validation_loss,███████▇▆▅▅▄▆▄▃▄▃▃▂▃▂▄▃▄▃▂▄▁▄▃▂▄▄▂▃▅▄▄

0,1
best_train_accuracy,0.70524
best_train_loss,0.76497
best_val_accuracy,0.72205
best_val_loss,0.77376
epoch,38.0
lr,0.00088
train_accuracy,0.76427
train_loss,0.70263
validation_accuracy,0.65683
validation_loss,0.89702



===== Fold 4 =====


Epoch 001 | train_loss=1.1785 acc=0.3647 | val_loss=1.0804 acc=0.4534 | time=16.6s
Epoch 002 | train_loss=1.1375 acc=0.3814 | val_loss=1.0838 acc=0.4534 | time=16.5s
Epoch 003 | train_loss=1.1326 acc=0.3860 | val_loss=1.0833 acc=0.4534 | time=16.6s
Epoch 004 | train_loss=1.1189 acc=0.3926 | val_loss=1.0808 acc=0.4534 | time=16.5s
Epoch 005 | train_loss=1.1223 acc=0.3790 | val_loss=1.0792 acc=0.4534 | time=16.5s
Epoch 006 | train_loss=1.1168 acc=0.3825 | val_loss=1.0799 acc=0.4534 | time=16.7s
Epoch 007 | train_loss=1.1067 acc=0.3899 | val_loss=1.0800 acc=0.4534 | time=16.5s
Epoch 008 | train_loss=1.0985 acc=0.4043 | val_loss=1.0756 acc=0.4565 | time=16.5s
Epoch 009 | train_loss=1.0846 acc=0.4202 | val_loss=1.0526 acc=0.5109 | time=16.9s
Epoch 010 | train_loss=1.0485 acc=0.4792 | val_loss=0.9925 acc=0.6320 | time=16.6s
Epoch 011 | train_loss=1.0024 acc=0.5480 | val_loss=0.9689 acc=0.5854 | time=16.2s
Epoch 012 | train_loss=0.9686 acc=0.5763 | val_loss=0.9151 acc=0.6522 | time=16.3s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇▇▇▇▇▇▇
train_accuracy,▁▁▁▁▁▁▁▂▂▃▄▅▅▅▅▅▅▆▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇█▇██
train_loss,█▇▇▇▇▇▇▇▇▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▂▅▄▆▆▆▆▆▇▆▇▇▄▇▆█▇█▇▇█▆█▇▇▇▆█▇▅▇▇
validation_loss,████████▇▆▆▄▄▄▄▄▃▄▃▂▅▃▃▁▂▁▃▁▁▃▁▂▃▂▂▃▄▄▂▃

0,1
best_train_accuracy,0.70369
best_train_loss,0.77415
best_val_accuracy,0.75
best_val_loss,0.74715
epoch,41.0
lr,0.00084
train_accuracy,0.77476
train_loss,0.69183
validation_accuracy,0.71118
validation_loss,0.85523



===== Fold 5 =====


Epoch 001 | train_loss=1.2039 acc=0.3571 | val_loss=1.1145 acc=0.3359 | time=16.8s
Epoch 002 | train_loss=1.1825 acc=0.3641 | val_loss=1.1155 acc=0.3359 | time=16.5s
Epoch 003 | train_loss=1.1675 acc=0.3769 | val_loss=1.1068 acc=0.3359 | time=17.1s
Epoch 004 | train_loss=1.1506 acc=0.3676 | val_loss=1.0913 acc=0.4355 | time=16.6s
Epoch 005 | train_loss=1.1326 acc=0.3824 | val_loss=1.0854 acc=0.4355 | time=16.7s
Epoch 006 | train_loss=1.1221 acc=0.3750 | val_loss=1.0863 acc=0.4355 | time=16.7s
Epoch 007 | train_loss=1.1107 acc=0.3859 | val_loss=1.0859 acc=0.4635 | time=16.5s
Epoch 008 | train_loss=1.0883 acc=0.4173 | val_loss=1.0353 acc=0.5739 | time=16.5s
Epoch 009 | train_loss=1.0361 acc=0.4973 | val_loss=0.9642 acc=0.6112 | time=16.5s
Epoch 010 | train_loss=1.0040 acc=0.5454 | val_loss=0.9724 acc=0.5910 | time=16.7s
Epoch 011 | train_loss=0.9771 acc=0.5675 | val_loss=0.9276 acc=0.6423 | time=16.5s
Epoch 012 | train_loss=0.9606 acc=0.5776 | val_loss=0.9373 acc=0.6065 | time=16.3s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇
train_accuracy,▁▁▁▁▁▁▂▂▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██▇███
train_loss,██▇▇▇▇▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▁▁▁▃▃▃▃▅▆▆▆▆▇▆▄▇█▇▇▆▇▇▇██▅▇▇▇▃█▇▇▇
validation_loss,▇▇▇▇▇▇▇▆▄▄▃▄▃▄▆▂▂▂▁▂▁▂▂▁▁▄▂▁▁█▂▃▂▂

0,1
best_train_accuracy,0.6778
best_train_loss,0.81969
best_val_accuracy,0.73095
best_val_loss,0.80494
epoch,34.0
lr,0.00093
train_accuracy,0.73719
train_loss,0.73504
validation_accuracy,0.67341
validation_loss,0.85218




0,1
avg_train_accuracy,▁
avg_train_loss,▁
avg_val_accuracy,▁
avg_val_loss,▁

0,1
avg_train_accuracy,0.68888
avg_train_loss,0.79568
avg_val_accuracy,0.72538
avg_val_loss,0.78271


### Block = 1, Head = 3
- Regularization to be strong
- Plateu detect ->  Weight Decay 1.5 times strong

In [None]:
import os
import json
import time
import gc

import torch
import torch.nn as nn
import wandb
from wandb import Settings
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

from eeg_dataset import EEGDataset
from model_optimized_5 import EEGformer

# ─── Fixed Hyperparameters ─────────────────────────────────────
LR = 1e-3
WEIGHT_DECAY = 5e-2
NUM_FILTERS   = 60
NUM_BLOCKS    = 1
NUM_HEADS     = 3
NUM_SEGMENTS  = 5

# ─── Training configuration ───────────────────────────────────
MAX_EPOCHS   = 100
BATCH_SIZE   = 32
ES_PATIENCE  = 15        # ← EarlyStopping 전용 patience
WD_PATIENCE  = 5         # ← WD 강화 전용 patience
MAX_WD       = 0.5       # ← WD 상한값
NUM_WORKERS  = max(1, min(4, os.cpu_count() - 1))

# ─── Data paths & device ──────────────────────────────────────
DATA_DIR   = '/content/drive/MyDrive/2025_Lab_Research/model-data'
LABEL_FILE = "labels.json"
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_and_evaluate():
    # 1) 데이터 로드
    with open(os.path.join(DATA_DIR, LABEL_FILE), "r") as f:
        all_meta = json.load(f)
    train_meta = [d for d in all_meta if d["type"] == "train"]
    full_ds     = EEGDataset(DATA_DIR, train_meta)
    labels      = [d["label"] for d in train_meta]
    input_length= full_ds[0][0].shape[-1]

    # 2) 5-Fold split
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(full_ds, labels), 1):
        print(f"\n===== Fold {fold} =====")
        wandb.init(
            project="eeg-5fold-cv-14",
            name=f"fold_{fold}",
            config={
                "lr": LR,
                "weight_decay": WEIGHT_DECAY,
                "num_blocks": NUM_BLOCKS,
                "num_heads": NUM_HEADS,
                "num_segments": NUM_SEGMENTS,
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS,
                "es_patience": ES_PATIENCE,
                "wd_patience": WD_PATIENCE,
                "max_wd": MAX_WD
            },
            settings=Settings(init_timeout=120)
        )

        train_loader = DataLoader(
            Subset(full_ds, train_idx),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
        )
        val_loader = DataLoader(
            Subset(full_ds, val_idx),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )

        # 3) 모델 생성
        model = EEGformer(
            in_channels=19,
            input_length=input_length,
            kernel_size=10,
            num_filters=NUM_FILTERS,
            num_heads=NUM_HEADS,
            num_blocks=NUM_BLOCKS,
            num_segments=NUM_SEGMENTS,
            num_classes=3
        ).to(DEVICE)

        # ← MOD: base_lr/base_wd 저장 (base_wd를 plateau 때 갱신)
        base_lr = LR
        base_wd = WEIGHT_DECAY

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=base_lr,
            weight_decay=base_wd
        )

        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        total_steps = MAX_EPOCHS * len(train_loader)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=base_lr,
            total_steps=total_steps,
            pct_start=0.2,
            anneal_strategy='cos',
            div_factor=10,
            final_div_factor=100
        )

        # ← MOD: EarlyStopping & Plateau counters
        best_train_loss = best_train_acc = None
        best_val_loss = float("inf")
        best_val_acc = 0
        es_count = 0       # EarlyStopping counter
        wd_count = 0       # WD-plateau counter

        # 4) Epoch 루프
        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()

            # — train —
            model.train()
            tloss = tcorrect = ttotal = 0
            for X, y in train_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                optimizer.zero_grad()
                logits = model(X)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                # ← MOD: OneCycleLR 배치 단위 step
                scheduler.step()

                # ← MOD: lr 변화에 맞춰 weight_decay 재계산
                cur_lr = optimizer.param_groups[0]['lr']
                new_wd = base_wd * (cur_lr / base_lr)
                for g in optimizer.param_groups:
                    g['weight_decay'] = new_wd

                tloss    += loss.item()
                tcorrect += (logits.argmax(1) == y).sum().item()
                ttotal   += y.size(0)

            train_loss = tloss / len(train_loader)
            train_acc  = tcorrect / ttotal

            # — validate —
            model.eval()
            vloss = vcorrect = vtotal = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(DEVICE), y.to(DEVICE)
                    logits = model(X)
                    loss   = criterion(logits, y)
                    vloss  += loss.item()
                    vcorrect += (logits.argmax(1) == y).sum().item()
                    vtotal   += y.size(0)
            val_loss = vloss / len(val_loader)
            val_acc  = vcorrect / vtotal
            elapsed  = time.time() - t0

            print(
                f"Epoch {epoch:03d} | "
                f"train_loss={train_loss:.4f} acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} acc={val_acc:.4f} | "
                f"time={elapsed:.1f}s"
            )

            # ← MOD: weight_decay & lr도 함께 로깅
            wandb.log({
                "epoch":               epoch,
                "train_loss":          train_loss,
                "train_accuracy":      train_acc,
                "validation_loss":     val_loss,
                "validation_accuracy": val_acc,
                "lr":                  cur_lr,
                "weight_decay":        new_wd
            }, step=epoch)

            # ← MOD: EarlyStopping & WD 스케줄링 분리
            if val_loss < best_val_loss:
                best_val_loss   = val_loss
                best_train_loss = train_loss
                best_train_acc  = train_acc
                best_val_acc    = val_acc
                es_count = wd_count = 0
            else:
                es_count += 1
                wd_count += 1

                # WD_PATIENCE마다 base_wd 강화 + 즉시 로그
                if wd_count >= WD_PATIENCE:
                    base_wd = min(base_wd * 1.5, MAX_WD)
                    for g in optimizer.param_groups:
                        # 재계산 후 적용
                        g['weight_decay'] = base_wd * (cur_lr / base_lr)
                    # ← MOD: 강화된 값 즉시 로깅
                    wandb.log({
                        "epoch":          epoch,
                        "weight_decay":   optimizer.param_groups[0]['weight_decay'],
                        "note":           "WD increased"
                    }, step=epoch)
                    print(f"  → base_wd increased to {base_wd:.5f} at epoch {epoch}")
                    wd_count = 0

                # ES_PATIENCE마다 학습 중단
                if es_count >= ES_PATIENCE:
                    print(f"Early stopping at epoch {epoch}")
                    break

        # Fold 결과 기록
        print(
            f"Fold {fold} best_train_loss={best_train_loss:.4f}, "
            f"best_train_acc={best_train_acc:.4f}, "
            f"best_val_loss={best_val_loss:.4f}, "
            f"best_val_acc={best_val_acc:.4f}"
        )
        wandb.summary["best_train_loss"]     = best_train_loss
        wandb.summary["best_train_accuracy"] = best_train_acc
        wandb.summary["best_val_loss"]       = best_val_loss
        wandb.summary["best_val_accuracy"]   = best_val_acc

        fold_results.append({
            "train_loss": best_train_loss,
            "train_acc":  best_train_acc,
            "val_loss":   best_val_loss,
            "val_acc":    best_val_acc
        })

        wandb.finish()
        torch.cuda.empty_cache()
        gc.collect()

    # ─── 5-Fold Average Metrics 기록 ─────────────────────────────
    avg = {k: sum(res[k] for res in fold_results) / len(fold_results)
           for k in fold_results[0]}

    wandb.init(
        project="eeg-5fold-cv-14",
        name="fold_average",
        reinit=True,
        config={
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "num_blocks": NUM_BLOCKS,
            "num_heads": NUM_HEADS,
            "num_segments": NUM_SEGMENTS,
            "batch_size": BATCH_SIZE,
            "max_epochs": MAX_EPOCHS
        },
        settings=Settings(init_timeout=120)
    )
    wandb.log({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.summary.update({
        "avg_train_loss":      avg["train_loss"],
        "avg_train_accuracy":  avg["train_acc"],
        "avg_val_loss":        avg["val_loss"],
        "avg_val_accuracy":    avg["val_acc"],
    })
    wandb.finish()

if __name__ == "__main__":
    train_and_evaluate()


Now using CUDA device 0
Enabling CUDA with 39.14 GiB available memory

===== Fold 1 =====


Epoch 001 | train_loss=1.1593 acc=0.3417 | val_loss=1.0958 acc=0.3587 | time=16.5s
Epoch 002 | train_loss=1.1366 acc=0.3522 | val_loss=1.0967 acc=0.3587 | time=16.7s
Epoch 003 | train_loss=1.1226 acc=0.3666 | val_loss=1.0929 acc=0.3587 | time=16.5s
Epoch 004 | train_loss=1.1099 acc=0.3798 | val_loss=1.0907 acc=0.3618 | time=16.7s
Epoch 005 | train_loss=1.1122 acc=0.3926 | val_loss=1.0865 acc=0.4379 | time=16.6s
Epoch 006 | train_loss=1.0964 acc=0.3984 | val_loss=1.0780 acc=0.5404 | time=16.6s
Epoch 007 | train_loss=1.0801 acc=0.4404 | val_loss=1.0419 acc=0.5683 | time=16.5s
Epoch 008 | train_loss=1.0426 acc=0.4936 | val_loss=0.9907 acc=0.6056 | time=16.5s
Epoch 009 | train_loss=1.0044 acc=0.5235 | val_loss=0.9840 acc=0.5792 | time=16.4s
Epoch 010 | train_loss=0.9868 acc=0.5495 | val_loss=0.9637 acc=0.6351 | time=16.5s
Epoch 011 | train_loss=0.9637 acc=0.5790 | val_loss=0.9285 acc=0.6382 | time=16.2s
Epoch 012 | train_loss=0.9385 acc=0.5903 | val_loss=0.9332 acc=0.6180 | time=16.7s
Epoc

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▅▆▆▇▇▇███████████████▇▇▇▇▇▇▇▇▇▆
train_accuracy,▁▁▁▂▂▂▃▃▄▅▅▅▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇█████
train_loss,██▇▇▇▇▇▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▁▁▁▁▃▄▅▆▆▆▆▆▆▆▆▆▅▆▇▇▆█▆█▇██▇▆▇▇▇▇▇▇▇█▇█▇
validation_loss,██████▇▆▅▄▅▃▄▄▄▃▇▃▂▂▃▂▃▂▂▁▁▁▃▂▂▃▂▂▂▃▃▁▂▃
weight_decay,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▅▄▄▄▆▆▆▆█

0,1
best_train_accuracy,0.70136
best_train_loss,0.7676
best_val_accuracy,0.72516
best_val_loss,0.7683
epoch,45
lr,0.00078
note,WD increased
train_accuracy,0.76272
train_loss,0.69276
validation_accuracy,0.6677



===== Fold 2 =====


Epoch 001 | train_loss=1.1613 acc=0.3670 | val_loss=1.0812 acc=0.4255 | time=16.3s
Epoch 002 | train_loss=1.1437 acc=0.3786 | val_loss=1.0812 acc=0.4255 | time=16.3s
Epoch 003 | train_loss=1.1447 acc=0.3701 | val_loss=1.0827 acc=0.4255 | time=16.1s
Epoch 004 | train_loss=1.1316 acc=0.3643 | val_loss=1.0820 acc=0.4255 | time=16.3s
Epoch 005 | train_loss=1.1254 acc=0.3810 | val_loss=1.0827 acc=0.4488 | time=16.3s
Epoch 006 | train_loss=1.1112 acc=0.3895 | val_loss=1.0891 acc=0.3665 | time=16.4s
Epoch 007 | train_loss=1.1041 acc=0.3872 | val_loss=1.0811 acc=0.4255 | time=16.3s
Epoch 008 | train_loss=1.0987 acc=0.3950 | val_loss=1.0791 acc=0.4255 | time=16.4s
Epoch 009 | train_loss=1.0872 acc=0.4058 | val_loss=1.0325 acc=0.5839 | time=16.2s
Epoch 010 | train_loss=1.0393 acc=0.5091 | val_loss=0.9785 acc=0.6351 | time=16.3s
Epoch 011 | train_loss=0.9943 acc=0.5398 | val_loss=0.9624 acc=0.6398 | time=16.2s
Epoch 012 | train_loss=0.9700 acc=0.5588 | val_loss=0.8991 acc=0.6677 | time=16.1s
Epoc

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇
train_accuracy,▁▁▁▁▁▁▁▂▂▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████
train_loss,███▇▇▇▇▇▇▆▅▅▄▄▃▃▃▃▃▃▂▂▃▂▂▂▂▂▁▁▁▁▁▁
validation_accuracy,▂▂▂▂▃▁▂▂▅▆▇▇▇▆▇█▅▆█▆▇▅▇▇█▇██▇█▆█▇█
validation_loss,████████▆▅▅▃▂▄▄▂▄▃▁▃▂█▂▂▂▃▃▂▁▁▃▂▂▂
weight_decay,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▆▆▆▆▆█

0,1
best_train_accuracy,0.64078
best_train_loss,0.86618
best_val_accuracy,0.71273
best_val_loss,0.81566
epoch,34
lr,0.00093
note,WD increased
train_accuracy,0.71612
train_loss,0.77703
validation_accuracy,0.69255



===== Fold 3 =====


Epoch 001 | train_loss=1.2098 acc=0.3515 | val_loss=1.0817 acc=0.4410 | time=16.6s
Epoch 002 | train_loss=1.1582 acc=0.3759 | val_loss=1.0848 acc=0.4410 | time=16.5s
Epoch 003 | train_loss=1.1458 acc=0.3790 | val_loss=1.0814 acc=0.4410 | time=16.2s
Epoch 004 | train_loss=1.1402 acc=0.3783 | val_loss=1.0817 acc=0.4410 | time=16.1s
Epoch 005 | train_loss=1.1278 acc=0.3810 | val_loss=1.0843 acc=0.4410 | time=16.1s
Epoch 006 | train_loss=1.1295 acc=0.3926 | val_loss=1.0832 acc=0.4410 | time=16.3s
Epoch 007 | train_loss=1.1273 acc=0.3674 | val_loss=1.0816 acc=0.4410 | time=16.2s
Epoch 008 | train_loss=1.1111 acc=0.3891 | val_loss=1.0862 acc=0.4410 | time=16.1s
  → base_wd increased to 0.07500 at epoch 8
Epoch 009 | train_loss=1.1056 acc=0.3953 | val_loss=1.0808 acc=0.4410 | time=16.3s
Epoch 010 | train_loss=1.0979 acc=0.3876 | val_loss=1.0820 acc=0.4410 | time=16.3s
Epoch 011 | train_loss=1.0900 acc=0.3992 | val_loss=1.0801 acc=0.4410 | time=16.3s
Epoch 012 | train_loss=1.0820 acc=0.4272 | 

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
lr,▁▁▁▂▂▄▄▅▆▆▇██████████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▅
train_accuracy,▁▁▁▁▁▁▂▂▂▂▄▄▅▅▅▅▅▅▆▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████
train_loss,█▇▇▇▇▇▇▇▇▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▁▁▁▁▁▁▁▁▁▁▆▅▇▅▇▆▆▇▇▆██▇▅▇▄▇█▆▆█▆▇▇█▅▇▆█▆
validation_loss,█████████▇▅▅▅▅▄▃▃▂▄▂▂▂▃▂▁▃▃▁▂▃▃▃▃▃▃▄▃▆▃▆
weight_decay,▁▁▁▁▁▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▅▅▅▄▆▆▆█

0,1
best_train_accuracy,0.72621
best_train_loss,0.74249
best_val_accuracy,0.67702
best_val_loss,0.81188
epoch,51
lr,0.00067
note,WD increased
train_accuracy,0.79728
train_loss,0.64661
validation_accuracy,0.62888



===== Fold 4 =====


Epoch 001 | train_loss=1.2026 acc=0.3452 | val_loss=1.1071 acc=0.3168 | time=16.3s
Epoch 002 | train_loss=1.1695 acc=0.3654 | val_loss=1.1015 acc=0.3168 | time=16.3s
Epoch 003 | train_loss=1.1595 acc=0.3810 | val_loss=1.1021 acc=0.3168 | time=16.6s
Epoch 004 | train_loss=1.1538 acc=0.3748 | val_loss=1.0993 acc=0.3168 | time=16.3s
Epoch 005 | train_loss=1.1316 acc=0.3891 | val_loss=1.1045 acc=0.3168 | time=16.4s
Epoch 006 | train_loss=1.1226 acc=0.3849 | val_loss=1.0859 acc=0.4099 | time=16.6s
Epoch 007 | train_loss=1.0872 acc=0.4280 | val_loss=1.0110 acc=0.5994 | time=16.5s
Epoch 008 | train_loss=1.0375 acc=0.5002 | val_loss=1.0113 acc=0.5450 | time=16.5s
Epoch 009 | train_loss=0.9960 acc=0.5468 | val_loss=0.9305 acc=0.6351 | time=16.5s
Epoch 010 | train_loss=0.9726 acc=0.5786 | val_loss=0.9294 acc=0.6351 | time=16.5s
Epoch 011 | train_loss=0.9562 acc=0.5833 | val_loss=0.9197 acc=0.6351 | time=16.3s
Epoch 012 | train_loss=0.9348 acc=0.5965 | val_loss=0.9058 acc=0.6506 | time=16.5s
Epoc

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
lr,▁▁▁▂▂▂▃▄▄▅▆▆▇▇██████████████▇▇▇▇▇▇▇▇▇▆▆▆
train_accuracy,▁▁▂▁▂▂▃▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████
train_loss,██▇▇▇▇▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▁▁▁
validation_accuracy,▁▁▁▁▁▃▅▇▇▇▇▇▆▆▇▇▆▆▇▇▇▇▇█▇▇▇██▆▇█▇▇▆▇▆█▇▇
validation_loss,██████▆▄▄▄▄▃▄▄▃▂▄▃▃▂▃▂▃▁▂▂▂▁▂▂▂▂▂▂▅▄▅▃▅▅
weight_decay,▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▃▃▃▅▅▅▄▆▆▆▆█

0,1
best_train_accuracy,0.71068
best_train_loss,0.76009
best_val_accuracy,0.7205
best_val_loss,0.78544
epoch,47
lr,0.00074
note,WD increased
train_accuracy,0.79806
train_loss,0.65527
validation_accuracy,0.64907



===== Fold 5 =====


Epoch 001 | train_loss=1.1645 acc=0.3381 | val_loss=1.0874 acc=0.4355 | time=16.1s
Epoch 002 | train_loss=1.1551 acc=0.3606 | val_loss=1.0913 acc=0.4355 | time=16.4s
Epoch 003 | train_loss=1.1567 acc=0.3424 | val_loss=1.1020 acc=0.4246 | time=16.4s
Epoch 004 | train_loss=1.1367 acc=0.3637 | val_loss=1.0933 acc=0.4355 | time=16.4s
Epoch 005 | train_loss=1.1163 acc=0.3715 | val_loss=1.0977 acc=0.4355 | time=16.4s
Epoch 006 | train_loss=1.1208 acc=0.3800 | val_loss=1.0879 acc=0.4355 | time=16.5s
  → base_wd increased to 0.07500 at epoch 6
Epoch 007 | train_loss=1.1093 acc=0.3855 | val_loss=1.0913 acc=0.4355 | time=16.3s
Epoch 008 | train_loss=1.1028 acc=0.3839 | val_loss=1.0902 acc=0.4355 | time=16.4s
Epoch 009 | train_loss=1.1008 acc=0.3781 | val_loss=1.0854 acc=0.4355 | time=16.5s
Epoch 010 | train_loss=1.0947 acc=0.3956 | val_loss=1.0874 acc=0.4355 | time=16.4s
Epoch 011 | train_loss=1.0775 acc=0.4410 | val_loss=1.0764 acc=0.5008 | time=16.1s
Epoch 012 | train_loss=1.0591 acc=0.4647 | 

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
lr,▁▁▁▂▂▂▃▃▄▄▅▆▆▇▇▇█████████████████▇▇▇▇▇▇▇
train_accuracy,▁▁▁▁▂▂▂▂▂▂▃▃▄▅▅▅▆▆▆▆▇▆▇▆▇▇▇▇▇▇▇▇████████
train_loss,████▇▇▇▇▇▇▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
validation_accuracy,▂▂▂▂▂▂▂▂▂▂▃▁▅▆▆▅▆▅▇▇▅▇▇▇▇█▇▇█▇▆█▇▅▇▇▇▇█▇
validation_loss,██████████▇█▆▅▅▅▃▄▃▂▅▃▃▃▂▁▂▂▂▃▅▁▂▄▂▂▃▁▂▂
weight_decay,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▆▆▆▆█

0,1
best_train_accuracy,0.67352
best_train_loss,0.82884
best_val_accuracy,0.72628
best_val_loss,0.76866
epoch,41
lr,0.00084
note,WD increased
train_accuracy,0.73525
train_loss,0.73335
validation_accuracy,0.68274




0,1
avg_train_accuracy,▁
avg_train_loss,▁
avg_val_accuracy,▁
avg_val_loss,▁

0,1
avg_train_accuracy,0.69051
avg_train_loss,0.79304
avg_val_accuracy,0.71234
avg_val_loss,0.78999
