Skip to content

Commit

Permalink
Updates Agatha to fix checkpoint bug and on_epoch_end failure
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 1, 2020
1 parent 63d91e7 commit 0f6e072
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions agatha/ml/hypothesis_predictor/hypothesis_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,27 @@ def forward(self, predicate_embeddings:torch.FloatTensor)->torch.FloatTensor:
def get_device(self):
return next(self.parameters()).device

def _step(self, positive_predicates:List[str])->Dict[str, Any]:
def _step(
self,
positive_predicates:List[str]
)->Tuple[torch.Tensor, Dict[str, Any]]:
""" Performs a forward pass of the model during training.
Used in both training_step and validation_step, this function accepts a set
of predicate names and performs a forward pass of the hypothesis generation
training routine. This involves generating negative samples for each
positive example and evaluating metrics that quantify the difference
between the two.
Args:
positive_predicates: A list of predicate names, each in the form,
p:subj:verb:obj
Returns:
The first element is the loss tensor, used for back propagation. The
second element is a dict containing all extra metrics.
"""
pos, negs = self.predicate_batch_generator.generate(positive_predicates)
positive_predictions = self.forward(
predicate_util
Expand All @@ -197,40 +217,44 @@ def _step(self, positive_predicates:List[str])->Dict[str, Any]:
positive_predictions.new_ones(len(positive_predictions))
)
)
return dict(
loss=sum(partial_losses)
loss=torch.mean(torch.stack(partial_losses))
return (
loss,
dict(
)
)

def training_step(
self,
positive_predictions:List[str],
batch_idx:int
)->Dict[str, Any]:
return self._step(positive_predictions)
loss, metrics = self._step(positive_predictions)
return dict(
loss=loss,
progress_bar=metrics,
log=metrics
)

def validation_step(
self,
positive_predictions:List[str],
batch_idx:int
)->Dict[str, Any]:
return self._step(positive_predictions)
loss, metrics = self._step(positive_predictions)
val_metrics = {f"val_{k}": v for k, v in metrics.items()}
val_metrics["val_loss"] = loss
return val_metrics

def _on_epoch_end(
self,
outputs:List[Dict[str,torch.Tensor]]
)->Dict[str, Dict[str,torch.Tensor]]:
metric2values = defaultdict(list)
for output in outputs:
for k, v in output.items():
metric2values[k].append(v)
metric2value = {}
for k in metric2values:
if len(metric2values[k]) > 0:
metric2values[k] = sum(metric2values[k]) / len(metric2values[k])
return dict(
log=metric2values,
progress_bar=metric2values
)
keys = outputs[0].keys()
return {
key: torch.mean(torch.stack([o[key] for o in outputs]))
for key in keys
}

def validation_epoch_end(self, outputs:List[Dict[str,torch.Tensor]]):
return self._on_epoch_end(outputs)
Expand Down

0 comments on commit 0f6e072

Please sign in to comment.