In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import checklist
from checklist.editor import Editor
from checklist.expect import Expect
from checklist.pred_wrapper import PredictorWrapper
from checklist.test_types import MFT
from typing import List
import warnings
warnings.filterwarnings('ignore')

# MFTs: Introduction
In this notebook, we will create Minimum Functionality Tests (MFTs) for a generative language model. MFTs test one specific function of a language model. They are analogous to unit tests in traditional software engineering.

## Setup generative model
Before we can test anything, we need to set up our language model. We will use the HuggingFace transformers library to load a GPT2 model.

First, we create a tokenizer. The tokenizer is responsible for splitting strings into individual words, then converting those words into vectors of numbers that our model can understand.

In [2]:
# Load pretrained model tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Demonstrate what the tokenizer does
tokenizer.encode("Wherefore art thou Romeo?")

[8496, 754, 1242, 14210, 43989, 30]

Our tokenizer has turned the human-readable text into a list of numbers that the model understands. Next, let's load the GPT2 model.

In [3]:
# Load pretrained model (weights)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
device = 'cuda'
model.eval()
model.to(device)
"Model loaded"

'Model loaded'

Generating text with the model requires a bit of work. Let's write a function `generate_sentences` to handle the text generation.

`generate_sentences` has 1 parameter, `prompts`, which is a list of strings. A prompt is a string that the model will use as a starting point for generating new text. It gives the model context about what kind of text should be generated.

`generate_sentences` will output a list of generated text responses for each prompt.

In [4]:
def generate_sentences(prompts: List[str]) -> str:
    sentences = []
    for prompt in prompts:
        token_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device) # return_tensors = "pt" returns a PyTorch tensor
        out = model.generate(
            token_tensor,
            do_sample=True,
            min_length=10,
            max_length=50,
            num_beams=1,
            no_repeat_ngram_size=2,
            early_stopping=False,
            output_scores=True,
            return_dict_in_generate=True)
        text = tokenizer.decode(out.sequences[0], skip_special_tokens=True)
        sentences.append(text[len(prompt):])
    return sentences

In [5]:
generate_sentences(["Wherefore art thou Romeo?"])

[' and how shall the day be unto thee? Let not thy love be so tender, neither shalt thou waken in wrath; nor shalt you do any hurt like for thee: nor shall thou stand in thy place; let']

Now that everything is ready, we can write our first MFT.

## MFT - Language prompt
For this MFT, we will expect the model to create a reasonable continuation of a prompt. The model will be prompted with strings like "In {country} the most commonly spoken language is " where {country} is a placeholder for a country such as Spain.

We need create a rule to determine if the model passes our test. The criteria for passing or failing the test is entirely user defined. We will consider this MFT to pass if the model's output contains any language name. This will demonstrate that the model understands the general context of the prompt. The mentioned language doesn't have to be accurate - for example, "In Spain the most commonly spoken language is Indonesian" would pass our test, because Indonesian is a language. The language may also be located anywhere in the output - for example, "In Spain the most commonly spoken language is not easy to learn. Spanish has many complicated conjugations." would also pass our test.

In a later section of this notebook, there is another version of this MFT that is stricter, requiring the correct language to be mentioned in the response.

### Handwritten MFT
First, we will write the MFT by hand. Then, we'll use Checklist's MFT class to demonstrate how Checklist helps us create the MFT much more quickly.

#### Generate prompts from template
We will use Checklist's Editor class to quickly create the prompts.

In [6]:
editor = Editor()
prompt_strs = editor.template("In {country} the most commonly spoken language is ")
prompt_strs.data = prompt_strs.data[0:10]
prompt_strs.data

['In China the most commonly spoken language is ',
 'In India the most commonly spoken language is ',
 'In United States the most commonly spoken language is ',
 'In Indonesia the most commonly spoken language is ',
 'In Brazil the most commonly spoken language is ',
 'In Pakistan the most commonly spoken language is ',
 'In Nigeria the most commonly spoken language is ',
 'In Bangladesh the most commonly spoken language is ',
 'In Russia the most commonly spoken language is ',
 'In Mexico the most commonly spoken language is ']

#### Language CSV
We need a list of languages to check if the model's output contains a language. To save some time, we will read language names from a CSV file. The data comes from standard ISO Language Codes https://datahub.io/core/language-codes 

In [7]:
import urllib.request
urllib.request.urlretrieve('https://datahub.io/core/language-codes/r/language-codes.csv', 'language-codes.csv')
lang_codes_csv = pd.read_csv('language-codes.csv')
lang_codes_csv

Unnamed: 0,alpha2,English
0,aa,Afar
1,ab,Abkhazian
2,ae,Avestan
3,af,Afrikaans
4,ak,Akan
...,...,...
179,yi,Yiddish
180,yo,Yoruba
181,za,Zhuang; Chuang
182,zh,Chinese


