Skip to content

fix(ckpt): correct safetensors path for decimal-suffix stems; return actual save path#44

Merged
kjs11 merged 5 commits into
masterfrom
fix/ckpt-safetensors-path
Jun 2, 2026
Merged

fix(ckpt): correct safetensors path for decimal-suffix stems; return actual save path#44
kjs11 merged 5 commits into
masterfrom
fix/ckpt-safetensors-path

Conversation

@wenh06
Copy link
Copy Markdown
Collaborator

@wenh06 wenh06 commented May 31, 2026

Summary

Two related bugs introduced when the checkpoint format switched from .pth.tar to .safetensors.


Bug A — CkptMixin.save() silently truncates filenames with decimal values

File: torch_ecg/utils/utils_nn.py

Trainer generates checkpoint folder names like …_metric_0.91. pathlib treats .91 as the file extension, so the old path.with_suffix(".safetensors") replaced .91 instead of appending, producing …metric_0.safetensors instead of the correct …metric_0.91.safetensors.

Fix: after determining use_safetensors=True, if path.suffix != ".safetensors" we use string concatenation instead of with_suffix:

path = Path(str(path) + ".safetensors")   # append, never replace

The now-redundant path.with_suffix(".safetensors") inside the single-file save_file call is also removed (path is already normalised at that point).

save() now returns Path (the actual file/directory written) instead of None.


Bug B — BaseTrainer checkpoint cleanup always fails

File: torch_ecg/components/trainer.py

saved_models stored the raw stem path (…metric_0.91), but the file on disk was …metric_0.91.safetensors. Every os.remove(model_to_remove) in the keep_checkpoint_max cleanup therefore raised FileNotFoundError silently.

Fix: save_checkpoint() forwards the Path returned by model.save(); the training loop stores that actual path:

actual_save_path = self.save_checkpoint(str(save_path))
self.saved_models.append(actual_save_path if actual_save_path is not None else save_path)

Directory-style checkpoints (non-single-file mode) are now cleaned up with shutil.rmtree instead of os.remove. shutil is promoted to a top-level import.


Changes

File What changed
torch_ecg/utils/utils_nn.py Path normalisation fix; save() returns Path
torch_ecg/components/trainer.py Use returned path in saved_models; handle dir cleanup; top-level shutil import
test/test_utils/test_utils_nn.py New test_ckpt_decimal_suffix_path covering all three save branches
CHANGELOG.rst Two Fixed entries under Unreleased

Tests

pytest test/test_utils/test_utils_nn.py -k "not test_mixin_classes"
# 9 passed

test_mixin_classes is excluded only because it requires a Dropbox network connection (pre-existing, unrelated).

wenh06 and others added 2 commits March 31, 2026 22:46
…actual save path

Bug A (utils_nn.py CkptMixin.save):
  Path("…metric_0.91").with_suffix(".safetensors") treated ".91" as the
  existing extension and replaced it, silently producing "…metric_0.safetensors".
  Fix: after determining use_safetensors=True, if path.suffix is not already
  ".safetensors" we append rather than replace, giving the correct
  "…metric_0.91.safetensors".  The redundant path.with_suffix(".safetensors")
  inside the single-file branch is removed (path is already normalised).
  save() now returns the final Path instead of None so callers know exactly
  where the file was written.

Bug B (components/trainer.py BaseTrainer):
  saved_models stored the raw stem path while the actual file on disk carried a
  ".safetensors" suffix, so every os.remove() in keep_checkpoint_max cleanup
  raised FileNotFoundError.  The trainer now stores the Path returned by
  save_checkpoint(), which in turn forwards the value returned by model.save().
  Directory-style (non-single-file) checkpoints are now removed with
  shutil.rmtree instead of os.remove.  shutil is promoted to a top-level import.

Also adds test_ckpt_decimal_suffix_path covering all three code paths:
  single-file safetensors, directory safetensors, and pth/torch.save fallback.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings May 31, 2026 15:16
sig = test_sig.clone()
sig = bp(sig)
with pytest.warns(RuntimeWarning, match="lowcut <= 0"):
sig = bp(sig)
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

The PR is described as a focused fix for two checkpoint-related bugs (decimal-suffix paths being truncated by pathlib.with_suffix, and BaseTrainer silently failing to clean up checkpoints because it stored the un-normalised path). In practice the diff is much broader: it also performs a sweeping migration from numbers.Real to int / float / Union[int, float] across ~40 files, adds input validation (with new warnings/errors) to bandpass_filter, and makes a few small ancillary changes (viewreshape in baseline_removal, Tuple[Union[type(None), int], ...]Tuple[Union[None, int], ...], etc.).

Changes:

  • Fix CkptMixin.save() decimal-suffix truncation; make save() and save_checkpoint() return the actual Path; clean up directory-style checkpoints with shutil.rmtree.
  • Replace numbers.Real annotations/isinstance checks with int/float/Union[int, float] throughout the codebase.
  • Add validation and warnings to bandpass_filter and update related tests.

Reviewed changes

