Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intgrads subclassed models #362

Merged
merged 39 commits into from Mar 26, 2021
Merged

Conversation

gipster
Copy link
Contributor

@gipster gipster commented Mar 12, 2021

Support for tensorflow and keras models with no explicit inputs (subclassed models).

In the present code base, tensor inputs are converted to lists at explain time in order to handle models with multiple inputs. If a model has a single input, the tensor input will be converted to a list with one element.

List input are handled well by models with an explicit input layer (functional and sequential models). When no explicit input layer is present (in subclassed models), the input must be in the format expected by the the first layer.

In order to support subclassed models, in this pull request

  • The attributions are calculate with two different functions _calculate_attributions_list_input and _calculate_attributions_tensor_input for list inputs and tensor inputs, respectively.
  • For subclassed models with no explicit input, the inputs types and the output shapes of the models, which are needed to calculate the attributions, are inferred at explain time with a single mock call of the model.

# Conflicts:
#	alibi/explainers/integrated_gradients.py
#	alibi/explainers/tests/test_integrated_gradients.py
@gipster gipster linked an issue Mar 12, 2021 that may be closed by this pull request
@gipster gipster requested a review from jklaise March 16, 2021 15:30
@gipster gipster marked this pull request as draft March 16, 2021 15:30
@gipster gipster added the Type: Method extension Extensions to existing methods label Mar 16, 2021
@codecov
Copy link

codecov bot commented Mar 24, 2021

Codecov Report

Merging #362 (e0a6b28) into master (9a0ac31) will increase coverage by 0.01%.
The diff coverage is 87.55%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #362      +/-   ##
==========================================
+ Coverage   87.72%   87.74%   +0.01%     
==========================================
  Files          56       56              
  Lines        7217     7390     +173     
==========================================
+ Hits         6331     6484     +153     
- Misses        886      906      +20     
Impacted Files Coverage Δ
alibi/explainers/integrated_gradients.py 88.54% <84.40%> (-1.07%) ⬇️
...libi/explainers/tests/test_integrated_gradients.py 95.66% <90.51%> (-3.72%) ⬇️

@jklaise jklaise marked this pull request as ready for review March 26, 2021 12:54
@jklaise jklaise merged commit 58d9796 into SeldonIO:master Mar 26, 2021
@loukasilias
Copy link

n_steps = 100
method = "gausslegendre"
internal_batch_size = 768
nb_samples = 10
ig  = IntegratedGradients(model,
                          layer=model.bert_model,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)

The model consists of BERT + custom_layers + Dense layers for binary classification.
What do I have to change to my code to get the contribution of each token/word towards the final output/prediction?

@jklaise
Copy link
Member

jklaise commented Mar 31, 2021

