In [1]:
!pip install transformers datasets



In [2]:
!pip install sentencepiece



In [3]:
!pip install yake



In [8]:
!pip install rank_bm25

Collecting bm25
  Downloading BM25-1.0.0.tar.gz (1.1 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: bm25
  Building wheel for bm25 (setup.py) ... [?25ldone
[?25h  Created wheel for bm25: filename=BM25-1.0.0-py3-none-any.whl size=1741 sha256=ae46d4ea57b8bdb4e635e92cd434faaefb00bac168931dd35dee14c404a6e15d
  Stored in directory: /raid/kavin-intern-maunendra/.cache/pip/wheels/00/0c/3c/eac477a276d6eebe52cb68ef5140c49bed58e5f418ce262301
Successfully built bm25
Installing collected packages: bm25
Successfully installed bm25-1.0.0


In [5]:
!pip install sklearn



In [15]:
import json
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np
from random import shuffle, randint
# from tqdm import tqdm
from datasets import DatasetDict, Dataset


### Loading Datasets

with open("wizard_of_wiki/train.json", "r") as f:
    wow_data_train = json.load(f)

with open("wizard_of_wiki/data.json", "r") as f:
    wow_data = json.load(f)

with open("wizard_of_wiki/test_random_split.json", "r") as f:
    wow_data_test = json.load(f)

with open("wizard_of_wiki/valid_random_split.json", "r") as f:
    wow_data_valid = json.load(f)

### Knowledge Corpus
topic2passage = {}
passages = []
topics = []

for i, conv in enumerate(wow_data):
    try:
        if topic2passage[conv["chosen_topic"]] != conv["chosen_topic_passage"]:
            if len(conv["chosen_topic_passage"]) > len(topic2passage[conv["chosen_topic"]]):
                topic2passage[conv["chosen_topic"]] = conv["chosen_topic_passage"]

    except KeyError:
        topic2passage[conv["chosen_topic"]] = conv["chosen_topic_passage"]
        topics.append(conv["chosen_topic"])
        passages.append(conv["chosen_topic_passage"])

    for j, dial in enumerate(conv["dialog"]):
        rp = {k: v for pas in dial["retrieved_passages"] for k, v in pas.items()}
        for t, pk in rp.items():
            try:
                if topic2passage[t] != pk:
                    if len(pk) > len(topic2passage[t]):
                        topic2passage[t] = pk
            except KeyError:
                topic2passage[t] = pk
                topics.append(t)
                passages.append(t)


for i, conv in enumerate(wow_data_train):
    try:
        if topic2passage[conv["chosen_topic"]] != conv["chosen_topic_passage"]:
            if len(conv["chosen_topic_passage"]) > len(topic2passage[conv["chosen_topic"]]):
                topic2passage[conv["chosen_topic"]] = conv["chosen_topic_passage"]

    except KeyError:
        topic2passage[conv["chosen_topic"]] = conv["chosen_topic_passage"]
        topics.append(conv["chosen_topic"])
        passages.append(conv["chosen_topic_passage"])

    for j, dial in enumerate(conv["dialog"]):
        rp = {k: v for pas in dial["retrieved_passages"] for k, v in pas.items()}
        
        for t, pk in rp.items():
            try:
                if topic2passage[t] != pk:
                    if len(pk) > len(topic2passage[t]):
                        topic2passage[t] = pk
            except KeyError:
                topic2passage[t] = pk
                topics.append(t)
                passages.append(t)
topic2passage["no_passages_used"] = []
topic2id = {}
id2topic = {}
for i, key in enumerate(topic2passage.keys()):
    topic2id[key] = i
    id2topic[i] = key

### BM25
def bm25_tokenizer(text):
    text = " ".join(text)
    tokenized_doc = []
    for token in text.lower().split():
        token = token.strip(string.punctuation)

        if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
            tokenized_doc.append(token)
    return tokenized_doc

all_para = topic2passage.values()

tokenized_corpus = []
for passage in tqdm(all_para):
    tokenized_corpus.append(bm25_tokenizer(passage))

bm25 = BM25Okapi(tokenized_corpus)


def search(query, n):

    # print("Input question:", query)

    ##### BM25 search (lexical search) #####
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -n)[-n:]

    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
    bm25_idx = [b["corpus_id"] for b in bm25_hits]

    shuffle(bm25_idx)
    out = [id2topic[i] for i in bm25_idx]


    # print(len(bm25_hits))
    return out

### Dataset Pre-Processing
index = 0
g7 = 0
l7 = 0
f = []
f2 = []
not_found = []
full_dat_train = {
    "gold_pass": [],
    "all_pass": [],
    "gold_sen": [],
    "all_sen": [],
    "last_ut": [],
    "context": [],
    "response": [],
    "context_eou": [],
    "all_topic": [],
}
for i, conv in tqdm(enumerate(wow_data_train)):

    input = ""
    last_ut = ""
    pseudo = ""

    for j, dial in enumerate(conv["dialog"]):
        if "Wizard" in dial["speaker"] and j > 0:
            # res_key = custom_kw_extractor_1.extract_keywords(dial["text"])
            # rk, res_key = key_exrt(res_key, dial["text"])
            # ut_key = custom_kw_extractor_1.extract_keywords(last_ut)
            # uk, ut_key = key_exrt(ut_key, last_ut)

            # masked_res = masker(dial["text"], rk, pseudo)

            # pr_in = f"ut: {ut_key} res: {res_key}"
            # sr_in = f"ut: {ut_key} res: {res_key}"
            # mpr_in = f"ut: {last_ut} res: {masked_res}"
            # msr_in = f"ut: {last_ut} res: {masked_res}"
            full_dat_train["context"].append(input + dial["speaker"][2:] + ": ")
            full_dat_train["last_ut"].append(last_ut)

            response = dial["text"]

            knowledge = [list(k.keys())[0] for k in dial["retrieved_passages"]]

            if len(knowledge) < 7:
                knowledge = search(response, 7)
                l7 += 1
            elif len(knowledge) > 7:
                knowledge = knowledge[:7]
                g7 += 1

            try:
                gold_pass = list(dial["checked_passage"].values())[0]
            except IndexError:
                gold_pass = "no_passages_used"

            try:
                gold_sen = list(dial["checked_sentence"].values())[0]
            except IndexError:
                gold_sen = "no_passages_used"

            if gold_pass == "no_passages_used" and gold_sen != "no_passages_used":
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])

            try:
                assert type(topic2passage[gold_pass]) == list
            except:
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])
                try:
                    assert type(topic2passage[gold_pass]) == list
                except:
                    gold_pass = "no_passages_used"
                    gold_sen = "no_passages_used"
                    f2.append((i, j))

            if gold_pass not in knowledge and gold_pass != "no_passages_used":
                idx = randint(0, 6)
                knowledge[idx] = gold_pass

            all_sen = topic2passage[gold_pass]


            if gold_pass == "no_passages_used":
                gid = float("-inf")
                sid = float("-inf")
            else:
                gid = knowledge.index(gold_pass)
                try:
                    sid = topic2passage[gold_pass].index(gold_sen)
                except:
                    sid = len(all_sen)
                    all_sen += [gold_sen]
                    f.append((i, j))

            all_pass = [" ".join(topic2passage[t]) for t in knowledge]
            if gid != float("-inf"):
                all_pass[gid] = " ".join(all_sen)

            full_dat_train["gold_pass"].append(gid)
            full_dat_train["gold_sen"].append(sid)
            full_dat_train["all_pass"].append(all_pass)
            full_dat_train["all_sen"].append(all_sen)
            full_dat_train["response"].append(dial["text"])
            full_dat_train["context_eou"].append(pseudo[:-1])
            full_dat_train["all_topic"].append(knowledge)



            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "

        else:
            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "
            last_ut = dial['text']

