This repository is currently being overhauled. The older experiments can be found in `/deprecated`.

In [3]:
from glob import glob
from pathlib import Path

# Training
The **[new]** training pipeline is defined in this notebook.\
This is still a work-in-progress, the older CLI used to initiate the process can be found in `deprecated/4_experiments_jul-aug/train.py`.

In [2]:
from src import prepare
from src.models import RnnModule
from src.datamodules import MultiParticipantDataModule
from src.datasets import WindowedDataset

In [4]:
root_dir = './data/signal/'
participants = [Path(p).stem for p in glob(f'{root_dir}/*.csv')]

In [None]:
from sklearn.model_selection import train_test_split

train_participants, test_participants = train_test_split(participants, test_size=0.2)
train_participants, val_participants = train_test_split(train_participants, test_size=0.25)

In [None]:
# Model is a combination between a module that tracks metrics, and a model defined in src/models/*
model = prepare(RnnModule)

# Datamodule is a combination between a dataset that is capable of loading multiple participants, 
# and a dataset defined in src/datasets/*.
datamodule = MultiParticipantDataModule(
    f'{root_dir}', 
    train_participants, validation_participants, test_participants, 
    batch_size=64,
    dataset=WindowedDataset,
    standardize=True
)

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

callbacks = [
    EarlyStopping(monitor="val_loss", patience=15),
    ModelCheckpoint(save_top_k=1, monitor="val_BinaryAccuracy", mode="max", save_last=True)
]

In [None]:
import lightning as L

trainer = L.Trainer(
    max_epochs=100, 
    callbacks=callbacks,
    accelerator="auto", 
    devices="auto", 
    strategy="auto", 
    profiler="simple",
    default_root_dir=f"./checkpoints/{type(model).__name__}",
    logger=L.pytorch.loggers.WandbLogger(
        project="stress-in-action"
    ),
)

In [None]:
tuner = L.pytorch.tuner.Tuner(
    trainer
)

trainer.fit(
    model=model,
    datamodule=datamodule
)

trainer.test(
    ckpt_path="best"
)