Skip to content

Commit

Permalink
upgrade black version in pre-commit config
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Mar 26, 2024
1 parent e9d7cb5 commit 6df4d25
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.3.0
hooks:
- id: black
args: [--line-length=128, --extend-exclude=.ipynb, --verbose]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
----------
to add
"""

from numbers import Real
from typing import List

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
[2] to add
"""

import multiprocessing as mp
import os
import time
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/train_mtl_cinc2022/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,9 +1629,11 @@ def load_data(
channels = [channels]
assert set(channels).issubset(self._channels), "invalid channels"
data = {
k: torch.from_numpy(data[k].astype(np.float32))
if data_format.lower() == "channel_first"
else torch.from_numpy(data[k].astype(np.float32).T)
k: (
torch.from_numpy(data[k].astype(np.float32))
if data_format.lower() == "channel_first"
else torch.from_numpy(data[k].astype(np.float32).T)
)
for k in channels
if k in data
}
Expand Down
16 changes: 10 additions & 6 deletions benchmarks/train_mtl_cinc2022/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,11 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
head_preds_classes = [np.array(all_outputs[0].murmur_output.classes)[np.where(row)[0]] for row in head_bin_preds]
head_labels = all_labels[0]["murmur"][:log_head_num]
head_labels_classes = [
np.array(all_outputs[0].murmur_output.classes)[np.where(row)]
if head_labels.ndim == 2
else np.array(all_outputs[0].murmur_output.classes)[row]
(
np.array(all_outputs[0].murmur_output.classes)[np.where(row)]
if head_labels.ndim == 2
else np.array(all_outputs[0].murmur_output.classes)[row]
)
for row in head_labels
]
log_head_num = min(log_head_num, len(head_scalar_preds))
Expand All @@ -372,9 +374,11 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
head_labels = all_labels[0]["outcome"][:log_head_num]
head_labels_classes = [
np.array(all_outputs[0].outcome_output.classes)[np.where(row)[0]]
if head_labels.ndim == 2
else np.array(all_outputs[0].outcome_output.classes)[row]
(
np.array(all_outputs[0].outcome_output.classes)[np.where(row)[0]]
if head_labels.ndim == 2
else np.array(all_outputs[0].outcome_output.classes)[row]
)
for row in head_labels
]
log_head_num = min(log_head_num, len(head_scalar_preds))
Expand Down
1 change: 1 addition & 0 deletions benchmarks/train_mtl_cinc2022/utils/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class PitchShift(torch.nn.Module):
>>> waveform_shift = transform(waveform) # (channel, time)
"""

__constants__ = [
"sample_rate",
"n_steps",
Expand Down
1 change: 1 addition & 0 deletions legacy/_pantompkins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
source: https://pypi.org/project/wfdb/2.2.1/#files
"""

import numpy as np
import scipy.signal as scisig

Expand Down
16 changes: 10 additions & 6 deletions test/test_pipelines/test_mtl_cinc2022_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3321,9 +3321,11 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
head_preds_classes = [np.array(all_outputs[0].murmur_output.classes)[np.where(row)[0]] for row in head_bin_preds]
head_labels = all_labels[0]["murmur"][:log_head_num]
head_labels_classes = [
np.array(all_outputs[0].murmur_output.classes)[np.where(row)]
if head_labels.ndim == 2
else np.array(all_outputs[0].murmur_output.classes)[row]
(
np.array(all_outputs[0].murmur_output.classes)[np.where(row)]
if head_labels.ndim == 2
else np.array(all_outputs[0].murmur_output.classes)[row]
)
for row in head_labels
]
log_head_num = min(log_head_num, len(head_scalar_preds))
Expand All @@ -3348,9 +3350,11 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
head_labels = all_labels[0]["outcome"][:log_head_num]
head_labels_classes = [
np.array(all_outputs[0].outcome_output.classes)[np.where(row)[0]]
if head_labels.ndim == 2
else np.array(all_outputs[0].outcome_output.classes)[row]
(
np.array(all_outputs[0].outcome_output.classes)[np.where(row)[0]]
if head_labels.ndim == 2
else np.array(all_outputs[0].outcome_output.classes)[row]
)
for row in head_labels
]
log_head_num = min(log_head_num, len(head_scalar_preds))
Expand Down
3 changes: 1 addition & 2 deletions torch_ecg/model_configs/ecg_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
ECG_UNET_VANILLA_CONFIG.up_mode = "nearest"
ECG_UNET_VANILLA_CONFIG.up_scales = list(repeat(2, ECG_UNET_VANILLA_CONFIG.down_up_block_num))
ECG_UNET_VANILLA_CONFIG.up_num_filters = [
ECG_UNET_VANILLA_CONFIG.init_num_filters * (2**idx)
for idx in range(ECG_UNET_VANILLA_CONFIG.down_up_block_num - 1, -1, -1)
ECG_UNET_VANILLA_CONFIG.init_num_filters * (2**idx) for idx in range(ECG_UNET_VANILLA_CONFIG.down_up_block_num - 1, -1, -1)
]
ECG_UNET_VANILLA_CONFIG.up_deconv_filter_lengths = list(repeat(_base_filter_length, ECG_UNET_VANILLA_CONFIG.down_up_block_num))
ECG_UNET_VANILLA_CONFIG.up_conv_filter_lengths = list(repeat(_base_filter_length, ECG_UNET_VANILLA_CONFIG.down_up_block_num))
Expand Down
1 change: 1 addition & 0 deletions torch_ecg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
select_k
"""

from . import ecg_arrhythmia_knowledge as EAK
from ._ecg_plot import ecg_plot
from .download import http_get
Expand Down

0 comments on commit 6df4d25

Please sign in to comment.