In [1]:
import sys

In [2]:
# we will use motor imaginery data for testing
from pathlib import Path
import importlib

DATA_CLEANER_DIR = Path("../1-data-cleaning-pipelines/motor-imaginary")

def get_motor_imaginery_train_and_validation_data():
    target_dir = DATA_CLEANER_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import centralize_train_data_cleaner 
    importlib.reload(centralize_train_data_cleaner)

    train_labels, train_features, val_labels, val_features = centralize_train_data_cleaner.pipeline()

    return train_labels, train_features, val_labels, val_features

train_labels, train_features, val_labels, val_features = get_motor_imaginery_train_and_validation_data();
train_labels.shape, train_features.shape, val_labels.shape, val_features.shape

[Fetcher] Starting data preparation...
[Download] Skip: BCICIV_2a_gdf.zip already exists
[Extract] Skip: already extracted at /home/kanathipp/Stuffs/Works/final-project-federated-learning/0-raw-data/motor-imaginary/data
[Fetcher] Completed.


(torch.Size([2928]),
 torch.Size([2928, 22, 176]),
 torch.Size([880]),
 torch.Size([880, 22, 176]))

In [5]:
from pathlib import Path
import importlib
MODEL_DIR = Path("../2-models")

def get_eegnet_pytorch_model():
    target_dir = MODEL_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import eegnet_pytorch 
    importlib.reload(eegnet_pytorch)

    return eegnet_pytorch.EEGNet

EEGNet = get_eegnet_pytorch_model()

In [16]:
# Pytorch Lightning Model 

from pytorch_lightning import LightningModule, Trainer
import torchmetrics
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn as nn

class EEGNetLightningModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = EEGNet(in_channel=22)
        self.learning_rate = 1e-3
        self.batch_size = 12
        self.num_workers = 2

        # แยก metric train/val ชัดเจน
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=4)
        self.val_acc   = torchmetrics.Accuracy(task="multiclass", num_classes=4)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    # ---------- dataloaders ----------
    def train_dataloader(self):
        train_dataset = TensorDataset(train_features, train_labels)
        return DataLoader(train_dataset,
                          batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(TensorDataset(val_features, val_labels),
                          batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    # ---------- steps ----------
    def training_step(self, batch, batch_idx):
        signal, label = batch
        logits = self(signal.float())          # [B, 4]
        target = label.long().view(-1)         # [B]

        loss = self.criterion(logits, target)
        self.train_acc.update(logits, target)

        # ให้ Lightning รวมเฉลี่ยทั้ง epoch ให้เอง
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc',  self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        logits = self(signal.float())
        target = label.long().view(-1)

        loss = self.criterion(logits, target)
        self.val_acc.update(logits, target)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc',  self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    # ---------- epoch-end hooks (ไม่มี outputs) ----------
    def on_train_epoch_end(self):
        # ถ้าจะพิมพ์สรุปเอง
        train_loss = float(self.trainer.callback_metrics.get('train_loss', torch.tensor(0.0)))
        train_acc  = float(self.train_acc.compute())
        print('train acc loss', round(train_acc, 2), round(train_loss, 2))
        self.train_acc.reset()

    def on_validation_epoch_end(self):
        val_loss = float(self.trainer.callback_metrics.get('val_loss', torch.tensor(0.0)))
        val_acc  = float(self.val_acc.compute())
        print('val   acc loss', round(val_acc, 2), round(val_loss, 2))
        self.val_acc.reset()

pytorch_lightning_model = EEGNetLightningModel()

In [None]:
trainer = Trainer(max_epochs=20)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(pytorch_lightning_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | EEGNet             | 1.5 K  | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | criterion | CrossEntropyLoss   | 0      | train
---------------------------------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)
22        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.54 1.36


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.65 0.94
train acc loss 0.55 1.14


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.65 0.91
train acc loss 0.61 1.04


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.65 0.9
train acc loss 0.61 1.02


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.65 0.89
train acc loss 0.61 1.0


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.66 0.88
train acc loss 0.62 0.97


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.66 0.88
train acc loss 0.62 0.95


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.66 0.88
train acc loss 0.62 0.95


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.65 0.88
train acc loss 0.62 0.95


Validation: |          | 0/? [00:00<?, ?it/s]

val   acc loss 0.66 0.85
train acc loss 0.62 0.94


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


val   acc loss 0.67 0.85
train acc loss 0.62 0.93
