In [1]:
import warnings
warnings.filterwarnings("ignore")

import importlib
from pathlib import Path
from argparse import ArgumentParser
from glob import glob 

import torch
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import ModelCheckpoint

import lightning as L
from lightning import LightningModule

from datasets import Dataset, load_dataset

import torchmetrics
from sklearn.model_selection import train_test_split
from typing import Callable

import wandb

torch.set_float32_matmul_precision('medium')

def prepare_model(
    model: LightningModule, 
    data: str,
    batch_size: int, 
    learning_rate: float, 
    num_workers: int,
    ignore_torch_format: bool,
    train_participants: list[str],
    val_participants: list[str],
    test_participants: list[str],
    dataset: torch.utils.data.Dataset = None,
    dataset_kwargs: dict = None,
    dataset_preprocessor  = None
) -> LightningModule:
    class Model(model):
        def __init__(self, batch_size: int, learning_rate: float, num_workers: int):
            self.save_hyperparameters()
            super().__init__()

        def prepare_data(self):
            self.data = load_dataset(
                data,
                trust_remote_code=True,
                train_participants=train_participants, 
                val_participants=val_participants, 
                test_participants=test_participants,
                num_proc=self.hparams.num_workers if len(train_participants) > self.hparams.num_workers else len(train_participants)
            )

            if dataset_preprocessor is not None:
                self.data = dataset_preprocessor(self.data)

            if ignore_torch_format == False:
                self.data = self.data.with_format("torch")

        def setup(self, stage):
            print("SETUP", stage)
            if dataset is not None:
                self.dataset = dataset(self.data[stage], **dataset_kwargs)
            else:
                self.dataset = self.data[stage]

            self.train_accuracy = torchmetrics.classification.BinaryAccuracy()
            self.train_f1score = torchmetrics.classification.BinaryF1Score()
            self.train_precision = torchmetrics.classification.BinaryPrecision()
            # self.train_loss = torch.nn.BCELoss()

            self.validation_accuracy = torchmetrics.classification.BinaryAccuracy()
            self.validation_f1score = torchmetrics.classification.BinaryF1Score()
            self.validation_precision = torchmetrics.classification.BinaryPrecision()
            # self.validation_loss = torch.nn.BCELoss()

            self.test_accuracy = torchmetrics.classification.BinaryAccuracy()
            self.test_f1score = torchmetrics.classification.BinaryF1Score()
            self.test_precision = torchmetrics.classification.BinaryPrecision()
            # self.test_loss = torch.nn.BCELoss()
        
        def training_step(self, batch, batch_idx):
            print("TRAIN", batch, batch_idx)
            y, y_hat = self._step(batch, batch_idx)

            # step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # step_loss = self.train_loss(y_hat, y)
            step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # self.train_accuracy.update(torch.argmax(y_hat, 1), y)
            # self.train_f1score.update(torch.argmax(y_hat, 1), y)
            # self.train_precision.update(torch.argmax(y_hat, 1), y)

            self.train_accuracy.update(y_hat.squeeze(), y)
            self.train_f1score.update(y_hat.squeeze(), y)
            self.train_precision.update(y_hat.squeeze(), y)

            if wandb.run is not None:
                wandb.log({"accuracy": self.train_accuracy.compute(), "precision": self.train_precision.compute(), "loss": step_loss, "f1": self.train_f1score.compute()})
            # self.log_dict({"accuracy": self.train_accuracy.compute(), "precision": self.train_precision.compute(), "loss": step_loss, "f1": self.train_f1score.compute()})
