Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/sbatch/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@ py_args: "" # arguments for main.py
note: "" # wandb run note
git_checkout: null # if null, no checkout. Use as `git_checkout=some-branch` or `git_checkout=somecommithash`

sweep: false
count: 0
array: 0

dev: false
6 changes: 6 additions & 0 deletions configs/sweep/defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# default minydra args for sweep.py

method: random
params: sweep_wandb_all.yml
count: 1
name: null
2 changes: 1 addition & 1 deletion hyperparams_tuning_victor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH -J Sweep
#SBATCH -t 48:00:00
#SBATCH -N 1
#SBATCH --gres=gpu:4
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#SBATCH --mem=32GB
#SBATCH --output="/network/scratch/s/schmidtv/ocp/runs/sweep/output-%j.out"
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
save_experiment_log,
setup_imports,
setup_logging,
update_from_sbatch_py_vars,
)


Expand Down Expand Up @@ -66,6 +67,7 @@ def __call__(self, config):
data_split=config.get("data_split", None),
note=config.get("note", ""),
test_rotation_invariance=config.get("test_ri", None),
wandb_tag=config.get("wandb_tag", None),
)
self.task = registry.get_task_class(config["mode"])(self.config)
self.task.setup(self.trainer)
Expand Down Expand Up @@ -168,6 +170,7 @@ def should_continue(self, config):

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
args = update_from_sbatch_py_vars(args)
if not args.mode or not args.config_yml:
args.mode = "train"
# args.config_yml = "configs/is2re/10k/schnet/new_schnet.yml"
Expand Down
4 changes: 4 additions & 0 deletions ocpmodels/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self, config):
wandb_id += f"{slurm_jobid}-"
wandb_id += self.config["cmd"]["timestamp_id"] + "-" + config["model"]

wandb_tag = config.get("wandb_tag")
tags = [wandb_tag] if wandb_tag else []

wandb.init(
config=self.config,
id=wandb_id,
Expand All @@ -75,6 +78,7 @@ def __init__(self, config):
project=project,
resume="allow",
notes=self.config["note"],
tags=tags,
)

sbatch_files = list(Path(self.config["run_dir"]).glob("sbatch_script*.sh"))
Expand Down
12 changes: 12 additions & 0 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def build_config(args, args_override):
config["distributed_port"] = args.distributed_port
config["world_size"] = args.num_nodes * args.num_gpus
config["distributed_backend"] = args.distributed_backend
config["wandb_tag"] = args.wandb_tag if hasattr(args, "wandb_tag") else None

return config

Expand Down Expand Up @@ -833,3 +834,14 @@ def resolve(path):
pathlib.Path: the resolved Path
"""
return Path(os.path.expandvars(os.path.expanduser(str(path)))).resolve()


def update_from_sbatch_py_vars(args):
sbatch_py_vars = {
k.replace("SBATCH_PY_", "").lower(): v if v != "true" else True
for k, v in os.environ.items()
if k.startswith("SBATCH_PY_")
}
for k, v in sbatch_py_vars.items():
setattr(args, k, v)
return args
14 changes: 13 additions & 1 deletion ocpmodels/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import random
import warnings
from pathlib import Path
from datetime import datetime
import time

import lmdb
import numpy as np
Expand Down Expand Up @@ -76,6 +78,7 @@ def __len__(self):
return self.num_samples

def __getitem__(self, idx):
t0 = time.time_ns()
if not self.path.is_file():
# Figure out which db this should be indexed from.
db_idx = bisect.bisect(self._keylen_cumulative, idx)
Expand All @@ -96,9 +99,18 @@ def __getitem__(self, idx):
else:
datapoint_pickled = self.env.begin().get(self._keys[idx])
data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))

t1 = time.time_ns()
if self.transform is not None:
data_object = self.transform(data_object)
t2 = time.time_ns()

load_time = (t1 - t0) * 1e-9 # time in s
transform_time = (t2 - t1) * 1e-9 # time in s
total_get_time = (t2 - t0) * 1e-9 # time in s

data_object.load_time = load_time
data_object.transform_time = transform_time
data_object.total_get_time = total_get_time

return data_object

Expand Down
2 changes: 2 additions & 0 deletions ocpmodels/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
data_split=None,
note="",
test_rotation_invariance=None,
wandb_tag=None,
):
self.name = name
self.cpu = cpu
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
},
"slurm": slurm,
"note": note,
"wandb_tag": wandb_tag,
}
# AMP Scaler
self.scaler = torch.cuda.amp.GradScaler() if amp else None
Expand Down
1 change: 1 addition & 0 deletions ocpmodels/trainers/energy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
data_split=None,
test_rotation_invariance=False,
note="",
wandb_tag=None,
):
super().__init__(
task=task,
Expand Down
42 changes: 34 additions & 8 deletions sbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
#SBATCH --output={output}
{time}

{git_checkout}

# {sbatch_command_line}
# git commit: {git_commit}
# cwd: {cwd}

{git_checkout}
{sbatch_py_vars}

export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
echo "Master port $MASTER_PORT"

module load anaconda/3
conda activate {env}

srun --output={output} python main.py {py_args}
srun --output={output} {python_command}
"""


Expand Down Expand Up @@ -90,9 +91,23 @@ def get_commit():
return commit


def make_sbatch_py_vars(sbatch_py_vars):
s = ""
for k, v in sbatch_py_vars.items():
k = "SBATCH_PY_" + k.replace("-", "_").upper()
s += k
if v:
s += f"={v}"
else:
s += "=true"
s += "\n"
return s[:-1]


if __name__ == "__main__":
# has the submission been successful?
success = False
sbatch_py_vars = {}

# repository root
root = Path(__file__).resolve().parent
Expand All @@ -115,9 +130,12 @@ def get_commit():

# distribute training
if args.ntasks_per_node > 1 and "--distributed" not in args.py_args:
args.py_args += (
f" --distributed --num-nodes {args.nodes} --num-gpus {args.ntasks_per_node}"
)
if args.sweep:
sbatch_py_vars["distributed"] = None
sbatch_py_vars["num-nodes"] = args.nodes
sbatch_py_vars["num-gpus"] = args.ntasks_per_node
else:
args.py_args += f" --distributed --num-nodes {args.nodes} --num-gpus {args.ntasks_per_node}"

# add logdir to main.py's command-line arguments
if "--logdir" not in args.py_args and args.logdir:
Expand All @@ -132,6 +150,12 @@ def get_commit():

git_checkout = f"git checkout {args.git_checkout}" if args.git_checkout else ""

if args.sweep:
count = f" --count {args.count}" if args.count else ""
python_command = f"wandb agent{count} {args.sweep}"
else:
python_command = f"python main.py {args.py_args}"

# format string template with defaults + command-line args
script = template.format(
cpus=args.cpus,
Expand All @@ -147,8 +171,9 @@ def get_commit():
ntasks=args.ntasks,
output=str(resolve(args.output)),
partition=args.partition,
py_args=args.py_args,
python_command=python_command,
sbatch_command_line=" ".join(["python"] + sys.argv),
sbatch_py_vars=make_sbatch_py_vars(sbatch_py_vars),
time="" if not args.time else f"#SBATCH --time={args.time}",
)

Expand Down Expand Up @@ -176,7 +201,8 @@ def get_commit():
f.write(script)

# command to request the job
command = f"sbatch {str(script_path)}"
array = f" --array={args.array}" if args.array else ""
command = f"sbatch{array} {str(script_path)}"
print(f"Executing:\n{command}")
print(f"\nFile content:\n{'=' * 50}\n{script}{'=' * 50}\n")

Expand Down
Loading