In [2]:
!pip install transformers



In [3]:
import pandas as pd
import numpy as np
import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

In [4]:
coqa = pd.read_json('http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json')
coqa.head()

Unnamed: 0,version,data
0,1,"{'source': 'wikipedia', 'id': '3zotghdk5ibi9ce..."
1,1,"{'source': 'cnn', 'id': '3wj1oxy92agboo5nlq4r7..."
2,1,"{'source': 'gutenberg', 'id': '3bdcf01ogxu7zdn..."
3,1,"{'source': 'cnn', 'id': '3ewijtffvo7wwchw6rtya..."
4,1,"{'source': 'gutenberg', 'id': '3urfvvm165iantk..."


In [5]:
del coqa["version"]

In [6]:
#required columns in our dataframe
cols = ["text","question","answer"]
#list of lists to create our dataframe
comp_list = []
for index, row in coqa.iterrows():
    for i in range(len(row["data"]["questions"])):
        temp_list = []
        temp_list.append(row["data"]["story"])
        temp_list.append(row["data"]["questions"][i]["input_text"])
        temp_list.append(row["data"]["answers"][i]["input_text"])
        comp_list.append(temp_list)
new_df = pd.DataFrame(comp_list, columns=cols)
#saving the dataframe to csv file for further loading
new_df.to_csv("CoQA_data.csv", index=False)

In [7]:
data = pd.read_csv("CoQA_data.csv")
data.head()

Unnamed: 0,text,question,answer
0,"The Vatican Apostolic Library (), more commonl...",When was the Vat formally opened?,It was formally established in 1475
1,"The Vatican Apostolic Library (), more commonl...",what is the library for?,research
2,"The Vatican Apostolic Library (), more commonl...",for what subjects?,"history, and law"
3,"The Vatican Apostolic Library (), more commonl...",and?,"philosophy, science and theology"
4,"The Vatican Apostolic Library (), more commonl...",what was started in 2014?,a project


In [8]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly ident

In [9]:
random_num = np.random.randint(0,len(data))
question = data["question"][random_num]
text = data["text"][random_num]

In [10]:
input_ids = tokenizer.encode(question, text)
print("The input has a total of {} tokens.".format(len(input_ids)))

The input has a total of 218 tokens.


In [11]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)
for token, id in zip(tokens, input_ids):
    print('{:8}{:8,}'.format(token,id))

[CLS]        101
do         2,079
some       2,070
teachers   5,089
have       2,031
to         2,000
continue   3,613
their      2,037
education   2,495
after      2,044
they       2,027
qualify    7,515
to         2,000
teach      6,570
?          1,029
[SEP]        102
the        1,996
role       2,535
of         1,997
teacher    3,836
is         2,003
often      2,411
formal     5,337
and        1,998
ongoing    7,552
,          1,010
carried    3,344
out        2,041
at         2,012
a          1,037
school     2,082
or         2,030
other      2,060
place      2,173
of         1,997
formal     5,337
education   2,495
.          1,012
in         1,999
many       2,116
countries   3,032
,          1,010
a          1,037
person     2,711
who        2,040
wishes     8,996
to         2,000
become     2,468
a          1,037
teacher    3,836
must       2,442
first      2,034
obtain     6,855
specified   9,675
professional   2,658
qualifications  15,644
or         2,030
credentials  22,4

In [12]:
tokenized_sentence = tokenizer.encode(text, padding=True, truncation=True,max_length=50, add_special_tokens = True)

In [13]:
#first occurence of [SEP] token
sep_idx = input_ids.index(tokenizer.sep_token_id)
print("SEP token index: ", sep_idx)

num_seg_a = sep_idx+1
print("Number of tokens in segment A: ", num_seg_a)

num_seg_b = len(input_ids) - num_seg_a
print("Number of tokens in segment B: ", num_seg_b)

segment_ids = [0]*num_seg_a + [1]*num_seg_b

assert len(segment_ids) == len(input_ids)

SEP token index:  15
Number of tokens in segment A:  16
Number of tokens in segment B:  202


In [14]:
#token input_ids to represent the input and token segment_ids to differentiate our segments - question and text
output = model(torch.tensor([input_ids]),  token_type_ids=torch.tensor([segment_ids]))

In [15]:
#tokens with highest start and end scores
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits)
if answer_end >= answer_start:
    answer = " ".join(tokens[answer_start:answer_end+1])
else:
    print("I am unable to find the answer to this question. Can you please ask another question?")

print("\nQuestion:\n{}".format(question.capitalize()))
print("\nAnswer:\n{}.".format(answer.capitalize()))


