-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updates how agatha model uses hparams
- Loading branch information
Showing
4 changed files
with
100 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,17 @@ | ||
from argparse import ArgumentParser, Namespace | ||
from agatha.ml.hypothesis_predictor import hypothesis_predictor as hp | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pathlib import Path | ||
from copy import deepcopy | ||
from socket import gethostname | ||
|
||
if __name__ == "__main__": | ||
parser = hp.HypothesisPredictor.add_argparse_args(ArgumentParser()) | ||
args = parser.parse_args() | ||
|
||
def train( | ||
training_args:Namespace, | ||
graph_db:Path, | ||
entity_db:Path, | ||
embedding_dir:Path, | ||
verbose:bool, | ||
): | ||
if verbose: | ||
if args.verbose: | ||
print("Started training on:", gethostname()) | ||
assert graph_db.is_file() | ||
assert entity_db.is_file() | ||
assert embedding_dir.is_dir() | ||
trainer = Trainer.from_argparse_args(training_args) | ||
model = hp.HypothesisPredictor(training_args) | ||
model.verbose = verbose | ||
model.configure_paths( | ||
graph_db=graph_db, | ||
entity_db=entity_db, | ||
embedding_dir=embedding_dir, | ||
) | ||
print(args) | ||
|
||
trainer = Trainer.from_argparse_args(args) | ||
model = hp.HypothesisPredictor(args) | ||
model.prepare_for_training() | ||
trainer.fit(model) | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser = Trainer.add_argparse_args(parser) | ||
parser = hp.HypothesisPredictor.add_argparse_args(parser) | ||
# These arguments will be serialized with the model | ||
training_args = deepcopy(parser.parse_known_args()[0]) | ||
|
||
# These arguments will be forgotten after training is complete | ||
parser.add_argument( | ||
"--graph-db", | ||
type=Path, | ||
help="Location of graph sqlite3 lookup table." | ||
) | ||
parser.add_argument( | ||
"--entity-db", | ||
type=Path, | ||
help="Location of entity sqlite3 lookup table." | ||
) | ||
parser.add_argument( | ||
"--embedding-dir", | ||
type=Path, | ||
help="Location of graph embedding hdf5 files." | ||
) | ||
parser.add_argument("--verbose", type=bool) | ||
all_args = parser.parse_args() | ||
|
||
if all_args.verbose: | ||
print(all_args) | ||
train( | ||
training_args=training_args, | ||
graph_db=all_args.graph_db, | ||
entity_db=all_args.entity_db, | ||
embedding_dir=all_args.embedding_dir, | ||
verbose=all_args.verbose | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from argparse import Namespace | ||
|
||
def remove_paths_from_namespace(hparams:Namespace): | ||
"""Removes variables from the namespace that ends in _db, _dir, or _path | ||
The model is going to include all hparams in the checkpoint. This is a | ||
problem for path variables that are needed during training, but are not | ||
wanted in the release of the model. For instance, during training we are | ||
going to need to tell the model about the embeddings and helper database | ||
locations, as well as where to save the model. These paths are machine | ||
specific. When we release the model, or even when we start to move files | ||
around, these paths will not be consistent. | ||
Args: | ||
hparams: The result of calling parse_args. | ||
Returns: | ||
A copy of hparams with no variables ending in _db, _dir, or _path. Also | ||
removes any variables of type Path. | ||
""" | ||
|
||
hparams = deepcopy(hparams) | ||
attributes = list(hparams.__dict__.keys()) | ||
for attr in attributes: | ||
if ( | ||
isinstance(getattr(hparams, attr), Path) | ||
or ( | ||
attr.endswith("_db") | ||
or attr.endswith("_path") | ||
or attr.endswith("_dir") | ||
) | ||
): | ||
delattr(hparams, attr) | ||
return hparams |
This file was deleted.
Oops, something went wrong.