# 
            return step_loss

        def validation_step(self, batch, batch_idx):
            print("VAL", batch, batch_idx)
            y, y_hat = self._step(batch, batch_idx)

            # step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # step_loss = self.validation_loss(y_hat, y)
            step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # self.validation_accuracy.update(torch.argmax(y_hat, 1), y)
            # self.validation_f1score.update(torch.argmax(y_hat, 1), y)
            # self.validation_precision.update(torch.argmax(y_hat, 1), y)

            self.validation_accuracy.update(y_hat.squeeze(), y)
            self.validation_f1score.update(y_hat.squeeze(), y)
            self.validation_precision.update(y_hat.squeeze(), y)

            if wandb.run is not None:
                wandb.log({"val_accuracy": self.validation_accuracy.compute(), "val_precision": self.validation_precision.compute(), "val_loss": step_loss, "val_f1": self.validation_f1score.compute()})
            self.log_dict({"val_accuracy": self.validation_accuracy.compute(), "val_precision": self.validation_precision.compute(), "val_loss": step_loss, "val_f1": self.validation_f1score.compute()})
            
            return step_loss
        
        def test_step(self, batch, batch_idx):
            print("TEST", batch, batch_idx)
            y, y_hat = self._step(batch, batch_idx)

            # step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # step_loss = self.test_loss(y_hat, y)
            step_loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
            # self.test_accuracy.update(torch.argmax(y_hat, 1), y)
            # self.test_f1score.update(torch.argmax(y_hat, 1), y)
            # self.test_precision.update(torch.argmax(y_hat, 1), y)
            self.test_accuracy.update(y_hat.squeeze(), y)
            self.test_f1score.update(y_hat.squeeze(), y)
            self.test_precision.update(y_hat.squeeze(), y)

            if wandb.run is not None:
                wandb.log({"test_accuracy": self.test_accuracy.compute(), "test_precision": self.test_precision.compute(), "test_loss": step_loss, "test_f1": self.test_f1score.compute()})
            self.log_dict({"test_accuracy": self.test_accuracy.compute(), "test_precision": self.test_precision.compute(), "test_loss": step_loss, "test_f1": self.test_f1score.compute()})
            
            return step_loss

        def _step(self, batch, batch_idx):
            print("1", batch, batch_idx)
            x, y = batch
            print("2", x, y)

            y_hat = self(x)
            
            return y, y_hat
        
        def configure_optimizers(self) -> L.pytorch.utilities.types.OptimizerLRScheduler:
            # return torch.optim.SGD(self.prepare_data(), lr=self.hparams.learning_rate, momentum=0.9)
            return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        
        def train_dataloader(self):
            return DataLoader(self.dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=True)

        def val_dataloader(self):
            return DataLoader(self.dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=True)
        
        def test_dataloader(self):
            return DataLoader(self.dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=True)

    return Model(
        batch_size=batch_size, 
        learning_rate=learning_rate, 
        num_workers=num_workers
    )

def train(name: str, model: LightningModule, epochs: int):
    # callbacks = [
    #     ModelCheckpoint(save_top_k=1, monitor="val_accuracy", mode="max", save_last=True)
    # ]

    trainer = L.Trainer(
        max_epochs=epochs, 
        # callbacks=callbacks,
        accelerator="auto", 
        devices="auto", 
        strategy="auto", 
        profiler="simple",
        default_root_dir=f"./checkpoints/{name}"
    )

    tuner = L.pytorch.tuner.Tuner(
        trainer
    )

    trainer.fit(
        model=model
    )

    return trainer

In [2]:
# model = 'sia.models.wickstrom_2020'
# dataset = 'sia.datasets.wickstrom_2020'
model = 'sia.models.time_series'
dataset = 'sia.datasets.stepping_dataset'

data_dir = './data/ecg_model'

In [3]:
participants = [Path(path).stem for path in glob(f'{data_dir}/*.csv')]
train_participants, test_participants = train_test_split(participants[:20], test_size=0.2, random_state=42)

In [4]:
model_name = model.split('.')[-1]
model_module = importlib.import_module(model)
dataset_module = importlib.import_module(dataset)

In [5]:
sampling_rate = 1000

In [6]:
import optuna
import pandas as pd
from tabulate import tabulate

