In [8]:
import pandas as pd
import re

In [9]:
df = pd.read_csv("dataset/sqad_split.csv", index_col=0)
df = df.drop(5146)
df["sentences"] = df.text.map(lambda x: x.split("\n"))
df.head()

Unnamed: 0,text,question,answer,answer_sentence,sentences
1,Létající jaguár je novela spisovatele Josefa F...,Kdo je autorem novely Létající jaguár?,Josefa Formánka,Létající jaguár je novela spisovatele Josefa F...,[Létající jaguár je novela spisovatele Josefa ...
2,"Houby (Fungi, dříve Mycophyta) představují vel...",Jak se nazývá věda zabývající se houbami?,mykologie,Věda zabývající se houbami se nazývá mykologie.,"[Houby (Fungi, dříve Mycophyta) představují ve..."
3,Čokoláda je obvyklá součást nejrůznějších druh...,"Jak se nazývá strom, jehož zrna jsou využívána...",Theobroma cacao,"Čokoláda se vyrábí z kvašených, pražených a ml...",[Čokoláda je obvyklá součást nejrůznějších dru...
4,Václav Havel (5. října 1936 Praha – 18. prosin...,Kdo se stal prvním prezidentem České republiky?,Václav Havel,Václav Havel (5. října 1936 Praha – 18. prosin...,[Václav Havel (5. října 1936 Praha – 18. prosi...
5,"Pampeliška (Taraxacum), či také smetánka, je z...",Do jaké čeledi rostlin patří pampeliška?,hvězdnicovité,"Pampeliška (Taraxacum), či také smetánka, je z...","[Pampeliška (Taraxacum), či také smetánka, je ..."


## Fix missmatching answers 

In [10]:
def get_sentence_index(row):
    try:
        return row.sentences.index(row.answer_sentence)
    except ValueError:
        return -1


for row in df[df.apply(get_sentence_index, axis=1) == -1].iloc:
    try:
        anss = row.answer_sentence.split("\n")
        sntcs = row.sentences

        i_beg = row.sentences.index(anss[0])
        i_end = row.sentences.index(anss[-1])

        sentences = sntcs[:i_beg] + [" ".join(sntcs[i_beg:i_end+1])] + sntcs[i_end+1:]
        answer = " ".join(anss)

        df.loc[row.name].sentences = sentences
        df.loc[row.name].answer_sentence = answer
        print(f"{row.name: 6d}", len(answer.split(" ")))
    except ValueError:
        print(repr(row.sentences))
        print(repr(row.answer_sentence))
        print()


     8 16
   433 16
   466 21
   600 11
   630 41
   636 29
   652 18
   825 19
   931 79
   953 61
   954 61
   955 61
  1307 29
  1972 20
  2439 20
  2441 20
  2702 73
  5260 35
  6623 144
  8264 22
  9356 24
  9927 62
 10026 23
 10768 21
 11516 28
 11525 25
 11576 54
 11719 38
 13341 31
 13346 79
 13368 50
 13444 31


## Check if any missmatching answers remain 

In [11]:
starts = df.apply(get_sentence_index, axis=1)
df["start"] = starts
df[starts == -1]

Unnamed: 0,text,question,answer,answer_sentence,sentences,start


## Create data for chosen task, convert to SQuAD format

In [12]:
from random import randint, random

sentencesep=" "


def get_context(row, margin=5):
    pre = randint(0, margin)
    post = randint(margin-pre, 2*margin-pre)
    pos = row.start

    l = pos-pre
    r = pos+post+1
    slen = len(row.sentences)

    if r - l > slen:
        text_sntcs = row.sentences
#                 print(slen, text_sntcs)
    else:
        text_sntcs = []
        if l < 0:
            text_sntcs += row.sentences[l:]
            text_sntcs += row.sentences[:pos]
        else:
            text_sntcs += row.sentences[l:pos]

        if r > slen:
            text_sntcs += row.sentences[:r-slen]
        text_sntcs += row.sentences[pos:r]

        if not r-l == len(text_sntcs):
            print(l, pos, r, slen, r-l, len(text_sntcs))
            
    return text_sntcs


def create_training_data(df, task, margin=5, dupe=1.0):
    wholecount = 0
    index = 0
    data = []
    out = pd.DataFrame(columns=["text", "question", "answer", "answer_sentence", "pos"])
    for row in df.iloc:
        curdupe = dupe
        pars = []
        
        while random() <= curdupe:
            text_sntcs = get_context(row, margin)

            answer_start = None
            context = f"{sentencesep} " + f" {sentencesep} ".join(text_sntcs) + f" {sentencesep}"
        
            for answer in row.answer.split(" # "): # + [row.answer_sentence]:
                if answer.lower() in ["ano", "ne", "ano.", "ne."]:
                    break
                    answer = row.answer_sentence
                    wholecount -= 1

                m = re.search(re.escape(answer.lower()), context.lower())
                if m is None:
                    continue
#                 if answer == row.answer_sentence:
#                     wholecount += 1
                answer_start = m.start()
                break
                
            if answer_start is None:
                wholecount += 1
#                 print(row.question, "\n", answer, "\n", context, "\n")
                break
    
#             print(context.split(" ")[:3])
            if context.split(" ")[2].lower() == answer.split(" ")[0]:
#                 print(answer)
                answer = answer.capitalize()
#                 if context.split(" ")[2] == answer.split(" ")[0]:
#                     print("ok", answer)
#                 else:
#                     print("nok", answer)
            
                
            
            par = {
                "qas": [{
                    "id": f"{row.name}.{len(pars)}",
                    "question": row.question,
                    "answers": [{
                        "text": answer,
                        "answer_start": answer_start
                    }],
                    "is_impossible": False
                }],
                "context": context
            }
            
#             if r - l > slen:
#                 break
        
            pars.append(par)
            curdupe -= 1
        
        if pars:
            data.append({"title": str(row.name), "paragraphs": pars})
            
    print(wholecount, len(data))
        
    return data


## View data

In [15]:
import ipywidgets as widgets
from IPython.display import HTML

def highlight_answer(text, answer, start):
    text = f"{text[:start]}>{text[start:]}"
    sub = answer
    if sub:
        l = text.split(sub)
        return f'<span style="background-color: #CCCC00">{sub}</span>'.join(l)
    return text.replace(sentencesep, "<br>")

def view_data(data):
    questions = {}
    for doc in data:
        for par in doc["paragraphs"]:
            for qq in par["qas"]:
                if not qq["answers"]:
                    continue
                q = {
                    "question": qq["question"],
                    "answer": qq["answers"][0]["text"],
                    "start": qq["answers"][0]["answer_start"],
                    "context": par["context"]
                }
                questions[qq["id"]] = q
    
    
    keys = list(questions.keys())
    slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(keys),
        step=1,
        description='Test:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )

    left = widgets.Output(layout={"width": "100%"})

    def f(val):
        key = keys[val]
        q = questions[key]
        
        left.clear_output()
        with left:
            display(HTML(
                q["question"] + "<hr>" + q["answer"] + "<hr>" + highlight_answer(q["context"], q["answer"], q["start"])
            ))
        
        
    out = widgets.interactive_output(f, {'val': slider})
    display(widgets.VBox([
        slider,
        widgets.HBox([left])
    ]))    

In [17]:
data = create_training_data(df, "answer extraction", 0, 1)

2309 11163


In [18]:
view_data(data)

VBox(children=(IntSlider(value=0, continuous_update=False, description='Test:', max=11163), HBox(children=(Out…

## Save the dataset

In [19]:
import json
import random

def write_squad_json(path, data, version="42.0"):
    jj = {
        "version": version,
        "data": data
    }

    with open(path, "w") as f:
        json.dump(jj, f)

def save_dataset(name):
    dev_part = 0.05

    trainname = f"dataset/{name}_train.json"
    devname = f"dataset/{name}_dev.json"

    random.shuffle(data)
    split = int(len(data)*dev_part)
    dev = data[:split]
    train = data[split:]
    print(len(train), len(dev))

    write_squad_json(trainname, train)
    write_squad_json(devname, dev)

In [None]:
save_dataset("sqad_extract")