Skip to content

Commit

Permalink
Updates how agatha model uses hparams
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 4, 2020
1 parent 942fd92 commit dd998f8
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 132 deletions.
66 changes: 8 additions & 58 deletions agatha/ml/hypothesis_predictor/__main__.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,17 @@
from argparse import ArgumentParser, Namespace
from agatha.ml.hypothesis_predictor import hypothesis_predictor as hp
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pathlib import Path
from copy import deepcopy
from socket import gethostname

if __name__ == "__main__":
parser = hp.HypothesisPredictor.add_argparse_args(ArgumentParser())
args = parser.parse_args()

def train(
training_args:Namespace,
graph_db:Path,
entity_db:Path,
embedding_dir:Path,
verbose:bool,
):
if verbose:
if args.verbose:
print("Started training on:", gethostname())
assert graph_db.is_file()
assert entity_db.is_file()
assert embedding_dir.is_dir()
trainer = Trainer.from_argparse_args(training_args)
model = hp.HypothesisPredictor(training_args)
model.verbose = verbose
model.configure_paths(
graph_db=graph_db,
entity_db=entity_db,
embedding_dir=embedding_dir,
)
print(args)

trainer = Trainer.from_argparse_args(args)
model = hp.HypothesisPredictor(args)
model.prepare_for_training()
trainer.fit(model)

if __name__ == "__main__":
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser = hp.HypothesisPredictor.add_argparse_args(parser)
# These arguments will be serialized with the model
training_args = deepcopy(parser.parse_known_args()[0])

# These arguments will be forgotten after training is complete
parser.add_argument(
"--graph-db",
type=Path,
help="Location of graph sqlite3 lookup table."
)
parser.add_argument(
"--entity-db",
type=Path,
help="Location of entity sqlite3 lookup table."
)
parser.add_argument(
"--embedding-dir",
type=Path,
help="Location of graph embedding hdf5 files."
)
parser.add_argument("--verbose", type=bool)
all_args = parser.parse_args()

if all_args.verbose:
print(all_args)
train(
training_args=training_args,
graph_db=all_args.graph_db,
entity_db=all_args.entity_db,
embedding_dir=all_args.embedding_dir,
verbose=all_args.verbose
)
66 changes: 55 additions & 11 deletions agatha/ml/hypothesis_predictor/hypothesis_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@
from agatha.ml.hypothesis_predictor import predicate_util
from pathlib import Path
import os
from pytorch_lightning import Trainer
from agatha.ml.util import hparam_util

class HypothesisPredictor(pl.LightningModule):
def __init__(self, hparams:Namespace):
super(HypothesisPredictor, self).__init__()
self.hparams = hparams
# If the hparams have been setup with paths, typical for training
if (
hasattr(hparams, "graph_db")
and hasattr(hparams, "entity_db")
and hasattr(hparams, "embedding_dir")
):
self.configure_paths(
graph_db=hparams.graph_db,
entity_db=hparams.entity_db,
embedding_dir=hparams.embedding_dir,
)
else: # Otherwise, the user will need to call configure_paths themselves
self.graph = None
self.embeddings = None
# Clear paths, don't want to serialize them later
self.hparams = hparam_util.remove_paths_from_namespace(hparams)

self.verbose = False
self.distributed = False
# Set when init_process_group is called
self._distributed = False

# Layers
## Graph Emb Input
Expand All @@ -44,9 +61,6 @@ def __init__(self, hparams:Namespace):
# Loss Fn
self.loss_fn = torch.nn.MarginRankingLoss(margin=self.hparams.margin)

# Helper data, set by configure_paths
self.embeddings = None
self.graph = None
# Helper data, set by prepare_for_training
self.training_predicates = None
self.validation_predicates = None
Expand All @@ -55,9 +69,12 @@ def __init__(self, hparams:Namespace):
self.predicate_batch_generator = None

def _vprint(self, *args, **kwargs):
if self.verbose:
if self.hparams.verbose:
print(*args, **kwargs)

def set_verbose(self, val:bool)->None:
self.hparams.verbose = val

def configure_paths(
self,
graph_db:Path,
Expand Down Expand Up @@ -138,7 +155,7 @@ def _configure_dataloader(
)->torch.utils.data.DataLoader:
self.preload()
sampler = None
if self.distributed:
if self._distributed:
shuffle = False
sampler=torch.utils.data.distributed.DistributedSampler(predicate_dataset)
return torch.utils.data.DataLoader(
Expand Down Expand Up @@ -294,12 +311,37 @@ def optimizer_step(

@staticmethod
def add_argparse_args(parser:ArgumentParser)->ArgumentParser:
"""Used to add all model parameters to argparse
This static function allows for the easy configuration of argparse for the
construction and training of the Agatha deep learning model. Example usage:
```python3
parser = HypothesisPredictor.add_argparse_args(ArgumentParser())
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args)
model = HypothesisPredictor(args)
```
Note, many of the arguments, such as the location of training databases or
the paths used to save the model during training, will _NOT_ be serialized
with the model. These can be configured either from `args` directly after
parsing, or through `configure_paths` after training.
Args:
parser: An argparse parser to be configured. Will receive all necessary
training and model parameter flags.
Returns:
A reference to the input argument parser.
"""
These arguments will be serialized along with the model after training.
Path-specific arguments will be passed in separately.
"""
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--dataloader-workers", type=int)
parser.add_argument("--dim", type=int)
parser.add_argument("--embedding-dir", type=Path)
parser.add_argument("--entity-db", type=Path)
parser.add_argument("--graph-db", type=Path)
parser.add_argument("--lr", type=float)
parser.add_argument("--margin", type=float)
parser.add_argument("--negative-scramble-rate", type=int)
Expand All @@ -311,6 +353,7 @@ def add_argparse_args(parser:ArgumentParser)->ArgumentParser:
parser.add_argument("--transformer-heads", type=int)
parser.add_argument("--transformer-layers", type=int)
parser.add_argument("--validation-fraction", type=float)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--warmup-steps", type=int)
parser.add_argument("--weight-decay", type=float)
return parser
Expand Down Expand Up @@ -354,6 +397,7 @@ def init_ddp_connection(
torch_backend,
os.environ["MASTER_ADDR"], proc_rank
)
self._distributed = True
torch.distributed.init_process_group(
torch_backend,
rank=proc_rank,
Expand Down
37 changes: 37 additions & 0 deletions agatha/ml/util/hparam_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from copy import deepcopy
from pathlib import Path
from argparse import Namespace

def remove_paths_from_namespace(hparams:Namespace):
"""Removes variables from the namespace that ends in _db, _dir, or _path
The model is going to include all hparams in the checkpoint. This is a
problem for path variables that are needed during training, but are not
wanted in the release of the model. For instance, during training we are
going to need to tell the model about the embeddings and helper database
locations, as well as where to save the model. These paths are machine
specific. When we release the model, or even when we start to move files
around, these paths will not be consistent.
Args:
hparams: The result of calling parse_args.
Returns:
A copy of hparams with no variables ending in _db, _dir, or _path. Also
removes any variables of type Path.
"""

hparams = deepcopy(hparams)
attributes = list(hparams.__dict__.keys())
for attr in attributes:
if (
isinstance(getattr(hparams, attr), Path)
or (
attr.endswith("_db")
or attr.endswith("_path")
or attr.endswith("_dir")
)
):
delattr(hparams, attr)
return hparams
63 changes: 0 additions & 63 deletions scripts/train_hypothesis_predictor.sh

This file was deleted.

0 comments on commit dd998f8

Please sign in to comment.