# Pretrained NLI Contradiction Model

Use a pre-trained model from Facebook research specifically designed to classify sentence pairs as "Entailment" (agree), "Neutral", and "Contradiction".

https://huggingface.co/ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli

In [2]:
# https://github.com/facebookresearch/anli/blob/main/src/hg_api/interactive.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"

# Will take a moment to download
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)
model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)

def evaluate(tokenizer, model, premise, hypothesis):
    max_length = 256

    tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis,
                                                     max_length=max_length,
                                                     return_token_type_ids=True, truncation=True)

    input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
    # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
    token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
    attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)

    outputs = model(input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    labels=None)
    # Note:
    # "id2label": {
    #     "0": "entailment",
    #     "1": "neutral",
    #     "2": "contradiction"
    # },

    predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist()  # batch_size only one

    #print("Premise:", premise)
    #print("Hypothesis:", hypothesis)
    print("Prediction:")
    print("Entailment:", predicted_probability[0])
    print("Neutral:", predicted_probability[1])
    print("Contradiction:", predicted_probability[2])

    print("="*20)



Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
premise = "The office is responsible for writing a report on the project status within 6 months."
hypothesis_contradict = "Interim project status reports are not required."
hypothesis_entail = "Interim project status reports will be required."
hypothesis_neutral = "Operations—activities or processes associated with the programs to be housed in a completed facility and those processes which are necessary to run the facility."

print("[Contradiction]", end=' '); evaluate(tokenizer, model, premise, hypothesis_contradict)
print("[Entailment]", end=' '); evaluate(tokenizer, model, premise, hypothesis_entail)
print("[Neutral]", end=' '); evaluate(tokenizer, model, premise, hypothesis_neutral)

[Contradiction] Prediction:
Entailment: 0.005922275595366955
Neutral: 0.5431420207023621
Contradiction: 0.4509357511997223
[Entailment] Prediction:
Entailment: 0.16487528383731842
Neutral: 0.5480648279190063
Contradiction: 0.2870599031448364
[Neutral] Prediction:
Entailment: 0.12228141725063324
Neutral: 0.6374253630638123
Contradiction: 0.2402932494878769


> **Initial observations and thoughts**
> 
> Note that this is based on literally only a single example. But, here we go:
> * For our problem statement, we don't necessarily care about distinction between *entailment* and *neutral*. We really just want to catch *contradiction*
> * Notice for the contradictory statement, the entailment probability was extremely low! And the contradiction probability was much higher than for the entailment and neutral statements.
> * Ideally we'd be able to fine-tune these models on our specific vocab and use-case...
>   * Perhaps we can train an estimator to take these output probabilities and produce a fine-tuned binary classification of contradiction or not
>   * We'd need a labeled training set
> * For now, we can just do some stats and create a custom heuristic for our contradiction classification!
> * Named Entities will likely mess us up a bit, since policies often use definitions and so two policies may say "The Headquarters" and be referring to two different headquarters!
> * I like this approach to get us started. We could also try the much more informed policy-parsing approach, maybe trying to use some existing Deloitte models/assets, such as RegExplorer..?
> * Next step is to run this inference on pairwise sentences across documents
>   * Might take a while since pairwise is exponential... I wonder what is `tokenizer.encode_plus` doing? Could we run tokenization ahead of time and save the results to a file?
