In [9]:
import sys

In [10]:
# get train and validation data
from pathlib import Path
import importlib

DATA_CLEANER_DIR = Path("../1-data-cleaning-pipelines/motor-imaginary")

def get_motor_imaginary_train_and_validation_data():
    target_dir = DATA_CLEANER_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import centralize_train_data_cleaner 
    importlib.reload(centralize_train_data_cleaner)

    train_features, train_labels, val_features, val_labels = centralize_train_data_cleaner.pipeline()

    return train_features, train_labels, val_features, val_labels

train_features, train_labels, val_features, val_labels = get_motor_imaginary_train_and_validation_data();
train_features.shape, train_labels.shape, val_features.shape, val_labels.shape

[Fetcher] Starting data preparation...
[Download] Skip: BCICIV_2a_gdf.zip already exists
[Extract] Skip: already extracted at /home/kanathipp/Stuffs/Works/final-project-federated-learning/0-raw-data/motor-imaginary/data
[Fetcher] Completed.


(torch.Size([2928, 22, 176]),
 torch.Size([2928]),
 torch.Size([880, 22, 176]),
 torch.Size([880]))

In [11]:
# get test data
from pathlib import Path
import importlib

DATA_CLEANER_DIR = Path("../1-data-cleaning-pipelines/motor-imaginary")

def get_motor_imaginary_test_data():
    target_dir = DATA_CLEANER_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import centralize_test_data_cleaner 
    importlib.reload(centralize_test_data_cleaner)

    test_features, test_labels = centralize_test_data_cleaner.pipeline()

    return test_features, test_labels

test_features, test_labels = get_motor_imaginary_test_data();
test_features.shape, test_labels.shape

[Fetcher] Starting data preparation...
[Download] Skip: BCICIV_2a_gdf.zip already exists
[Extract] Skip: already extracted at /home/kanathipp/Stuffs/Works/final-project-federated-learning/0-raw-data/motor-imaginary/data
[Fetcher] Completed.


(torch.Size([5256, 22, 176]), torch.Size([5256]))

In [12]:
from pathlib import Path
import importlib
MODEL_DIR = Path("../2-models")

def get_eegnet_pytorch_model():
    target_dir = MODEL_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import eegnet_pytorch 
    importlib.reload(eegnet_pytorch)

    return eegnet_pytorch.EEGNet

EEGNet = get_eegnet_pytorch_model()

In [13]:
# Pytorch Lightning Model 

from pytorch_lightning import LightningModule, Trainer
import torchmetrics
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn as nn

class EEGNetLightningModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = EEGNet(in_channel=22)
        self.learning_rate = 1e-3
        self.batch_size = 12
        self.num_workers = 2

        # metrics
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=4)
        self.val_acc   = torchmetrics.Accuracy(task="multiclass", num_classes=4)
        self.test_acc  = torchmetrics.Accuracy(task="multiclass", num_classes=4)   # <--- NEW
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    # ---------- dataloaders ----------
    def train_dataloader(self):
        train_dataset = TensorDataset(train_features, train_labels)
        return DataLoader(train_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=True)

    def val_dataloader(self):
        val_dataset = TensorDataset(val_features, val_labels)
        return DataLoader(val_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    def test_dataloader(self):                                                # <--- NEW
        test_dataset = TensorDataset(test_features, test_labels)
        return DataLoader(test_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    # ---------- steps ----------
    def training_step(self, batch, batch_idx):
        signal, label = batch
        logits = self(signal.float())          # [B, 4]
        target = label.long().view(-1)         # [B]

        loss = self.criterion(logits, target)
        self.train_acc.update(logits, target)

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc',  self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        logits = self(signal.float())
        target = label.long().view(-1)

        loss = self.criterion(logits, target)
        self.val_acc.update(logits, target)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc',  self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):                                   # <--- NEW
        signal, label = batch
        logits = self(signal.float())
        target = label.long().view(-1)

        loss = self.criterion(logits, target)
        self.test_acc.update(logits, target)

        # Lightning will aggregate over test epoch
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc',  self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

    # ---------- epoch-end hooks ----------
    def on_train_epoch_end(self):
        self.train_acc.reset()

    def on_validation_epoch_end(self):
        self.val_acc.reset()

    def on_test_epoch_end(self):
        self.test_acc.reset()

pytorch_lightning_model = EEGNetLightningModel()

In [None]:
trainer = Trainer(max_epochs=40)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(pytorch_lightning_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | EEGNet             | 1.5 K  | train
1 | train_acc | MulticlassAccuracy | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
4 | criterion | CrossEntropyLoss   | 0      | train
---------------------------------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [16]:
trainer.test(pytorch_lightning_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 1.136815071105957, 'test_acc': 0.5232115387916565}]