#### Run the MFT
Now we're ready to create the MFT. We will create 3 Pandas dataframes, one each for prompts, responses, and results. Then, we will loop over the prompts, send each prompt to the model, and determine if it passes or fails the test. Each prompt and its test result will be recorded in the dataframes.

In [8]:
prompts = pd.DataFrame({"id": [], "prompt": []})
responses = pd.DataFrame({"id": [], "response": []})
results = pd.DataFrame({"id": [], "p/f": []})
langs = lang_codes_csv["English"].tolist()

model_responses = generate_sentences(prompt_strs.data)

for (i, response) in enumerate(model_responses):
    pf = 'fail'
    
    # Check if any language from the CSV data is in the generated string
    for l in langs:
        if l in response:
            pf = 'pass'
            break

    prompts = prompts.append({"id": i, "prompt": prompt_strs.data[i]}, ignore_index=True)
    responses = responses.append({"id": i, "response": response}, ignore_index=True)
    results = results.append({"id": i, "p/f": pf}, ignore_index=True)

#### Show test results
Now let's look at the results of our test.

In [9]:
pd.set_option("max_colwidth", 250)

In [10]:
prompts

Unnamed: 0,id,prompt
0,0.0,In China the most commonly spoken language is
1,1.0,In India the most commonly spoken language is
2,2.0,In United States the most commonly spoken language is
3,3.0,In Indonesia the most commonly spoken language is
4,4.0,In Brazil the most commonly spoken language is
5,5.0,In Pakistan the most commonly spoken language is
6,6.0,In Nigeria the most commonly spoken language is
7,7.0,In Bangladesh the most commonly spoken language is
8,8.0,In Russia the most commonly spoken language is
9,9.0,In Mexico the most commonly spoken language is


In [11]:
responses

