Skip to content

Commit

Permalink
Merge pull request graphnet-team#456 from RasmusOrsoe/simplify_examples
Browse files Browse the repository at this point in the history
Simplify examples
  • Loading branch information
RasmusOrsoe committed Apr 18, 2023
2 parents 61b0e8e + b720b5b commit 068c834
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 35 deletions.
33 changes: 20 additions & 13 deletions examples/04_training/01_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from graphnet.utilities.logging import Logger


# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)


def main(
dataset_config_path: str,
model_config_path: str,
Expand All @@ -34,18 +29,23 @@ def main(
num_workers: int,
prediction_names: Optional[List[str]],
suffix: Optional[str] = None,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

# Build model
model_config = ModelConfig.load(model_config_path)
Expand Down Expand Up @@ -80,7 +80,7 @@ def main(
# Log configurations to W&B
# NB: Only log to W&B on the rank-zero process in case of multi-GPU
# training.
if rank_zero_only.rank == 0:
if wandb and rank_zero_only.rank == 0:
wandb_logger.experiment.config.update(config)
wandb_logger.experiment.config.update(model_config.as_dict())
wandb_logger.experiment.config.update(dataset_config.as_dict())
Expand All @@ -98,7 +98,7 @@ def main(
dataloaders["train"],
dataloaders["validation"],
callbacks=callbacks,
logger=wandb_logger,
logger=wandb_logger if wandb else None,
**config.fit,
)

Expand Down Expand Up @@ -166,6 +166,12 @@ def main(
default=None,
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args = parser.parse_args()

main(
Expand All @@ -178,4 +184,5 @@ def main(
args.num_workers,
args.prediction_names,
args.suffix,
args.wandb,
)
41 changes: 24 additions & 17 deletions examples/04_training/02_train_model_without_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS

# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)


def main(
path: str,
Expand All @@ -40,18 +36,23 @@ def main(
early_stopping_patience: int,
batch_size: int,
num_workers: int,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

logger.info(f"features: {features}")
logger.info(f"truth: {truth}")
Expand All @@ -72,9 +73,9 @@ def main(

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
run_name = "dynedge_{}_example".format(config["target"])

# Log configuration to W&B
wandb_logger.experiment.config.update(config)
if wandb:
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

(
training_dataloader,
Expand Down Expand Up @@ -137,17 +138,16 @@ def main(
training_dataloader,
validation_dataloader,
callbacks=callbacks,
logger=wandb_logger,
logger=wandb_logger if wandb else None,
**config["fit"],
)

# Get predictions
prediction_columns = [config["target"] + "_pred"]
additional_attributes = [config["target"]]
additional_attributes = model.target_labels
assert isinstance(additional_attributes, list) # mypy

results = model.predict_as_dataframe(
validation_dataloader,
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)

Expand Down Expand Up @@ -206,6 +206,12 @@ def main(
"num-workers",
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args = parser.parse_args()

main(
Expand All @@ -218,4 +224,5 @@ def main(
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.wandb,
)
File renamed without changes.
2 changes: 1 addition & 1 deletion src/graphnet/models/coarsening.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch_geometric.utils import degree

# NOTE: From [https://github.com/pyg-team/pytorch_geometric/pull/4903]
# TODO: Remove once bumping to torch_geometric>=2.1.0
# TODO: Remove once bumping to torch_geometric>=2.1.0
# See [https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md]


Expand Down
46 changes: 45 additions & 1 deletion src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.logger import Logger as LightningLogger
import torch
from torch import Tensor
Expand All @@ -18,6 +19,7 @@

from graphnet.utilities.logging import Logger
from graphnet.utilities.config import Configurable, ModelConfig
from graphnet.training.callbacks import ProgressBar


class Model(Logger, Configurable, LightningModule, ABC):
Expand Down Expand Up @@ -88,8 +90,17 @@ def fit(
**trainer_kwargs: Any,
) -> None:
"""Fit `Model` using `pytorch_lightning.Trainer`."""
self.train(mode=True)
# Checks
if callbacks is None:
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
)
elif val_dataloader is not None:
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
)

self.train(mode=True)
self._construct_trainers(
max_epochs=max_epochs,
gpus=gpus,
Expand All @@ -110,6 +121,38 @@ def fit(
self.warning("[ctrl+c] Exiting gracefully.")
pass

def _create_default_callbacks(self, val_dataloader: DataLoader) -> List:
callbacks = [ProgressBar()]
callbacks = self._add_early_stopping(
val_dataloader=val_dataloader, callbacks=callbacks
)
return callbacks

def _add_early_stopping(
self, val_dataloader: DataLoader, callbacks: List
) -> List:
if val_dataloader is None:
return callbacks
has_early_stopping = False
assert isinstance(callbacks, list)
for callback in callbacks:
if isinstance(callback, EarlyStopping):
has_early_stopping = True

if not has_early_stopping:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=5,
)
)
self.warning_once(
"Got validation dataloader but no EarlyStopping callback. An "
"EarlyStopping callback has been added automatically with "
"patience=5 and monitor = 'val_loss'."
)
return callbacks

def predict(
self,
dataloader: DataLoader,
Expand Down Expand Up @@ -178,6 +221,7 @@ def predict_as_dataframe(
"doesn't resample batches; or do not request "
"`additional_attributes`."
)
self.info(f"Column names for predictions are: \n {prediction_columns}")
predictions_torch = self.predict(
dataloader=dataloader,
gpus=gpus,
Expand Down
41 changes: 41 additions & 0 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import pandas as pd

from graphnet.models.coarsening import Coarsening
from graphnet.utilities.config import save_model_config
Expand Down Expand Up @@ -62,6 +63,18 @@ def __init__(
self._scheduler_kwargs = scheduler_kwargs or dict()
self._scheduler_config = scheduler_config or dict()

@property
def target_labels(self) -> List[str]:
"""Return target label."""
return [label for task in self._tasks for label in task._target_labels]

@property
def prediction_labels(self) -> List[str]:
"""Return prediction labels."""
return [
label for task in self._tasks for label in task._prediction_labels
]

def configure_optimizers(self) -> Dict[str, Any]:
"""Configure the model's optimizer(s)."""
optimizer = self._optimizer_class(
Expand Down Expand Up @@ -175,3 +188,31 @@ def predict(
gpus=gpus,
distribution_strategy=distribution_strategy,
)

def predict_as_dataframe(
self,
dataloader: DataLoader,
prediction_columns: Optional[List[str]] = None,
*,
node_level: bool = False,
additional_attributes: Optional[List[str]] = None,
index_column: str = "event_no",
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = None,
) -> pd.DataFrame:
"""Return predictions for `dataloader` as a DataFrame.
Include `additional_attributes` as additional columns in the output
DataFrame.
"""
if prediction_columns is None:
prediction_columns = self.prediction_labels
return super().predict_as_dataframe(
dataloader=dataloader,
prediction_columns=prediction_columns,
node_level=node_level,
additional_attributes=additional_attributes,
index_column=index_column,
gpus=gpus,
distribution_strategy=distribution_strategy,
)
Loading

0 comments on commit 068c834

Please sign in to comment.