Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Remove inplace modification in Covariance (#1691)
Browse files Browse the repository at this point in the history
this was modifying the labels inplace, which led to problems if you called something like

```
loss = some_function(pred, labels)
covariance(pred, labels)
```

This makes pytorch complain in the following way:

```
=============================================== FAILURES ================================================
___________ TestEventFactualityRegression.test_selective_tagger_role_can_train_save_and_load ____________

self = <tests.models.selective_regressor.event_factuality_test.TestEventFactualityRegression testMethod=test_selective_tagger_role_can_train_save_and_load>

    def test_selective_tagger_role_can_train_save_and_load(self):
>       self.ensure_model_can_train_save_and_load(self.param_file)

tests/models/selective_regressor/event_factuality_test.py:20:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
contexteval/common/model_test_case.py:46: in ensure_model_can_train_save_and_load
    model = train_model_from_file(param_file, save_dir)
../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/allennlp/commands/train.py:132: in train_model_from_file
    return train_model(params, serialization_dir, file_friendly_logging, recover)
../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/allennlp/commands/train.py:320: in train_model
    metrics = trainer.train()
../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/allennlp/training/trainer.py:720: in train
    train_metrics = self._train_epoch(epoch)
../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/allennlp/training/trainer.py:487: in _train_epoch
    loss.backward()
../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/torch/tensor.py:93: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensors = (tensor(1.9771, grad_fn=<AddBackward>),), grad_tensors = (tensor(1.),), retain_graph = False
create_graph = False, grad_variables = None

    def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
        r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.

        The graph is differentiated using the chain rule. If any of ``tensors``
        are non-scalar (i.e. their data has more than one element) and require
        gradient, the function additionally requires specifying ``grad_tensors``.
        It should be a sequence of matching length, that contains gradient of
        the differentiated function w.r.t. corresponding tensors (``None`` is an
        acceptable value for all tensors that don't need gradient tensors).

        This function accumulates gradients in the leaves - you might need to zero
        them before calling it.

        Arguments:
            tensors (sequence of Tensor): Tensors of which the derivative will be
                computed.
            grad_tensors (sequence of (Tensor or None)): Gradients w.r.t.
                each element of corresponding tensors. None values can be specified for
                scalar Tensors or ones that don't require grad. If a None value would
                be acceptable for all grad_tensors, then this argument is optional.
            retain_graph (bool, optional): If ``False``, the graph used to compute the grad
                will be freed. Note that in nearly all cases setting this option to ``True``
                is not needed and often can be worked around in a much more efficient
                way. Defaults to the value of ``create_graph``.
            create_graph (bool, optional): If ``True``, graph of the derivative will
                be constructed, allowing to compute higher order derivative products.
                Defaults to ``False``.
        """
        if grad_variables is not None:
            warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
            if grad_tensors is None:
                grad_tensors = grad_variables
            else:
                raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                                   "arguments both passed to backward(). Please only "
                                   "use 'grad_tensors'.")

        tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

        if grad_tensors is None:
            grad_tensors = [None] * len(tensors)
        elif isinstance(grad_tensors, torch.Tensor):
            grad_tensors = [grad_tensors]
        else:
            grad_tensors = list(grad_tensors)

        grad_tensors = _make_grads(tensors, grad_tensors)
        if retain_graph is None:
            retain_graph = create_graph

        Variable._execution_engine.run_backward(
            tensors, grad_tensors, retain_graph, create_graph,
>           allow_unreachable=True)  # allow_unreachable flag
E       RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

../../../miniconda3/envs/elmo_eval/lib/python3.6/site-packages/torch/autograd/__init__.py:90: RuntimeError
----------------------------------------- Captured stdout call ------------------------------------------
<class 'contexteval.contextualizers.contextualizer.Contextualizer'>
<class 'contexteval.contextualizers.contextualizer.Contextualizer'>
----------------------------------------- Captured stderr call ------------------------------------------
15it [00:00, 85.22it/s]
100%|██████████| 15/15 [00:00<00:00, 258907.65it/s]
15it [00:00, 85.53it/s]
15it [00:00, 80.51it/s]
30it [00:00, 319363.25it/s]
  0%|          | 0/1 [00:00<?, ?it/s]
===Flaky Test Report===

test_batch_classifications_are_consistent passed 1 out of the required 1 times. Success!

===End Flaky Test Report===
========================== 1 failed, 1 passed, 1160 deselected in 6.36 seconds ==========================
```
  • Loading branch information
nelson-liu committed Aug 29, 2018
1 parent a0506a7 commit 4b6f8d1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions allennlp/training/metrics/covariance.py
Expand Up @@ -54,8 +54,8 @@ def __call__(self,

if mask is not None:
mask = mask.view(-1)
predictions *= mask
gold_labels *= mask
predictions = predictions * mask
gold_labels = gold_labels * mask
num_batch_items = torch.sum(mask).item()
else:
num_batch_items = gold_labels.numel()
Expand Down

0 comments on commit 4b6f8d1

Please sign in to comment.