Skip to content

Commit

Permalink
add test for slice indexing for builtin datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Apr 1, 2024
1 parent 0e71981 commit 3d476c2
Show file tree
Hide file tree
Showing 14 changed files with 62 additions and 11 deletions.
3 changes: 2 additions & 1 deletion test/test_databases/test_afdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

with pytest.warns(RuntimeWarning):
reader = AFDB(_CWD, verbose=3)
reader.download()
if len(reader) == 0:
reader.download()


class TestAFDB:
Expand Down
3 changes: 2 additions & 1 deletion test/test_databases/test_apnea_ecg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

with pytest.warns(RuntimeWarning):
reader = ApneaECG(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestApneaECG:
Expand Down
3 changes: 2 additions & 1 deletion test/test_databases/test_cinc2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

with pytest.warns(RuntimeWarning):
reader = CINC2017(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestCINC2017:
Expand Down
5 changes: 5 additions & 0 deletions test/test_databases/test_cinc2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ def test_getitem(self):
)
assert target.ndim == 1 and target.shape == (len(config.classes),)

# test slice indexing
data, target = ds[:2]
assert data.shape == (2, len(config.leads), config.input_len)
assert target.shape == (2, len(config.classes))

def test_load_one_record(self):
for rec in ds.records:
data, target = ds._load_one_record(rec)
Expand Down
5 changes: 5 additions & 0 deletions test/test_databases/test_cinc2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def test_getitem(self):
)
assert target.ndim == 1 and target.shape == (len(config.classes),)

# test slice indexing
data, target = ds[:2]
assert data.shape == (2, len(config.leads), config.input_len)
assert target.shape == (2, len(config.classes))

def test_load_one_record(self):
for rec in ds.records:
data, target = ds._load_one_record(rec)
Expand Down
12 changes: 11 additions & 1 deletion test/test_databases/test_cpsc2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@


reader = CPSC2019(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestCPSC2019:
Expand Down Expand Up @@ -144,5 +145,14 @@ def test_getitem(self):
assert data.ndim == 2 and data.shape == (1, config_1.input_len)
assert bin_mask.ndim == 2 and bin_mask.shape == (config_1.input_len, 1)

# test slice indexing
data, bin_mask = ds[:2]
assert data.ndim == 3 and data.shape == (2, 1, config.input_len)
assert bin_mask.ndim == 3 and bin_mask.shape == (
2,
config.input_len // config.reduction,
1,
)

def test_properties(self):
assert str(ds) == repr(ds)
3 changes: 2 additions & 1 deletion test/test_databases/test_cpsc2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


reader = CPSC2020(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestCPSC2020:
Expand Down
6 changes: 6 additions & 0 deletions test/test_databases/test_cpsc2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ def test_getitem(self):
assert data.shape == (config.n_leads, input_len)
assert qrs_mask.shape == (input_len, 1)

# test slice indexing
data, af_mask, weight_mask = ds[:2]
assert data.ndim == 3 and data.shape == (2, config.n_leads, input_len)
assert af_mask.ndim == 3 and af_mask.shape == (2, input_len, 1)
assert weight_mask.ndim == 3 and weight_mask.shape == (2, input_len, 1)

def test_properties(self):
assert ds.task == "main"
assert ds_1.task == "rr_lstm"
Expand Down
3 changes: 2 additions & 1 deletion test/test_databases/test_ltafdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

with pytest.warns(RuntimeWarning):
reader = LTAFDB(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestLTAFDB:
Expand Down
11 changes: 10 additions & 1 deletion test/test_databases/test_ludb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@


reader = LUDB(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestLUDB:
Expand Down Expand Up @@ -140,6 +141,14 @@ def test_getitem(self):
assert signals.shape == (config.n_leads, config.input_len)
assert labels.shape == (config.input_len, len(config.classes))

# test slice indexing
signals, labels = ds[:2]
assert signals.shape == (2, config.n_leads, config.input_len)
# NOTE that the (segmentation) labels have collapsed lead dimension
# so the shape is (n_samples, signal_len, n_classes)
# instead of (n_samples, n_leads, signal_len, n_classes)
assert labels.shape == (2, config.input_len, len(config.classes))

def test_properties(self):
signals_shape = ds.signals.shape # (n_samples, n_leads, signal_len)
labels_shape = ds.labels.shape # (n_samples, n_leads, signal_len, n_classes)
Expand Down
8 changes: 7 additions & 1 deletion test/test_databases/test_mitdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@


reader = MITDB(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestMITDB:
Expand Down Expand Up @@ -164,6 +165,11 @@ def test_getitem(self):

# `ds_rhythm` and `ds_af` have bugs now

# test slice indexing
data, ann = ds[:2]
assert data.shape == (2, config.n_leads, config[TASK].input_len)
assert ann.shape == (2, config[TASK].input_len, 1)

def test_load_seg_data(self):
seg = ds.all_segments[list(ds.all_segments)[0]][0]
data = ds._load_seg_data(seg)
Expand Down
3 changes: 2 additions & 1 deletion test/test_databases/test_ptb_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@


reader = PTBXL(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestPTBXL:
Expand Down
3 changes: 2 additions & 1 deletion test/test_databases/test_qtdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


reader = QTDB(_CWD)
reader.download()
if len(reader) == 0:
reader.download()


class TestQTDB:
Expand Down
5 changes: 4 additions & 1 deletion torch_ecg/databases/datasets/ludb/ludb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def __len__(self) -> int:
return len(self.records)

def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
if self.config.use_single_lead:
rec_idx, lead_idx = divmod(index, len(self.leads))
else:
Expand All @@ -100,7 +102,8 @@ def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]
labels = labels[lead_idx, ...]
else:
# merge labels in all leads to one
# TODO: map via self.waveform_priority
# TODO: map via self.waveform_priority,
# or make it configurable whether to collapse the lead dimension
labels = np.max(labels, axis=0)
sampfrom = randint(self.config.start_from, signals.shape[1] - self.config.end_at - self.siglen)
sampto = sampfrom + self.siglen
Expand Down

0 comments on commit 3d476c2

Please sign in to comment.