In [None]:
import datasets
import numpy as np
import pandas as pd
import torch
import transformers
import os
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("cointegrated/rut5-small-chitchat2", truncation_side='left', padding_side='right')

In [None]:
test = datasets.Dataset.from_json('/home/stc/persona/data/raw/TolokaPersonaChat(test).jsonl')
train = datasets.Dataset.from_json('/home/stc/persona/data/raw/TolokaPersonaChat(train).jsonl')
ds =  datasets.DatasetDict({"train": train, "test": test})

In [None]:
def join_same_person(row):
    dialog = row['dialog']
    new_dialog = dialog[:1]
    for d in dialog[1:]:
        if new_dialog[-1]["person"] == d["person"]:
            new_dialog[-1]["text"] = new_dialog[-1]["text"] + " " + d["text"]
            new_dialog[-1]["gk"] = list(set(new_dialog[-1]["gk"]) | set(d["gk"]))
        else:
            new_dialog.append(d)
    return {"dialog": new_dialog}

In [None]:
def get_gk_from_persona(row):
    dialog = row['dialog']
    persons = row['persons']
    pocesed_dialog = []
    for turn in dialog:
        persona = persons[turn['person']]
        gk = [persona['description'][i] for i in turn['gk']]
        gender = persona['gender']
        pocesed_dialog.append({"text": turn['text'], "gks": gk, "gender": gender})
    return {"dialog": pocesed_dialog}

In [None]:
def next_answer_sampler(batch):
    dialogs = batch['dialog']
    historys=[]
    answers=[]
    gks = []
    for dialog in dialogs:
        for turn_i in range(1, len(dialog)):
            history = dialog[: turn_i]
            answer = dialog[turn_i]
            gk = dialog[turn_i]["gks"]
            historys.append(history)
            if len(gk)==0:
                gk = ["<EmptyGK>"]
            gks.append(gk)
            answers.append(answer)
    [[turn.pop('gks', 0) for turn in dialog] for dialog in historys]
    [answer.pop('gks', 0) for answer in answers]
    return {"history": historys, "gk": gks, "answer": answers}

def current_gk_sampler(batch):
    dialogs = batch['dialog']
    turns=[]
    gks=[]
    for dialog in dialogs:
        for turn in dialog:
            if len(turn['gks'])>0:
                for gk in turn['gks']:
                    turns.append(turn)
                    gks.append(gk)
            else:
                turns.append(turn)
                gks.append('<EmptyGK>')
    [turn.pop('gks', 0) for turn in turns]  
    return {"turn": turns, "gk": gks}


def next_gk_sampler(batch):
    dialogs = batch['dialog']
    historys=[]
    gks = []
    all_gks=[]
    for dialog in dialogs:
        for turn_i in range(1, len(dialog)):
            history = dialog[: turn_i]
            answer = dialog[turn_i]
            if len(answer['gks'])>0:
                for gk in answer['gks']:
                    historys.append(history)
                    gks.append(gk)
                    all_gks.append(answer['gks'])
            else:
                historys.append(history)
                gks.append('<EmptyGK>')
                all_gks.append(['<EmptyGK>'])
            
    [[turn.pop('gks', 0) for turn in dialog] for dialog in historys]
    return {"history": historys, "gk": gks, "all_gks": all_gks}

In [None]:
new_ds = ds.map(join_same_person)
new_ds = new_ds.map(get_gk_from_persona, remove_columns=["persons"])

next_answer_ds= new_ds.map(next_answer_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)
current_gk_ds= new_ds.map(current_gk_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)
next_gk_ds= new_ds.map(next_gk_sampler, remove_columns=new_ds['train'].column_names, batched=True, batch_size=2)


In [None]:
next_answer_ds.save_to_disk('next_answer')
current_gk_ds.save_to_disk('current_gk')
next_gk_ds.save_to_disk('next_gk')

In [None]:
next_gk_ds['train'][159]