In [1]:
import time

import numpy as np
import pywt
import h5py
import torch

In [2]:
path = 'unlabeled-elm-events/elm-events.hdf5'
with h5py.File(path, 'r') as hf:
    ids = list(hf.keys())
    first_elm_id = ids[0]
    second_elm_id = ids[1]
    first_elm = np.array(hf[first_elm_id]['signals'], dtype=np.float32)
    second_elm = np.array(hf[second_elm_id]['signals'], dtype=np.float32)
first_elm = first_elm.T
second_elm = second_elm.T

In [3]:
max_scale = 512
num = int(np.log2(max_scale)) + 1
widths = np.round(
    np.geomspace(1, max_scale, num=num, endpoint=True)
).astype(int)

In [None]:
# first_elm_cwt, _ = pywt.cwt(first_elm, scales=widths, wavelet="morl", axis=0)
# first_elm_cwt = np.transpose(first_elm_cwt, (1, 0, 2))
#
# second_elm_cwt, _ = pywt.cwt(second_elm, scales=widths, wavelet="morl", axis=0)
# second_elm_cwt = np.transpose(second_elm_cwt, (1, 0, 2))

In [4]:
first_elm_tensor = torch.tensor(first_elm, dtype=torch.float32)
second_elm_tensor = torch.tensor(second_elm, dtype=torch.float32)

In [7]:
times_np = []
for i in range(len(first_elm) - max_scale + 1):
    start = time.time()
    sws = first_elm[i: i+max_scale]
    sws_cwt_np, _ = pywt.cwt(sws, scales=widths, wavelet="morl", axis=0)
    sws_cwt_np = np.transpose(sws_cwt_np, (1, 0, 2))
    times_np.append(time.time() - start)
total_time_taken_np = np.mean(times_np)
print(f'Total time taken with numpy array: {total_time_taken_np}')

Total time taken with numpy array: 0.04691557268607057


In [8]:
times_pt = []
for i in range(len(first_elm) - max_scale + 1):
    start = time.time()
    sws = first_elm_tensor[i:i+max_scale]
    sws = sws.numpy()
    sws_cwt_pt, _ = pywt.cwt(sws, scales=widths, wavelet="morl", axis=0)
    sws_cwt_pt = np.transpose(sws_cwt_pt, (1, 0, 2))
    sws_cwt_pt = torch.as_tensor(sws_cwt_pt, dtype=torch.float32)
    times_pt.append(time.time() - start)
total_time_taken_pt = np.mean(times_pt)
print(f'Total time taken with torch tensor: {total_time_taken_pt}')

Total time taken with torch tensor: 0.04509874770985038


In [16]:
x = np.random.rand(16, 1, 512, 8, 8)
x_dwt = pywt.wavedec(x, 'db4', level=6, axis=2)