In [None]:
import pickle
import pandas as pd

In [None]:
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
ROOT = "/gpfs/space/projects/stud_ml_22/NLP"
PATH_TO_CONVERTED_TOKENIZER = os.path.join(ROOT, "llama/7B_converted/")

In [None]:
with open(os.path.join(ROOT, "data/course_questions.pkl"), 'rb') as f:
    data = pickle.load(f, encoding='utf8')
data

In [None]:
data = data.reset_index()

In [None]:
val_data = data.sample(frac=0.05, random_state=42)
val_data.head()

In [None]:
len(train_data) / 8

In [None]:
train_data = data.drop(val_data.index)
for i,r in train_data.iterrows():
    print(r['question'])
    print(r['answer'])
    
    if i > 200:
        break

## Test dataset preparation from llama finetune tutorial

In [None]:
import torch
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

In [None]:
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), dataset.iterrows()), total=nb_examples):
        text = prepare_sample_text(example[1])
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens


def prepare_sample_text(example):
    print(example)
    """Prepare the text from a sample of the dataset."""
    text = f"Question: {example['question']}\n\nAnswer: {example['answer']}"
    return text


class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            infinite (bool): If True the iterator is reset after dataset reaches end else stops.
            seq_length (int): Length of token sequences to return.
            num_of_sequences (int): Number of token sequences to keep in buffer.
            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
    """

    def __init__(
        self,
        tokenizer,
        dataset,
        infinite=False,
        seq_length=1024,
        num_of_sequences=1024,
        chars_per_token=3.6,
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
        self.dataset = dataset
        self.seq_length = seq_length
        self.infinite = infinite
        self.current_size = 0
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences

    def __iter__(self):
        iterator = self.dataset.iterrows()
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(prepare_sample_text(next(iterator)[1]))
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                all_token_ids.extend(tokenized_input + [self.concat_token_id])
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    self.current_size += 1
                    yield {
                        "input_ids": torch.LongTensor(input_ids),
                        "labels": torch.LongTensor(input_ids),
                    }


In [None]:
tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

In [None]:
dataset = ConstantLengthDataset(tokenizer, data[:100], infinite=False)

In [None]:
for i in dataset:
    print(i)

In [None]:
chars_token_ratio(data[:100], tokenizer)