Skip to content

Commit

Permalink
Adds AblationFlip.is_compatible() implementation.
Browse files Browse the repository at this point in the history
Isolates unit tests. Separates integration tests.

PiperOrigin-RevId: 481989488
  • Loading branch information
RyanMullins authored and LIT team committed Oct 18, 2022
1 parent be51efd commit db94849
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 263 deletions.
22 changes: 15 additions & 7 deletions lit_nlp/components/ablation_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
import collections
import copy
import itertools
from typing import Iterator, List, Optional, Text, Tuple
from typing import Iterator, Optional

from absl import logging
from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components import cf_utils
from lit_nlp.lib import utils

JsonDict = types.JsonDict
Spec = types.Spec
Expand Down Expand Up @@ -91,11 +92,11 @@ def _subset_exists(self, cand_set, sets):

def _gen_ablation_idxs(
self,
loo_scores: List[Tuple[str, int, float]],
loo_scores: list[tuple[str, int, float]],
max_ablations: int,
orig_regression_score: Optional[float] = None,
regression_thresh: Optional[float] = None
) -> Iterator[Tuple[Tuple[str, int], ...]]:
) -> Iterator[tuple[tuple[str, int], ...]]:
"""Generates sets of token positions that are eligible for ablation."""

# Order tokens by their leave-one-out ablation scores. (Note that these
Expand Down Expand Up @@ -133,7 +134,7 @@ def _get_tokens(self,
def _create_cf(self,
example: JsonDict,
input_spec: Spec,
ablation_idxs: List[Tuple[str, int]]) -> JsonDict:
ablation_idxs: list[tuple[str, int]]) -> JsonDict:
# Build a dictionary mapping input fields to the token idxs to be ablated
# from that field.
ablation_idxs_per_field = collections.defaultdict(list)
Expand Down Expand Up @@ -174,8 +175,8 @@ def _generate_leave_one_out_ablation_score(
input_spec: Spec,
output_spec: Spec,
orig_output: JsonDict,
pred_key: Text,
fields_to_ablate: List[str]) -> List[Tuple[str, int, float]]:
pred_key: str,
fields_to_ablate: list[str]) -> list[tuple[str, int, float]]:
# Returns a list of triples: field, token_idx and leave-one-out score.
ret = []
for field in input_spec.keys():
Expand All @@ -191,6 +192,13 @@ def _generate_leave_one_out_ablation_score(
ret.append((field, i, loo_score))
return ret

def is_compatible(self, model: lit_model.Model) -> bool:
supported_inputs = (types.SparseMultilabel, types.TextSegment, types.URL)
supported_preds = (types.MulticlassPreds, types.RegressionScore)
input_fields = utils.find_spec_keys(model.input_spec(), supported_inputs)
output_fields = utils.find_spec_keys(model.output_spec(), supported_preds)
return (bool(input_fields) and bool(output_fields))

def config_spec(self) -> types.Spec:
return {
NUM_EXAMPLES_KEY: types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)),
Expand All @@ -212,7 +220,7 @@ def generate(self,
example: JsonDict,
model: lit_model.Model,
dataset: lit_dataset.Dataset,
config: Optional[JsonDict] = None) -> List[JsonDict]:
config: Optional[JsonDict] = None) -> list[JsonDict]:
"""Identify minimal sets of token albations that alter the prediction."""
del dataset # Unused.

