In [1]:
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
import json
import tqdm

from torchfly.transformers import UnifiedTokenizer
from dialog_utils import DialogFragmentSampler

In [2]:
def process_dialog(dialog):
    new_dialog = []
    for turn in dialog:
        token_ids = tokenizer.encode(turn[0] + ":" + turn[1] + "\n\n\n")
        new_dialog.append(token_ids)
        
    return new_dialog

In [3]:
tokenizer = UnifiedTokenizer()

with open("../DialogCorpus/all_dialogs.json") as f:
    all_dialogs = json.load(f)

In [4]:
new_all_dialogs = {}

for key, value in tqdm.tqdm(all_dialogs.items()):
    processed_dialog = process_dialog(value)
    lengths = [len(item) for item in processed_dialog]
    if max(lengths) > 256:
        continue

    new_all_dialogs[key] = {}
    new_all_dialogs[key]["text"] = value
    new_all_dialogs[key]["token_ids"] = processed_dialog

100%|██████████| 146255/146255 [03:00<00:00, 812.53it/s]


In [6]:
# save the file
with open("dialog_corpus.json", "w") as f:
    json.dump(new_all_dialogs, f)

In [7]:
with open("dialog_corpus.json", "r") as f:
    new_all_dialogs = json.load(f)

In [8]:
class DialogFragmentSampler:
    def __init__(self, max_tokens=1024, max_turns=20):
        """Sample dialog fragments from a dialog
        """
        self.max_num_tokens = max_tokens - 1
        self.max_num_turns = max_turns

    def __call__(self, dialog):
        """dialog is a dict which has key "token_ids"
        """
        dialog_fragment = {}

        lengths = np.array([len(item) for item in dialog['token_ids']])

        # if the entire dialog is smaller than the max len
        if lengths.sum() <= self.max_num_tokens:
            return dialog

        cumsum_len = lengths.cumsum()
        reverse_cumsum_len = cumsum_len[::-1]

        # based on the reverse cumsum, we can have a range to select from
        start_turns = np.arange(len(reverse_cumsum_len)
                               )[reverse_cumsum_len > self.max_num_tokens]
        # remove odd numbers
        start_turns = [idx for idx in start_turns if idx % 2 == 0]
        # randomly choose one
        try:
            random_start_turn = random.choice(start_turns)
        except:
            breakpoint()
        cumsum_len = np.concatenate([[0], cumsum_len], axis=0)
        new_cumsum_len = cumsum_len - cumsum_len[random_start_turn]

        # find the maximum end turn (only odd turn)
        for i in reversed(range(len(new_cumsum_len))):
            if i % 2 == 1 and new_cumsum_len[i] < self.max_num_tokens:
                random_end_turn = i
                break

        random_end_turn = min(
            random_end_turn, random_start_turn + self.max_num_turns - 1
        )

        dialog_fragment["token_ids"] = dialog['token_ids'][random_start_turn:
                                                           random_end_turn]

        if sum(
            [len(item) for item in dialog_fragment["token_ids"]]
        ) == 0:
            breakpoint()
        
        
        
        return dialog_fragment

In [9]:
class DialogCorpusDataset(Dataset):
    def __init__(self, data, tokenizer):
        # only interested in the values
        self.data = list(data.values())
        self.tokenizer = tokenizer
        self.tokenizer.max_len = 4096
        self.turn_ending = tokenizer.encode("\n\n\n")
        self.sampler = DialogFragmentSampler(max_tokens=800)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get data
        sample = self.data[index]
        dialog = sample
        dialog_fragment = self.sampler(dialog)
        return dialog_fragment["token_ids"]

    def collate(self, batch):
        # only one item in the batch
        batch = batch[0]
        total_len = sum([len(item) for item in batch])
        # make random positions
        start_position = random.randint(0, 1024 - total_len)
        
        position_ids = []
        for item in batch:
            pos = torch.arange(start_position, start_position + len(item)).unsqueeze(0)
            position_ids.append(pos)
            start_position = start_position + len(item)
        
        batch = [torch.LongTensor([item]) for item in batch]
        
        return batch

In [10]:
train_data = new_all_dialogs

In [11]:
tokenizer = UnifiedTokenizer()

train_dataset = DialogCorpusDataset(train_data, tokenizer)
train_sampler = RandomSampler(train_dataset)

train_dataloader = DataLoader(
    dataset=train_dataset,
    sampler=train_sampler,
    batch_size=1,
    collate_fn=train_dataset.collate
)

In [28]:
lengths = []
for batch in tqdm.tqdm(train_dataloader):
    max_len = sum([item.shape[1] for item in batch])

    if max_len == 15:
        break
    lengths.append(max_len)


  0%|          | 0/142298 [00:00<?, ?it/s][A
  0%|          | 481/142298 [00:00<00:29, 4805.95it/s][A
  1%|          | 1059/142298 [00:00<00:27, 5060.24it/s][A
  1%|          | 1684/142298 [00:00<00:26, 5364.64it/s][A
  2%|▏         | 2278/142298 [00:00<00:25, 5524.35it/s][A
  2%|▏         | 2853/142298 [00:00<00:24, 5589.67it/s][A
  2%|▏         | 3454/142298 [00:00<00:24, 5709.29it/s][A
  3%|▎         | 4056/142298 [00:00<00:23, 5792.79it/s][A
  3%|▎         | 4658/142298 [00:00<00:23, 5856.43it/s][A
  4%|▎         | 5252/142298 [00:00<00:23, 5879.27it/s][A
  4%|▍         | 5859/142298 [00:01<00:23, 5931.43it/s][A
  5%|▍         | 6471/142298 [00:01<00:22, 5986.61it/s][A
  5%|▍         | 7064/142298 [00:01<00:22, 5968.72it/s][A
  5%|▌         | 7658/142298 [00:01<00:22, 5959.28it/s][A
  6%|▌         | 8250/142298 [00:01<00:22, 5905.89it/s][A
  6%|▌         | 8838/142298 [00:01<00:22, 5819.15it/s][A
  7%|▋         | 9435/142298 [00:01<00:22, 5861.25it/s][A
  7%|▋    

KeyboardInterrupt: 

In [36]:
tokenizer.decode(batch[6][0].tolist())

'user:Pretty soon, but no earlier than Thursday September 8\n\n\n'