Unnamed: 0,id,response
0,0.0,"一郭妶, the official Chinese transliteration of 高極度棘和.\n\nBut in the latter part of the 19th century, they brought"
1,1.0,"vernacular spoken in India. One of the reasons is many language and culture are not shared as widely. However, many people find themselves conversing together in the future, like Indians, as well as in"
2,2.0,"vernacular urn, spoken in all over the country about two-thirds of the time.\n\nAs it turns out, in many ways as well, English is the only official language in any"
3,3.0,"ൂຈೳಈ, followed by เส and ยഢ. In Malaysia there is 몬비람"
4,4.0,한국어 고가장까 트정 件관은.\n\n(In
5,5.0,"vernodá, commonly called nihār (plural nārs), which is not a contraction of ෌ or taurí, but is a form of the word rah"
6,6.0,"vernacular spoken by 80 percent of the people in Nigeria.\n\nLanguage Development\n, based on English, can be used in education in different areas. Languages like Chinese, Cantonese, Chinese/"
7,7.0,"vernacular Bengali, with a small number of local speakers, including Bengalis speaking here but also speakers of other languages including English and Bengala (Bengala is another local language).\n\nFor"
8,8.0,"свите (говкно). Its dialect was invented by the Soviet state.\n\nIn the early 20th century, the Russian language was widely spoken through the medium"
9,9.0,"vernacular in many Latin American countries of the world. The English pronunciation is in Spanish, French and some other Romance languages. In Argentina, Spanish and Español are used as their official languages,"


In [12]:
results

Unnamed: 0,id,p/f
0,0.0,pass
1,1.0,fail
2,2.0,pass
3,3.0,pass
4,4.0,fail
5,5.0,fail
6,6.0,pass
7,7.0,pass
8,8.0,pass
9,9.0,pass


We can merge all the dataframes to make the results easier to read.

In [13]:
merged = pd.merge(responses, results, on="id")
merged = pd.merge(prompts, merged, on="id")
merged

Unnamed: 0,id,prompt,response,p/f
0,0.0,In China the most commonly spoken language is,"一郭妶, the official Chinese transliteration of 高極度棘和.\n\nBut in the latter part of the 19th century, they brought",pass
1,1.0,In India the most commonly spoken language is,"vernacular spoken in India. One of the reasons is many language and culture are not shared as widely. However, many people find themselves conversing together in the future, like Indians, as well as in",fail
2,2.0,In United States the most commonly spoken language is,"vernacular urn, spoken in all over the country about two-thirds of the time.\n\nAs it turns out, in many ways as well, English is the only official language in any",pass
3,3.0,In Indonesia the most commonly spoken language is,"ൂຈೳಈ, followed by เส and ยഢ. In Malaysia there is 몬비람",pass
4,4.0,In Brazil the most commonly spoken language is,한국어 고가장까 트정 件관은.\n\n(In,fail
5,5.0,In Pakistan the most commonly spoken language is,"vernodá, commonly called nihār (plural nārs), which is not a contraction of ෌ or taurí, but is a form of the word rah",fail
6,6.0,In Nigeria the most commonly spoken language is,"vernacular spoken by 80 percent of the people in Nigeria.\n\nLanguage Development\n, based on English, can be used in education in different areas. Languages like Chinese, Cantonese, Chinese/",pass
7,7.0,In Bangladesh the most commonly spoken language is,"vernacular Bengali, with a small number of local speakers, including Bengalis speaking here but also speakers of other languages including English and Bengala (Bengala is another local language).\n\nFor",pass
8,8.0,In Russia the most commonly spoken language is,"свите (говкно). Its dialect was invented by the Soviet state.\n\nIn the early 20th century, the Russian language was widely spoken through the medium",pass
9,9.0,In Mexico the most commonly spoken language is,"vernacular in many Latin American countries of the world. The English pronunciation is in Spanish, French and some other Romance languages. In Argentina, Spanish and Español are used as their official languages,",pass


Finally, let's display the failing tests.

In [14]:
merged.loc[merged['p/f'] == 'fail']

Unnamed: 0,id,prompt,response,p/f
1,1.0,In India the most commonly spoken language is,"vernacular spoken in India. One of the reasons is many language and culture are not shared as widely. However, many people find themselves conversing together in the future, like Indians, as well as in",fail
4,4.0,In Brazil the most commonly spoken language is,한국어 고가장까 트정 件관은.\n\n(In,fail
5,5.0,In Pakistan the most commonly spoken language is,"vernodá, commonly called nihār (plural nārs), which is not a contraction of ෌ or taurí, but is a form of the word rah",fail


### Test with Checklist

Next, let's try running the MFT with Checklist. We will no longer need to keep track of results in Pandas dataframes, since Checklist will track the results for us.

#### Create the expectation function
In order to determine if an example passes or fails the test, Checklist uses an expectation function. An expectation function is a function that receives the example, then returns true if the example passes the test, or false if the example fails.

In [15]:
def response_contains_language(x, pred, conf, label=None, meta=None):
    for l in langs:
        if l in pred:
            return True
    return False

We will wrap this function with `Expect.single`, which causes the expectation function to be called for each example. In other cases, you might want to have an expectation function that checks multiple examples simulatneously. See the tutorial notebook "3. Test types, expectation functions, running tests" for detailed information about expectation functions.

In [16]:
contains_language_expect_fn = Expect.single(response_contains_language)

Now we can feed our prompts and expectation function into the MFT constructor.

In [17]:
test = MFT(**prompt_strs, name='Language in response', description='The response contains a language.', expect=contains_language_expect_fn)

In order to run the test, Checklist also needs a function that generates the model's predictions for the inputs. The function receives all inputs (prompts) as a list, and must return the results in a tuple `(model_predictions, confidences)`, where `model_predictions` is a list of all the predictions, and `confidences` is a list of the model's scores for those predictions.

We will not be using confidences in this test. Checklist provides a wrapper function `PredictorWrapper.wrap_predict()` that outputs a tuple with a confidence score of 1 for any prediction. We can use it to wrap `generate_sentences` so the predictions will have a confidence score as needed.

In [18]:
wrapped_generator = PredictorWrapper.wrap_predict(generate_sentences)
wrapped_generator(["In Brazil the most commonly spoken language is "])

(['усский авого (Friesen). A native of Romania has a fondness for many of the classical styles of Chinese and Japanese, and has an interest'],
 array([1.]))

Now we're ready to run the test. The first argument to the `test.run()` function is the generator function we just created. We will also set the optional parameter `overwrite=True` so the test can be re-run without an error. If overwrite=False, then Checklist will reject subsequent test runs to prevent us from accidentally overwriting your test results.

In [19]:
test.run(wrapped_generator, overwrite=True)

Predicting 10 examples


To see the results, we can use the `summary` function.

In [20]:
def format_example(x, pred, conf, label=None, meta=None): 
    return 'Prompt:      %s\nCompletion:      %s' % (x, pred) 

In [21]:
test.summary(format_example_fn = format_example)

Test cases:      10
Fails (rate):    2 (20.0%)

Example fails:
Prompt:      In Nigeria the most commonly spoken language is 
Completion:      ಡೕక, which means "The one with the blue skin," and it's pronounced like: the one for the yellow one."
----
Prompt:      In Pakistan the most commonly spoken language is 
Completion:      vernacular - this is achieved by having the accent applied throughout the dialect. While some languages, such as ফ঴টকড and েଦ�
----


Test results can also be explored visually by using the `visual_summary` function.

In [22]:
test.visual_summary()

TestSummarizer(stats={'npassed': 8, 'nfailed': 2, 'nfiltered': 0}, summarizer={'name': 'Language in response',…

## MFT - Language prompt with accurate response

Let's make our test a little stricter to better understand the model's behavior. We will now require the model to respond with the correct language instead of any language in general. To simplify the logic, we will limit the prompts to use specific countries. By using the `meta=True` argument for `editor.template()`, the country associated with the prompt will be will be stored in the `country_prompts` object.


In [23]:
country_prompts = editor.template("In {country} the most commonly spoken language is ", country = ["United States", "France", "Guatemala", "Mongolia", "Japan"], meta=True)
correct_responses = {
    "United States": "English",
    "France": "French",
    "Guatemala": "Spanish",
    "Mongolia": "Mongolian",
    "Japan": "Japanese"
}

The country metadata can be accessed with `country_prompts.meta`.

In [24]:
country_prompts.meta

[{'country': 'United States'},
 {'country': 'France'},
 {'country': 'Guatemala'},
 {'country': 'Mongolia'},
 {'country': 'Japan'}]

### Handwritten Test

In [25]:
prompts = pd.DataFrame({"id": [], "prompt": []})
responses = pd.DataFrame({"id": [], "response": []})
test_results = pd.DataFrame({"id": [], "p/f": []})

model_responses = generate_sentences(country_prompts.data)

for (i, response) in enumerate(model_responses):
    pf = 'fail'
    country = country_prompts.meta[i]["country"]
    
    # Check if the correct language is in the response
    language = correct_responses[country]
    if language in response:
        pf = 'pass'

    prompts = prompts.append({"id": i, "prompt": country_prompts.data[i]}, ignore_index=True)
    responses = responses.append({"id": i, "response": response}, ignore_index=True)
    test_results = test_results.append({"id": i, "p/f": pf}, ignore_index=True)


#### Show test results
Let's look at our test results. The first dataframe contains the prompts given to the model.

In [26]:
prompts

Unnamed: 0,id,prompt
0,0.0,In United States the most commonly spoken language is
1,1.0,In France the most commonly spoken language is
2,2.0,In Guatemala the most commonly spoken language is
3,3.0,In Mongolia the most commonly spoken language is
4,4.0,In Japan the most commonly spoken language is


The next dataframe shows the model's response to the prompt (not including the prompt itself)

In [27]:
responses

Unnamed: 0,id,response
0,0.0,ˈpɑd or pɠɒk.\n\n(Pronunciation) Upright or uncouth Uˈɡr\n. It is not a part of the
1,1.0,"French, and as mentioned above, English is second-last in terms of popularity. The country is also home to a number of French speakers. All of these languages are in English, so if you"
2,2.0,"vernacular; in most cases in Guatemala it is ""Gún-Vícida."" One word from what I have seen at the border crossing is that of ivia-la-Cí"
3,3.0,"ichl. It is probably the only language which has not been used in rural areas,"" Heinz wrote.\n\nHinterborn said there is no evidence that his son has been killed by the"
4,4.0,仰渏 (昀時).\n\nWhat is it?\n and why should people use it to find their future\n (or their worst possible ending). When talking about Japanese language


The final dataframe shows the pass/fail status of the test

In [28]:
test_results

Unnamed: 0,id,p/f
0,0.0,fail
1,1.0,pass
2,2.0,fail
3,3.0,fail
4,4.0,pass


### Testing with Checklist
Now let's run the test with Checklist. All we need is a new expectation function. The rest of the process is the same as before.

In [29]:
def response_contains_correct_language(x, pred, conf, label=None, meta=None):
    language = meta['country']
    return language in pred

In [30]:
correct_language_expect_fn = Expect.single(response_contains_correct_language)

In [31]:
test = MFT(**country_prompts, name='Correct language in response', description='The response contains the correct language for the country in the prompt.', expect=correct_language_expect_fn)

In [32]:
test.run(wrapped_generator, overwrite=True)

Predicting 5 examples


In [33]:
test.summary(format_example_fn = format_example)

Test cases:      5
Fails (rate):    4 (80.0%)

Example fails:
Prompt:      In France the most commonly spoken language is 
Completion:      рас днизыми, сирдкивик Русскай and Иойтнке
----
Prompt:      In United States the most commonly spoken language is 
Completion:      Ѥпроиит ди с оке улек. An тальные is an attempt to indicate that you
----
Prompt:      In Guatemala the most commonly spoken language is 
Completion:      ícoban, and its use is highly regulated. Its name is used on beaches and beaches of the U.S., where children are often involved. And it is also used for surfing, surfing
----


In [34]:
test.visual_summary()

TestSummarizer(stats={'npassed': 1, 'nfailed': 4, 'nfiltered': 0}, summarizer={'name': 'Correct language in re…