In [1]:
# from sia import Dataset

# dataset = Dataset("Stress-in-Action")

# dataset.attach("./data/ecg_raw/*.edf")

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
import mne
import numpy as np
import pandas as pd
from typing import List, Dict
from torch.utils.data import Dataset as TorchDataset, IterableDataset
from tabulate import tabulate
from IPython.display import (
    display, display_html, display_png, display_svg
)

In [4]:
import lightning as L

from joblib import Parallel, delayed

In [5]:
from pyarrow.parquet import ParquetDataset

In [7]:
from pyarrow.parquet import ParquetDataset
from sia.utils import get_file_paths, tqdm_joblib

pd.options.mode.chained_assignment = None
target_labels = ['TA', 'SSST_Sing_countdown', 'Pasat', 'Raven', 'TA_repeat', 'Pasat_repeat']

class Dataset(TorchDataset):#IterableDataset):
    def __init__(self, name: str, window: int = 1000):
        self.name = name
        self.window = window
        self.dataset = None

    def attach(self, arg: str):
        if isinstance(arg, str):
            files = get_file_paths(arg)
            self.dataset = ParquetDataset(files)
        elif hasattr(arg, "__len__"):
            files = arg
            self.dataset = ParquetDataset(files)
        else: 
            raise TypeError("The argument type is not supported")
        
        return self
    
    def __len__(self):
        return sum(p.count_rows() for p in self.dataset.fragments)

    def __getitem__(self, idx):
        i = idx
        print("Loop over pieces")
        for piece in ds.fragments:
            print("PIECE ITERATION")
            if i - piece.count_rows() < 0:
                print("FOUND")
                window = piece.take(list(range(idx, idx+self.window)), columns=['ECG_Clean', 'category'])
                break
            else:
                print("NEXT")
                i -= piece.count_rows()
        
        if window is None:
            raise IndexError("Index out of range")
        
        print("PARSE DATA")
        signal = window['ECG_Clean'].to_numpy()
        label = window['category'].to_numpy()
        label[np.isin(label, target_labels)] = 1
        label[~np.isin(label, target_labels)] = 0
        
        signal = torch.tensor(signal)
        label = torch.tensor(label.astype(int))

        print("FOUND")
        return signal, torch.tensor(1) if torch.mode(label, 0)[0] == 1 else torch.tensor(0)
        
    def __repr__(self):
        data = []

        data.append(["name", self.name])

        if len(self.dataset.fragments) > 0:
            data.append(["files", len(self.dataset.fragments)])
            
            data.append(["length", [f.count_rows() for f in self.dataset.fragments]])

        return tabulate(data, tablefmt="fancy_grid")

In [8]:
# from sia.utils import get_file_paths, tqdm_joblib

# pd.options.mode.chained_assignment = None
# target_labels = ['TA', 'SSST_Sing_countdown', 'Pasat', 'Raven', 'TA_repeat', 'Pasat_repeat']

# class Dataset(TorchDataset):#IterableDataset):
#     def __init__(self, name: str, window: int = 1000):
#         self.name = name
#         self.window = window

#         self.files = []

#     def attach(self, arg: str):
#         def count_lines(file: str) -> int:
#             n = len(pd.read_feather(file))
#             return (file, n)

#         if isinstance(arg, str):
#             files = get_file_paths(arg)
#             self.files = Parallel(n_jobs=4)(delayed(count_lines)(f) for f in files)
#         elif hasattr(arg, "__len__"):
#             files = arg
#             self.files = Parallel(n_jobs=4)(delayed(count_lines)(f) for f in files)
#         else: 
#             raise TypeError("The argument type is not supported")
        
#         return self
    
#     def __len__(self):
#         return sum([f[1] for f in self.files]) - self.window

#     def __iter__(self):
#         for f in self.files:
#             df = pd.read_csv(f[0])
#             for window in list(df.rolling(self.window)):
#                 signal = window['ECG_Clean']
#                 label = window['category']
#                 label[label.isin(target_labels)] = 1
#                 label[~label.isin(target_labels)] = 0
                
#                 signal = torch.tensor(signal)
#                 label = torch.tensor(label)

#                 yield signal, label

#     def __getitem__(self, idx):
#         i = idx
#         for f in self.files:
#             if i - f[1] < 0:
#                 break
#             else:
#                 i -= f[1]

#         row = pd.read_feather(f[0], columns=['ECG_Clean', 'category'])
#         row = row.iloc[i:i+self.window]
        
#         signal = row['ECG_Clean']
#         label = row['category']
#         label[label.isin(target_labels)] = 1
#         label[~label.isin(target_labels)] = 0

#         label = label.to_numpy(dtype=int)
        
#         signal = torch.tensor(signal.to_numpy())
#         label = torch.tensor(label)

#         return signal, torch.tensor(1) if torch.mode(label, 0)[0] == 1 else torch.tensor(0)
        
#     def __repr__(self):
#         data = []

#         data.append(["name", self.name])

#         if len(self.files) > 0:
#             data.append(["files", len(self.files)])
#             data.append(["file paths", [f[0] for f in self.files]])
#             data.append(["length", [f[1] for f in self.files]])

#         return tabulate(data, tablefmt="fancy_grid")

In [9]:
import glob
from sklearn.model_selection import train_test_split

In [10]:
participants = glob.glob("./data/parquet/*.parquet")
train_participants, test_participants = train_test_split(participants, test_size=0.2)

In [11]:
ds_train = Dataset("Stress-in-Action")
ds_train.attach(train_participants)
ds_train

╒════════╤═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╕
│ name   │ Stress-in-Action                                                  

In [12]:
ds_test = Dataset("Stress-in-Action")
ds_test.attach(test_participants)
ds_test

╒════════╤═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╕
│ name   │ Stress-in-Action                                                                                                                                                                                                                            │
├────────┼─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ files  │ 26                                                                                                                                                                                                                                          │
├───

In [13]:
from torch.utils.data import DataLoader

In [14]:
train_dataloader = DataLoader(ds_train, batch_size=32, shuffle=False, drop_last=True, num_workers=4, pin_memory=True)
test_dataloader = DataLoader(ds_test, batch_size=32, shuffle=False, drop_last=True, pin_memory=True)

In [15]:
train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels

In [None]:
import wandb

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="stress-in-action",
    
    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "SiA",
        "epochs": 11,
    }
)

In [198]:
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch import nn

class Test(L.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(1000, 10).double(),
            nn.Tanh(),
            nn.Linear(10, 1).double(),
            nn.Softmax(1),
        )

        self.layers.cuda(0)

    def forward(self, x):
        embedding = self.layers(x)
        return embedding

    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        wandb.log({"loss": loss})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        wandb.log({"val_loss": loss})
        return loss

In [199]:
model = Test()
trainer = L.Trainer(max_epochs=11, accelerator="gpu", devices=1)
trainer.fit(model, train_dataloader, test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | Sequential | 10.0 K
--------------------------------------
10.0 K    Trainable params
0         Non-trainable params
10.0 K    Total params
0.040     Total estimated model params size (MB)


Epoch 0:   0%|          | 43/1514804 [13:37<7994:40:40,  0.05it/s, v_num=31]

In [None]:
wandb.finish()

0,1
loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▁▁

0,1
loss,0.0
val_loss,0.0
