Skip to content

Commit

Permalink
ruff: I (#1024)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] committed May 31, 2023
1 parent d52c665 commit ec81c59
Show file tree
Hide file tree
Showing 50 changed files with 65 additions and 106 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ repos:
- mdformat_frontmatter
exclude: CHANGELOG.md

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
#- repo: https://github.com/PyCQA/isort
# rev: 5.12.0
# hooks:
# - id: isort

- repo: https://github.com/psf/black
rev: 23.3.0
Expand All @@ -63,7 +63,7 @@ repos:
- flake8-pytest-style

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.269
rev: v0.0.270
hooks:
- id: ruff
args: ["--fix"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ line-length = 120
select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
"I", # see: isort
# "D", # see: https://pypi.org/project/pydocstyle
# "N", # see: https://pypi.org/project/pep8-naming
]
Expand Down
5 changes: 2 additions & 3 deletions tests/callbacks/test_data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

import pytest
import torch
from pytorch_lightning import Trainer
from torch import nn

from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models import LitMNIST
from pytorch_lightning import Trainer
from torch import nn


# @pytest.mark.parametrize(("log_every_n_steps", "max_steps", "expected_calls"), [pytest.param(3, 10, 3)])
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import pytest
from pl_bolts.callbacks import ORTCallback
from pl_bolts.utils import _TORCH_ORT_AVAILABLE
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.callbacks import ORTCallback
from pl_bolts.utils import _TORCH_ORT_AVAILABLE
from tests.helpers.boring_model import BoringModel

if _TORCH_ORT_AVAILABLE:
Expand Down
3 changes: 1 addition & 2 deletions tests/callbacks/test_param_update_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import pytest
import torch
from torch import nn

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from torch import nn


@pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1])
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import pytest
import torch
from pl_bolts.callbacks import SparseMLCallback
from pl_bolts.utils import _SPARSEML_AVAILABLE
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.callbacks import SparseMLCallback
from pl_bolts.utils import _SPARSEML_AVAILABLE
from tests.helpers.boring_model import BoringModel

if _SPARSEML_AVAILABLE:
Expand Down
3 changes: 1 addition & 2 deletions tests/callbacks/test_variational_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pytorch_lightning.loggers.base import DummyLogger

from pl_bolts.callbacks import LatentDimInterpolator
from pl_bolts.models.gans import GAN
from pytorch_lightning.loggers.base import DummyLogger


def test_latent_dim_interpolator():
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/verification/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pytest
import torch
import torch.nn as nn
from pl_bolts.callbacks.verification.base import VerificationBase
from pl_bolts.utils import _PL_GREATER_EQUAL_1_4
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import move_data_to_device

from pl_bolts.callbacks.verification.base import VerificationBase
from pl_bolts.utils import _PL_GREATER_EQUAL_1_4
from tests import _MARK_REQUIRE_GPU


Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/verification/test_batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pytest
import torch
from pl_bolts.callbacks import BatchGradientVerificationCallback
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval
from pl_bolts.utils import BatchGradientVerification
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, nn

from pl_bolts.callbacks import BatchGradientVerificationCallback
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval
from pl_bolts.utils import BatchGradientVerification
from tests import _MARK_REQUIRE_GPU


Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import pytest
import torch
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.imports import _IS_WINDOWS

from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13
from pl_bolts.utils.stability import UnderReviewWarning
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.imports import _IS_WINDOWS

# GitHub Actions use this path to cache datasets.
# Use `datadir` fixture where possible and use `DATASETS_PATH` in
Expand Down
3 changes: 1 addition & 2 deletions tests/datamodules/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
from torch.utils.data import DataLoader

from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datasets.cifar10_dataset import CIFAR10
from torch.utils.data import DataLoader


def test_async_dataloader(datadir):
Expand Down
1 change: 0 additions & 1 deletion tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import torch
from PIL import Image

