### Get splits for BART training

In [1]:
import pandas as pd
import pickle
import re
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
with open('data/threads_with_metas_3ut_aug_full.pkl', 'rb') as f:
    threads = pickle.load(f)

In [3]:
sep_token = '</s>'

In [4]:
utt_set = set()
speakers_set = set()

In [5]:
def get_discourse_tokens(discourse_list):
    return [f'<u{discourse_list[1]+1}>', f'<to:u{discourse_list[0]+1}>', f'<{discourse_list[2]}>']


def get_aug_value(ut, speaker='<s1>'):
    return ' '.join([speaker] + get_discourse_tokens(ut['discourse']) + ['<' + ut['sentiment'][0] + '>', ut['text']])


def preproc_text(text, utt_set, speakers_set):
    utt_set |= set(re.findall(r'<u\d+>', text))
    speakers_set |= set(re.findall(r'<s\d+>', text))
    if type(text) == str:
        res = re.sub(r'\s+', ' ', str(text))
        if len(res.strip()) == 0:
            return 'unk'
        return res.strip()
    return 'unk'


def get_dialogue_instances(threads, utt_set, speakers_set):
    utter_covered = set() # кажду реплику генерим только один раз
    
    result = []
    for thr in tqdm(threads):
        speakers = {}
        
        for i, ut in enumerate(thr['dialogue']):
            speaker = ut['speaker']
            if speaker not in speakers:
                speakers[speaker] = '<s' + str(len(speakers) + 1) + '>'
            
            if i >= 2:
                if thr['id'] + '_' + ut['id'] not in utter_covered:
                    utter_covered.add(thr['id'] + '_' + ut['id'])
                    
                    utter_dict = {
                        'thread_id': thr['id'],
                        'id': thr['id'] + '_' + ut['id'],
                        'history': f' {sep_token} '.join([speakers[ut_his['speaker']] + ' ' + ut_his['text'] for
                                                         ut_his in thr['dialogue'][:i]] + [speakers[ut['speaker']]]),
                        'history_aug': f' {sep_token} '.join([get_aug_value(ut_his, speakers[ut_his['speaker']]) for
                                                              ut_his in thr['dialogue'][:i]] +
                                                             [' '.join(get_aug_value(ut, speakers[ut['speaker']]).split()[:3])]),
                        'history_amr': f' {sep_token} '.join([ut_his['amr'] for ut_his in thr['dialogue'][:i]]),
                        'history_discourse': f' {sep_token} '.join([' '.join(get_discourse_tokens(ut_his['discourse'])) for
                                                             ut_his in thr['dialogue'][:i]]),
                        'addr_amr': thr['dialogue'][i-1]['amr'],
                        'response': ut['text'],
                        'response_aug': ' '.join(get_aug_value(ut, speakers[ut['speaker']]).split()[3:]),
                        'grounding': thr['grounding'],
                        'title': thr['meta']['title'],
                    }
                    
                    for k in utter_dict:
                        try:
                            utter_dict[k] = preproc_text(utter_dict[k], utt_set, speakers_set)
                        except:
                            utter_dict[k] = 'unk'
                    
                    if len(utter_dict['response']) > 3:
                        result.append(utter_dict)
                    
    return pd.DataFrame(result)

In [6]:
dialogue_df = get_dialogue_instances(threads, utt_set, speakers_set)

100%|██████████| 39803/39803 [00:24<00:00, 1601.39it/s]


In [7]:
dialogue_df.shape

(74069, 11)

In [8]:
# split by thread ids
train_threads, val_threads = train_test_split(list(dialogue_df['thread_id'].unique()), test_size=0.1, random_state=575)
train_df = dialogue_df[dialogue_df.thread_id.isin(train_threads)]
val_df = dialogue_df[dialogue_df.thread_id.isin(val_threads)]

In [9]:
train_df.shape, val_df.shape

((67597, 11), (6472, 11))

In [11]:
train_df.to_csv('bart_input/train_reddit_dial_df.csv', sep='\t', index=False)
val_df.to_csv('bart_input/val_reddit_dial_df.csv', sep='\t', index=False)

In [12]:
disco_rels = ['<negativereaction>',
             '<other>',
             '<appreciation>',
             '<unk>',
             '<elaboration>',
             '<answer>',
             '<question>',
             '<humor>',
             '<announcement>',
             '<agreement>',
             '<disagreement>']

In [13]:
special_tokens_list = sorted(list(utt_set)) +\
                      sorted([u.replace('<', '<to:') for u in list(utt_set)]) +\
                      sorted(list(speakers_set)) +\
                      ['<Negative>', '<Neutral>', '<Positive>'] +\
                      disco_rels + ['<init>']

In [14]:
len(special_tokens_list)

99

In [15]:
special_tokens_dict = {'additional_special_tokens': special_tokens_list,
                         'bos_token': '<s>',
                         'eos_token': '</s>',
                         'unk_token': '<unk>',
                         'sep_token': '</s>',
                         'pad_token': '<pad>',
                         'cls_token': '<s>',
                         'mask_token': '<mask>'}

In [16]:
with open('bart_input/special_tokens_map_reddit_dial.pkl', 'wb') as f:
    pickle.dump(special_tokens_dict, f)

### Calculate source & target lengths

In [20]:
import numpy as np
import torch
from transformers import BartForConditionalGeneration, BartTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name_or_path = "facebook/bart-base"

tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
model =  BartForConditionalGeneration.from_pretrained(model_name_or_path).to(device) # to check load

In [21]:
with open('bart_input/special_tokens_map_reddit_dial.pkl', 'rb') as f:
    special_tokens_dict = pickle.load(f)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [22]:
lens = {}
for column in train_df:
    if 'id' not in column:
        num_tokens_text = []
        for record in tqdm(train_df[column].values):
            num_tokens_text.append(len(tokenizer.encode(record)))
        lens[column] = (np.mean(num_tokens_text), np.median(num_tokens_text), np.quantile(num_tokens_text, 0.95))

  0%|          | 123/68445 [00:00<02:35, 438.55it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1055 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 68445/68445 [01:28<00:00, 773.69it/s] 
100%|██████████| 68445/68445 [01:32<00:00, 740.33it/s] 
100%|██████████| 68445/68445 [02:02<00:00, 557.01it/s] 
100%|██████████| 68445/68445 [00:12<00:00, 5403.48it/s]
100%|██████████| 68445/68445 [00:38<00:00, 1800.54it/s]
100%|██████████| 68445/68445 [00:26<00:00, 2631.42it/s]
100%|██████████| 68445/68445 [00:27<00:00, 2517.32it/s]
100%|██████████| 68445/68445 [00:37<00:00, 1830.03it/s] 
100%|██████████| 68445/68445 [00:19<00:00, 3600.38it/s]


In [23]:
lens

{'history': (193.21510702023522, 123.0, 570.0),
 'history_aug': (210.1300314120827, 139.0, 597.0),
 'history_amr': (374.19500328731095, 258.0, 1016.0),
 'history_discourse': (15.91492439184747, 13.0, 33.0),
 'addr_amr': (99.22618160566878, 67.0, 271.0),
 'response': (42.63747534516765, 25.0, 134.0),
 'response_aug': (44.620147563737305, 27.0, 136.0),
 'grounding': (86.53854920008766, 2.0, 841.0),
 'title': (30.278325662941047, 26.0, 63.0)}