In [1]:
import numpy as np

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from giskard import scan
from giskard import Dataset, Model
import tqdm

In [2]:
dataset = load_dataset("multi_nli", "validation_matched")
df = dataset["validation_matched"].to_pandas()

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
mapping = {0:"False", 2: "True"}
df = df.replace({"label": mapping})
df = df[df.label != 1]
len(df)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation_matched split: 0 examples [00:00, ? examples/s]

Generating validation_mismatched split: 0 examples [00:00, ? examples/s]

6692

## Wrapping the dataset

In [3]:
gsk_dataset = Dataset(df, target="label")

In [4]:
nli_model = AutoModelForSequenceClassification.from_pretrained('MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7')
tokenizer = AutoTokenizer.from_pretrained('MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7')
device = "cpu" # running on mac

## Wrapping the model

In [5]:
def prediction_function(df):
    predictions = []
    for index, row in tqdm.tqdm(df.iterrows()):
        premise = row.premise
        hypothesis = row.hypothesis

        # run through model pre-trained on MNLI
        x = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = nli_model(x.to(device))[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        #prob_label_is_true = probs[:,1]
        predictions.append(probs.cpu().detach().numpy()[0])
    return np.array(predictions)

model_args = {"model": prediction_function,
              "model_type": "classification",
              "feature_names": ["premise", "hypothesis"],
              "classification_labels": ["True", "False"]}
gsk_hf_model = Model(**model_args)
assert (gsk_hf_model.model_predict(df[["premise", "hypothesis"]].head()) == prediction_function(df[["premise", "hypothesis"]].head())).all()

5it [00:02,  2.25it/s]
5it [00:02,  2.33it/s]


# Default Scan Results

In [6]:
results = scan(gsk_hf_model, gsk_dataset)

10it [00:03,  3.10it/s]
1it [00:00,  3.57it/s]
6681it [31:52,  3.49it/s]


Your model is successfully validated.
🔎 Running scan…
2023-09-21 15:44:41,849 pid:3782 MainThread giskard.scanner.logger INFO     Running detectors: ['PerformanceBiasDetector', 'TextPerturbationDetector', 'EthicalBiasDetector', 'DataLeakageDetector', 'StochasticityDetector', 'OverconfidenceDetector', 'UnderconfidenceDetector', 'SpuriousCorrelationDetector']
Running detector PerformanceBiasDetector…2023-09-21 15:44:41,853 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Running
2023-09-21 15:44:41,853 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Calculating loss




2023-09-21 15:44:47,481 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Loss calculated (took 0:00:05.626354)
2023-09-21 15:44:47,482 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Finding data slices
2023-09-21 15:50:18,811 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: 34 slices found (took 0:05:31.317688)
2023-09-21 15:50:18,816 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Analyzing issues
2023-09-21 15:50:18,926 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Testing 34 slices for performance issues.
2023-09-21 15:50:51,385 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBiasDetector: Testing slice `avg_whitespace(premise)` >= 0.178 AND `avg_whitespace(premise)` < 0.184	Precision = 0.032 (global 0.041) Δm = -0.206	is_issue = True
2023-09-21 15:51:53,907 pid:3782 MainThread giskard.scanner.logger INFO     PerformanceBias

1000it [06:12,  2.68it/s]

2023-09-21 16:00:21,804 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to uppercase`	Fail rate: 0.037



1000it [05:50,  2.85it/s]

2023-09-21 16:06:12,574 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to uppercase`	Fail rate: 0.032



1000it [04:28,  3.73it/s]

2023-09-21 16:10:40,852 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to lowercase`	Fail rate: 0.008



1000it [04:54,  3.39it/s]

2023-09-21 16:15:35,652 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to lowercase`	Fail rate: 0.005



1000it [06:10,  2.70it/s]

2023-09-21 16:21:45,816 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Transform to title case`	Fail rate: 0.012



1000it [05:02,  3.31it/s]

2023-09-21 16:26:48,161 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Transform to title case`	Fail rate: 0.015



1000it [04:44,  3.51it/s]

2023-09-21 16:31:34,380 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Add typos`	Fail rate: 0.045



1000it [04:23,  3.80it/s]


2023-09-21 16:35:58,301 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Add typos`	Fail rate: 0.078


1000it [05:28,  3.04it/s]

2023-09-21 16:41:27,306 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `premise` for perturbation `Punctuation Removal`	Fail rate: 0.014



1000it [09:29,  1.76it/s]

2023-09-21 16:50:56,877 pid:3782 MainThread giskard.scanner.logger INFO     TextPerturbationDetector: Testing `hypothesis` for perturbation `Punctuation Removal`	Fail rate: 0.007





 1 issues detected. (Took 0:56:47.999879)
Running detector EthicalBiasDetector…2023-09-21 16:50:56,949 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Running with transformations=['Switch Gender', 'Switch Religion', 'Switch countries from high- to low-income and vice versa'] threshold=None output_sensitivity=None num_samples=None


1000it [10:36,  1.57it/s]

2023-09-21 17:01:35,263 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch Gender`	Fail rate: 0.133



787it [09:08,  1.43it/s]

2023-09-21 17:10:44,401 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch Gender`	Fail rate: 0.274



149it [01:05,  2.29it/s]

2023-09-21 17:11:51,176 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch Religion`	Fail rate: 0.128



86it [00:43,  1.96it/s]

2023-09-21 17:12:36,400 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch Religion`	Fail rate: 0.360



413it [04:35,  1.50it/s]

2023-09-21 17:18:42,650 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `premise` for perturbation `Switch countries from high- to low-income and vice versa`	Fail rate: 0.278



226it [04:04,  1.08s/it]

2023-09-21 17:23:27,304 pid:3782 MainThread giskard.scanner.logger INFO     EthicalBiasDetector: Testing `hypothesis` for perturbation `Switch countries from high- to low-income and vice versa`	Fail rate: 0.469
 6 issues detected. (Took 0:32:30.402893)
Running detector DataLeakageDetector…2023-09-21 17:23:27,346 pid:3782 MainThread giskard.scanner.logger INFO     DataLeakageDetector: Running



1it [00:00,  2.61it/s]
1it [00:00,  3.21it/s]
1it [00:00,  2.79it/s]
1it [00:00,  2.08it/s]
1it [00:00,  2.63it/s]
1it [00:00,  2.41it/s]
1it [00:00,  2.94it/s]
1it [00:00,  3.32it/s]
1it [00:00,  2.07it/s]
1it [00:00,  2.59it/s]
1it [00:00,  3.25it/s]
1it [00:00,  1.94it/s]
1it [00:00,  2.08it/s]
1it [00:00,  2.13it/s]
1it [00:00,  2.64it/s]
1it [00:00,  3.08it/s]
1it [00:00,  2.35it/s]
1it [00:00,  1.93it/s]
1it [00:00,  1.84it/s]
1it [00:00,  3.33it/s]
1it [00:00,  3.92it/s]
1it [00:00,  1.24it/s]
1it [00:00,  2.54it/s]
1it [00:00,  2.78it/s]
1it [00:00,  3.01it/s]
1it [00:00,  3.07it/s]
1it [00:00,  2.87it/s]
1it [00:00,  3.26it/s]
1it [00:00,  3.50it/s]
1it [00:00,  3.83it/s]
1it [00:00,  3.24it/s]
1it [00:00,  3.09it/s]
1it [00:00,  3.22it/s]
1it [00:00,  2.62it/s]
1it [00:00,  3.09it/s]
1it [00:00,  2.46it/s]
1it [00:00,  2.48it/s]
1it [00:00,  2.82it/s]
1it [00:00,  2.73it/s]
1it [00:00,  2.62it/s]
1it [00:00,  3.48it/s]
1it [00:00,  3.66it/s]
1it [00:00,  2.83it/s]
1it [00:00

 0 issues detected. (Took 0:00:38.753444)
Running detector StochasticityDetector…2023-09-21 17:24:06,103 pid:3782 MainThread giskard.scanner.logger INFO     StochasticityDetector: Running



100it [00:32,  3.10it/s]
100it [00:41,  2.43it/s]

 0 issues detected. (Took 0:01:13.458618)
Running detector OverconfidenceDetector…2023-09-21 17:25:19,561 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Running
2023-09-21 17:25:19,562 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Calculating loss
2023-09-21 17:25:19,637 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Loss calculated (took 0:00:00.073837)
2023-09-21 17:25:19,638 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Finding data slices





2023-09-21 17:27:46,830 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: 3 slices found (took 0:02:27.191048)
2023-09-21 17:27:46,832 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Analyzing issues
2023-09-21 17:27:46,909 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: Using overconfidence threshold = 0.5
2023-09-21 17:27:47,035 pid:3782 MainThread giskard.scanner.logger INFO     OverconfidenceDetector: 0 issues found (took 0:00:00.202777)
 0 issues detected. (Took 0:02:27.475981)
Running detector UnderconfidenceDetector…2023-09-21 17:27:47,037 pid:3782 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Running
2023-09-21 17:27:47,038 pid:3782 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Calculating loss
2023-09-21 17:27:47,123 pid:3782 MainThread giskard.scanner.logger INFO     UnderconfidenceDetector: Loss calculated (took 0:00:00.082467)
2023-09-21 17:27:47,125

In [7]:
results.to_html("MoritzLaurer--mDeBERTa-v3-base-xnli-multilingual-nli-2mil7.html")