In [7]:
d = load_dataset(
    data_dir,
    trust_remote_code=True,
    train_participants=train_participants[:10], 
    test_participants=test_participants[:10],
)

In [8]:
d = d.with_format("torch")

In [9]:
for x, y in dataset_module.Dataset(d['fit'], 60 * 1000):
    print(x, y)
    break

tensor([-1.9722e-05, -1.9795e-05, -1.9857e-05,  ..., -6.5197e-05,
        -6.4483e-05, -6.3775e-05]) tensor(0)


In [10]:
def encode(baseline = 0, mental_stress = -1, high_physical_activity = -1, moderate_physical_activity = -1, low_physical_activity = -1):
    def inner(labels):
        baseline_list = ['Sitting', 'Recov1', 'Recov2', 'Recov3', 'Recov4', 'Recov5', 'Recov6']
        mental_stress_list = ['TA', 'SSST_Sing_countdown', 'Pasat', 'Raven', 'TA_repeat', 'Pasat_repeat']
        high_physical_stress_list = ['Treadmill1', 'Treadmill2', 'Treadmill3', 'Treadmill4', 'Walking_fast_pace', 'Cycling', 'stairs_up_and_down']
        moderate_physical_stress_list = ['Walking_own_pace', 'Dishes', 'Vacuum']
        low_physical_stress_list = ['Standing', 'Lying_supine', 'Recov_standing']
        
        def encode_multiclass(label):
            if label in baseline_list:
                return baseline
            elif label in mental_stress_list:
                return mental_stress
            elif label in high_physical_stress_list:
                return high_physical_activity
            elif label in moderate_physical_stress_list:
                return moderate_physical_activity
            elif label in low_physical_stress_list:
                return low_physical_activity
            else:
                return -1
            
        return {
            'label': [encode_multiclass(label) for label in labels],
        }
    return inner

def clean(dataset, mapping={}):
    print("--- Cleaning ---")
    dataset = dataset.map(
        encode(**mapping), 
        batched=True, 
        batch_size=2048, 
        input_columns=['label'],
        num_proc=4
    )
    print("--- Filtering ---")
    return dataset.filter(
        lambda label: label != -1,
        input_columns=['label'],
    )

In [11]:
train_participants, val_participants = train_test_split(train_participants, test_size=0.2, random_state=42)

In [12]:
model = prepare_model(
    model=model_module.Model, # assuming all models are named Model.
    data=data_dir,
    dataset=dataset_module.Dataset,
    batch_size=100,
    learning_rate=0.01,
    num_workers=8,
    ignore_torch_format=False,
    train_participants=train_participants,
    val_participants=val_participants,
    test_participants=test_participants,
    dataset_kwargs={
        'window': 60 * 1000
    },
    dataset_preprocessor=lambda data: clean(data, { 'mental_stress': 1})
)
trainer = train(
    model_name,
    model=model,
    epochs=11
)

trainer.test(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


--- Cleaning ---
--- Filtering ---

SETUP TrainerFn.FITTING


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                 | Type            | Params
----------------------------------------------------------
0  | rnn                  | LSTM            | 14.4 M
1  | fc                   | Linear          | 61    
2  | dropout              | Dropout         | 0     
3  | batch_norm           | BatchNorm1d     | 120   
4  | train_accuracy       | BinaryAccuracy  | 0     
5  | train_f1score        | BinaryF1Score   | 0     
6  | train_precision      | BinaryPrecision | 0     
7  | validation_accuracy  | BinaryAccuracy  | 0     
8  | validation_f1score   | BinaryF1Score   | 0     
9  | validation_precision | BinaryPrecision | 0     
10 | test_accuracy        | BinaryAccuracy  | 0     
11 | test_f1score         | BinaryF1Score   | 0     
12 | test_precision       | BinaryPrecision | 0     
----------------------------------------------------------
14.4 M    Trainable params
0         Non-trainable params
14.4 M    Total params
57.660    Total 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]