Skip to content

Commit

Permalink
ruff: RET (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed May 31, 2023
1 parent ec81c59 commit 2d61e05
Show file tree
Hide file tree
Showing 50 changed files with 146 additions and 304 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ select = [
extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
"PT", # see: https://pypi.org/project/flake8-pytest-style
# "RET", # see: https://pypi.org/project/flake8-return
"RET", # see: https://pypi.org/project/flake8-return
"SIM", # see: https://pypi.org/project/flake8-simplify
"YTT", # see: https://pypi.org/project/flake8-2020
# "ANN", # see: https://pypi.org/project/flake8-annotations
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None:
self.log_histograms(inp, group=input_group_name)
self.log_histograms(out, group=output_group_name)

handle = module.register_forward_hook(hook)
return handle
# handler
return module.register_forward_hook(hook)


@under_review()
Expand Down
4 changes: 2 additions & 2 deletions src/pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def predict(self, query_feature: Tensor, feature_bank: Tensor, target_bank: Tens
# weighted score ---> [B, C]
pred_scores = torch.sum(one_hot_label.view(B, -1, self.num_classes) * sim_weight.unsqueeze(dim=-1), dim=1)

pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# pred_labels
return pred_scores.argsort(dim=-1, descending=True)

def to_device(self, batch: Tensor, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]:
# get the labeled batch
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def dicts_to_table(
none_keys = [k for k, v in d.items() if v is None]
if skip_none_lines and none_keys:
continue
elif replace_values:
if replace_values:
for k in d:
if k in replace_values and d[k] in replace_values[k]:
d[k] = replace_values[k][d[k]]
Expand Down
6 changes: 2 additions & 4 deletions src/pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,11 @@ def __init__(
self._sample_idx = sample_idx

def message(self, *args: Any, **kwargs: Any) -> str:
message = (
return (
"Your model is mixing data across the batch dimension."
" This can lead to wrong gradient updates in the optimizer."
" Check the operations that reshape and permute tensor dimensions in your model."
)
return message

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
verification = BatchGradientVerification(pl_module)
Expand Down Expand Up @@ -189,8 +188,7 @@ def default_output_mapping(data: Any) -> Tensor:
batches = default_input_mapping(data)
# cannot use .flatten(1) because of tensors with shape (B, )
batches = [batch.view(batch.size(0), -1).float() for batch in batches]
combined = torch.cat(batches, 1) # combined batch has shape (B, N)
return combined
return torch.cat(batches, 1) # combined batch has shape (B, N)


@under_review()
Expand Down
17 changes: 5 additions & 12 deletions src/pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ def train_dataloader(self) -> DataLoader:
**self.extra_args,
)

loader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def val_dataloader(self) -> DataLoader:
"""Cityscapes val set."""
Expand All @@ -167,15 +166,14 @@ def val_dataloader(self) -> DataLoader:
**self.extra_args,
)

loader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)
return loader

def test_dataloader(self) -> DataLoader:
"""Cityscapes test set."""
Expand All @@ -191,29 +189,24 @@ def test_dataloader(self) -> DataLoader:
target_transform=target_transforms,
**self.extra_args,
)
loader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def _default_transforms(self) -> Callable:
cityscapes_transforms = transform_lib.Compose(
return transform_lib.Compose(
[
transform_lib.ToTensor(),
transform_lib.Normalize(
mean=[0.28689554, 0.32513303, 0.28389177], std=[0.18696375, 0.19017339, 0.18720214]
),
]
)
return cityscapes_transforms

def _default_target_transforms(self) -> Callable:
cityscapes_target_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())]
)
return cityscapes_target_transforms
return transform_lib.Compose([transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())])
7 changes: 2 additions & 5 deletions src/pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(self, generate_batch: Callable) -> None:
self.generate_batch = generate_batch

def __iter__(self) -> Iterator:
iterator = self.generate_batch()
return iterator
return self.generate_batch() # iterator


# Experience Sources
Expand Down Expand Up @@ -191,9 +190,7 @@ def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
self.cur_rewards[env_idx] += r
self.cur_steps[env_idx] += 1

exp = Experience(state=self.states[env_idx], action=action[0], reward=r, done=is_done, new_state=next_state)

return exp
return Experience(state=self.states[env_idx], action=action[0], reward=r, done=is_done, new_state=next_state)

def update_env_stats(self, env_idx: int) -> None:
"""To be called at the end of the history tail generation during the termination state. Updates the stats
Expand Down
7 changes: 2 additions & 5 deletions src/pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def train_transform(self) -> Callable:
),
])
"""
preprocessing = transform_lib.Compose(
return transform_lib.Compose(
[
transform_lib.RandomResizedCrop(self.image_size),
transform_lib.RandomHorizontalFlip(),
Expand All @@ -233,8 +233,6 @@ def train_transform(self) -> Callable:
]
)

return preprocessing

def val_transform(self) -> Callable:
"""The standard imagenet transforms for validation.
Expand All @@ -251,15 +249,14 @@ def val_transform(self) -> Callable:
])
"""

preprocessing = transform_lib.Compose(
return transform_lib.Compose(
[
transform_lib.Resize(self.image_size + 32),
transform_lib.CenterCrop(self.image_size),
transform_lib.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing

@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
Expand Down
12 changes: 4 additions & 8 deletions src/pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,45 +94,41 @@ def __init__(
)

def train_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.trainset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def val_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.valset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def test_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.testset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def _default_transforms(self) -> Callable:
kitti_transforms = transforms.Compose(
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]
),
]
)
return kitti_transforms
9 changes: 3 additions & 6 deletions src/pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,31 @@ def _init_datasets(
self.test_dataset = SklearnDataset(x_test, y_test)

def train_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def val_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def test_dataloader(self) -> DataLoader:
loader = DataLoader(
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader
12 changes: 4 additions & 8 deletions src/pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,14 @@ def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool =
split="train",
transform=transforms,
)
loader: DataLoader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader:
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
Expand All @@ -115,15 +114,14 @@ def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = F
split="val",
transform=transforms,
)
loader: DataLoader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader:
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
Expand All @@ -135,16 +133,14 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
split="test",
transform=transforms,
)
loader: DataLoader = DataLoader(
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
return loader

def _default_transforms(self) -> Callable:
transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()])
return transforms
return transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()])

0 comments on commit 2d61e05

Please sign in to comment.