In [1]:
import copy
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import BlenderbotSmallTokenizer

In [2]:
tokenizer = BlenderbotSmallTokenizer.from_pretrained('facebook/blenderbot-90M')

In [3]:
def str_to_msg(txt, ignore_fields=''):
    """
    Convert formatted string to ParlAI message dict.

    :param txt:
        formatted string to convert. String format is tab-separated fields,
        with colon separating field name and contents.
    :param ignore_fields:
        (default '') comma-separated field names to not
        include in the msg dict even if they're in the string.
    """

    def tostr(txt):
        txt = str(txt)
        txt = txt.replace('\\t', '\t')
        txt = txt.replace('\\n', '\n')
        txt = txt.replace('__PIPE__', '|')
        return txt

    def tolist(txt):
        vals = txt.split('|')
        for v in vals:
            v = tostr(v)
        return vals

    def convert(key, value):
        if key == 'text' or key == 'id':
            return tostr(value)
        elif (
            key == 'label_candidates'
            or key == 'labels'
            or key == 'eval_labels'
            or key == 'text_candidates'
        ):
            return tolist(value)
        elif key == 'episode_done':
            return bool(value)
        else:
            return tostr(value)

    if txt == '' or txt is None:
        return None

    msg = {}
    for t in txt.split('\t'):
        ind = t.find(':')
        key = t[:ind]
        value = t[ind + 1 :]
        if key not in ignore_fields.split(','):
            msg[key] = convert(key, value)
    msg['episode_done'] = msg.get('episode_done', False)
    return msg

In [4]:
class ParlaiFormatDataset(torch.utils.data.Dataset):
    def __init__(self,
                 parlai_format_data_path,
                 tokenizer,
                 text_truncate,
                 label_truncate=None):
        if label_truncate is None:
            label_truncate = text_truncate
        
        with open(parlai_format_data_path) as f:
            turns = [str_to_msg(l.strip()) for l in f.readlines()]
            
        self.data = []
        
        history = []
        label_tokens = []
        for turn in tqdm(turns):
            history += label_tokens
            for text in turn['text'].split('\n'):
                text_tokens = tokenizer(text,
                                        padding=False,
                                        return_token_type_ids=False,
                                        return_attention_mask=False)['input_ids']
                history += text_tokens
                history += [tokenizer.eos_token_id,]

            label_tokens = tokenizer(turn['labels'][0],
                                     padding=False,
                                     return_token_type_ids=False,
                                     return_attention_mask=False)['input_ids']
            label_tokens += [tokenizer.eos_token_id,]

            self.data.append((torch.tensor(history[-text_truncate:]), torch.tensor(label_tokens[-label_truncate:])))

            if turn['episode_done']:
                history = []
                label_tokens = []
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

In [5]:
class Collator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, batch):
        x, y = zip(*batch)
        x = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        y = nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        return x, y

In [15]:
dataset = ParlaiFormatDataset('data/train.txt',
                              tokenizer,
                              text_truncate=128)

100%|██████████| 27018/27018 [00:19<00:00, 1418.84it/s]


In [16]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True,
                                         collate_fn=Collator(tokenizer))

In [18]:
sum([data[0].shape[1] == 128 for data in dataloader])

6652

In [19]:
len(dataloader)

6755

In [20]:
for data in dataloader:
    break

In [23]:
data[1].shape

torch.Size([4, 24])