
# Tutorial: bAbI6 Training and Preprocessing in Python

In this tutorial, we will try to implement question answering for bAbI6 tasks using the new {py:mod}`~lambeq.experimental.discocirc`. In bAbI6 tasks, a text describes people moving between locations, and the goal is to answer questions about where they are. More on the bAbI tasks can be found in this [paper](https://arxiv.org/abs/1502.05698) and this [repository](https://github.com/facebookarchive/bAbI-tasks?tab=readme-ov-file) by Facebook (now Meta).

In [None]:
from pathlib import Path
from typing import Tuple, List
from lambeq.experimental.discocirc import DisCoCircReader
import os
import warnings
import pickle

warnings.filterwarnings('ignore')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

Before we delve into the code, we first highlight two new features of the new {term}`parser <parser>` that will be used in this tutorial: the {term}`sandwich functor <sandwich functor>` and foliated {term}`frames <frame>`. 

In previous versions of the {term}`parser <parser>`, the semantic {term}`functor <functor>`, while providing quantum implementations for boxes, wires and states, did not specify the quantum implementation of {term}`frames <frame>`. The {term}`sandwich functor <sandwich functor>` tackles this by breaking down a frame into a sequence of boxes with the frame's content {cite:p}`laakkonen_2024`. Now that we have these different boxes, we can decide how to assign operators to them. We can either give each box its own operator, with different parameters for each box. Or, we can use the same operator for all boxes, meaning all boxes share the same parameters. {cite:p}`krawchuk_2025`. For more details on this, we recommend {cite:p}`krawchuk_2025`.

## 1. Setting Up Configuration Variables

This cell defines paths and key configuration variables:
- `FILEPATH` specifies paths to the file containing the bAbI6 data. 
- `TEXT_LENGTH` specifies the maximum number of sentences in a text for the experiment.
- `MAX_WIDTH` specifies the maximum number of wires in a circuit for the experiment.
- `SANDWICH` is a flag for using the {term}`sandwich functor <sandwich functor>`: True to apply the sandwich functor on the circuits, False to apply the usual semantic {term}`functor <functor>`.
- `FFL` is a flag for activating the foliated {term}`frames <frame>`. True to set different parameters for the different layers (boxes) of frames. False to set the same parameter for all the layers. It is to note that it only makes sense to have this flag if the sandwich functor is activated.

We also define file paths to store the prepared training, validation, and test datasets. 

In [None]:
# Here we store all the variables needed for the rest of the code: file paths, configurations, parameters...

# The path of the file where the initial babI6 data is stored
FILEPATH = '../examples/datasets/babi6_10k.txt'

# Maximum length of the text
TEXT_LENGTH = 4

# Maximum Number of wires in a circuit
MAX_WIDTH = 9

# SANDWICH functor flag
SANDWICH = True

# Updating the FFL parameter
FFL = False

# Paths of resulting files for the datasets for training the model later
TRAINING_DATASET_FILEPATH = 'tutorial_training_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
VALIDATION_DATASET_FILEPATH = 'tutorial_validation_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'
TEST_DATASET_FILEPATH = 'tutorial_test_ffl' + str(FFL) + '_sandwich_' + str(SANDWICH) + '.pkl'

## 2. Data Preprocessing Function

The next step is to write a function `task_file_reader`, which processes the bAbI6 dataset and returns lists of texts, questions, answers and text lengths. This function returns texts as strings.

After extracting the texts, they are filtered to only keep the ones whose number of sentences is less than or equal to `TEXT_LENGTH`, which we set in the previous cell to determine the maximum number of sentences that we want in a text for better efficiency. This is to make sure that we do not get huge circuits later on when we convert the texts into circuits, which might slow down the experiment. 

In [3]:
# Reading the texts, questions, expected answers, and text_length from the TXT file
def task_file_reader(path : str | Path) -> Tuple[List[List[str]],
                                                       List[str],
                                                       List[str],
                                                       List[str]]:
    """
    reads the .txt file at `path`
    returns 3 lists of equal length
    - text sentences, questions, answers and text length
    """
    with open(path) as f:
        lines = f.readlines()

    # split the lines into stories
    # record the first line location of new stories
    story_splits = [i for i, line in enumerate(lines) if line[0:2] == '1 ']
    # have no more need for line indices - delete these
    lines = [' '.join(line.split(' ')[1:]) for line in lines]
    # also delete . and \n
    lines = [line.replace('.', '').replace('\n','') for line in lines]
    stories = [lines[i:j] for i, j in zip(story_splits, story_splits[1:]+[None])]

    # create text and QnA pairs
    texts = []
    qnas = []
    text_length = []
    for story in stories:
        # record the lines in the story corresponding to questions
        question_splits = [i for i, line in enumerate(story) if '?' in line]
        for index in question_splits:
            # record the text corresponding to each question
            text = [line.lower() for line in story[:index] if '?' not in line]
            texts.append(text)
            text_length.append(len(text))
            # record the question
            qnas.append(story[index])

    # split qna into questions and answers
    questions = [qna.split('\t')[0].lower().rstrip()[:-1] + " ?" for qna in qnas]
    answers = [qna.split('\t')[1].lower() for qna in qnas]

    # we convert answers to 0s and 1s to be ready for training later on
    answers = [1 if ans == 'yes' else 0 for ans in answers]

    # Filtering the data   
    filtered_data = [
    (text, question, answer, text_length)
    for text, question, answer, text_length in zip(texts, questions, answers, text_length)
    if len(text) <= TEXT_LENGTH
    ]

    # Applying the filter
    texts, questions, answers, text_length = map(list, zip(*filtered_data))

    # Converting the texts from arrays of sentences to strings
    processed_texts_list = []
    for text in texts:
        processed_text = ""
        for sentence in text:
            processed_text += sentence + ". "
                
        processed_texts_list.append(processed_text)
    

    return processed_texts_list, questions, answers, text_length


texts, questions, answers, text_lengths = task_file_reader(FILEPATH)

## 3. Converting The Texts into Circuits

Now that we have our texts and the rest of data ready and pre-processed, we move on to the crucial step of converting them into circuits. We first start by initializing the {py:class}`~lambeq.experimental.discocirc.DisCoCircReader`, then we use the {py:meth}`~lambeq.experimental.discocirc.DisCoCircReader.text2circuit` with the `sandwich` argument indicating whether to use the {term}`sandwich functor <sandwich functor>` or not, as well as the `foliated_frame_labels` argument which indicates whether to assign different parameters to the different layers of the frame, or the same parameters.

Moreover, we store the data in a dictionary where each entry includes the text, the corresponding generated {term}`DisCoCirc` circuit, the question, the answer, and the text length.

In [None]:
# making the circuits from the texts and storing them in the dictionary
reader = DisCoCircReader()
datadict = {}
for i, (text, quest, ans, text_length) in enumerate(zip(texts, questions, answers, text_lengths)):
    datadict.update({i:{'text':text, 'dsc_diag': reader.text2circuit(text, sandwich=SANDWICH, foliated_frame_labels = FFL),  'question':quest, 'answer':ans, 'text_length': text_length}})
    

## 4. Converting The Circuits from DisCoCirc Circuits to Quantum Circuits
While we have the circuits corresponding to the texts ready, they are still {term}`DisCoCirc` circuits, not quantum circuits. Therefore, we need to convert the DisCocirc circuits into {term}`quantum circuits <quantum circuit>` by applying an {term}`ansatz <ansatz (plural: ansätze)>`. In this case, we choose to apply the Sim4Ansatz with 3 layers, and one {term}`qubit <qubit>` for each noun. This choice of anzatz has shown good experiemental results. More information on the motivation behind this choice can be found [here](https://arxiv.org/pdf/2409.08777).

In [6]:
from lambeq import Sim4Ansatz
from lambeq import AtomicType

N = AtomicType.NOUN
ansatz = Sim4Ansatz({N:1}, n_layers=3)

for i in datadict.keys():
    datadict[i].update({'text_circuit_sim4_13': ansatz(datadict[i]['dsc_diag'])})

## 5. Assertion Circuits and Further Processing of the Circuits
The main spirit of this tutorial is having assertion circuits sequetially composed with the text circuits to see the similarity between the texts and the assertions. More details on assertions and implementation of questions in general can be found [here](https://arxiv.org/pdf/2409.08777).

Now that we already have the circuits representing the texts, we need to make the circuits representing the assertions. Remember, in our experiment, we need to have a pair of circuits, one for the affirmative case, and the other for the negative case. However, when adding the box corresponding to the assertion, we have to make sure that the wires of the assertion box match with the wires representing the nouns from the text.  

Below, the function `return_noun_list` returns all the nouns in a text. The function `return_q_nouns` return all the nouns in a question. In the latter, we take the third and sixth word as the person and location in the question respectively. This works because of the simple case of the bAbI6 experiments, all the questions are of the format "Is the person in the location?".

**Important note**: It is to note that this isn't the standard way to implement the questions/assertions, we went for the simplest approach in this tutorial for the sake of simplicity. More complex approaches to assertions on text can be found [here](https://arxiv.org/pdf/2409.08777). 

In [7]:
from lambeq.backend.grammar import Ty

def return_noun_list(text):
    noun_list = []
    for b in text.boxes:
        if b.dom == Ty() and b.cod == N:
            noun_list.append(b.name)
            
    return noun_list

In [8]:
def return_q_nouns(question):
    question_words = question.split(' ')
    q_nouns = [question_words[1], question_words[4].strip('?')]
    return q_nouns

We proceed to add the lists obtained from the functions above to the dictionary to be used later when building the assertion boxes.

In [9]:
for i in datadict.keys():
    datadict[i].update({'noun_list_text': return_noun_list(datadict[i]['dsc_diag'])})
    datadict[i].update({'noun_list_question': return_q_nouns(datadict[i]['question'])})

We needed to extract the list of nouns in the texts and the list of nouns in their corresponding questions to remove the entries where we ask questions on people or locations not present in the text.

In [10]:
reduced_datadict = {
    i: entry 
    for i, entry in datadict.items() 
    if set(entry['noun_list_question']).issubset(entry['noun_list_text'])
}

Moreover, remember that to enhance performance, we also wanted to limit the number of wires in every circuit by checking that every circuit's codomain (which is the number of open wires of a circuit) is less than or equal to `MAX_WIDTH`. The following filters the entries in the `datadict` dictionary and only keeps the entries in which the text circuits have less than or equal to the maximum number of wires specified in `MAX_WIDTH`.

In [11]:
# # Reducing the size of the dictionary by removing the circuits that have more than a certain number of wires. 
def right_cod_size(circuit):
    if len(circuit.cod) > MAX_WIDTH:
        return False
    return True

filtered_cod_datadict = {
    i: entry 
    for i, entry in reduced_datadict.items() 
    if right_cod_size(entry['text_circuit_sim4_13'])
}

Now that the text circuits have been post-processed for optimization, we move on to building the assertion circuits. These will later be sequentially composed with the text circuits.

We first start with constructing two boxes `q1` and `q2` for both the affirmative and negative assertions respectively. An affirmative assertion corresponds to an affirmative answer to the question. On the other hand, a negative assertion corresponds to a negative answer to the question. For example, if the question related to a text is "Is Emily in the kitchen?", the equivalent negative assertion would be "Emily is not in the kitchen". For the purposes of this training, all the questions are either in the format of "Is person in location?" or "Is person not in location". Therefore, we will need two boxes for the assertions, a box for the "is in" assertions, and another for the "is not in" assertions. The purpose of having two generic boxes is that the ML model will learn later the parameters for these boxes. For more details on the choice of implementation of assertions in {term}`DisCoCirc`, we recommend [this paper](https://arxiv.org/pdf/2409.08777).

We added two assertion boxes with swaps to match the wires in the text circuit, in case the noun order is flipped.

We apply the same {term}`ansatz <ansatz (plural: ansätze)>` applied on the text circuits (Sim4Ansatz with 3 layers and one qubit for each wire). 

In [12]:
from lambeq.backend.grammar import Box, Id
from lambeq.backend.quantum import  Bra, Discard, qubit, Id, Swap
from lambeq import AtomicType, Sim4Ansatz

N = AtomicType.NOUN

q1 = Box('is_in', N@N, N@N)
q2 = Box('is_not_in', N@N, N@N)

ansatz = Sim4Ansatz({N:1}, n_layers=3)
qcirc1 = ansatz(q1)
qcirc2 = ansatz(q2)

qcirc1_final = qcirc1 >> Bra(0) @ Bra(0)
qcirc2_final = qcirc2 >> Bra(0) @ Bra(0)

is_in_q = qcirc1_final
is_not_in_q = qcirc2_final

is_in_q_swp = Swap(qubit, qubit) >> qcirc1 >> Bra(0) @ Bra(0)
is_not_in_q_swp = Swap(qubit, qubit) >> qcirc2 >> Bra(0) @ Bra(0)

## 5. Assembling The Text Circuits with the Question Circuits

Now that we have all the ingredients in place (the text and assertion circuits), it is time to assemble them using sequential composition. We need to connect the assertion wires only to their matching text wires, so the nouns align. This is where the assetion circuits with the swaps will be needed. Moreover, we have to discard the wires of the nouns that are not included in the question.

Notice that, throughout the next cell, we always have two circuits. The circuit names ending in "aff" signal the circuits corrsponding to the affirmative assertions, while their counterparts ending in "neg" signal the ones corresponding to the negative assertions.

In [13]:
# Bringing everything together by plugging the assertion circuits to the text circuits
for i in reduced_datadict.keys():
    
    text_circuit = reduced_datadict[i]['text_circuit_sim4_13']
    text_nouns = datadict[i]['noun_list_text']
    q_nouns = datadict[i]['noun_list_question']
    qid1 = datadict[i]['noun_list_text'].index(q_nouns[0])
    qid2 = datadict[i]['noun_list_text'].index(q_nouns[1])

    swap_required = qid1 > qid2
    
    if qid1 == qid2:
        print('noun ids are idential, removing entry')
        del reduced_datadict[i]
        continue

    quest_mid_layer = Id(qubit) if (qid1 == 0 or qid2 == 0) else Discard()
    
    for k in range(1, len(text_circuit.cod)):
        if k == qid1 or k == qid2:
            quest_mid_layer = quest_mid_layer @ Id(qubit)
        else:
            quest_mid_layer = quest_mid_layer @ Discard()

    final_circuit = text_circuit >> quest_mid_layer

    if swap_required:
        final_circuit_aff = final_circuit >> is_in_q_swp
        final_circuit_neg = final_circuit >> is_not_in_q_swp
    else:
        final_circuit_aff = final_circuit >> is_in_q
        final_circuit_neg = final_circuit >> is_not_in_q

    reduced_datadict[i].update({'quantum_circ_pair_aff_neg': (final_circuit_aff, final_circuit_neg)})

## 6. Preparing The Datasets for Training

With our circuit pairs prepared, we now move on to the final step: training a model. This begins with creating three datasets: training, validation, and test. We ensure each set is balanced in terms of both text length and answers (positive and negative).


In [20]:
import random
from collections import defaultdict

# Add the 'measure' field to each item
for key, value in reduced_datadict.items():
    temp = -1 if value['answer'] == 0 else 1
    value['measure'] = temp * value['text_length']

# Group items by absolute value of measure
abs_value_groups = defaultdict(list)
for key, value in reduced_datadict.items():
    abs_value = abs(value['measure'])
    abs_value_groups[abs_value].append((key, value))

# Balance signs within each group and ensure diverse sizes
new_balanced_dict = {}
max_length = 100
for abs_value, items in abs_value_groups.items():
        # Separate positive and negative items
        positive_items = [(k, v) for k, v in items if v['measure'] > 0]
        negative_items = [(k, v) for k, v in items if v['measure'] < 0]
        
        # Determine the minimum balanced size for this group
        max_size = min(len(positive_items), len(negative_items), max_length)
        
        # Randomly sample from each group to balance
        balanced_positive = random.sample(positive_items, max_size)
        balanced_negative = random.sample(negative_items, max_size)

        # Add to the balanced dictionary
        for k, v in balanced_positive + balanced_negative:
            new_balanced_dict[k] = v

Lastly, we need to split the data into training, validation and test sets. We ensure that each dataset contains a balanced number of positive and negative answers during the split. 

In [17]:
import random
from collections import defaultdict

# Label configurations
train_ratio = 0.6
val_ratio = 0.25
test_ratio = 0.25

# We group by answer
answer_to_keys = defaultdict(list)
for key, value in new_balanced_dict.items():
    answer = value['answer']
    answer_to_keys[answer].append(key)

# Initializing the dictionaries for the training, valodation, and test datasets
training_dict_bAbI6, validation_dict_bAbI6, test_dict_bAbI6 = {}, {}, {}

# For each answer, we split the keys proportionally and add to splits
for label, keys in answer_to_keys.items():
    random.shuffle(keys)  # shuffle in-place for randomness

    total = len(keys)
    n_train = int(train_ratio * total)
    n_val = int(val_ratio * total)
    n_test = total - n_train - n_val  # just to account for any rounding error

    train_keys = keys[:n_train]
    val_keys = keys[n_train:n_train + n_val]
    test_keys = keys[n_train + n_val:]

    # Populating the datasets
    for k in train_keys:
        training_dict_bAbI6[k] = new_balanced_dict[k]
    for k in val_keys:
        validation_dict_bAbI6[k] = new_balanced_dict[k]
    for k in test_keys:
        test_dict_bAbI6[k] = new_balanced_dict[k]

print(f"Training set size: {len(training_dict_bAbI6)}")
print(f"Validation set size: {len(validation_dict_bAbI6)}")
print(f"Test size set: {len(test_dict_bAbI6)}")

Training set size: 62
Validation set size: 26
Test size set: 16


The following cell is to check that we have a balanced set. 

In [18]:
yes_count = 0
no_count = 0
for i in training_dict_bAbI6:
    if training_dict_bAbI6[i]['answer'] == 0:
        no_count += 1
    else:
        yes_count += 1

print(yes_count)
print(no_count)

31
31


Now, the final step is to store all of this data in separate files for training, validation, and testing, to be used in part II of this tutorial.

In [19]:
with open(TRAINING_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(training_dict_bAbI6, file)
with open(VALIDATION_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(validation_dict_bAbI6, file)
with open(TEST_DATASET_FILEPATH, 'wb') as file:
    pickle.dump(test_dict_bAbI6, file) 