In [1]:
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
import json
import tqdm

from torchfly.transformers import UnifiedTokenizer
from dialog_utils import DialogFragmentSampler

In [2]:
def process_dialog(dialog):
    new_dialog = []
    for turn in dialog:
        token_ids = tokenizer.encode(turn[0] + ":" + turn[1] + "\n\n\n")
        new_dialog.append(token_ids)
        
    return new_dialog

In [3]:
tokenizer = UnifiedTokenizer()

with open("../DialogCorpus/all_dialogs.json") as f:
    all_dialogs = json.load(f)

In [4]:
new_all_dialogs = {}

for key, value in tqdm.tqdm(all_dialogs.items()):
    processed_dialog = process_dialog(value)
    if (len(processed_dialog[0]) + len(processed_dialog[1])) > 512:
        continue

    new_all_dialogs[key] = {}
    new_all_dialogs[key]["text"] = value
    new_all_dialogs[key]["token_ids"] = processed_dialog

100%|██████████| 146255/146255 [02:48<00:00, 867.61it/s] 


In [6]:
# save the file
with open("dialog_corpus.json", "w") as f:
    json.dump(new_all_dialogs, f)

In [7]:
with open("dialog_corpus.json", "r") as f:
    new_all_dialogs = json.load(f)

In [18]:
class DialogFragmentSampler:
    def __init__(self, max_tokens=1024, max_turns=20):
        """Sample dialog fragments from a dialog
        """
        self.max_num_tokens = max_tokens - 1
        self.max_num_turns = max_turns

    def __call__(self, dialog):
        """dialog is a dict which has key "token_ids"
        """
        dialog_fragment = {}

        lengths = np.array([len(item) for item in dialog['token_ids']])

        # if the entire dialog is smaller than the max len
        if lengths.sum() <= self.max_num_tokens:
            return dialog

        cumsum_len = lengths.cumsum()
        reverse_cumsum_len = cumsum_len[::-1]

        # based on the reverse cumsum, we can have a range to select from
        start_turns = np.arange(len(reverse_cumsum_len)
                               )[reverse_cumsum_len > self.max_num_tokens]
        # remove odd numbers
        start_turns = [idx for idx in start_turns if idx % 2 == 0]
        # randomly choose one
        try:
            random_start_turn = random.choice(start_turns)
        except:
            breakpoint()
        cumsum_len = np.concatenate([[0], cumsum_len], axis=0)
        new_cumsum_len = cumsum_len - cumsum_len[random_start_turn]

        # find the maximum end turn (only odd turn)
        for i in reversed(range(len(new_cumsum_len))):
            if i % 2 == 1 and new_cumsum_len[i] < self.max_num_tokens:
                random_end_turn = i
                break

        random_end_turn = min(
            random_end_turn, random_start_turn + self.max_num_turns - 1
        )

        dialog_fragment["token_ids"] = dialog['token_ids'][random_start_turn:
                                                           random_end_turn]

        assert sum(
            [len(item) for item in dialog_fragment["token_ids"]]
        ) < self.max_num_tokens

        return dialog_fragment

In [19]:
class DialogCorpusDataset(Dataset):
    def __init__(self, data, tokenizer):
        # only interested in the values
        self.data = list(data.values())
        self.tokenizer = tokenizer
        self.tokenizer.max_len = 4096
        self.turn_ending = tokenizer.encode("\n\n\n")
        self.sampler = DialogFragmentSampler(max_tokens=800)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get data
        sample = self.data[index]
        dialog = sample
        dialog_fragment = self.sampler(dialog)
        return dialog_fragment["token_ids"]

    def collate(self, batch):
        # only one item in the batch
        batch = batch[0]
        total_len = sum([len(item) for item in batch])
        # make random positions
        start_position = random.randint(0, 1024 - total_len)
        
        position_ids = []
        for item in batch:
            pos = torch.arange(start_position, start_position + len(item)).unsqueeze(0)
            position_ids.append(pos)
            start_position = start_position + len(item)
        
        batch = [torch.LongTensor([item]) for item in batch]
        
        return batch

In [20]:
train_data = new_all_dialogs

In [21]:
tokenizer = UnifiedTokenizer()

train_dataset = DialogCorpusDataset(train_data, tokenizer)
train_sampler = RandomSampler(train_dataset)

train_dataloader = DataLoader(
    dataset=train_dataset,
    sampler=train_sampler,
    batch_size=1,
    collate_fn=train_dataset.collate
)

In [22]:
lengths = []
for batch in tqdm.tqdm(train_dataloader):
    max_len = sum([item.shape[1] for item in batch])
    lengths.append(max_len)



  0%|          | 0/145043 [00:00<?, ?it/s][A[A

  0%|          | 474/145043 [00:00<00:30, 4734.60it/s][A[A

  1%|          | 1047/145043 [00:00<00:28, 4994.89it/s][A[A

  1%|          | 1643/145043 [00:00<00:27, 5249.91it/s][A[A

  2%|▏         | 2421/145043 [00:00<00:24, 5816.76it/s][A[A

  2%|▏         | 3246/145043 [00:00<00:22, 6381.05it/s][A[A

  3%|▎         | 4110/145043 [00:00<00:20, 6922.89it/s][A[A

  3%|▎         | 4880/145043 [00:00<00:19, 7137.94it/s][A[A

  4%|▍         | 5729/145043 [00:00<00:18, 7495.30it/s][A[A

  5%|▍         | 6560/145043 [00:00<00:17, 7720.60it/s][A[A

  5%|▌         | 7371/145043 [00:01<00:17, 7833.37it/s][A[A

  6%|▌         | 8188/145043 [00:01<00:17, 7931.10it/s][A[A

  6%|▌         | 8982/145043 [00:01<00:17, 7848.93it/s][A[A

  7%|▋         | 9777/145043 [00:01<00:17, 7878.22it/s][A[A

  7%|▋         | 10651/145043 [00:01<00:16, 8117.35it/s][A[A

  8%|▊         | 11465/145043 [00:01<00:16, 8051.77it/s][A[A

  

 57%|█████▋    | 83003/145043 [00:13<00:11, 5541.20it/s][A[A

 58%|█████▊    | 83575/145043 [00:13<00:10, 5593.50it/s][A[A

 58%|█████▊    | 84155/145043 [00:13<00:10, 5653.17it/s][A[A

 58%|█████▊    | 84726/145043 [00:13<00:10, 5666.52it/s][A[A

 59%|█████▉    | 85307/145043 [00:13<00:10, 5707.46it/s][A[A

 59%|█████▉    | 85889/145043 [00:13<00:10, 5738.46it/s][A[A

 60%|█████▉    | 86464/145043 [00:13<00:10, 5704.09it/s][A[A

 60%|██████    | 87056/145043 [00:13<00:10, 5766.11it/s][A[A

 60%|██████    | 87633/145043 [00:13<00:10, 5704.23it/s][A[A

 61%|██████    | 88211/145043 [00:13<00:09, 5723.69it/s][A[A

 61%|██████    | 88784/145043 [00:14<00:09, 5651.32it/s][A[A

 62%|██████▏   | 89350/145043 [00:14<00:09, 5610.82it/s][A[A

 62%|██████▏   | 89916/145043 [00:14<00:09, 5625.43it/s][A[A

 62%|██████▏   | 90499/145043 [00:14<00:09, 5684.34it/s][A[A

 63%|██████▎   | 91068/145043 [00:14<00:09, 5656.87it/s][A[A

 63%|██████▎   | 91634/145043 [00:14<00:

In [105]:
max(lengths)

1017