index = 0
g7 = 0
l7 = 0
f = []
f2 = []
not_found = []
full_dat_valid = {
    "gold_pass": [],
    "all_pass": [],
    "gold_sen": [],
    "all_sen": [],
    "last_ut": [],
    "context": [],
    "response": [],
    "context_eou": [],
    "all_topic": [],
}
for i, conv in tqdm(enumerate(wow_data_valid)):

    input = ""
    last_ut = ""
    pseudo = ""

    for j, dial in enumerate(conv["dialog"]):
        if "Wizard" in dial["speaker"] and j > 0:
            # res_key = custom_kw_extractor_1.extract_keywords(dial["text"])
            # rk, res_key = key_exrt(res_key, dial["text"])
            # ut_key = custom_kw_extractor_1.extract_keywords(last_ut)
            # uk, ut_key = key_exrt(ut_key, last_ut)

            # masked_res = masker(dial["text"], rk, pseudo)

            # pr_in = f"ut: {ut_key} res: {res_key}"
            # sr_in = f"ut: {ut_key} res: {res_key}"
            # mpr_in = f"ut: {last_ut} res: {masked_res}"
            # msr_in = f"ut: {last_ut} res: {masked_res}"
            full_dat_valid["context"].append(input + dial["speaker"][2:] + ": ")
            full_dat_valid["last_ut"].append(last_ut)

            response = dial["text"]

            knowledge = [list(k.keys())[0] for k in dial["retrieved_passages"]]

            if len(knowledge) < 7:
                knowledge = search(response, 7)
                l7 += 1
            elif len(knowledge) > 7:
                knowledge = knowledge[:7]
                g7 += 1

            try:
                gold_pass = list(dial["checked_passage"].values())[0]
            except IndexError:
                gold_pass = "no_passages_used"

            try:
                gold_sen = list(dial["checked_sentence"].values())[0]
            except IndexError:
                gold_sen = "no_passages_used"

            if gold_pass == "no_passages_used" and gold_sen != "no_passages_used":
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])

            try:
                assert type(topic2passage[gold_pass]) == list
            except:
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])
                try:
                    assert type(topic2passage[gold_pass]) == list
                except:
                    gold_pass = "no_passages_used"
                    gold_sen = "no_passages_used"
                    f2.append((i, j))

            if gold_pass not in knowledge and gold_pass != "no_passages_used":
                idx = randint(0, 6)
                knowledge[idx] = gold_pass

            all_sen = topic2passage[gold_pass]


            if gold_pass == "no_passages_used":
                gid = float("-inf")
                sid = float("-inf")
            else:
                gid = knowledge.index(gold_pass)
                try:
                    sid = topic2passage[gold_pass].index(gold_sen)
                except:
                    sid = len(all_sen)
                    all_sen += [gold_sen]
                    f.append((i, j))

            all_pass = [" ".join(topic2passage[t]) for t in knowledge]
            if gid != float("-inf"):
                all_pass[gid] = " ".join(all_sen)

            full_dat_valid["gold_pass"].append(gid)
            full_dat_valid["gold_sen"].append(sid)
            full_dat_valid["all_pass"].append(all_pass)
            full_dat_valid["all_sen"].append(all_sen)
            full_dat_valid["response"].append(dial["text"])
            full_dat_valid["context_eou"].append(pseudo[:-1])
            full_dat_valid["all_topic"].append(knowledge)



            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "

        else:
            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "
            last_ut = dial['text']

