Skip to content

Commit

Permalink
Merge pull request graphnet-team#426 from MortenHolmRep/batch_training
Browse files Browse the repository at this point in the history
Bash script for training on multiple models
  • Loading branch information
MortenHolmRep committed Feb 10, 2023
2 parents 3aae0b6 + 2a97797 commit 5b059be
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
2 changes: 1 addition & 1 deletion configs/models/dynedge_position_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ arguments:
arguments: {}
class_name: MSELoss
loss_weight: null
target_labels: ["position_x", "position_y", "position_y"]
target_labels: ["position_x", "position_y", "position_z"]
transform_inference: null
transform_prediction_and_target: null
transform_support: null
Expand Down
14 changes: 13 additions & 1 deletion examples/04_training/01_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def main(
batch_size: int,
num_workers: int,
prediction_names: Optional[List[str]],
suffix: Optional[str] = None,
) -> None:
"""Run example."""
# Initialise Weights & Biases (W&B) run
Expand Down Expand Up @@ -62,7 +63,10 @@ def main(
dataloader={"batch_size": batch_size, "num_workers": num_workers},
)

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model")
if suffix is not None:
archive = os.path.join(EXAMPLE_OUTPUT_DIR, f"train_model_{suffix}")
else:
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model")
run_name = "dynedge_{}_example".format("_".join(config.target))

# Construct dataloaders
Expand Down Expand Up @@ -154,6 +158,13 @@ def main(
default=None,
)

parser.add_argument(
"--suffix",
type=str,
help="Name addition to folder (default: %(default)s)",
default=None,
)

args = parser.parse_args()

main(
Expand All @@ -165,4 +176,5 @@ def main(
args.batch_size,
args.num_workers,
args.prediction_names,
args.suffix,
)
62 changes: 62 additions & 0 deletions examples/04_training/01_train_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/bin/bash

#### This script enables the user to run multiple trainings in sequence on the same database but for different model configs.
# To execute this file, copy the file path and write in the terminal; $ bash <filepath>


# execution of bash file in same directory as the script
bash_directory=$(dirname -- "$(readlink -f "${BASH_SOURCE}")")

## Global; applies to all models
# path to dataset configuration file in the GraphNeT directory
dataset_config=$(realpath "$bash_directory/../../configs/datasets/training_example_data_sqlite.yml")
# what GPU to use; more information can be gained with the module nvitop
gpus=0
# the maximum number of epochs; if used, this greatly affect learning rate scheduling
max_epochs=5
# early stopping threshold
early_stopping_patience=5
# events in a batch
batch_size=16
# number of CPUs to use
num_workers=2

## Model dependent; applies to each model in sequence
# path to model files in the GraphNeT directory
model_directory=$(realpath "$bash_directory/../../configs/models")
# list of model configurations to train
declare -a model_configs=(
"${model_directory}/example_direction_reconstruction_model.yml"
"${model_directory}/example_energy_reconstruction_model.yml"
"${model_directory}/example_vertex_position_reconstruction_model.yml"
)

# suffix ending on the created directory
declare -a suffixs=(
"direction"
"energy"
"position"
)

# prediction name outputs per model
declare -a prediction_names=(
"zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred"
"energy_pred"
"position_x_pred position_y_pred position_z_pred"
)

for i in "${!model_configs[@]}"; do
echo "training iteration ${i} on ${model_configs[$i]} with output variables ${prediction_names[i][@]}"
python ${bash_directory}/01_train_model.py \
--dataset-config ${dataset_config} \
--model-config ${model_configs[$i]} \
--gpus ${gpus} \
--max-epochs ${max_epochs} \
--early-stopping-patience ${early_stopping_patience} \
--batch-size ${batch_size} \
--num-workers ${num_workers} \
--prediction-names ${prediction_names[i][@]} \
--suffix ${suffixs[i]}
wait
done
echo "all trainings are done."

0 comments on commit 5b059be

Please sign in to comment.