Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictor.get_gradients fails when an embedding sets trainable = False #3679

Open
michaeljneely opened this issue Jan 26, 2020 · 6 comments
Open

Comments

@michaeljneely
Copy link

@michaeljneely michaeljneely commented Jan 26, 2020

Describe the bug

Calling Predictor.get_gradients() returns an empty dictionary

To Reproduce
I am replicating the binary sentiment classification tasked described in the paper 'Attention is not Explanation ' (Jain and Wallace 2019 - https://arxiv.org/pdf/1902.10186.pdf).

My first experiment is on the Stanford Sentiment TreeBank Dataset. I need to measure the correlation between the attention weights and the gradients of the loss with respect to the model inputs. I define the following experiment file:

{
  "dataset_reader": {
    "type": "sst_tokens",
    "granularity": "2-class"
  },
  "train_data_path": std.join("/", [std.extVar("PWD"), "ane_research/datasets/SST/train.txt"]),
  "validation_data_path": std.join("/", [std.extVar("PWD"), "ane_research/datasets/SST/dev.txt"]),
  "test_data_path": std.join("/", [std.extVar("PWD"), "ane_research/datasets/SST/test.txt"]),
  "model": {
    "type": "jain_wallace_attention_binary_classifier",
    "word_embeddings": {
      "token_embedders": {
        "tokens": {
          "type": "embedding",
          "pretrained_file": "https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip",
          "embedding_dim": 300,
          "trainable": false
        }
      }
    },
    "encoder": {
      "type": "lstm",
      "bidirectional": true,
      "input_size": 300,
      "hidden_size": 128,
      "num_layers": 1,
    },
    "decoder": {
      "input_dim": 256,
      "num_layers": 1,
      "hidden_dims": 1,
      "activations": ["linear"]
    }
  },
  "iterator": {
    "type": "bucket",
    "sorting_keys": [['tokens', 'num_tokens']],
    "batch_size": 64
  },
  "trainer": {
    "num_epochs": 40,
    "patience": 10,
    "cuda_device": -1,
    "validation_metric": "+auc",
    "optimizer": {
      "type": "adam",
      "weight_decay": 1e-5,
      "amsgrad": true
    }
  }
}

And the following model:

@Model.register('jain_wallace_attention_binary_classifier')
class JWAED(Model):
  '''
    Encoder/decoder with attention model for binary classification as described in 'Attention is Not
    Explanation' (Jain and Wallace 2019) - Jain and Wallace) 2019 (https://arxiv.org/pdf/1902.10186.pdf)
  '''
  def __init__(self, vocab: Vocabulary, word_embeddings: TextFieldEmbedder, encoder: Seq2SeqEncoder, 
               decoder: FeedForward):
    super().__init__(vocab)
    self.word_embeddings = word_embeddings
    self.num_classes = self.vocab.get_vocab_size('labels')
    self.encoder = encoder
    self.attention = AdditiveAttention(encoder.get_output_dim())
    self.decoder = decoder
    self.metrics = {
      'accuracy': CategoricalAccuracy(),
      'f1_measure': F1Measure(positive_label=1),
      'auc': Auc(positive_label=1)
    }
    self.loss = torch.nn.BCEWithLogitsLoss()

  @overrides
  def forward(self, tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

    output_dict = {}

    # encode
    tokens_mask = util.get_text_field_mask(tokens)
    embedded_tokens = self.word_embeddings(tokens)
    encoded_tokens = self.encoder(embedded_tokens, tokens_mask)

    # compute attention
    attention = self.attention(encoded_tokens, tokens_mask)
    context = (attention.unsqueeze(-1) * encoded_tokens).sum(1)

    # decode
    logits = self.decoder(context)
    positive_probability = torch.sigmoid(logits)
    class_probabilities = torch.cat((1 - positive_probability, positive_probability), dim=1)
    output_dict['class_probabilities'] = class_probabilities

    if label is not None:
      loss = self.loss(logits, label.unsqueeze(-1).float())
      output_dict['loss'] = loss
      self.metrics['accuracy'](class_probabilities, label)
      self.metrics['f1_measure'](class_probabilities, label)
      self.metrics['auc'](positive_probability.squeeze(-1), label)

    return output_dict

  @overrides
  def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    precision, recall, f1_measure = self.metrics['f1_measure'].get_metric(reset=reset)
    return {
      # f1 get_metric returns (precision, recall, f1)
      'positive/precision': precision,
      'positive/recall': recall,
      'f1_measure': f1_measure,
      'accuracy': self.metrics['accuracy'].get_metric(reset=reset),
      'auc': self.metrics['auc'].get_metric(reset=reset)
    }

  @overrides
  def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    '''
    Does a simple argmax over the class probabilities, converts indices to string labels, and
    adds a ``'label'`` key to the dictionary with the result.
    '''
    class_probabilities = output_dict['class_probabilities']
    predictions = class_probabilities.cpu().data.numpy()
    argmax_indices = np.argmax(predictions, axis=-1)
    labels = [self.vocab.get_token_from_index(x, namespace='labels') for x in argmax_indices]
    output_dict['label'] = labels
    return output_dict

Once the model has trained, I attempt to extract the gradients using a predictor:

    archive = load_archive(model_path)
    predictor = Predictor.from_archive(archive, 'text_classifier')
    validation_instances = predictor._dataset_reader.read('/.../datasets/SST/test.txt')
    gradients, outputs = predictor.get_gradients(validation_instances)
    print(gradients)
    # prints {}

Expected behavior
Gradient dictionary contains gradients of the loss with respect to the input tokens

System (please complete the following information):

  • OS: OSX
  • Python version: 3.7.0
  • AllenNLP version: v0.9.0
  • PyTorch version: 1.4.0

Additional context
As per Jain and Wallace, I use the pretrained fasttext embeddings. Attention is computed via the 'Additive variant' (tanh).

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Jan 26, 2020

You need to be sure that you actually have labels. Do your test instances have labels in them? Our demos use the interpret_from_json and attack_from_json methods, instead of going straight to get_gradients, because these methods first call json_to_labeled_instances, which runs the model forward to get predictions, then converts those predictions to labels, which are used when calling get_gradients.

If your instances have labels, this should work, because this is the same predictor that we're using for our demo. As a sanity check, you could see if you can get your code to work with our demo model: https://github.com/allenai/allennlp-demo/blob/9580a2d97911e794c7a936262bba4dc41b30bbcf/models.json#L95-L99. If it works, and you're sure that your instances have labels, then my next guess is that the predictor's predictions_to_labeled_instances function is looking for a different key in the output than your model is providing. Actually, looking at your model code, maybe you could jump straight to trying this. The predictor is looking for a key called probs, but you're not providing one. Note that if you jump straight here, you'll be interpreting the gold labels at test time, though, not the model's predictions. I'm not sure that getting gradient-based interpretations for the gold labels is all that meaningful, if it's not what the model actually predicted. So you probably want to directly call predictions_to_labeled_instances after running your model forward, before calling get_gradients.

I'm also not sure what batching is available for any of this. It might work, but I'm not sure - I only ever used it in our demo, one instance at a time. If you run into issues with batching things, let us know.

@michaeljneely

This comment has been minimized.

Copy link
Author

@michaeljneely michaeljneely commented Jan 26, 2020

Hi @matt-gardner. I tried all of your code modification suggestions, including manually turning my validation instances into JSON and trying predictions that way. The gradient dictionary is always empty!

My succinct test:
(model predicts wrong label)

archive = load_archive(model_path)
predictor = Predictor.from_archive(archive, 'text_classifier')
validation_instances = predictor._dataset_reader.read('/.../datasets/SST/test.txt')
test_instance = validation_instances[0]
print(test_instance)
# Instance with fields:
#  	 tokens: TextField of length 21 with text:
#  		[If, you, sometimes, like, to, go, to, the, movies, to, have, fun, ,, Wasabi, is, a, good, place,
#		to, start, .]
# 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
# 	 label: LabelField with label: 1 in namespace: 'labels'.'
prediction = predictor.predict_instance(instance)
# {'probs': [0.56487637758255, 0.43512362241744995], 'loss': 0.5711482763290405, 'label': '0'}
predicted_labeled_instance = predictor.predictions_to_labeled_instances(test_instance, prediction)
print(predicted_labeled_instance)
# Instance with fields:
#  	 tokens: TextField of length 21 with text:
#  		[If, you, sometimes, like, to, go, to, the, movies, to, have, fun, ,, Wasabi, is, a, good, place,
#		to, start, .]
# 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
# 	 label: LabelField with label: 0 in namespace: 'labels'.'
gradients, outputs = predictor.get_gradients(predicted_labeled_instance)
print(gradients)
# {}
print(outputs)
#  {'probs': tensor([[0.5649, 0.4351]], grad_fn=<CatBackward>), 'loss': tensor(0.5711, grad_fn=<BinaryCrossEntropyWithLogitsBackward>), 'label': ['0']} 
@michaeljneely

This comment has been minimized.

Copy link
Author

@michaeljneely michaeljneely commented Jan 26, 2020

I put some print statements in Predictor.get_gradients. The variable embedding_gradients is always an empty list. printing the embedding_layer in the _register_embedding_gradient_hooks function yields Embedding(). Is this expected?

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Jan 26, 2020

Can you confirm that you can get this to work locally with our demo model?

@michaeljneely

This comment has been minimized.

Copy link
Author

@michaeljneely michaeljneely commented Jan 26, 2020

@matt-gardner I ran the demo sentiment classifier locally and found my problem. Setting trainable: false in the embedding config 'freezes' the parameter as per the pytorch documentation.

I had copied the format of the embedding configuration from the paper prediction tutorial without realizing the ramifications.

Hopefully if anyone else has this problem they can learn from my mistake.

Thank you for your help!

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Jan 26, 2020

Huh, we should ideally be robust to that setting, as we typically only get gradients like this after a model is trained, anyway, when all of the parameters are "frozen". I'm going to re-open this and rename it. I have higher-priority stuff to work on for the next while, but I'll label this as contributions welcome in case anyone wants to pick it up. This issue should give a pretty easy minimal test case - just create a small embedding layer with trainable = False.

I believe a potential solution here would be to just force all parameters to require gradients at the top of predictor.get_gradients(). You probably want to reset them to their original value at the end of the method, though, just in case, so we don't have any side effects from the method.

@matt-gardner matt-gardner reopened this Jan 26, 2020
@matt-gardner matt-gardner changed the title Predictor.get_gradients returns empty dictionary Predictor.get_gradients fails when an embedding sets trainable = False Jan 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
2 participants
You can’t perform that action at this time.