diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index 1d9838473b..f57bd85cfe 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -139,7 +139,7 @@ def cli_main(): # data X, y = load_iris(return_X_y=True) - loaders = SklearnDataModule(X, y, batch_size=args.batch_size) + loaders = SklearnDataModule(X, y, batch_size=args.batch_size, num_workers=0) # train trainer = pl.Trainer.from_argparse_args(args) diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index 11f4f9b73a..f1cd01a28f 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -53,6 +53,7 @@ def __init__( data_dir: str = '', num_classes: int = 10, batch_size: int = 200, + num_workers: int = 16, **kwargs, ): """ @@ -213,7 +214,7 @@ def train_dataloader(self): batch_size=self.hparams.batch_size, pin_memory=True, drop_last=True, - num_workers=16, + num_workers=self.hparams.num_workers, ) return loader @@ -227,7 +228,7 @@ def val_dataloader(self): batch_size=self.hparams.batch_size, pin_memory=True, drop_last=True, - num_workers=16, + num_workers=self.hparams.num_workers, ) return loader @@ -342,6 +343,7 @@ def add_model_specific_args(parent_parser): # data parser.add_argument('--data_dir', default=os.getcwd(), type=str) + parser.add_argument('--num_workers', type=int, default=16) return parser diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 8d31909d66..e52f19f4fc 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -156,7 +156,7 @@ def add_model_specific_args(parent_parser): # Data parser.add_argument('--data_dir', type=str, default='.') - parser.add_argument('--num_workers', default=0, type=int) + parser.add_argument('--num_workers', default=8, type=int) # optim parser.add_argument('--batch_size', type=int, default=256) diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py index a468fe2c99..7ea9b3d024 100644 --- a/tests/models/rl/test_scripts.py +++ b/tests/models/rl/test_scripts.py @@ -17,7 +17,7 @@ def test_cli_run_rl_dqn(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.dqn_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -36,7 +36,7 @@ def test_cli_run_rl_double_dqn(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.double_dqn_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -55,7 +55,7 @@ def test_cli_run_rl_dueling_dqn(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.dueling_dqn_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -74,7 +74,7 @@ def test_cli_run_rl_noisy_dqn(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.noisy_dqn_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -93,7 +93,7 @@ def test_cli_run_rl_per_dqn(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.per_dqn_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -108,7 +108,7 @@ def test_cli_run_rl_reinforce(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.reinforce_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -123,6 +123,6 @@ def test_cli_run_rl_vanilla_policy_gradient(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.vanilla_policy_gradient_model import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index 76ef8d6053..9bfdc3886f 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -7,13 +7,20 @@ @pytest.mark.parametrize( - 'cli_args', [f"--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2"] + 'cli_args', [ + f"--data_dir {DATASETS_PATH}" + " --max_epochs 1" + " --max_steps 3" + " --fast_dev_run 1" + " --batch_size 2" + " --num_workers 0" + ] ) def test_cli_run_self_supervised_amdim(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.amdim.amdim_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -21,83 +28,130 @@ def test_cli_run_self_supervised_amdim(cli_args): # TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize( - 'cli_args', - [f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --encoder resnet18'] + 'cli_args', [ + f' --data_dir {DATASETS_PATH} --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + ' --encoder resnet18' + ' --num_workers 0' + ] ) def test_cli_run_self_supervised_cpc(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.cpc.cpc_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( - 'cli_args', [f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2'] + 'cli_args', [ + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + ' --num_workers 0' + ] ) def test_cli_run_self_supervised_moco(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.moco.moco2_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( 'cli_args', [ - f'--data_dir {DATASETS_PATH} --gpus 0 --fp32 --max_epochs 1 --max_steps 3 --fast_dev_run 1' - ' --batch_size 2 --online_ft' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + ' --num_workers 0' + ' --online_ft' + ' --gpus 0' + ' --fp32' ] ) def test_cli_run_self_supervised_simclr(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.simclr.simclr_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( - 'cli_args', - [f'--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2 --online_ft'] + 'cli_args', [ + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + " --num_workers 0" + ' --online_ft' + ] ) def test_cli_run_self_supervised_byol(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.byol.byol_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( 'cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' - ' --gpus 0 --arch resnet18 --hidden_mlp 512 --fp32 --sinkhorn_iterations 1 --nmb_prototypes 2 --queue_length 0' + ' --dataset cifar10' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + ' --arch resnet18' + ' --hidden_mlp 512' + ' --sinkhorn_iterations 1' + ' --nmb_prototypes 2' + ' --num_workers 0' + ' --queue_length 0' + ' --gpus 0' + ' --fp32' ] ) def test_cli_run_self_supervised_swav(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.swav.swav_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( 'cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 3 --fast_dev_run 1 --batch_size 2' - ' --gpus 0 --fp32 --online_ft' + ' --dataset cifar10' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --max_steps 3' + ' --fast_dev_run 1' + ' --batch_size 2' + ' --num_workers 0' + ' --online_ft' + ' --gpus 0' + ' --fp32' ] ) def test_cli_run_self_supervised_simsiam(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.self_supervised.simsiam.simsiam_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() diff --git a/tests/models/test_gans.py b/tests/models/test_gans.py index b47fe7eb21..70f0c9c00a 100644 --- a/tests/models/test_gans.py +++ b/tests/models/test_gans.py @@ -7,13 +7,15 @@ @pytest.mark.parametrize( - "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), - pytest.param(CIFAR10DataModule, id="cifar10")] + "dm_cls", [ + pytest.param(MNISTDataModule, id="mnist"), + pytest.param(CIFAR10DataModule, id="cifar10"), + ] ) def test_gan(tmpdir, datadir, dm_cls): seed_everything() - dm = dm_cls(data_dir=datadir) + dm = dm_cls(data_dir=datadir, num_workers=0) model = GAN(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=dm) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index ea745f883d..5da6c976c8 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -6,17 +6,21 @@ from tests import DATASETS_PATH +@pytest.mark.parametrize('dataset_name', ['mnist', 'cifar10']) @pytest.mark.parametrize( 'cli_args', [ - f'--dataset mnist --data_dir {DATASETS_PATH} --max_epochs 1' - ' --batch_size 2 --limit_train_batches 2 --limit_val_batches 2', - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1' - ' --batch_size 2 --limit_train_batches 2 --limit_val_batches 2', + ' --dataset %(dataset_name)s' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --batch_size 2' + ' --limit_train_batches 2' + ' --limit_val_batches 2' ] ) -def test_cli_run_basic_gan(cli_args): +def test_cli_run_basic_gan(cli_args, dataset_name): from pl_bolts.models.gans.basic.basic_gan_module import cli_main + cli_args = cli_args % {'dataset_name': dataset_name} with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): cli_main() @@ -25,14 +29,18 @@ def test_cli_run_basic_gan(cli_args): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize( 'cli_args', [ - f'--data_dir {DATASETS_PATH} --max_epochs 1 --limit_train_batches 2' - ' --limit_val_batches 2 --batch_size 2 --encoder resnet18', + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --limit_train_batches 2' + ' --limit_val_batches 2' + ' --batch_size 2' + ' --encoder resnet18' ] ) def test_cli_run_cpc(cli_args): from pl_bolts.models.self_supervised.cpc.cpc_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -42,35 +50,45 @@ def test_cli_run_mnist(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.mnist_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( 'cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --batch_size 2 --fast_dev_run 1', + ' --dataset cifar10' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --batch_size 2' + ' --fast_dev_run 1' + ' --num_workers 0' ] ) def test_cli_run_basic_ae(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @pytest.mark.parametrize( 'cli_args', [ - f'--dataset cifar10 --data_dir {DATASETS_PATH} --max_epochs 1 --batch_size 2 --fast_dev_run 1', + ' --dataset cifar10' + f' --data_dir {DATASETS_PATH}' + ' --max_epochs 1' + ' --batch_size 2' + ' --fast_dev_run 1' + ' --num_workers 0' ] ) def test_cli_run_basic_vae(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -80,7 +98,7 @@ def test_cli_run_lin_regression(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.regression.linear_regression import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -90,7 +108,7 @@ def test_cli_run_log_regression(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.regression.logistic_regression import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() @@ -102,6 +120,6 @@ def test_cli_run_vision_image_gpt(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.vision.image_gpt.igpt_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main()