Question:
Do some teachers have to continue their education after they qualify to teach?

Answer:
Teachers , like other professionals , may have to continue their education after they qualify , a process known as continuing professional development.


In [16]:
def question_answer(question, text):

    #tokenize question and text as a pair
    input_ids = tokenizer.encode(question, text)

    #string version of tokenized ids
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    sep_idx = input_ids.index(tokenizer.sep_token_id)
    #number of tokens in segment A (question)
    num_seg_a = sep_idx+1
    #number of tokens in segment B (text)
    num_seg_b = len(input_ids) - num_seg_a

    #list of 0s and 1s for segment embeddings
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(input_ids)

    #model output using input_ids and segment_ids
    output = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]))

    #reconstructing the answer
    answer_start = torch.argmax(output.start_logits)
    answer_end = torch.argmax(output.end_logits)
    if answer_end >= answer_start:
        answer = tokens[answer_start]
        for i in range(answer_start+1, answer_end+1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]

    if answer.startswith("[CLS]"):
        answer = "Unable to find the answer to your question."

    print("\nPredicted answer:\n{}".format(answer.capitalize()))

In [37]:
import textwrap

wrapper = textwrap.TextWrapper(width=150)

article = "A 10-nation coalition has been announced to fight missile and drone attacks by Houthi militants on ships transiting in the southern Red Sea and the Gulf of Aden. The patrol force was announced by US Secretary of Defense Lloyd Austin on Monday and will see the participation of the United Kingdom, Bahrain, Canada, France, Italy, the Netherlands, Norway, Seychelles and Spain.\n\n\"This is an international challenge that demands collective action,\" Defence Secretary Lloyd Austin said in a statement released from Bahrain. \"Therefore today I am announcing the establishment of Operation Prosperity Guardian, an important new multinational security initiative.\"\n\n\n\nThe patrol mission will be coordinated by Combined Task Force 153 set up in April 2022 to improve maritime security in the Red Sea, Bab el-Mandeb and the Gulf of Aden. The Task Force has 39 member nations.\n\n\n\nWhile some countries will conduct joint patrols, others provide intelligence support. \"Besides these 10, several other countries have also agreed to be involved in the operation but prefer not to be publicly named,\" AP quoted a defence official on the condition of anonymity.\n\n\n\nThe formation of a patrol force comes as Iran-backed Houthi rebels escalate attacks on tankers, cargo ships and other vessels in the Red Sea to avenge Israel's invasion of the Gaza Strip. The move has affected the cargo movement as this transit route carries up to 12 per cent of global trade. Though Houthis said \"no harm will be dealt\" to ships heading to ports around the world except for Israeli ports, multiple shipping companies had ordered their ships to hold in place and not enter the Bab el-Mandeb Strait until the situation can be addressed.\n\n\n\nOn Monday, a cargo ship Swan Atlantic came under attack from Houthis, following which a US warship, USS Carney, rushed to their aid. According to the US Central Command, the chemical/oil tanker, a Cayman Islands-flagged ship, was attacked by a one-way attack drone and an anti-ship ballistic missile launched from a Houthi-controlled area in Yemen.\n\n\n\nAnother vessel also came under Houthi attack on the same day. The bulk cargo ship MSC Clara reported an explosion in the water near its location, CENTCOM added.\n\n\n\nAt present, the US has deployed two warships - the USS Carney and the USS Mason -- through the Bab el-Mandeb Strait daily to respond to attacks from the Houthis."
print(wrapper.fill(article))

A 10-nation coalition has been announced to fight missile and drone attacks by Houthi militants on ships transiting in the southern Red Sea and the
Gulf of Aden. The patrol force was announced by US Secretary of Defense Lloyd Austin on Monday and will see the participation of the United Kingdom,
Bahrain, Canada, France, Italy, the Netherlands, Norway, Seychelles and Spain.  "This is an international challenge that demands collective action,"
Defence Secretary Lloyd Austin said in a statement released from Bahrain. "Therefore today I am announcing the establishment of Operation Prosperity
Guardian, an important new multinational security initiative."    The patrol mission will be coordinated by Combined Task Force 153 set up in April
2022 to improve maritime security in the Red Sea, Bab el-Mandeb and the Gulf of Aden. The Task Force has 39 member nations.    While some countries
will conduct joint patrols, others provide intelligence support. "Besides these 10, several other countries h

In [38]:
question = "Whose participation will be seen by the Patrol ForceS?"
question_answer(question, article)


Predicted answer:
United kingdom , bahrain , canada , france , italy , the netherlands , norway , seychelles and spain
