Skip to content

Commit

Permalink
update workers - hanging tests (#521)
Browse files Browse the repository at this point in the history
* tests

* yapf

* space

* num_workers

* strip

* gpus

* fp32

* num_workers

* num_workers

* num_workers
  • Loading branch information
Borda committed Jan 18, 2021
1 parent a82e9f4 commit 99232c3
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 49 deletions.
2 changes: 1 addition & 1 deletion pl_bolts/models/regression/logistic_regression.py
Expand Up @@ -139,7 +139,7 @@ def cli_main():

# data
X, y = load_iris(return_X_y=True)
loaders = SklearnDataModule(X, y, batch_size=args.batch_size)
loaders = SklearnDataModule(X, y, batch_size=args.batch_size, num_workers=0)

# train
trainer = pl.Trainer.from_argparse_args(args)
Expand Down
6 changes: 4 additions & 2 deletions pl_bolts/models/self_supervised/amdim/amdim_module.py
Expand Up @@ -53,6 +53,7 @@ def __init__(
data_dir: str = '',
num_classes: int = 10,
batch_size: int = 200,
num_workers: int = 16,
**kwargs,
):
"""
Expand Down Expand Up @@ -213,7 +214,7 @@ def train_dataloader(self):
batch_size=self.hparams.batch_size,
pin_memory=True,
drop_last=True,
num_workers=16,
num_workers=self.hparams.num_workers,
)
return loader

Expand All @@ -227,7 +228,7 @@ def val_dataloader(self):
batch_size=self.hparams.batch_size,
pin_memory=True,
drop_last=True,
num_workers=16,
num_workers=self.hparams.num_workers,
)
return loader

Expand Down Expand Up @@ -342,6 +343,7 @@ def add_model_specific_args(parent_parser):

# data
parser.add_argument('--data_dir', default=os.getcwd(), type=str)
parser.add_argument('--num_workers', type=int, default=16)
return parser


Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/byol/byol_module.py
Expand Up @@ -156,7 +156,7 @@ def add_model_specific_args(parent_parser):

# Data
parser.add_argument('--data_dir', type=str, default='.')
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--num_workers', default=8, type=int)

# optim
parser.add_argument('--batch_size', type=int, default=256)
Expand Down
14 changes: 7 additions & 7 deletions tests/models/rl/test_scripts.py
Expand Up @@ -17,7 +17,7 @@ def test_cli_run_rl_dqn(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.dqn_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -36,7 +36,7 @@ def test_cli_run_rl_double_dqn(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.double_dqn_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -55,7 +55,7 @@ def test_cli_run_rl_dueling_dqn(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.dueling_dqn_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -74,7 +74,7 @@ def test_cli_run_rl_noisy_dqn(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.noisy_dqn_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -93,7 +93,7 @@ def test_cli_run_rl_per_dqn(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.per_dqn_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -108,7 +108,7 @@ def test_cli_run_rl_reinforce(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.reinforce_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand All @@ -123,6 +123,6 @@ def test_cli_run_rl_vanilla_policy_gradient(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.rl.vanilla_policy_gradient_model import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()
92 changes: 73 additions & 19 deletions tests/models/self_supervised/test_scripts.py
Expand Up @@ -7,97 +7,151 @@


@pytest.mark.parametrize(
'cli_args', [f"--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2"]
'cli_args', [
f"--data_dir {DATASETS_PATH}"
" --max_epochs 1"
" --max_steps 3"
" --fast_dev_run 1"
" --batch_size 2"
" --num_workers 0"
]
)
def test_cli_run_self_supervised_amdim(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.amdim.amdim_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it...
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize(
'cli_args',
[f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --encoder resnet18']
'cli_args', [
f' --data_dir {DATASETS_PATH} --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
' --encoder resnet18'
' --num_workers 0'
]
)
def test_cli_run_self_supervised_cpc(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.cpc.cpc_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize(
'cli_args', [f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2']
'cli_args', [
f' --data_dir {DATASETS_PATH}'
' --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
' --num_workers 0'
]
)
def test_cli_run_self_supervised_moco(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.moco.moco2_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize(
'cli_args', [
f'--data_dir {DATASETS_PATH} --gpus 0 --fp32 --max_epochs 1 --max_steps 3 --fast_dev_run 1'
' --batch_size 2 --online_ft'
f' --data_dir {DATASETS_PATH}'
' --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
' --num_workers 0'
' --online_ft'
' --gpus 0'
' --fp32'
]
)
def test_cli_run_self_supervised_simclr(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.simclr.simclr_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize(
'cli_args',
[f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft']
'cli_args', [
f' --data_dir {DATASETS_PATH}'
' --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
" --num_workers 0"
' --online_ft'
]
)
def test_cli_run_self_supervised_byol(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.byol.byol_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize(
'cli_args', [
f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2'
' --gpus 0 --arch resnet18 --hidden_mlp 512 --fp32 --sinkhorn_iterations 1 --nmb_prototypes 2 --queue_length 0'
' --dataset cifar10'
f' --data_dir {DATASETS_PATH}'
' --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
' --arch resnet18'
' --hidden_mlp 512'
' --sinkhorn_iterations 1'
' --nmb_prototypes 2'
' --num_workers 0'
' --queue_length 0'
' --gpus 0'
' --fp32'
]
)
def test_cli_run_self_supervised_swav(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.swav.swav_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize(
'cli_args', [
f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2'
' --gpus 0 --fp32 --online_ft'
' --dataset cifar10'
f' --data_dir {DATASETS_PATH}'
' --max_epochs 1'
' --max_steps 3'
' --fast_dev_run 1'
' --batch_size 2'
' --num_workers 0'
' --online_ft'
' --gpus 0'
' --fp32'
]
)
def test_cli_run_self_supervised_simsiam(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.self_supervised.simsiam.simsiam_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = cli_args.strip().split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()
8 changes: 5 additions & 3 deletions tests/models/test_gans.py
Expand Up @@ -7,13 +7,15 @@


@pytest.mark.parametrize(
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"),
pytest.param(CIFAR10DataModule, id="cifar10")]
"dm_cls", [
pytest.param(MNISTDataModule, id="mnist"),
pytest.param(CIFAR10DataModule, id="cifar10"),
]
)
def test_gan(tmpdir, datadir, dm_cls):
seed_everything()

dm = dm_cls(data_dir=datadir)
dm = dm_cls(data_dir=datadir, num_workers=0)
model = GAN(*dm.size())
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, datamodule=dm)
Expand Down

0 comments on commit 99232c3

Please sign in to comment.