In [1]:
import datasets
import numpy as np
import pandas as pd
import torch
import transformers
import os
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def join_same_person(row):
    dialog = row['dialog']
    new_dialog = dialog[:1]
    for d in dialog[1:]:
        if new_dialog[-1]["person"] == d["person"]:
            new_dialog[-1]["text"] = new_dialog[-1]["text"] + " " + d["text"]
            new_dialog[-1]["gk"] = list(set(new_dialog[-1]["gk"]) | set(d["gk"]))
        else:
            new_dialog.append(d)
    return {"dialog": new_dialog}

def get_gk_from_persona(row):
    dialog = row['dialog']
    persons = row['persons']
    pocesed_dialog = []
    for turn in dialog:
        persona = persons[turn['person']]
        gk = [persona['description'][i] for i in turn['gk']]
        gender = persona['gender']
        pocesed_dialog.append({"text": turn['text'], "gks": gk, "gender": gender})
    return {"dialog": pocesed_dialog}

In [3]:
def next_answer_sampler(batch):
    dialogs = batch['dialog']
    historys=[]
    answers=[]
    gks = []
    for dialog in dialogs:
        for turn_i in range(1, len(dialog)):
            history = dialog[: turn_i]
            answer = dialog[turn_i]
            gk = dialog[turn_i]["gks"]
            historys.append(history)
            if len(gk)==0:
                gk = ["<EmptyGK>"]
            gks.append(gk)
            answers.append(answer)
    [[turn.pop('gks', 0) for turn in dialog] for dialog in historys]
    [answer.pop('gks', 0) for answer in answers]
    return {"history": historys, "gk": gks, "answer": answers}

def current_gk_sampler(batch):
    dialogs = batch['dialog']
    turns=[]
    gks=[]
    for dialog in dialogs:
        for turn in dialog:
            if len(turn['gks'])>0:
                turns.append(turn)
                gks.append(turn['gks'])
            else:
                turns.append(turn)
                gks.append({'<EmptyGK>'})
    [turn.pop('gks', 0) for turn in turns]  
    return {"turn": turns, "gk": gks}


def next_gk_sampler(batch):
    dialogs = batch['dialog']
    historys=[]
    gks = []
    all_gks=[]
    for dialog in dialogs:
        for turn_i in range(1, len(dialog)):
            history = dialog[: turn_i]
            answer = dialog[turn_i]
            if len(answer['gks'])>0:
                for gk in answer['gks']:
                    historys.append(history)
                    gks.append(gk)
                    all_gks.append(answer['gks'])
            else:
                historys.append(history)
                gks.append('<EmptyGK>')
                all_gks.append(['<EmptyGK>'])
            
    [[turn.pop('gks', 0) for turn in dialog] for dialog in historys]
    return {"history": historys, "gk": gks, "all_gks": all_gks}

In [6]:
train = datasets.Dataset.from_json('../raw/TolokaPersonaChat(train).jsonl')
val = datasets.Dataset.from_json('../raw/TolokaPersonaChat(val).jsonl')
#test = datasets.Dataset.from_json('../raw/all_dialogs.jsonl')
ds =  datasets.DatasetDict({"train": train, "val":val}) # , "test": test

new_ds = ds.map(join_same_person)
new_ds = new_ds.map(get_gk_from_persona, remove_columns=["persons"])

next_answer_ds= new_ds.map(next_answer_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)
current_gk_ds= new_ds.map(current_gk_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)
next_gk_ds= new_ds.map(next_gk_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)

Using custom data configuration default-88c455e72692d40c
Found cached dataset json (/home/posokhov@ad.speechpro.com/.cache/huggingface/datasets/json/default-88c455e72692d40c/0.0.0)
Using custom data configuration default-9b97865f25fbb5b3
Found cached dataset json (/home/posokhov@ad.speechpro.com/.cache/huggingface/datasets/json/default-9b97865f25fbb5b3/0.0.0)
100%|██████████| 9018/9018 [00:02<00:00, 3447.26ex/s]
100%|██████████| 995/995 [00:00<00:00, 4429.19ex/s]
100%|██████████| 9018/9018 [00:01<00:00, 4582.40ex/s]
100%|██████████| 995/995 [00:00<00:00, 6119.57ex/s]
100%|██████████| 4509/4509 [00:06<00:00, 736.88ba/s]
100%|██████████| 498/498 [00:00<00:00, 731.31ba/s]
100%|██████████| 4509/4509 [00:03<00:00, 1238.03ba/s]
100%|██████████| 498/498 [00:00<00:00, 1170.23ba/s]
100%|██████████| 4509/4509 [00:05<00:00, 813.58ba/s]
100%|██████████| 498/498 [00:00<00:00, 895.64ba/s]


In [7]:
next_answer_ds.save_to_disk('../processed/data/next_answer')
current_gk_ds.save_to_disk('../processed/current_gk')
next_gk_ds.save_to_disk('../processed/next_gk')