Expand Down
294 changes: 294 additions & 0 deletions lit_nlp/components/ablation_flip_int_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lit_nlp.components.ablation_flip."""

from typing import Iterable, Iterator

from absl.testing import absltest
from lit_nlp.api import types
from lit_nlp.components import ablation_flip
from lit_nlp.examples.models import glue_models
import numpy as np

# TODO(lit-dev): Move glue_models out of lit_nlp/examples


BERT_TINY_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz' # pylint: disable=line-too-long
STSB_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_tiny.tar.gz' # pylint: disable=line-too-long
import transformers
BERT_TINY_PATH = transformers.file_utils.cached_path(BERT_TINY_PATH,
extract_compressed_file=True)
STSB_PATH = transformers.file_utils.cached_path(STSB_PATH,
extract_compressed_file=True)


class SST2ModelNonRequiredField(glue_models.SST2Model):

def input_spec(self):
spec = super().input_spec()
spec['sentence'] = types.TextSegment(required=False, default='')
return spec


class SST2ModelWithPredictCounter(glue_models.SST2Model):

def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.predict_counter = 0

def predict(self,
inputs: Iterable[types.JsonDict],
scrub_arrays=True,
**kw) -> Iterator[types.JsonDict]:
results = super().predict(inputs, scrub_arrays, **kw)
self.predict_counter += 1
return results


class ModelBasedAblationFlipTest(absltest.TestCase):

def setUp(self):
super(ModelBasedAblationFlipTest, self).setUp()
self.ablation_flip = ablation_flip.AblationFlip()

# Classification model that clasifies a given input sentence.
self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
self.classification_config = {ablation_flip.PREDICTION_KEY: 'probas'}

# Clasification model with the 'sentence' field marked as
# non-required.
self.classification_model_non_required_field = SST2ModelNonRequiredField(
BERT_TINY_PATH)

# Clasification model with a counter to count number of predict calls.
# TODO(ataly): Consider setting up a Mock object to count number of
# predict calls.
self.classification_model_with_predict_counter = (
SST2ModelWithPredictCounter(BERT_TINY_PATH))

# Regression model determining similarity between two input sentences.
self.regression_model = glue_models.STSBModel(STSB_PATH)
self.regression_config = {ablation_flip.PREDICTION_KEY: 'score'}

def test_ablation_flip_num_ex(self):
ex = {'sentence': 'this long movie was terrible'}
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 0
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
self.assertEmpty(
self.ablation_flip.generate(ex, self.classification_model, None,
self.classification_config))
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 1
self.assertLen(
self.ablation_flip.generate(ex, self.classification_model, None,
self.classification_config), 1)
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 2
self.assertLen(
self.ablation_flip.generate(ex, self.classification_model, None,
self.classification_config), 2)

def test_ablation_flip_num_ex_multi_input(self):
ex = {'sentence1': 'this long movie is terrible',
'sentence2': 'this short movie is great'}
self.regression_config[ablation_flip.NUM_EXAMPLES_KEY] = 2
thresh = 2
self.regression_config[ablation_flip.REGRESSION_THRESH_KEY] = thresh
self.regression_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence1',
'sentence2',
]
self.assertLen(
self.ablation_flip.generate(ex, self.regression_model, None,
self.regression_config), 2)

def test_ablation_flip_long_sentence(self):
sentence = (
'this was a terrible terrible movie but I am a writing '
'a nice long review for testing whether AblationFlip '
'can handle long sentences with a bounded number of '
'predict calls.')
ex = {'sentence': sentence}
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 100
self.classification_config[ablation_flip.MAX_ABLATIONS_KEY] = 100
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
model = self.classification_model_with_predict_counter
cfs = self.ablation_flip.generate(
ex, model, None, self.classification_config)

# This example must yield 19 ablation_flips.
self.assertLen(cfs, 19)

# Number of predict calls made by ablation_flip should be upper-bounded by
# <number of tokens in sentence> + 2**MAX_ABLATABLE_TOKENS
num_tokens = len(model.tokenizer(sentence))
num_predict_calls = model.predict_counter
self.assertLessEqual(num_predict_calls,
num_tokens + 2**ablation_flip.MAX_ABLATABLE_TOKENS)

# We use a smaller value of MAX_ABLATABLE_TOKENS and check that the
# number of predict calls is smaller, and that the prediction bound still
# holds.
model.predict_counter = 0
ablation_flip.MAX_ABLATABLE_TOKENS = 5
self.assertLessEqual(model.predict_counter, num_predict_calls)
self.assertLessEqual(model.predict_counter,
num_tokens + 2**ablation_flip.MAX_ABLATABLE_TOKENS)

def test_ablation_flip_freeze_fields(self):
ex = {'sentence1': 'this long movie is terrible',
'sentence2': 'this long movie is great'}
self.regression_config[ablation_flip.NUM_EXAMPLES_KEY] = 10
thresh = 2
self.regression_config[ablation_flip.REGRESSION_THRESH_KEY] = thresh
self.regression_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence1'
]
cfs = self.ablation_flip.generate(ex, self.regression_model, None,
self.regression_config)
for cf in cfs:
self.assertEqual(cf['sentence2'], ex['sentence2'])

def test_ablation_flip_max_ablations(self):
ex = {'sentence': 'this movie is terrible'}
ex_tokens = self.ablation_flip.tokenize(ex['sentence'])
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 1
self.classification_config[ablation_flip.MAX_ABLATIONS_KEY] = 1
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
cfs = self.ablation_flip.generate(
ex, self.classification_model, None, self.classification_config)
cf_tokens = self.ablation_flip.tokenize(list(cfs)[0]['sentence'])
self.assertLen(cf_tokens, len(ex_tokens) - 1)

ex = {'sentence': 'this long movie is terrible and horrible.'}
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 1
self.classification_config[ablation_flip.MAX_ABLATIONS_KEY] = 1
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
cfs = self.ablation_flip.generate(
ex, self.classification_model, None, self.classification_config)
self.assertEmpty(cfs)

def test_ablation_flip_max_ablations_multi_input(self):
ex = {'sentence1': 'this movie is terrible',
'sentence2': 'this movie is great'}
ex_tokens1 = self.ablation_flip.tokenize(ex['sentence1'])
ex_tokens2 = self.ablation_flip.tokenize(ex['sentence2'])

self.regression_config[ablation_flip.NUM_EXAMPLES_KEY] = 20
self.regression_config[ablation_flip.REGRESSION_THRESH_KEY] = 2
max_ablations = 1
self.regression_config[ablation_flip.MAX_ABLATIONS_KEY] = max_ablations
self.regression_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence1',
'sentence2',
]
cfs = self.ablation_flip.generate(ex, self.regression_model, None,
self.regression_config)
for cf in cfs:
# Number of ablations in each field should be no more than MAX_ABLATIONS.
cf_tokens1 = self.ablation_flip.tokenize(cf['sentence1'])
cf_tokens2 = self.ablation_flip.tokenize(cf['sentence2'])
self.assertGreaterEqual(
len(cf_tokens1) + len(cf_tokens2),
len(ex_tokens1) + len(ex_tokens2) - max_ablations)

def test_ablation_flip_yields_multi_field_ablations(self):
ex = {'sentence1': 'this short movie is awesome',
'sentence2': 'this short movie is great'}
ex_tokens1 = self.ablation_flip.tokenize(ex['sentence1'])
ex_tokens2 = self.ablation_flip.tokenize(ex['sentence2'])

self.regression_config[ablation_flip.NUM_EXAMPLES_KEY] = 20
self.regression_config[ablation_flip.REGRESSION_THRESH_KEY] = 2
self.regression_config[ablation_flip.MAX_ABLATIONS_KEY] = 5
self.regression_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence1',
'sentence2',
]
cfs = self.ablation_flip.generate(ex, self.regression_model, None,
self.regression_config)

# Verify that at least one counterfactual involves ablations across
# multiple fields.
multi_field_ablation_found = False
for cf in cfs:
cf_tokens1 = self.ablation_flip.tokenize(cf['sentence1'])
cf_tokens2 = self.ablation_flip.tokenize(cf['sentence2'])
if ((len(cf_tokens1) < len(ex_tokens1))
and (len(cf_tokens2) < len(ex_tokens2))):
multi_field_ablation_found = True
break
self.assertTrue(multi_field_ablation_found)

def test_ablation_flip_changes_pred_class(self):
ex = {'sentence': 'this long movie is terrible'}
ex_output = list(self.classification_model.predict([ex]))[0]
pred_class = str(np.argmax(ex_output['probas']))
self.assertEqual('0', pred_class)
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
cfs = self.ablation_flip.generate(ex, self.classification_model, None,
self.classification_config)
cf_outputs = self.classification_model.predict(cfs)
for cf_output in cf_outputs:
self.assertNotEqual(np.argmax(ex_output['probas']),
np.argmax(cf_output['probas']))

def test_ablation_flip_changes_regression_score(self):
ex = {'sentence1': 'this long movie is terrible',
'sentence2': 'this short movie is great'}
self.regression_config[ablation_flip.NUM_EXAMPLES_KEY] = 2
ex_output = list(self.regression_model.predict([ex]))[0]
thresh = 2
self.regression_config[ablation_flip.REGRESSION_THRESH_KEY] = thresh
self.regression_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence1',
'sentence2',
]
cfs = self.ablation_flip.generate(ex, self.regression_model, None,
self.regression_config)
cf_outputs = self.regression_model.predict(cfs)
for cf_output in cf_outputs:
self.assertNotEqual((ex_output['score'] <= thresh),
(cf_output['score'] <= thresh))

def test_ablation_flip_fails_without_pred_key(self):
ex = {'sentence': 'this long movie is terrible'}
with self.assertRaises(AssertionError):
self.ablation_flip.generate(ex, self.classification_model, None, None)

def test_ablation_flip_required_field(self):
ex = {'sentence': 'terrible'}
self.classification_config[ablation_flip.NUM_EXAMPLES_KEY] = 1
self.classification_config[ablation_flip.FIELDS_TO_ABLATE_KEY] = [
'sentence'
]
self.assertEmpty(
self.ablation_flip.generate(
ex, self.classification_model, None, self.classification_config))
self.assertLen(
self.ablation_flip.generate(
ex, self.classification_model_non_required_field,
None, self.classification_config), 1)

if __name__ == '__main__':
absltest.main()

0 comments on commit db94849

Please sign in to comment.