In [1]:
from string import Template
from sentencepiece import SentencePieceProcessor
from logging import getLogger
from typing import List
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


logger = getLogger()


class Tokenizer:
    def __init__(self, model_path: str):
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        logger.info(f"Reloaded SentencePiece model from {model_path}")

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        logger.info(
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
        )
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

In [3]:
tokenizer = Tokenizer('./models/LLaMA_tokenizer.model')

In [77]:
import json

with open('dialog_dataset.json', 'r') as f:
    dataset = json.load(f)
    
new_dataset = []
sessions = ['session_1', 'session_2', 'session_3']
for data in dataset.values():
    for i, session in enumerate(sessions[1:]):
        for phrase_idx in range(1, len(data[session]['dialog']), 2):
            dialog = '\n'.join(data[session]['dialog'][:phrase_idx])
            context = '\n'.join([data[session]['context'] for session in sessions[:i+1]])
            answer = data[session]['dialog'][phrase_idx]
            facts = data[session]['facts']
            new_dataset.append({'dialog': dialog, 'context': context, 'answer': answer, 'facts': facts})

In [68]:
print(context)

* The bot_0 mentions watching the TV show "Game of Thrones" and enjoys walks on the beach.
* bot_1 likes crime books and coffee, specifically caramel cappuccino, and enjoys music, particularly pop music.
* bot_1 mentions that they moved away from their parents but still live in the same city, and that they have parties often.
1. The bot_0 enjoys going to the beach.
2. The bot_1 sells auto parts and enjoys hard rock and industrial music. They mention that they have never been outside of the US except for Canada, but are planning to go to Japan with their wife.
3. The bot_0 talks about going to Thailand to visit the beach and go to an elephant sanctuary for charity work, and mentions that they used to go to Disney a lot as a child.


In [24]:
dialog = '\n'.join(dataset['train:ordered_1234']['session_2']['dialog'][:1])
answer = dataset['train:ordered_1234']['session_2']['dialog'][1]
facts = '\n\n'.join([f'Facts about {person}:\n' + '\n'.join(fact) 
                     for person, fact in dataset['train:ordered_1234']['session_1']['facts'].items()])
context = dataset['train:ordered_1234']['session_1']['context']

In [25]:
tokenized_facts = tokenizer.encode(facts, bos=False, eos=False)
tokenized_context = tokenizer.encode(context, bos=False, eos=False)
tokenized_dialog = tokenizer.encode(dialog, bos=False, eos=False)
tokenized_answer = tokenizer.encode(answer, bos=False, eos=False)

In [36]:
input_text = torch.tensor([tokenizer.bos_id] + tokenized_facts + tokenized_context + tokenized_dialog + tokenized_answer)
target = torch.tensor([tokenizer.pad_id]*len(tokenized_facts) + [tokenizer.pad_id]*len(tokenized_context) + \
                      [tokenizer.pad_id]*len(tokenized_dialog) + tokenized_answer + [tokenizer.eos_id])

In [78]:
from torch.utils.data import Dataset, DataLoader

class ClassifierDataset(Dataset):

    def __init__(self, data):
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        sample = self.data[index]

        dialog = 'Dialog:\n' + sample['dialog']
        context = 'Context:\n' + sample['context']
        answer = sample['answer']
        facts = '\n\n'.join([f'Facts about {person}:\n' + '\n'.join(facts) 
                             for person, facts in sample['facts'].items()])

        return context, facts, dialog, answer

In [123]:
class Collator:

    def __init__(self, tokenizer, max_length=2048):

        self.tokenizer = tokenizer

        self.max_length = max_length

    def __call__(self, batch):

        inputs = []
        targets = []

        max_len = 0
        for context, facts, dialog, answer in batch:
            tokenized_context = tokenizer.encode(context, bos=False, eos=False)
            tokenized_facts = tokenizer.encode(facts, bos=False, eos=False)
            tokenized_dialog = tokenizer.encode(dialog, bos=False, eos=False)
            tokenized_answer = tokenizer.encode(answer, bos=False, eos=False)
            
            input_tokens = [self.tokenizer.bos_id] + tokenized_context + tokenized_facts + tokenized_dialog + tokenized_answer
            target_tokens = [0]*(len(input_tokens) - 1 - len(tokenized_answer)) + tokenized_answer + [self.tokenizer.eos_id]
            
            assert len(input_tokens) == len(target_tokens)
            
            inputs.append(input_tokens)
            targets.append(target_tokens)
            
            if len(inputs[-1]) > max_len:
                max_len = len(input_tokens)

        inputs = [tokens + [0]*(max_len - len(tokens)) for tokens in inputs]
        targets = [tokens + [0]*(max_len - len(tokens)) for tokens in targets]

        return torch.tensor(inputs), torch.LongTensor(targets)

In [124]:
from random import sample

BATCH_SIZE = 8

train_dataset = ClassifierDataset(data=new_dataset[:int(len(new_dataset)*.8)])
valid_dataset = ClassifierDataset(data=new_dataset[int(len(new_dataset)*.8):])

collator = Collator(tokenizer=tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collator, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collator, shuffle=False)

In [125]:
for x, y in train_loader:
    break

In [127]:
tokenizer.decode(y[0,].tolist())

' ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 

In [129]:
len(new_dataset)

677