index = 0
g7 = 0
l7 = 0
f = []
f2 = []
not_found = []
full_dat_test = {
    "gold_pass": [],
    "all_pass": [],
    "gold_sen": [],
    "all_sen": [],
    "last_ut": [],
    "context": [],
    "response": [],
    "context_eou": [],
    "all_topic": [],
}
for i, conv in tqdm(enumerate(wow_data_test)):

    input = ""
    last_ut = ""
    pseudo = ""

    for j, dial in enumerate(conv["dialog"]):
        if "Wizard" in dial["speaker"] and j > 0:
            # res_key = custom_kw_extractor_1.extract_keywords(dial["text"])
            # rk, res_key = key_exrt(res_key, dial["text"])
            # ut_key = custom_kw_extractor_1.extract_keywords(last_ut)
            # uk, ut_key = key_exrt(ut_key, last_ut)

            # masked_res = masker(dial["text"], rk, pseudo)

            # pr_in = f"ut: {ut_key} res: {res_key}"
            # sr_in = f"ut: {ut_key} res: {res_key}"
            # mpr_in = f"ut: {last_ut} res: {masked_res}"
            # msr_in = f"ut: {last_ut} res: {masked_res}"
            full_dat_test["context"].append(input + dial["speaker"][2:] + ": ")
            full_dat_test["last_ut"].append(last_ut)

            response = dial["text"]

            knowledge = [list(k.keys())[0] for k in dial["retrieved_passages"]]

            if len(knowledge) < 7:
                knowledge = search(response, 7)
                l7 += 1
            elif len(knowledge) > 7:
                knowledge = knowledge[:7]
                g7 += 1

            try:
                gold_pass = list(dial["checked_passage"].values())[0]
            except IndexError:
                gold_pass = "no_passages_used"

            try:
                gold_sen = list(dial["checked_sentence"].values())[0]
            except IndexError:
                gold_sen = "no_passages_used"

            if gold_pass == "no_passages_used" and gold_sen != "no_passages_used":
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])

            try:
                assert type(topic2passage[gold_pass]) == list
            except:
                gold_pass = " ".join(list(dial["checked_sentence"].keys())[0].split("_")[1:-1])
                try:
                    assert type(topic2passage[gold_pass]) == list
                except:
                    gold_pass = "no_passages_used"
                    gold_sen = "no_passages_used"
                    f2.append((i, j))

            if gold_pass not in knowledge and gold_pass != "no_passages_used":
                idx = randint(0, 6)
                knowledge[idx] = gold_pass

            all_sen = topic2passage[gold_pass]


            if gold_pass == "no_passages_used":
                gid = float("-inf")
                sid = float("-inf")
            else:
                gid = knowledge.index(gold_pass)
                try:
                    sid = topic2passage[gold_pass].index(gold_sen)
                except:
                    sid = len(all_sen)
                    all_sen += [gold_sen]
                    f.append((i, j))

            all_pass = [" ".join(topic2passage[t]) for t in knowledge]
            if gid != float("-inf"):
                all_pass[gid] = " ".join(all_sen)

            full_dat_test["gold_pass"].append(gid)
            full_dat_test["gold_sen"].append(sid)
            full_dat_test["all_pass"].append(all_pass)
            full_dat_test["all_sen"].append(all_sen)
            full_dat_test["response"].append(dial["text"])
            full_dat_test["context_eou"].append(pseudo[:-1])
            full_dat_test["all_topic"].append(knowledge)



            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "

        else:
            pseudo += dial["text"]
            pseudo += " <eou> "
            input += f"{dial['speaker'][2:]}: {dial['text']} "
            last_ut = dial['text']

wow = DatasetDict()
wow["train"] = Dataset.from_dict(full_dat_train)

wow["valid"] = Dataset.from_dict(full_dat_valid)

wow["test"] = Dataset.from_dict(full_dat_test)

wow.save_to_disk("wow")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 184164/184164 [00:14<00:00, 12695.72it/s]
18430it [00:11, 1588.88it/s]
981it [00:00, 1124.27it/s]
965it [00:01, 732.99it/s]
Saving the dataset (2/2 shards): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 74092/74092 [00:00<00:00, 106624.73 examples/s]
Saving the dataset (1/1 shards): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3939/3939 [00:00<00:00, 86314.45 examples/s]
Saving the dataset (1/1 shards): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3865/38

In [7]:
# import rank_bm25

ModuleNotFoundError: No module named 'rank_bm25'