@loukasilias have you installed alibi from source since this PR was merged (pip install git+https://github.com/SeldonIO/alibi.git )? We will do a proper release soon after sorting out a few other issues.

@loukasilias
Copy link

Thanks. Ok now. However, another error arises now.

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

def process_sentences(sentence1):
    
    z = tokenizer.encode_plus(sentence1, add_special_tokens = True, padding = 'max_length', max_length = 512, truncation = True,return_token_type_ids=True, return_attention_mask = True,  return_tensors = 'np')
    
    return [z['input_ids'], z['attention_mask']]

n_steps = 768
method = "gausslegendre"
internal_batch_size = 512
nb_samples = 10
ig  = IntegratedGradients(model,
                          layer=model.bert_model,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)

x_test_sample = X_test[:nb_samples]
predictions = [model.predict(process_sentences(samples)).squeeze() for samples in x_test_sample]
explanation = ig.explain(x_test_sample,
                         baselines=None,
                         target=predictions)

The input/output to the BERT model is a sequence of tokens with dimensions: [batch_size, sequence_length, vector_dimensionality]=[batch_size, 512,768].

x_test_sample and predictions have size (10,)

I am getting the following error:

StagingError: in user code:

    <ipython-input-10-22f6afb59993>:13 call  *
        output_sentence_1 = self.bert_model(input_ids = inputs[0], attention_mask = inputs[1])
    /opt/conda/lib/python3.7/site-packages/transformers/models/bert/modeling_tf_bert.py:891 call  *
        outputs = self.bert(
    /opt/conda/lib/python3.7/site-packages/transformers/models/bert/modeling_tf_bert.py:654 call  *
        embedding_output = self.embeddings(
    /opt/conda/lib/python3.7/site-packages/transformers/models/bert/modeling_tf_bert.py:192 call  *
        return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
    /opt/conda/lib/python3.7/site-packages/transformers/models/bert/modeling_tf_bert.py:207 _embedding  *
        seq_length = input_shape[1]

    IndexError: list index out of range

What is the mistake now? Thank you.

@jklaise
Copy link
Member

jklaise commented Mar 31, 2021

@loukasilias is it possible to provide a self-contained script of the code that results in this error as we can then debug it more efficiently? I feel like the issue might be that the output of model has to work on the input of type process_sentences(x_test_sample) whereas you are passing x_test_sample directly to explain.

@loukasilias
Copy link

Do you want the full code, i.e model + generator? I pass to the model the input_ids and attention mask obtained from bert_tokenizer.

Should I pass the x_test_samples through the process_sentences? Then I will have input_ids and attention mask per sample in x_test_sample. How can I pass them to explain?

@jklaise
Copy link
Member

jklaise commented Mar 31, 2021

@loukasilias the input to explain should be a numpy array (or a list of numpy arrays for multi-input models, which I think is your use case). explain does support List[np.ndarray] for multi-input models, so you should be able to call explain with [input_ids, attention_mask] as long as those two inputs are numpy arrays.

In general, the input to explain should be exactly the same format as the input to model.predict.

@loukasilias
Copy link

It doesn't work.
I set nb_samples to 1.
x is in the same format as the input to model.predict. Dimensions: [2,1,512]

x_test_sample = X_test[:nb_samples]

for samples in x_test_sample:
    x = process_sentences(samples)
    
predictions = [model.predict(process_sentences(samples)).squeeze() for samples in x_test_sample]

explanation = ig.explain(x,baselines=None,target=predictions)

I also tried with: predictions = [1*(model.predict(process_sentences(samples)).squeeze()>0.5) for samples in x_test_sample]

I am getting the following error:

StagingError: in user code:

    <ipython-input-10-22f6afb59993>:13 call  *
        output_sentence_1 = self.bert_model(input_ids = inputs[0], attention_mask = inputs[1])
    /opt/conda/lib/python3.7/site-packages/alibi/explainers/integrated_gradients.py:204 wrapper  *
        layer.result = func(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/transformers/models/bert/modeling_tf_bert.py:891 call  *
        outputs = self.bert(

    KeyError: 'input_ids'

@jklaise
Copy link
Member

jklaise commented Mar 31, 2021

@loukasilias does your model work on batches or single instances? It's a bit unclear what's going on as you seem to be preparing predictions by calling the model on each sample one by one, but feeding in a batch for explain. The assumption of IntegratedGradients is that the model always works on batches, so in particual if nb_samples=1, the leading batch dimension should be 1 also (and if it's a multi-input model, then a list of arrays where each array has a leading batch dimension).

What is the dimension of the output of model? The target should be a list whose length is the same as the batch to be explained and each entry an integer denoting the output to be explained (e.g. the predicted class if it's a classifier).

I'm not sure what's going on in the last error, it appears to be some issue with BERT as it can't recognize the input_ids keyword.

Do you have some self-contained code that you can share so we can test this otherwise I'm not sure we're making much progress.

@loukasilias
Copy link

Model:

class MyModel(tf.keras.Model):
    
    def __init__(self, flag):
        
        super(MyModel,self).__init__()
        self.bert_model = TFBertModel.from_pretrained("bert-base-uncased")
        self.bert_model.trainable = flag
        self.layer_3 = Dense(units = 128, activation = 'relu')
        self.layer_4 = Dense(units = 1, activation = 'sigmoid')
        self.pooling = tf.keras.layers.GlobalMaxPool1D()

    def call(self,inputs):
        output_sentence_1 = self.bert_model(input_ids = inputs[0], attention_mask = inputs[1])
        output_sentence_1 = self.pooling(output_sentence_1['last_hidden_state'])
        layer_output = self.layer_3(output_sentence_1)
        output = self.layer_4(layer_output)

        return output

Generator:

class BertGenerator(tf.keras.utils.Sequence):
    
    def __init__(self, X_train, y_train, batch_size = 16):
        self.batch_size = batch_size
        self.X_train = X_train
        self.y_train = y_train
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
        self.indexes = np.arange(len(self.X_train))
        self.on_epoch_end()
    
    
    def __len__(self):
        
        return len(self.X_train) // self.batch_size
    
    def __getitem__(self,idx):
        indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]
        batch_sentences = self.X_train[indexes]
        batch_labels = self.y_train[indexes]
        
        
        z = self.tokenizer.batch_encode_plus(batch_sentences, add_special_tokens = True, padding = 'longest', return_token_type_ids=True, return_attention_mask = True,  return_tensors = 'np')

        return [z['input_ids'], z['attention_mask']], np.array(batch_labels) 
    
    def on_epoch_end(self):

        np.random.RandomState(42).shuffle(self.indexes)

Train & Evaluate:

model = MyModel(flag = False)

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = 0.001),loss="binary_crossentropy",metrics=["acc",tf.keras.metrics.Recall(), tf.keras.metrics.Precision(),tf.keras.metrics.AUC(),
                                                                                       tf.keras.metrics.TrueNegatives(),tf.keras.metrics.FalsePositives()])
    
train_data = BertGenerator(X_train,y_train,batch_size=6)
eval_data = BertGenerator(X_eval,y_eval,batch_size = 16)
history = model.fit(train_data, epochs=5, verbose = 1, class_weight = class_weight_function(y_train),validation_data=eval_data,callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10),
                                                                                                                                             tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', mode = 'min',
                                                                                                                                             factor=0.1, patience=3)])
    
    
test_data = BertGenerator(X_test,y_test,batch_size=16)
result = model.evaluate(test_data, verbose=0)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

def process_sentences(sentence1):
    
    z = tokenizer.batch_encode_plus(sentence1, add_special_tokens = True, padding = 'max_length', truncation = True, max_length = 512, return_token_type_ids=True, return_attention_mask = True,  return_tensors = 'np')
    
    return [z['input_ids'], z['attention_mask']]
def encode_sentences(sentence1):
    
    z = tokenizer.encode_plus(sentence1, add_special_tokens = True, padding = 'max_length', truncation = True, max_length = 512, return_token_type_ids=True, return_attention_mask = True,  return_tensors = 'np')
    
    return [z['input_ids'], z['attention_mask']]

Explain:

n_steps = 768
method = "gausslegendre"
internal_batch_size = 512
nb_samples = 10
ig  = IntegratedGradients(model.bert_model,
                          layer=None,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)
x_test_sample = X_test[:nb_samples]

inputs_ids = []
attentions_mask = []
for samples in x_test_sample:
    x = encode_sentences(samples)
    inputs_ids.append(x[0])
    attentions_mask.append(x[1])

predictions = 1*(model.predict(process_sentences(x_test_sample)).squeeze()>0.5)

explanation = ig.explain([np.array(inputs_ids),np.array(attentions_mask)],baselines=None,target=predictions)

Error:
ValueError: Cannot reshape a tensor with 768 elements to shape [1,1,512,1] (512 elements) for '{{node tf_bert_model_1/bert/embeddings/LayerNorm/Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32](tf_bert_model_1/bert/embeddings/LayerNorm/Reshape/ReadVariableOp, tf_bert_model_1/bert/embeddings/LayerNorm/Reshape/shape)' with input shapes: [768], [4] and with input tensors computed as partial shapes: input[1] = [1,1,512,1].

@jklaise
Copy link
Member

jklaise commented Mar 31, 2021

Thanks for this, I'll set some time aside to investigate. Do you have some examples of the data X_train and y_train so I can run the example end-to-end?

Btw, it looks like the output of the model is 1-dimensional, does it represent a class probability in 2-class classification? In that case you may treat is the same as a regression case and provide target=None to explain the probability of the predicted class. An alternative is to have a 2-dimensional softmax output and provide the classes predicted by the model for each instance in the batch.

@loukasilias
Copy link

dataset.zip

import pandas as pd
import numpy as np
dataset = pd.read_csv('dataset.csv')
dataset['outcome_class'] = dataset['outcome_class'].replace('t',0)
dataset['outcome_class'] = dataset['outcome_class'].replace('d',1)
dataset = dataset.sample(frac=1).reset_index(drop = True)

import tensorflow as tf
tf.config.experimental_run_functions_eagerly(True)
from tensorflow.keras import backend as K, initializers
from transformers import BertConfig, BertTokenizer, TFBertModel
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, auc, roc_curve, classification_report, recall_score, precision_score
from sklearn.model_selection import StratifiedKFold, train_test_split
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

X = dataset['q1']
y = dataset['outcome_class']
X = np.array(X.tolist())
y = np.array(y.tolist())
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)
X_train, X_eval, y_train, y_eval = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

import sklearn

def class_weight_function(y_train):
    
    class_weights = sklearn.utils.class_weight.compute_class_weight('balanced',
                                                      np.unique(y_train),
                                                      y_train)
    
    class_weights = {0 : class_weights[0], 1 : class_weights[1]}
    
    return class_weights

Yes the model outputs a probability .
In targets = predictions, I adjusted the code so that each prediction will be 0 or 1.

@jklaise
Copy link
Member

jklaise commented Apr 1, 2021

@loukasilias there are two issues, a minor one is that you've configured the BERT model with keyword arguments whereas it should be configured to take list inputs as detailed here.

The second bigger issue is that the BERT model is not a layer in it's own right, we need to dig deeper to find the appropriate layer to be explained. This is exactly the same issue as #377.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Method extension Extensions to existing methods
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: Layer my_model_2 is not connected, no input to return.
3 participants