In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import dotenv
import random

from transformers import AutoTokenizer
from tokenizers.tools import EncodingVisualizer

from llm_ol.experiments.llm.templates import (
    MISTRAL_TEMPLATE,
    PROMPT_TEMPLATE,
    RESPONSE_TEMPLATE,
)

dotenv.load_dotenv()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    "alpindale/Mistral-7B-v0.2-hf", add_prefix_space=False
)

In [None]:
with open("out/experiments/llm/v2/train_dataset.jsonl") as f:
    items = [json.loads(line) for line in f]

In [None]:
item = random.choice(items)
title, abstract, paths = item["title"], item["abstract"], item["paths"]


def to_tokens(text: str):
    return tokenizer.encode(text, add_special_tokens=False)


prompt = PROMPT_TEMPLATE.render(title=title, abstract=abstract)
response = RESPONSE_TEMPLATE.render(paths=paths)
messages = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": response},
]
full = MISTRAL_TEMPLATE.render(
    messages=messages, bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token
)
print(full)
full_tokens = to_tokens(full)

inst_end = [733, 28748, 16289, 28793]
arrow = 3193
linebreak = 13


def find_index(list_, sublist):
    for i in range(len(list_) - len(sublist) + 1):
        if list_[i : i + len(sublist)] == sublist:
            return i
    raise ValueError(f"Sublist {sublist} not found in list")


resp_start_idx = find_index(full_tokens, inst_end) + len(inst_end)
# resp_parts = [[[]]]
# for token in full_tokens[resp_start_idx:]:
#     if token == linebreak:
#         resp_parts.append([[]])
#     elif token == arrow:
#         resp_parts[-1].append([])
#     else:
#         resp_parts[-1][-1].append(token)

weights = [0] * resp_start_idx
word = []
for token in full_tokens[resp_start_idx:]:
    if token == linebreak or token == arrow:
        # print(repr(tokenizer.decode(word)))
        weights += [1] * len(word) + [2]
        word = []
    elif token == tokenizer.eos_token_id:
        weights += [4]
    # elif token == arrow:
    #     print(repr(tokenizer.decode(word)))
    #     word = []
    else:
        word.append(token)


tokens_per_line = 20

for i in range(0, len(full_tokens), tokens_per_line):
    tokens = [
        tokenizer.convert_ids_to_tokens(ids)
        for ids in full_tokens[i : i + tokens_per_line]
    ]
    weight = weights[i : i + tokens_per_line]
    print(list(zip(tokens, weight)))

# for token, weight in zip(full_tokens, weights):
#     print(repr(tokenizer.decode(token)), weight)

# for path in resp_parts:
#     for words in path:
#         print(repr(tokenizer.decode(words)))
#     print()

# parts = [
#     (f"{tokenizer.bos_token}", 0),
#     (f"[INST] Title: {title}\n{abstract} [/INST]", 0),
# ]
# for path in paths:
#     for i, item in enumerate(path):
#         parts.append((item, 1))
#         if i < len(path) - 1:
#             parts.append(("->", 0))
#     parts.append(("\n", 2))
# parts.append((f"{tokenizer.eos_token}", 2))
# print(parts)

# tokens = []
# weights = []
# for part, w in parts:
#     part_tokens = to_tokens(part)
#     tokens += part_tokens
#     weights += [w] * len(part_tokens)
# print(tokenizer.decode(tokens))
# print(tokens)
# print(weights)

In [None]:
tokens = tokenizer.tokenize(full, add_special_tokens=False)
tokens_per_line = 20

for i in range(0, len(tokens), tokens_per_line):
    toks = tokens[i : i + tokens_per_line]
    ids = tokenizer.convert_tokens_to_ids(toks)
    print(list(zip(toks, ids)))
    print()

In [None]:
# tokenizer.encode("->", add_special_tokens=False)
tokenizer.encode("\n", add_special_tokens=False)