Skip to content

Commit

Permalink
removed flag from simclr/swav (#517)
Browse files Browse the repository at this point in the history
* removed flag from simclr/swav

* dm

* simclr

* simsiam

* swav

* nodes

* kwargs

* nodes

* nodes

* gpus

* gpus

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
ananyahjha93 and Borda committed Jan 18, 2021
1 parent 55ab214 commit 61d3a26
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 62 deletions.
30 changes: 11 additions & 19 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Expand Up @@ -68,7 +68,7 @@ def __init__(
num_samples: int,
batch_size: int,
dataset: str,
nodes: int = 1,
num_nodes: int = 1,
arch: str = 'resnet50',
hidden_mlp: int = 2048,
feat_dim: int = 128,
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
self.save_hyperparameters()

self.gpus = gpus
self.nodes = nodes
self.num_nodes = num_nodes
self.arch = arch
self.dataset = dataset
self.num_samples = num_samples
Expand Down Expand Up @@ -127,7 +127,9 @@ def __init__(
self.projection = Projection(input_dim=self.hidden_mlp, hidden_dim=self.hidden_mlp, output_dim=self.feat_dim)

# compute iters per epoch
global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
nb_gpus = len(self.gpus) if isinstance(gpus, (list, tuple)) else self.gpus
assert isinstance(nb_gpus, int)
global_batch_size = self.num_nodes * nb_gpus * self.batch_size if nb_gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
Expand Down Expand Up @@ -312,15 +314,10 @@ def add_model_specific_args(parent_parser):
parser.add_argument("--data_dir", type=str, default=".", help="path to download data")

# training params
parser.add_argument("--fast_dev_run", default=1, type=int)
parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used")
parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay")
parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")

Expand All @@ -342,6 +339,7 @@ def cli_main():

# model args
parser = SimCLR.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

if args.dataset == 'stl10':
Expand All @@ -361,8 +359,8 @@ def cli_main():
args.jitter_strength = 1.
elif args.dataset == 'cifar10':
val_split = 5000
if args.nodes * args.gpus * args.batch_size > val_split:
val_split = args.nodes * args.gpus * args.batch_size
if args.num_nodes * args.gpus * args.batch_size > val_split:
val_split = args.num_nodes * args.gpus * args.batch_size

dm = CIFAR10DataModule(
data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split
Expand All @@ -388,7 +386,7 @@ def cli_main():
args.jitter_strength = 1.

args.batch_size = 64
args.nodes = 8
args.num_nodes = 8
args.gpus = 8 # per-node
args.max_epochs = 800

Expand Down Expand Up @@ -432,16 +430,10 @@ def cli_main():
model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss')
callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]

trainer = pl.Trainer(
max_epochs=args.max_epochs,
max_steps=None if args.max_steps == -1 else args.max_steps,
gpus=args.gpus,
num_nodes=args.nodes,
distributed_backend='ddp' if args.gpus > 1 else None,
trainer = pl.Trainer.from_argparse_args(
args,
sync_batchnorm=True if args.gpus > 1 else False,
precision=32 if args.fp32 else 16,
callbacks=callbacks,
fast_dev_run=args.fast_dev_run
)

trainer.fit(model, datamodule=dm)
Expand Down
25 changes: 10 additions & 15 deletions pl_bolts/models/self_supervised/simsiam/simsiam_module.py
Expand Up @@ -73,7 +73,7 @@ def __init__(
num_samples: int,
batch_size: int,
dataset: str,
nodes: int = 1,
num_nodes: int = 1,
arch: str = 'resnet50',
hidden_mlp: int = 2048,
feat_dim: int = 128,
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self.save_hyperparameters()

self.gpus = gpus
self.nodes = nodes
self.num_nodes = num_nodes
self.arch = arch
self.dataset = dataset
self.num_samples = num_samples
Expand All @@ -132,7 +132,9 @@ def __init__(
self.init_model()

# compute iters per epoch
global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
nb_gpus = len(self.gpus) if isinstance(gpus, (list, tuple)) else self.gpus
assert isinstance(nb_gpus, int)
global_batch_size = self.num_nodes * nb_gpus * self.batch_size if nb_gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
Expand Down Expand Up @@ -292,7 +294,6 @@ def add_model_specific_args(parent_parser):
parser.add_argument("--data_dir", type=str, default=".", help="path to download data")

# training params
parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used")
Expand Down Expand Up @@ -346,8 +347,8 @@ def cli_main():
args.jitter_strength = 1.0
elif args.dataset == "cifar10":
val_split = 5000
if args.nodes * args.gpus * args.batch_size > val_split:
val_split = args.nodes * args.gpus * args.batch_size
if args.num_nodes * args.gpus * args.batch_size > val_split:
val_split = args.num_nodes * args.gpus * args.batch_size

dm = CIFAR10DataModule(
data_dir=args.data_dir,
Expand Down Expand Up @@ -376,7 +377,7 @@ def cli_main():
args.jitter_strength = 1.0

args.batch_size = 64
args.nodes = 8
args.num_nodes = 8
args.gpus = 8 # per-node
args.max_epochs = 800

Expand Down Expand Up @@ -422,16 +423,10 @@ def cli_main():
dataset=args.dataset,
)

trainer = pl.Trainer(
max_epochs=args.max_epochs,
max_steps=None if args.max_steps == -1 else args.max_steps,
gpus=args.gpus,
num_nodes=args.nodes,
distributed_backend="ddp" if args.gpus > 1 else None,
trainer = pl.Trainer.from_argparse_args(
args,
sync_batchnorm=True if args.gpus > 1 else False,
precision=32 if args.fp32 else 16,
callbacks=[online_evaluator] if args.online_ft else None,
fast_dev_run=args.fast_dev_run,
)

trainer.fit(model, datamodule=dm)
Expand Down
36 changes: 16 additions & 20 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Expand Up @@ -32,7 +32,7 @@ def __init__(
num_samples: int,
batch_size: int,
dataset: str,
nodes: int = 1,
num_nodes: int = 1,
arch: str = 'resnet50',
hidden_mlp: int = 2048,
feat_dim: int = 128,
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
Args:
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
nodes: number of nodes to train on
num_nodes: number of nodes to train on
num_samples: number of image samples used for training
batch_size: batch size per GPU in ddp
dataset: dataset being used for train/val
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
self.save_hyperparameters()

self.gpus = gpus
self.nodes = nodes
self.num_nodes = num_nodes
self.arch = arch
self.dataset = dataset
self.num_samples = num_samples
Expand Down Expand Up @@ -136,15 +136,17 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus * self.nodes > 1:
if self.gpus or self.num_nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()

# compute iters per epoch
global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
nb_gpus = len(self.gpus) if isinstance(gpus, (list, tuple)) else self.gpus
assert isinstance(nb_gpus, int)
global_batch_size = self.num_nodes * nb_gpus * self.batch_size if nb_gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
Expand Down Expand Up @@ -429,15 +431,10 @@ def add_model_specific_args(parent_parser):
)

# training params
parser.add_argument("--fast_dev_run", default=1, type=int)
parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used")
parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay")
parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")

Expand Down Expand Up @@ -490,6 +487,7 @@ def cli_main():

# model args
parser = SwAV.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

if args.dataset == 'stl10':
Expand Down Expand Up @@ -532,7 +530,7 @@ def cli_main():
args.jitter_strength = 1.

args.batch_size = 64
args.nodes = 8
args.num_nodes = 8
args.gpus = 8 # per-node
args.max_epochs = 800

Expand Down Expand Up @@ -579,22 +577,20 @@ def cli_main():
if args.online_ft:
# online eval
online_evaluator = SSLOnlineEvaluator(
drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset
drop_p=0.,
hidden_dim=None,
z_dim=args.hidden_mlp,
num_classes=dm.num_classes,
dataset=args.dataset,
)

model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss')
callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]

trainer = pl.Trainer(
max_epochs=args.max_epochs,
max_steps=None if args.max_steps == -1 else args.max_steps,
gpus=args.gpus,
num_nodes=args.nodes,
distributed_backend='ddp' if args.gpus > 1 else None,
trainer = pl.Trainer.from_argparse_args(
args,
sync_batchnorm=True if args.gpus > 1 else False,
precision=32 if args.fp32 else 16,
callbacks=callbacks,
fast_dev_run=args.fast_dev_run
)

trainer.fit(model, datamodule=dm)
Expand Down
14 changes: 7 additions & 7 deletions tests/models/rl/test_scripts.py
Expand Up @@ -5,7 +5,7 @@

@pytest.mark.parametrize(
'cli_args', [
'--env PongNoFrameskip-v4'
' --env PongNoFrameskip-v4'
' --max_steps 10'
' --fast_dev_run 1'
' --warm_start_size 10'
Expand All @@ -24,7 +24,7 @@ def test_cli_run_rl_dqn(cli_args):

@pytest.mark.parametrize(
'cli_args', [
'--env PongNoFrameskip-v4'
' --env PongNoFrameskip-v4'
' --max_steps 10'
' --fast_dev_run 1'
' --warm_start_size 10'
Expand All @@ -43,7 +43,7 @@ def test_cli_run_rl_double_dqn(cli_args):

@pytest.mark.parametrize(
'cli_args', [
'--env PongNoFrameskip-v4'
' --env PongNoFrameskip-v4'
' --max_steps 10'
' --fast_dev_run 1'
' --warm_start_size 10'
Expand All @@ -62,7 +62,7 @@ def test_cli_run_rl_dueling_dqn(cli_args):

@pytest.mark.parametrize(
'cli_args', [
'--env PongNoFrameskip-v4'
' --env PongNoFrameskip-v4'
' --max_steps 10'
' --fast_dev_run 1'
' --warm_start_size 10'
Expand All @@ -81,7 +81,7 @@ def test_cli_run_rl_noisy_dqn(cli_args):

@pytest.mark.parametrize(
'cli_args', [
'--env PongNoFrameskip-v4'
' --env PongNoFrameskip-v4'
' --max_steps 10'
' --fast_dev_run 1'
' --warm_start_size 10'
Expand All @@ -99,7 +99,7 @@ def test_cli_run_rl_per_dqn(cli_args):


@pytest.mark.parametrize('cli_args', [
'--env CartPole-v0'
' --env CartPole-v0'
' --max_steps 10'
' --fast_dev_run 1'
' --batch_size 10',
Expand All @@ -114,7 +114,7 @@ def test_cli_run_rl_reinforce(cli_args):


@pytest.mark.parametrize('cli_args', [
'--env CartPole-v0'
' --env CartPole-v0'
' --max_steps 10'
' --fast_dev_run 1'
' --batch_size 10',
Expand Down
1 change: 0 additions & 1 deletion tests/models/self_supervised/test_scripts.py
Expand Up @@ -122,7 +122,6 @@ def test_cli_run_self_supervised_byol(cli_args):
' --num_workers 0'
' --queue_length 0'
' --gpus 0'
' --fp32'
]
)
def test_cli_run_self_supervised_swav(cli_args):
Expand Down

0 comments on commit 61d3a26

Please sign in to comment.