# Generate signal encoding/embeddings

For the thesis final experiment, fusing autoencoders with XGB decision trees.


In [1]:
import warnings
from itertools import takewhile
from glob import glob

import numpy as np
import zarr
from torch.utils.data import DataLoader

from utils import ElapsedTimer
from datasets.SequenceZarr import SequenceZarr
from sequence_autoencoder import SequenceAutoEncoder


In [6]:
# initialize zarr group
# root = zarr.open_group("data/ecgs.zarr", mode="r")

store = zarr.DirectoryStore("data/ecgs.zarr")
root = zarr.group(store=store)
seq_embeddings = root.require_group("seq_embeddings")
print(root.tree())
# root.info

ds = SequenceZarr(sequence_length=20)
dl = DataLoader(
    ds,
    batch_size=128,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
    collate_fn=SequenceZarr.collate_fn
)

/
 ├── beats
 │   ├── r_peak_idxs (43099,) object
 │   ├── valid_r_peak_idxs (43099,) object
 │   ├── window_size_400 (43099,) object
 │   ├── window_size_400_normalized (43099,) object
 │   ├── window_size_400_normalized_flattened (801266, 400, 12) float32
 │   ├── window_size_400_outlier (43099,) int32
 │   └── window_size_400_shape (43099, 3) int32
 ├── cleaned
 │   └── p_signal (43099,) object
 ├── meta
 │   └── record_idx_to_window_400_range (1,) object
 ├── raw
 │   ├── dx (43099,) object
 │   ├── meta (43099, 3) int32
 │   ├── p_signal (43099,) object
 │   └── p_signal_shape (43099, 2) int32
 └── seq_embeddings
     ├── version_0 (43099, 768) float64
     ├── version_1 (43099, 768) float64
     ├── version_10 (43099, 768) float64
     ├── version_11 (43099, 768) float64
     ├── version_12 (43099, 768) float64
     ├── version_13 (43099, 768) float64
     ├── version_14 (43099, 768) float64
     ├── version_15 (43099, 768) float64
     ├── version_16 (43099, 768) float64
     ├─



In [3]:
torch_beat_checkpoints = glob("log_beat_autoencoder/*/checkpoints/*.ckpt")
torch_seq_checkpoints = glob("log_sequence_autoencoder/*/*.ckpt")

beat_seq_checkpoints = list(zip(sorted(torch_beat_checkpoints), sorted(torch_seq_checkpoints)))

In [4]:
def generate_seq_embeddings(beat_seq_checkpoint):
    beat_checkpoint, seq_checkpoint = beat_seq_checkpoint
    version_str = beat_checkpoint.split("/")[1]
    model = SequenceAutoEncoder.load_from_checkpoint(seq_checkpoint)
    model.cuda()
    iter_dl = iter(dl)
    
    embeddings = []
    counter = 0
    while True:
        try:
            counter += 1
            batch = next(iter_dl)
            beat_windows, seq_lens, dxs, str_abbrv_dxs, str_code_dxs = batch
            beat_windows = [bw.cuda() for bw in beat_windows]
            _pred_classes, embedding, _x_source, _x_hat = model(beat_windows, seq_lens)
            embeddings.append(embedding.detach().cpu().numpy())
        except StopIteration:
            break
        finally:
            if counter % 7 == 0:
                print(counter, end="\r")
    print(len(embeddings))
    return np.concatenate(embeddings, axis=0)


In [5]:
for beat_seq_checkpoint in beat_seq_checkpoints:
    with ElapsedTimer() as t:
        version_str = beat_seq_checkpoint[0].split("/")[1]
        print(version_str)
        embd_container = seq_embeddings.empty(
            version_str,
            shape=(len(ds), 768),
            chunks=(1, 768)
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            embeddings = generate_seq_embeddings(beat_seq_checkpoint)

        embd_container[:] = embeddings
    print(f"took {t.duration:.2f} seconds")

version_0
337
took 85.97 seconds
version_1
337
took 82.47 seconds
version_10
337
took 80.63 seconds
version_11
337
took 80.83 seconds
version_12
337
took 81.73 seconds
version_13
337
took 80.99 seconds
version_14
337
took 81.82 seconds
version_15
337
took 82.61 seconds
version_16
337
took 81.33 seconds
version_17
337
took 80.20 seconds
version_18
337
took 81.81 seconds
version_19
337
took 80.64 seconds
version_2
337
took 81.04 seconds
version_3
337
took 81.62 seconds
version_4
337
took 82.01 seconds
version_5
337
took 82.02 seconds
version_6
337
took 82.16 seconds
version_7
337
took 82.83 seconds
version_8
337
took 82.62 seconds
version_9
337
took 81.74 seconds
