In [1]:
import numpy as np

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from giskard import scan
from giskard import Dataset, Model
import tqdm

In [2]:
dataset = load_dataset("multi_nli", "validation_matched")
df = dataset["validation_matched"].to_pandas()

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
mapping = {0:"False", 2: "True"}
df = df.replace({"label": mapping})
df = df[df.label != 1]
len(df)

6692

## Wrapping the dataset

In [3]:
gsk_dataset = Dataset(df, target="label")

In [4]:
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
device = "cpu" # running on mac

## Wrapping the model

In [5]:
def prediction_function(df):
    predictions = []
    for index, row in tqdm.tqdm(df.iterrows()):
        premise = row.premise
        hypothesis = row.hypothesis

        # run through model pre-trained on MNLI
        x = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = nli_model(x.to(device))[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        #prob_label_is_true = probs[:,1]
        predictions.append(probs.cpu().detach().numpy()[0])
    return np.array(predictions)

model_args = {"model": prediction_function,
              "model_type": "classification",
              "feature_names": ["premise", "hypothesis"],
              "classification_labels": ["True", "False"]}
gsk_hf_model = Model(**model_args)
assert (gsk_hf_model.model_predict(df[["premise", "hypothesis"]].head()) == prediction_function(df[["premise", "hypothesis"]].head())).all()

5it [00:01,  4.28it/s]
5it [00:01,  4.32it/s]


# Default Scan Results

In [6]:
results = scan(gsk_hf_model, gsk_dataset)

10it [00:02,  4.26it/s]
1it [00:00,  5.53it/s]
6681it [48:46,  2.28it/s]


Your model is successfully validated.
🔎 Running scan…
2023-09-19 19:23:20,016 pid:38636 MainThread giskard.scanner.logger INFO     Running detectors: ['PerformanceBiasDetector', 'TextPerturbationDetector', 'EthicalBiasDetector', 'DataLeakageDetector', 'StochasticityDetector', 'OverconfidenceDetector', 'UnderconfidenceDetector', 'SpuriousCorrelationDetector']
Running detector PerformanceBiasDetector…2023-09-19 19:23:20,018 pid:38636 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Running
2023-09-19 19:23:20,019 pid:38636 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Calculating loss
2023-09-19 19:23:29,223 pid:38636 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Loss calculated (took 0:00:09.202001)
2023-09-19 19:23:29,224 pid:38636 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Finding data slices
2023-09-19 19:29:10,531 pid:38636 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: 3 sli

1000it [08:26,  1.97it/s]

2023-09-19 19:39:39,739 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to uppercase`	Fail rate: 0.020



1000it [05:52,  2.83it/s]

2023-09-19 19:45:32,891 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to uppercase`	Fail rate: 0.027



1000it [05:29,  3.03it/s]

2023-09-19 19:51:02,944 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to lowercase`	Fail rate: 0.013



1000it [03:28,  4.80it/s]

2023-09-19 19:54:31,519 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to lowercase`	Fail rate: 0.005



1000it [03:56,  4.23it/s]

2023-09-19 19:58:27,959 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to title case`	Fail rate: 0.010



1000it [03:54,  4.27it/s]

2023-09-19 20:02:22,234 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to title case`	Fail rate: 0.015



1000it [04:51,  3.43it/s]

2023-09-19 20:07:15,588 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Add typos`	Fail rate: 0.045



1000it [04:36,  3.62it/s]

2023-09-19 20:11:53,363 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Add typos`	Fail rate: 0.068



1000it [04:24,  3.78it/s]

2023-09-19 20:16:18,283 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Punctuation Removal`	Fail rate: 0.004



1000it [04:11,  3.98it/s]

2023-09-19 20:20:30,165 pid:38636 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Punctuation Removal`	Fail rate: 0.006
 1 issues detected. (Took 0:49:16.954040)
Running detector EthicalBiasDetector…2023-09-19 20:20:30,168 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Running with transformations=['Switch Gender', 'Switch Religion', 'Switch countries from high- to low-income and vice versa'] threshold=None output_sensitivity=None num_samples=None



1000it [05:21,  3.11it/s]

2023-09-19 20:26:16,734 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch Gender`	Fail rate: 0.209



787it [05:30,  2.38it/s]

2023-09-19 20:32:25,899 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch Gender`	Fail rate: 0.327



151it [01:19,  1.89it/s]

2023-09-19 20:34:31,029 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch Religion`	Fail rate: 0.139



87it [00:33,  2.62it/s]

2023-09-19 20:35:49,304 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch Religion`	Fail rate: 0.333



413it [02:31,  2.73it/s]

2023-09-19 20:39:37,280 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch countries from high- to low-income and vice versa`	Fail rate: 0.288



226it [01:23,  2.72it/s]

2023-09-19 20:41:45,452 pid:38636 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch countries from high- to low-income and vice versa`	Fail rate: 0.500
 6 issues detected. (Took 0:21:15.252321)
Running detector DataLeakageDetector…2023-09-19 20:41:45,459 pid:38636 MainThread giskard.scanner.logger INFO     DataLeakageDetector: Running



1it [00:00,  3.26it/s]
1it [00:00,  4.48it/s]
1it [00:00,  3.68it/s]
1it [00:00,  3.62it/s]
1it [00:00,  5.44it/s]
1it [00:00,  3.88it/s]
1it [00:00,  4.01it/s]
1it [00:00,  4.37it/s]
1it [00:00,  3.44it/s]
1it [00:00,  2.91it/s]
1it [00:00,  5.05it/s]
1it [00:00,  3.66it/s]
1it [00:00,  3.65it/s]
1it [00:00,  2.57it/s]
1it [00:00,  3.73it/s]
1it [00:00,  4.54it/s]
1it [00:00,  4.78it/s]
1it [00:00,  2.62it/s]
1it [00:00,  4.05it/s]
1it [00:00,  5.80it/s]
1it [00:00,  5.34it/s]
1it [00:00,  2.33it/s]
1it [00:00,  5.63it/s]
1it [00:00,  3.39it/s]
1it [00:00,  3.97it/s]
1it [00:00,  3.68it/s]
1it [00:00,  3.89it/s]
1it [00:00,  4.37it/s]
1it [00:00,  4.36it/s]
1it [00:00,  5.97it/s]
1it [00:00,  3.94it/s]
1it [00:00,  4.36it/s]
1it [00:00,  4.74it/s]
1it [00:00,  3.21it/s]
1it [00:00,  4.14it/s]
1it [00:00,  2.87it/s]
1it [00:00,  2.80it/s]
1it [00:00,  4.35it/s]
1it [00:00,  4.63it/s]
1it [00:00,  3.02it/s]
1it [00:00,  5.52it/s]
1it [00:00,  5.71it/s]
1it [00:00,  3.82it/s]
1it [00:00

 0 issues detected. (Took 0:00:28.476499)
Running detector StochasticityDetector…2023-09-19 20:42:13,937 pid:38636 MainThread giskard.scanner.logger INFO     StochasticityDetector: Running


100it [00:24,  4.05it/s]
100it [00:27,  3.60it/s]

 0 issues detected. (Took 0:00:52.515258)
Running detector OverconfidenceDetector…2023-09-19 20:43:06,455 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Running
2023-09-19 20:43:06,456 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Calculating loss
2023-09-19 20:43:06,518 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Loss calculated (took 0:00:00.059140)
2023-09-19 20:43:06,519 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Finding data slices





2023-09-19 20:43:22,818 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: 0 slices found (took 0:00:16.297994)
2023-09-19 20:43:22,819 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Analyzing issues
2023-09-19 20:43:22,868 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Using overconfidence threshold = 0.5
2023-09-19 20:43:22,870 pid:38636 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: 0 issues found (took 0:00:00.050156)
 0 issues detected. (Took 0:00:16.414924)
Running detector UnderconfidenceDetector…2023-09-19 20:43:22,871 pid:38636 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Running
2023-09-19 20:43:22,872 pid:38636 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Calculating loss
2023-09-19 20:43:22,909 pid:38636 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Loss calculated (took 0:00:00.035175)
2023-09-19 20:43

In [7]:
results.to_html("facebook--bart-large-mnli.html")