from pl_bolts.datamodules import (
BinaryEMNISTDataModule,
BinaryMNISTDataModule,
Expand Down
3 changes: 1 addition & 2 deletions tests/datamodules/test_experience_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import gym
import numpy as np
import torch
from torch.utils.data import DataLoader

from pl_bolts.datamodules.experience_source import (
BaseExperienceSource,
DiscountedExperienceSource,
Expand All @@ -14,6 +12,7 @@
ExperienceSourceDataset,
)
from pl_bolts.models.rl.common.agents import Agent
from torch.utils.data import DataLoader


class DummyAgent(Agent):
Expand Down
3 changes: 1 addition & 2 deletions tests/datamodules/test_sklearn_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from warnings import warn

import numpy as np
from pytorch_lightning import seed_everything

from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pytorch_lightning import seed_everything

try:
from sklearn.utils import shuffle as sk_shuffle
Expand Down
3 changes: 1 addition & 2 deletions tests/datasets/test_array_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import pytest
import torch
from pytorch_lightning.utilities import exceptions

from pl_bolts.datasets import ArrayDataset, DataModel
from pl_bolts.datasets.utils import to_tensor
from pytorch_lightning.utilities import exceptions


class TestArrayDataset:
Expand Down
1 change: 0 additions & 1 deletion tests/datasets/test_base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pytest
import torch

from pl_bolts.datasets.base_dataset import DataModel
from pl_bolts.datasets.utils import to_tensor

Expand Down
5 changes: 2 additions & 3 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as transform_lib

from pl_bolts.datasets import (
BinaryEMNIST,
BinaryMNIST,
Expand All @@ -19,6 +16,8 @@
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as transform_lib

if _PIL_AVAILABLE:
from PIL import Image
Expand Down
1 change: 0 additions & 1 deletion tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import torch.testing

from pl_bolts.datasets.utils import to_tensor


Expand Down
3 changes: 1 addition & 2 deletions tests/losses/test_rl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import numpy as np
import torch
from torch import Tensor

from pl_bolts.losses.rl import double_dqn_loss, dqn_loss, per_dqn_loss
from pl_bolts.models.rl.common.gym_wrappers import make_environment
from pl_bolts.models.rl.common.networks import CNN
from torch import Tensor


class TestRLLoss(TestCase):
Expand Down
1 change: 0 additions & 1 deletion tests/metrics/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch

from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k


Expand Down
7 changes: 3 additions & 4 deletions tests/models/gans/integration/test_gans.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import warnings

import pytest
from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as transform_lib

from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet


@pytest.mark.parametrize(
"dm_cls",
Expand Down
3 changes: 1 addition & 2 deletions tests/models/gans/unit/test_basic_components.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
import torch
from pytorch_lightning import seed_everything

from pl_bolts.models.gans.basic.components import Discriminator, Generator
from pytorch_lightning import seed_everything


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tests/models/regression/test_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import operator

import pytorch_lightning as pl

from pl_bolts import datamodules
from pl_bolts.models import regression

Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/integration/test_actor_critic_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse

from pytorch_lightning import Trainer

from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
from pl_bolts.models.rl.sac_model import SAC
from pytorch_lightning import Trainer


def test_a2c():
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/integration/test_policy_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
from unittest import TestCase

from pytorch_lightning import Trainer

from pl_bolts.models.rl.reinforce_model import Reinforce
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
from pytorch_lightning import Trainer


class TestPolicyModels(TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/integration/test_value_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from unittest import TestCase

import pytest
from pytorch_lightning import Trainer

from pl_bolts.models.rl.double_dqn_model import DoubleDQN
from pl_bolts.models.rl.dqn_model import DQN
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
from pl_bolts.models.rl.per_dqn_model import PERDQN
from pytorch_lightning import Trainer


class TestValueModels(TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/unit/test_a2c.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse

import torch
from torch import Tensor

from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
from torch import Tensor


def test_a2c_loss():
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/unit/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import gym
import numpy as np
import torch
from torch import Tensor

from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent
from torch import Tensor


class TestAgents(TestCase):
Expand Down
1 change: 0 additions & 1 deletion tests/models/rl/unit/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import torch

from pl_bolts.models.rl.common.memory import Buffer, Experience, MultiStepBuffer, PERBuffer, ReplayBuffer


Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/unit/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import torch
from pl_bolts.models.rl.ppo_model import PPO
from pytorch_lightning import Trainer
from torch import Tensor

from pl_bolts.models.rl.ppo_model import PPO


def test_discount_rewards():
"""Test calculation of discounted rewards."""
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/unit/test_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import gym
import numpy as np
import torch
from torch import Tensor

from pl_bolts.datamodules.experience_source import DiscountedExperienceSource
from pl_bolts.models.rl.common.agents import Agent
from pl_bolts.models.rl.common.gym_wrappers import ToTensor
from pl_bolts.models.rl.common.networks import MLP
from pl_bolts.models.rl.reinforce_model import Reinforce
from torch import Tensor


class TestReinforce(TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/models/rl/unit/test_sac.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse

import torch
from torch import Tensor

from pl_bolts.models.rl.sac_model import SAC
from torch import Tensor


def test_sac_loss():
Expand Down

0 comments on commit ec81c59

Please sign in to comment.