Copilot reviewed 62 out of 62 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
torch_ecg/utils/utils_nn.py Append .safetensors instead of with_suffix; return saved Path; drop Real.
torch_ecg/components/trainer.py Use returned Path from save(); shutil.rmtree for dir checkpoints; promote shutil import.
torch_ecg/utils/utils_signal_t.py New input validation in bandpass_filter; viewreshape in baseline_removal; Realint/float.
torch_ecg/utils/utils_metrics.py, utils_interval.py, _preproc.py, _edr.py Realint/float; _getxy now returns a Python float via .item().
torch_ecg/models/_nets.py Realfloat/(float, dict) in several isinstance checks (regression: int no longer accepted).
torch_ecg/models/cnn/{xception,resnet,regnet,mobilenet}.py, models/loss.py, models/ecg_fcn.py Real(int, float, …) in isinstance checks and annotations.
torch_ecg/preprocessors/.py, _preprocessors/normalize.py, augmenters/.py Type-hint cleanup (Realint/float); minor dtype additions in augmenters.
torch_ecg/databases/**/*.py, components/{inputs,outputs,metrics,loggers}.py Doc/type updates from Real to int/float; Union[str, type(None)]Union[str, None].
test/test_utils/test_utils_nn.py New regression test test_ckpt_decimal_suffix_path covering the three save branches.
test/test_preprocessors.py, test/test_preprocessors_t.py, test/test_databases/test_shhs.py Tests updated to drop Real and to cover new bandpass_filter validation.
CHANGELOG.rst Two Fixed entries for the checkpoint bugs (the broader Real migration and bandpass validation are not noted).
benchmarks/**/*.py, torch_ecg/preprocessors/README.md Annotation updates from Real to int/float and minor import cleanups.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread torch_ecg/models/_nets.py Outdated
Comment thread torch_ecg/models/_nets.py Outdated
Comment thread torch_ecg/models/_nets.py Outdated
Comment thread CHANGELOG.rst
Comment on lines +75 to +85
- Fix ``CkptMixin.save()`` silently truncating checkpoint filenames that contain
decimal values (e.g. ``…metric_0.91``): ``pathlib.Path.with_suffix(".safetensors")``
treated ``.91`` as the existing suffix and replaced it, producing
``…metric_0.safetensors`` instead of the correct ``…metric_0.91.safetensors``.
The method now appends ``.safetensors`` for paths without a recognised
extension, and returns the final ``Path`` used so callers can track it.
- Fix ``BaseTrainer`` checkpoint cleanup permanently failing: ``saved_models``
stored the raw stem path while the actual file on disk had a ``.safetensors``
suffix, causing every ``os.remove()`` call to raise ``FileNotFoundError``.
The trainer now stores the path returned by ``save_checkpoint()``, and handles
both single-file (``os.remove``) and directory (``shutil.rmtree``) checkpoints.
Comment on lines +485 to +495
if effective_lowcut is not None:
if effective_lowcut <= 0:
warnings.warn(
"lowcut <= 0 in bandpass_filter; disabling high-pass side.",
RuntimeWarning,
)
effective_lowcut = None
elif effective_lowcut >= nyquist:
raise ValueError(
f"lowcut must be less than Nyquist frequency (fs/2={nyquist}), " f"got lowcut={effective_lowcut!r}"
)
Comment on lines +493 to +495
raise ValueError(
f"lowcut must be less than Nyquist frequency (fs/2={nyquist}), " f"got lowcut={effective_lowcut!r}"
)
Comment on lines +779 to +785
Returns
-------
Path, optional
The actual path the checkpoint was saved to (suffix may differ
from ``path`` after normalisation, e.g. ``.safetensors``).
Returns ``None`` when the model does not implement ``save()``.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c274891267

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread torch_ecg/models/_nets.py Outdated
Comment on lines 885 to 888
if isinstance(dropouts, (float, dict)):
_dropouts = list(repeat(dropouts, self.__num_convs))
else:
_dropouts = list(dropouts) # type: ignore
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accept integer dropout values

When callers pass dropouts=0 (a common way to disable dropout and previously accepted because numbers.Real includes int), this now falls into the else branch and executes list(0), raising TypeError during model construction. This also affects the same replacement in BranchedConv/SeqLin; keep int in the scalar check so existing integer dropout configs continue to work.

Useful? React with 👍 / 👎.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 31, 2026

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
483 1 482 34
View the top 1 failed test(s) by shortest run time
test/test_databases/test_cpsc2020.py::TestCPSC2020::test_locate_premature_beats
Stack Traces | 0.001s run time
self = <test_cpsc2020.TestCPSC2020 object at 0x7fd49b24e100>

    def test_locate_premature_beats(self):
>       premature_beat_intervals = reader.locate_premature_beats(0)

test/test_databases/test_cpsc2020.py:78: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../databases/cpsc_databases/cpsc2020.py:577: in locate_premature_beats
    premature_intervals = get_optimal_covering(
torch_ecg/utils/utils_interval.py:584: in get_optimal_covering
    if (tot_start > min([item if isinstance(item, (int, float)) else item[0] for item in to_cover])) or (
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

.0 = <iterator object at 0x7fd4017cf8b0>

>   if (tot_start > min([item if isinstance(item, (int, float)) else item[0] for item in to_cover])) or (
        tot_end < max([item if isinstance(item, (int, float)) else item[-1] for item in to_cover])
    ):
E   IndexError: invalid index to scalar variable.

torch_ecg/utils/utils_interval.py:584: IndexError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

kjs11 and others added 3 commits June 2, 2026 13:05
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
@kjs11 kjs11 enabled auto-merge June 2, 2026 05:07
@kjs11 kjs11 added this pull request to the merge queue Jun 2, 2026
Merged via the queue into master with commit 02547df Jun 2, 2026
8 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants