In [31]:
%load_ext autoreload
%autoreload 2
import os
if not os.path.exists('./modules') and not os.path.exists('modules.zip'):
    from google.colab import files
    uploaded = files.upload()
if not os.path.exists('./modules') and os.path.exists('modules.zip'):
    os.system('unzip modules.zip -d .')

import kagglehub
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from torch.utils.data import TensorDataset, DataLoader
import optuna
from modules import EEGDataset
from modules.utils import split_and_get_loaders, evaluate_model, manual_write_study_params

torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cpu')

In [None]:
#! need to modify those for the competition itself
TRIAL_LENGTH = 640  # frequency of changing.. frequency
# Download dataset
# path_1 = kagglehub.dataset_download("xuannguyenuet2004/12-class-ssvep-eeg-data") proofed to be bad
path_1 = kagglehub.dataset_download("girgismicheal/steadystate-visual-evoked-potential-signals")
path_1 += "/SSVEP (BrainWheel)"
print("Download datasetaset files:", "\n", path_1)

In [None]:
class SSVEPClassifier(nn.Module):
    def __init__(self, input_size: int, out_size: int, hidden_size: int, num_layers: int, dropout: float, bidirectional: bool, device=None):
        super().__init__()
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.device = device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dir_mult = 2 if bidirectional else 1

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout, bidirectional=bidirectional, device=self.device, batch_first=True)
        self.fc_out = nn.Sequential(
            nn.Linear(hidden_size * self.dir_mult, out_size),
        )

    def forward(self, x: torch.Tensor):
        h0 = torch.zeros([self.num_layers * self.dir_mult, x.shape[0], self.hidden_size], device=self.device)
        c0 = torch.zeros([self.num_layers * self.dir_mult, x.shape[0], self.hidden_size], device=self.device)

        out, (hn, cn) = self.lstm(x, (h0, c0))  # out shape [B x window_length x out_shape]
        return self.fc_out(out[:, -1])

In [36]:
class Trainer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.train_epochs = 1000
        self.tune_epochs = 25
        self.optuna_n_trials = 120

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = None
        self.trial = None

        self.train_loader = None
        self.eval_loader = None
        self.test_loader = None
        self.dataset = None

        self.storage = "sqlite:///optuna_studies.db"
        self.study_name = "ssvep_classifier_optimization"
        
        self.checkpoint_path = "./checkpoints/ssvep"
        os.makedirs(os.path.join(self.checkpoint_path, "models"), exist_ok=True)
        self.checkpoint_model_path = os.path.join(self.checkpoint_path, "models")

    def _train_loop(self, n_epochs: int, should_save=False, should_print=False):
        assert isinstance(self.optimizer, torch.optim.Optimizer), "optimizer is not a valid optimizer"
        assert isinstance(self.train_loader, DataLoader), "train_laoder is not valid Datloader"
        if self.trial is None:
            print("Warning: self.trial is none, we are probably in acutal training phase")

        for epoch in range(n_epochs):
            self.model.to(self.device)
            self.model.train()

            avg_loss = 0
            for x, y in self.train_loader:
                x = x.to(self.device)
                y = y.to(self.device)

                y_pred = self.model(x)  # B x out_size
                loss = self.criterion(y_pred, y)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                avg_loss += loss.item()

            avg_loss = avg_loss / len(self.train_loader)
            evaluation = evaluate_model(self.model, self.val_loader, self.device)
            
            if self.trial is not None:
                self.trial.report(evaluation, epoch)
                if self.trial.should_prune():
                    optuna.exceptions.TrialPruned()

            if should_print:
                print(f"epoch {epoch}, evaluation {evaluation}, avg_loss {avg_loss}")
                
            if should_save:
                self.model.cpu()
                torch.save(self.model.state_dict(), os.path.join(self.checkpoint_model_path, f"ssvep.pth"))
                self.model.to(self.device)
                
    
    def _prepare_training(self, is_trial):
        if is_trial:
            assert isinstance(self.trial, optuna.Trial), "trial is none, cant' suggest params"
            
            window_length = self.trial.suggest_categorical("window_length", [128, 160])
            stride_factor = self.trial.suggest_int("stride", 2, 3)

            hidden_size = self.trial.suggest_int("hidden_size", 64, 192, step=32)
            num_layers = self.trial.suggest_int("num_layers", 1, 3)
            dropout = self.trial.suggest_float("dropout", 0, 0.4)
            lr = self.trial.suggest_float("lr", 3e-4, 3e-2, log=True)
            batch_size = self.trial.suggest_categorical("batch_size", [32, 64])
            
        else:
            best_params = self._get_study().best_params
            
            window_length = best_params['window_length']
            stride_factor = best_params['stride']
            num_layers = best_params["num_layers"]
            dropout = best_params["dropout"]
            lr = best_params["lr"]
            batch_size = best_params["batch_size"]
            hidden_size = best_params["hidden_size"]
                
        stride = int(window_length // stride_factor)
        self.dataset = EEGDataset(path_1, TRIAL_LENGTH, window_length, stride=stride)
        unique_freqs = torch.unique(self.dataset.labels)

        input_size = self.dataset.data[0].shape[1]  # data[0] shape: CxT
        out_size = len(unique_freqs)

        self.model = SSVEPClassifier(input_size, out_size, hidden_size, num_layers, dropout, bidirectional=True)
        self.train_loader, self.val_loader, self.test_loader = split_and_get_loaders(self.dataset, batch_size)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
    
    def _objective(self, trial: optuna.Trial):
        self.trial = trial
        self._prepare_training(True)
        
        self._train_loop(self.tune_epochs, should_save=False, should_print=False)
        evaluation = evaluate_model(self.model, self.val_loader, self.device)
        return evaluation

    def _get_study(self):
        return optuna.create_study(study_name=self.study_name, storage=self.storage, direction="maximize", load_if_exists=True)
        
    def optimize(self, delete_existing=False):
        if delete_existing:
            try:
                optuna.delete_study(study_name=self.study_name, storage=self.storage)
            except Exception:
                pass

        study = self._get_study()
        study.optimize(self._objective, n_trials=self.optuna_n_trials, timeout=60 * 10)

        # Print optimization results
        print("\nStudy statistics:")
        print(f"  Number of finished trials: {len(study.trials)}")
        print(f"  Number of pruned trials: {len(study.get_trials(states=[optuna.trial.TrialState.PRUNED]))}")
        print(f"  Number of complete trials: {len(study.get_trials(states=[optuna.trial.TrialState.COMPLETE]))}")

        print("\nBest trial:")
        trial = study.best_trial
        print(f"  Value: {trial.value}")
        print("\nBest hyperparameters:")
        for key, value in trial.params.items():
            print(f"  {key}: {value}")

        return study.best_params
        
    def train(self):
        self.trial = None
        self._prepare_training(False)

        self._train_loop(self.train_epochs, should_save=True, should_print=True)
        evaluation = evaluate_model(self.model, self.val_loader, self.device)
        print("done training")
        return evaluation

trainer = Trainer()

In [None]:
delete_existing = True
trainer.optimize(delete_existing)

In [37]:
manual_write_study_params(trainer.study_name, trainer.storage)
trainer.train()

[I 2025-06-16 16:16:08,926] A new study created in RDB with name: ssvep_classifier_optimization
[I 2025-06-16 16:16:08,985] Using an existing study with name 'ssvep_classifier_optimization' instead of creating a new one.




KeyboardInterrupt: 