In [1]:
import os
import sys
import copy
import time
import random
import warnings
from datetime import timedelta, datetime

import numpy as np
import polars as pl
import plotly.graph_objects as go

from tqdm.notebook import tqdm
from IPython.display import display
from ipywidgets.widgets import HBox

import torch
from torch.functional import F
from torch import nn, optim, cuda
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

In [2]:
ROOT_PATH = './'
DRIVE_PATH = 'Colab/TimeSeries-TP'

# When on Colab, use Google Drive as the root path to persist and load data
if 'google.colab' in sys.modules:
    from google.colab import drive, output
    output.enable_custom_widget_manager()

    drive.mount('/content/drive')
    ROOT_PATH = os.path.join('/content/drive/My Drive/', DRIVE_PATH)
    os.makedirs(ROOT_PATH, exist_ok=True)
    os.chdir(ROOT_PATH)

In [3]:
RANDOM_SEED = 1984

BATCH_SIZE = 128

TOTAL_EPOCHS = 10

BETA_1 = 0.9
BETA_2 = 0.999
EPS = 1e-8
AMSGRAD = False
WEIGHT_DECAY = 0.01

WARMUP_RATIO = 0.05
LEARNING_RATE = 0.04


EVAL_K = 10


PYTORCH_DEVICE = 'cpu'

# Use NVIDIA GPU if available
if cuda.is_available():
    PYTORCH_DEVICE = 'cuda'

# Use Apple Metal backend if available
if torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("Your device supports MPS but it is not installed. Checkout https://developer.apple.com/metal/pytorch/")
    else:
        PYTORCH_DEVICE = 'mps'


print (f"Using {PYTORCH_DEVICE} device for PyTorch")

Using cuda device for PyTorch


In [4]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.mps.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
DATA_DIR = "./data/ltafdb-processed"
all_data = []
for root, _, files in os.walk(DATA_DIR):
    for file in tqdm(files, desc="Loading Records", unit="Records"):
        if file.endswith(".pqt.zstd"):
                all_data.append(pl.read_parquet(os.path.join(root, file)))

Loading Records:   0%|          | 0/84 [00:00<?, ?Records/s]

In [6]:
train_data, test_data = train_test_split(all_data[:10], test_size=0.2, random_state=RANDOM_SEED)
test_data, validation_data = train_test_split(test_data, test_size=0.5, random_state=RANDOM_SEED)

In [7]:
class WindowedDataset(Dataset):
    def __init__(
        self,
        dataframes: list[pl.DataFrame],
        window_size: int,
        stride: int,
    ):
        super().__init__()
        self.dataframe = dataframes
        self.window_size = window_size
        self.stride = stride

        self.window_count = [
            (len(df) - window_size) // stride
            for df in self.dataframe
        ]
        self.len = sum(self.window_count)

        self.signals = [df['signal'].explode().to_numpy() for df in self.dataframe]
        self.labels = [df['is_abnormal'].explode().to_numpy() for df in self.dataframe]
        self.starts = [df['start'].explode().to_numpy() for df in self.dataframe]
        self.ends = [df['end'].explode().to_numpy() for df in self.dataframe]


    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        df_idx = 0
        while idx >= self.window_count[df_idx]:
            idx -= self.window_count[df_idx]
            df_idx += 1
        
        start = idx * self.stride
        end = start + self.window_size

        series = self.signals[df_idx][self.starts[df_idx][start]:self.ends[df_idx][end-1]]
        label = np.any(self.labels[df_idx][start:end])
        
        return series, label

In [8]:
train_dataset = WindowedDataset(train_data, window_size=20, stride=1)
validation_dataset = WindowedDataset(validation_data, window_size=20, stride=1)
test_dataset = WindowedDataset(test_data, window_size=20, stride=1)

In [9]:
print(f"Training dataset size (windows): {len(train_dataset)}")
print(f"Validation dataset size (windows): {len(validation_dataset)}")
print(f"Test dataset size (windows): {len(test_dataset)}")

Training dataset size (windows): 814556
Validation dataset size (windows): 109604
Test dataset size (windows): 114454


In [10]:
# pos_count = 0
# for _, label in tqdm(train_dataset):
#     if label:
#         pos_count += 1
# print(f"Positive ratio: {pos_count/len(train_dataset):2%}")

In [11]:
class ECGModel(nn.Module):
    def __init__(self, n_features=1, lstm_units=200, lstm_layers=1, lstm_dropout=0, dense_units=50, dense_dropout=0.1):
        super(ECGModel, self).__init__()
        self.n_features = n_features
        self.lstm = nn.LSTM(n_features, lstm_units, num_layers=lstm_layers, batch_first=True, bidirectional=True, dropout=lstm_dropout)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc1 = nn.Linear(2*lstm_units, dense_units)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dense_dropout)
        self.output = nn.Linear(dense_units, 1)

    def forward(self, seq, seq_len):
        packed_seq = nn.utils.rnn.pack_padded_sequence(seq, seq_len.cpu(), batch_first=True, enforce_sorted=False)
        seq, _ = self.lstm(packed_seq)
        seq, _ = nn.utils.rnn.pad_packed_sequence(seq, batch_first=True)
        seq = seq.permute(0, 2, 1)
        seq = self.pool(seq).squeeze(-1)
        seq = self.fc1(seq)
        seq = self.relu(seq)
        seq = self.dropout(seq)
        seq = self.output(seq)
        return seq

In [12]:
def collate_fn(data: list[tuple[np.ndarray, bool]]):
    seq_lengths = torch.tensor([len(d[0]) for d in data], dtype=torch.long)
    padded_seqs = torch.nn.utils.rnn.pad_sequence([torch.tensor(d[0], dtype=torch.float32) for d in data], batch_first=True, padding_value=0.0)

    return padded_seqs.unsqueeze(-1), seq_lengths, torch.tensor([d[1] for d in data], dtype=torch.float32)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, collate_fn=collate_fn)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)

In [13]:
model = ECGModel()
model.to(PYTORCH_DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(BETA_1, BETA_2), eps=EPS, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, total_steps=TOTAL_EPOCHS * len(train_dataloader), 
    pct_start=WARMUP_RATIO, cycle_momentum=False, anneal_strategy='linear')

In [14]:
def train(
    model: ECGModel,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler | None,
    criterion: nn.modules.loss._Loss,
    train_loader: DataLoader,
    validation_loader: DataLoader,
    epochs: int,
    device: str,
):
    loss_history = []

    steps = 0
    for epoch in tqdm(range(epochs), desc="Epoch"):
        model.train()
        epoch_loss = 0

        for batch in tqdm(train_loader, desc="Batch", leave=False, total=len(train_loader)):
            model.zero_grad()


            inputs, seq_len, targets = batch

            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs, seq_len).squeeze(1)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()
            steps += 1
        
        print(f"Epoch {epoch + 1} loss: {epoch_loss / steps}")
        loss_history.append(epoch_loss / steps)
    
    return loss_history

# def validate(
#     model: ECGModel,
#     criterion: nn.modules.loss._Loss,
#     validation_loader: DataLoader,
#     device: str,
# ):
#     model.eval()
#     with torch.no_grad():
#         loss = 0
#         for batch in validation_loader:
#             inputs, targets = batch
#             inputs = inputs.to(device)
#             targets = targets.to(device)

#             outputs = model(inputs)
#             loss += criterion(outputs, targets).item()
        
#         return loss / len(validation_loader)

In [15]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    loss_history = train(model, optimizer, scheduler, criterion, train_dataloader, validation_dataloader, TOTAL_EPOCHS, PYTORCH_DEVICE)

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/6364 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [33]:
import gc

gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()