
# Tutorial: bAbI6 Training and Preprocessing in Python

In this tutorial, we will try to implement question answering for bAbI6 tasks using the new {py:mod}`~lambeq.experimental.discocirc`. bAbI6 tasks are tasks where we supply a text that describes movement of people in different locations and ask questions about the locations of said people while they are moving around. More on the bAbI tasks can be in this [paper](https://arxiv.org/abs/1502.05698) and this [repository](https://github.com/facebookarchive/bAbI-tasks?tab=readme-ov-file) by Facebook.

In [None]:
from pathlib import Path
from typing import Tuple, List
from lambeq.experimental.discocirc import DisCoCircReader
import os
import warnings
import pickle

warnings.filterwarnings('ignore')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

Before we delve into the code, we first highlight two new features of the new {term} `parser <parser>` that will be used in this tutorial: the {term}`sandwich functor <sandwich functor>`and foliated {term}`frames <frame>`. 

In the previous versions of the {term} `parser <parser>`, the semantic {term} `functor <functor>`, while providing quantum implementations for boxes, wires and states, did not specify the quantum implementation of {term}`frames <frame>`. The {term}`sandwich functor <sandwich functor>` tackles this by breaking down a frame into a sequence of boxes with the frame's content {cite:p}`laakkonen_2024`. Now that we have these different frames, we can decide how to assign operators to the layers. We can either give each layer its own operator, with different parameters for each layer. Or, we can use the same operator for all layers, meaning all layers share the same parameters. {cite:p}`krawchuk_2025`. For more detail on this, we recommend reading the paper explaining the theory behind the new parser {cite:p}`krawchuk_2025`.

## 1. Setting Up Configuration Variables

This cell defines paths and key configuration variables:
- `FILEPATH` specifies paths to the file containing the bAbI6 data. 
- `TEXT_LENGTH` specifies the maximum number of sentences in a context for the experiment.
- `MAX_WIDTH` specifies the maximum number of wires in a circuit for the experiment.
- `SANDWICH` is a flag for using the {term}`sandwich functor <sandwich functor>`: True to apply the sandwich functor on the circuits, False to apply the usual semantic {term} `functor <functor>`.
- `FFL` is a flag for activating the foliated {term}`frames <frame>`. True to set different parameters for the different layers of frames, False to set the same parameter for all the layers. It is to note that it only makes sense to have this flag if the sandwich functor is activated.

In [21]:
# Here we store all the variables needed for the rest of the code: file paths, configurations, parameters...

# The path of the file where the initial babI6 data is stored
FILEPATH = '../examples/datasets/babi6_10k.txt'

# Maximum length of the context
TEXT_LENGTH = 4

# Maximum Number of wires in a circuit
MAX_WIDTH = 9

# SANDWICH functor flag
SANDWICH = True

# Updating the FFL parameter
FFL = False

# Paths of resulting files for the datasets
TRAINING_DATASET_FILEPATH = 'circuits/tutorial_training_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
VALIDATION_DATASET_FILEPATH = 'circuits/tutorial_validation_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
TEST_DATASET_FILEPATH = 'circuits/tutorial_test_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'

## 2. Data Preprocessing Function

The next step is to write a function `task_file_reader`, which processes the bAbI6 dataset and returns a list of texts, a list of questions on these texts, a list of answers to these questions, and a list of the lengths of the texts. This function reads and cleans lines from the `FILEPATH`, splits lines into stories, and extracts text sentences, questions, and answers.

After extracting the texts, they are filtered to only keep the ones whose number of sentences is less than or equal to `TEXT_LENGTH`, which we set in the previous cell to determine the maximum number of sentences that we want in a text for better efficiency. This is to make sure that we do not get huge circuits later on when we convert the texts into circuits, which might slow down the experiment. 

After this filtering, the last step is to convert the list of texts from a list of arrays of sentences, to a list of sentences. In other words, we concatenate the sentences in each text (which is an array) to obtain a string.

In [3]:
# Reading the texts, questions, expected answers, and text_length from the TXT file
def task_file_reader(path : str | Path) -> Tuple[List[List[str]],
                                                       List[str],
                                                       List[str],
                                                       List[str]]:
    """
    reads the .txt file at `path`
    returns 3 lists of equal length
    - text sentences, questions, answers and text length
    """
    with open(path) as f:
        lines = f.readlines()

    # split the lines into stories
    # record the first line location of new stories
    story_splits = [i for i, line in enumerate(lines) if line[0:2] == '1 ']
    # have no more need for line indices - delete these
    lines = [' '.join(line.split(' ')[1:]) for line in lines]
    # also delete . and \n
    lines = [line.replace('.', '').replace('\n','') for line in lines]
    stories = [lines[i:j] for i, j in zip(story_splits, story_splits[1:]+[None])]

    # create text and QnA pairs
    texts = []
    qnas = []
    text_length = []
    for story in stories:
        # record the lines in the story corresponding to questions
        question_splits = [i for i, line in enumerate(story) if '?' in line]
        for index in question_splits:
            # record the text corresponding to each question
            ctx = [line.lower() for line in story[:index] if '?' not in line]
            texts.append(ctx)
            text_length.append(len(ctx))
            # record the question
            qnas.append(story[index])

    # split qna into questions and answers
    questions = [qna.split('\t')[0].lower().rstrip()[:-1] + " ?" for qna in qnas]
    answers = [qna.split('\t')[1].lower() for qna in qnas]

    # Filtering the data   
    filtered_data = [
    (text, question, answer, text_length)
    for text, question, answer, text_length in zip(texts, questions, answers, text_length)
    if len(text) <= TEXT_LENGTH
    ]

    # Applying the filter
    texts, questions, answers, text_length = map(list, zip(*filtered_data))

    # Converting the texts from arrays of sentences to strings
    processed_texts_list = []
    for text in texts:
        processed_text = ""
        for sentence in text:
            processed_text += sentence + ". "
                
        processed_texts_list.append(processed_text)
    

    return processed_texts_list, questions, answers, text_length


texts, questions, answers, text_lengths = task_file_reader(FILEPATH)

## 3. Converting The Texts into Circuits

Now that we have our texts and the rest of data ready and pre-processed, we move on to the crucial step of converting them into circuits. We first start by initializing the {py:class}`~lambeq.experimental.discocirc.DisCoCircReader`, then we use the {py:meth}`~lambeq.experimental.discocirc.DisCoCircReader.text2circuit` with the `SANDWICH` argument indicating whether to use the {term}`sandwich functor <sandwich functor>` or not, as well as the `FFL` argument which indicates whether to assign different parameters to the different layers of the frame, or the same parameters.

Moreover, we store the data in a dictionary where each entry includes the text, the corresponding generated {term}`DisCoCirc` circuit, the question, the answer, and the text length.

In [None]:
# making the circuits from the texts and storing them in the dictionary
reader = DisCoCircReader()
datadict = {}
for i, (text, quest, ans, text_length) in enumerate(zip(texts, questions, answers, text_lengths)):
    datadict.update({i:{'text':text, 'dsc_diag': reader.text2circuit(text, sandwich=SANDWICH, foliated_frame_labels = FFL),  'question':quest, 'answer':ans, 'text_length': text_length}})
    

## 4. Converting The Circuits from DisCoCirc Circuits to Quantum Circuits
While we have the circuits corresponding to the texts ready, they are still {term}`DisCoCirc` circuits, not quantum circuits. Therefore, we need to convert the DisCocirc circuits into {term}`quantum circuits <quantum circuit>` by applying an {term}`ansatz <ansatz (plural: ansätze)>`. In this case, we chose to apply the Sim4Ansatz with 3 layers, and one {term}`qubit <qubit>` for each noun. More information on the motivation behind this choice can be found [here](https://arxiv.org/pdf/2409.08777).

In [5]:
from lambeq import Sim4Ansatz
from lambeq import AtomicType

N = AtomicType.NOUN
ansatz = Sim4Ansatz({N:1}, n_layers=3)

for i in datadict.keys():
    datadict[i].update({'text_circuit_sim4_13': ansatz(datadict[i]['dsc_diag'])})

## 5. Assertion Circuits and Further Processing of the Circuits
The main spirit of this tutorial is having assertion circuits sequetially composed with the text circuits to see the similarity between the texts and the assertions. More details on assertions and question asking in general can be found [here](https://arxiv.org/pdf/2409.08777).

Now that we already have the circuits representing the texts, we need to make the circuits representing the assertions. Remember, in our experiment, we need to have a pair of circuits, one for the affirmative case, and the other for the negative case. However, when adding the box corresponding to the assertion, we have to make sure that the wires of the assertion box correspond to the nouns from the text.  

Below, the function `return_noun_list` returns all the nouns in a text. The function `return_q_nouns` return all the nouns in a question. In the latter, we take the second and fifth noun as the subject and object of the question respectively. This works because of the simple case of the bAbI6 experiments, all the questions are of the format "Is the subject in the location?".

In [6]:
from lambeq.backend.grammar import Ty

def return_noun_list(text):
    noun_list = []
    for b in text.boxes:
        if b.dom == Ty() and b.cod == N:
            noun_list.append(b.name)
            
    return noun_list

In [7]:
def return_q_nouns(question):
    question_words = question.split(' ')
    q_nouns = [question_words[1], question_words[4].strip('?')]
    return q_nouns

We proceed to add the lists obtained from the functions above to the dictionary to be used later when building the assertion boxes.

In [8]:
for i in datadict.keys():
    datadict[i].update({'noun_list_text': return_noun_list(datadict[i]['dsc_diag'])})
    datadict[i].update({'noun_list_question': return_q_nouns(datadict[i]['question'])})

We needed to extract the list of nouns in the texts and the list of nouns in their corresponding questions to remove the entries where we ask questions on subjects or locations not present in the text. We also filter by text length.

In [10]:
reduced_datadict = {
    i: entry 
    for i, entry in datadict.items() 
    if set(entry['noun_list_question']).issubset(entry['noun_list_text']) and entry['text_length'] < TEXT_LENGTH
}

Moreover, remember that to enhance performance, we also wanted to limit the number of wires in every circuit by chekcing that every circuit's codomain (which is the number of open wires of a circuit) is less than or equal to `MAX_WIDTH`. The following filters the entries in the `datadict` dictionary and only keep the entries in which the text circuits have less than or equal to  the maximum number of wires specified in `MAX_WIDTH`.

In [11]:
# # Reducing the size of the dictionary by removing the circuits that have more than a certain number of wires. 
def right_cod_size(circuit):
    if len(circuit.cod) > MAX_WIDTH:
        return False
    return True

filtered_cod_datadict = {
    i: entry 
    for i, entry in reduced_datadict.items() 
    if right_cod_size(entry['text_circuit_sim4_13'])
}

Now that the text circuits are post-processed for optimization, it is time to make the assertion circuits to later sequentially compose the latter with the former. 

We first start with constructing two boxes `q1` and `q2` for both the affirmative and negative assertions respectively. An affirmative assertion corresponds to an affirmative answer to the question. On the other hand, a negative assertion corresponds to a negative answer to the question. For example, if the question related to a text is "Is Emily in the kitchen?", the equivalent negative assertion would be "Emily not in the kitchen". For the purposes of this training, all the questions are either in the format of "Is subject in object?" or "Is subject not in object". Therefore, we will need two boxes for the assertions, a box for the "is in" assertions, and another for the "is not in" assertions. The purpose of having two generic boxes is that the ML model will learn later the parameters for these boxes. For more details about question asking in {term}`DisCoCirc`, we recommend [this paper](https://arxiv.org/pdf/2409.08777).

Notice that we also created two assertion boxes that are equiped with swaps, the purpose of which will become clearer in later parts of the tutorial. 

We apply the same {term}`ansatz <ansatz (plural: ansätze)>` applied on the text circuits (Sim4Ansatz with 3 layers and one qubit for each wire). Lastly, we add the postselections by sequentially composing the resulting circuit from applying the {term}`ansatz <ansatz (plural: ansätze)>` to a parallel composition of two effects (bras). 

In [12]:
from lambeq.backend.grammar import Box, Id
from lambeq.backend.quantum import  Bra, Discard, qubit, Id, Swap
from lambeq import AtomicType, Sim4Ansatz

N = AtomicType.NOUN

q1 = Box('is_in', N@N, N@N)
q2 = Box('is_not_in', N@N, N@N)

ansatz = Sim4Ansatz({N:1}, n_layers=3)
qcirc1 = ansatz(q1)
qcirc2 = ansatz(q2)

#add the postselections to the questions
qcirc1_final = qcirc1 >> Bra(0) @ Bra(0)
qcirc2_final = qcirc2 >> Bra(0) @ Bra(0)

is_in_q = qcirc1_final
is_not_in_q = qcirc2_final

is_in_q_swp = Swap(qubit, qubit) >> qcirc1 >> Bra(0) @ Bra(0)
is_not_in_q_swp = Swap(qubit, qubit) >> qcirc2 >> Bra(0) @ Bra(0)

## 5. Assembling The Text Circuits with the Question Circuits

Now that we have all the ingredients in place (the text and assertion circuits), it is time to assemble them using sequential composition. However, we need to be careful and only attach the wires in the question boxes to the corresponding wires in the text boxes so that the nouns match. Moreover, we have to discard the wires of the nouns that are not included in the question. In order to do this, we might need to perform some swaps so that the wires that get composed with the question circuit are the corresponding wires from the text circuit.

We start by creating a layer composed of either identities (to link with the wires corresponding to the question nouns), or discards (for the rest of the wires). Once we sequentially compose this layer with the text circuit, this leaves us with a circuit whose codomain has two wires corresponding to the question nouns. In order for us to attach the assertion boxes, we have to make sure that the wires from the assertion circuits are linked to the right wires from the text circuit. To achieve this, we check the question ids of the wires in the text circuits (to see whether the nouns in the text circuits are in the right order). This helps us decide whether to use the assertion boxes that come with swaps, or the ones without swaps (if the question wires are in the wrong order, we would need a swap to bring them back to the right order for the questions. Remember, we already created assertion boxes that are also equiped with swaps for this purpose).

Notice that, throughout the next cell, we always have two circuits. The circuit names ending in "pos" signal the circuits corrsponding to the affirmative assertions, while their counterparts ending in "neg" signal the ones corresponding to the negative assertions.

**Important note**: It is to note that this isn't the standard way to present the questions/assertions (in this approach, we opted for using post-selections directly), and we went for this approach for the sake of simplicity for this tutorial. More complex approaches to assertions on text can be found [here](https://arxiv.org/pdf/2409.08777). 

In [13]:
# Bringing everything together by plugging the question asking part to the text circuits
for i in reduced_datadict.keys():
    
    text_circuit = reduced_datadict[i]['text_circuit_sim4_13']
    text_nouns = datadict[i]['noun_list_text']
    q_nouns = datadict[i]['noun_list_question']
    qid1 = datadict[i]['noun_list_text'].index(q_nouns[0])
    qid2 = datadict[i]['noun_list_text'].index(q_nouns[1])

    swap_required = qid1 > qid2
    
    if qid1 == qid2:
        print('noun ids are idential, removing entry')
        del reduced_datadict[i]
        continue

    quest_mid_layer = Id(qubit) if (qid1 == 0 or qid2 == 0) else Discard()
    
    for k in range(1, len(text_circuit.cod)):
        if k == qid1 or k == qid2:
            quest_mid_layer = quest_mid_layer @ Id(qubit)
        else:
            quest_mid_layer = quest_mid_layer @ Discard()

    final_circuit = text_circuit >> quest_mid_layer

    if swap_required:
        final_circuit_pos = final_circuit >> is_in_q_swp
        final_circuit_neg = final_circuit >> is_not_in_q_swp
    else:
        final_circuit_pos = final_circuit >> is_in_q
        final_circuit_neg = final_circuit >> is_not_in_q

    reduced_datadict[i].update({'pos_neg_circuit_pair': (final_circuit_pos, final_circuit_neg)})

## 6. Preparing The Datasets for Training
Now that our circuit pairs are ready, we move on to the final step of training a model.

The first step is to prepare the data for training. We start with updating the "yes" and "no" entries to 0s and 1s.

In [14]:
bAbI6_datadict = {}
for i in reduced_datadict.keys():
    # Add the updated dictionary with the transformed 'answer'
    bAbI6_datadict.update({
        i: {
            'text': reduced_datadict[i]['text'],
            'question': reduced_datadict[i]['question'],
            'answer': 1 if reduced_datadict[i]['answer'] == 'yes' else 0,  # Transform 'yes' to 1 and 'no' to 0
            'quantum_circ_pair_pos_neg': reduced_datadict[i]['pos_neg_circuit_pair'],
            'text_length': reduced_datadict[i]['text_length']
        }
    })

The next step would be to make three sets: training, validation, and test sets. We try to balance the entries.

In [15]:
import random
from collections import defaultdict

# Add the 'measure' field to each item
for key, value in bAbI6_datadict.items():
    temp = -1 if value['answer'] == 0 else 1
    value['measure'] = temp * value['text_length']

# Group items by absolute value of measure
abs_value_groups = defaultdict(list)
for key, value in bAbI6_datadict.items():
    abs_value = abs(value['measure'])
    abs_value_groups[abs_value].append((key, value))

# Balance signs within each group and ensure diverse sizes
new_balanced_dict = {}
for abs_value, items in abs_value_groups.items():
    # Separate positive and negative items
    positive_items = [(k, v) for k, v in items if v['measure'] > 0]
    negative_items = [(k, v) for k, v in items if v['measure'] < 0]
    
    # Determine the maximum balanced size for this group
    min_size = min(len(positive_items), len(negative_items))
    print("the minimum size is: " + str(min_size))
    
    # Randomly sample from each group to balance
    balanced_positive = random.sample(positive_items, min_size)
    balanced_negative = random.sample(negative_items, min_size)
    
    # Add to the balanced dictionary
    for k, v in balanced_positive + balanced_negative:
        new_balanced_dict[k] = v

the minimum size is: 89


In [16]:
from sklearn.model_selection import train_test_split

# Convert dictionary into list of keys and values
keys = list(new_balanced_dict.keys())
values = list(new_balanced_dict.values())

# Split into training and temporary (validation + testing)
train_keys, temp_keys, train_values, temp_values = train_test_split(
    keys, values, test_size=0.4, random_state=42  # 60% training, 40% temp
)

# =Split the temporary set into validation and testing
val_keys, test_keys, val_values, test_values = train_test_split(
    temp_keys, temp_values, test_size=0.5, random_state=42  # 20% validation, 20% testing
)

# Reconstruct dictionaries for training, validation, and testing
training_dict_bAbI6 = {k: v for k, v in zip(train_keys, train_values)}
validation_dict_bAbI6 = {k: v for k, v in zip(val_keys, val_values)}
test_dict_bAbI6 = {k: v for k, v in zip(test_keys, test_values)}

The following cell is to check that we have a balanced set. 

In [17]:
yes_count = 0
no_count = 0
for i in training_dict_bAbI6:
    if training_dict_bAbI6[i]['answer'] == 0:
        no_count += 1
    else:
        yes_count += 1

print(yes_count)
print(no_count)

53
53


Now, the final step is to store all of this data in separate files for training, validation, and testing, to be used in part II of this tutorial.

In [19]:
# Paths of resulting files for the datasets
TRAINING_DATASET_FILEPATH = 'tutorial_training_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
VALIDATION_DATASET_FILEPATH = 'tutorial_validation_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
TEST_DATASET_FILEPATH = 'tutorial_test_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'

In [20]:
with open(TRAINING_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(training_dict_bAbI6, file)
with open(VALIDATION_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(validation_dict_bAbI6, file)
with open(TEST_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(test_dict_bAbI6, file) 