In [None]:
%run data_pipelines.ipynb 

In [None]:
from transformers import AutoTokenizer
from collections import defaultdict
from datasets import Dataset, DatasetDict
import tqdm
import numpy as np
import random

In [None]:
class Document:
    def __init__(self, chat_name, messages):
        self.chat_name = chat_name
        self.messages = messages

    def to_text(self):
        result = "<<SYS>>Write a realistic text message chat. Avoid repetition.<</SYS>>\n"
        
        participants = {msg.sender_name for msg in self.messages}
        for msg in self.messages:
            participants.update(receiver for receiver in msg.receivers)

        if len(participants) > 2:
            result += f"[INST]Write a chat in the group '{self.chat_name}' between {', '.join(participants)}[/INST]\n"
        else:
            participants.remove("Tiến Dũng Nguyễn")
            single_participant = participants.pop()
            result += f"[INST]Write a chat between Tiến Dũng Nguyễn and {single_participant}[/INST]\n"
        
        result += " ".join(f"### {message.sender_name}: {message.content}" for message in self.messages)
        return result

    def token_len(self):
        return len(tokenizer.encode(self.to_text()))

In [None]:
from huggingface_hub import login
login(token= os.getenv('HF_TOKEN'))

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
MAX_LENGTH = 200  

def clump_messages():
    documents = []
    for chat_name, message_list in tqdm.tqdm(all_messages.items()):
        pointer = 0
        while pointer < len(message_list):
            size = 1
            current_document = Document(chat_name, [message_list[pointer]])
            
            while pointer + size < len(message_list):
                next_message = message_list[pointer + size]
                temp_document = Document(chat_name, current_document.messages + [next_message])
                
                if temp_document.token_len() >= MAX_LENGTH:
                    break
                
                current_document = temp_document
                size += 1
            
            documents.append(current_document)
            pointer += size

    return documents

documents = clump_messages()
print(f"{len(documents):,} messages")

In [None]:
data = [doc.to_text() for doc in documents]
lengths = [doc.token_len() for doc in documents]
counts = sum(d.count('###') for d in data)

print(f'There are {counts:,} messages; average {counts/len(documents):.2} messages in each of {len(documents):,} documents')

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xlabel('Number of tokens in a document')
ax.set_ylabel('Count of documents')
ax.get_yaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda y, p: format(int(y), ',')))
l2 = [min(MAX_LENGTH+100,l) for l in lengths]
_ = ax.hist(l2, bins=range(0,MAX_LENGTH+50,10), color='darkorange', rwidth=0.5)

In [None]:
random.seed(42)
random.shuffle(data)

In [None]:
split = int(0.95 * len(data))
train, test = data[:split], data[split:]

In [None]:
train_dataset = Dataset.from_dict({'text': train})
test_dataset = Dataset.from_dict({'text': test})
dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})

In [None]:
dataset.push_to_hub("simme", private=True)