In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import dotenv

from transformers import AutoTokenizer

from llm_ol.experiments.llm.templates import (
    MISTRAL_TEMPLATE,
    PROMPT_TEMPLATE,
    RESPONSE_TEMPLATE,
)

dotenv.load_dotenv()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    "alpindale/Mistral-7B-v0.2-hf", add_prefix_space=False
)

In [None]:
title = "TITLE 123"
abstract = "This is an abstract. It is a very good abstract. It is the best abstract."

paths = [
    ["Hello", "baked apples", "world!"],
    ["Earth sciences", "Geology", "Geophysics"],
]


def to_tokens(text: str):
    return tokenizer.encode(text, add_special_tokens=False)


prompt = PROMPT_TEMPLATE.render(title=title, abstract=abstract)
response = RESPONSE_TEMPLATE.render(paths=paths)
messages = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": response},
]
full = MISTRAL_TEMPLATE.render(
    messages=messages, bos_token=tokenizer.bos_token, eos_token=tokenizer.eos_token
)
print(full)

parts = [
    (f"{tokenizer.bos_token}", 0),
    (f"[INST] Title: {title}\n{abstract} [/INST]", 0),
]
for path in paths:
    for i, item in enumerate(path):
        parts.append((item, 1))
        if i < len(path) - 1:
            parts.append(("->", 0))
    parts.append(("\n", 2))
parts.append((f"{tokenizer.eos_token}", 2))
print(parts)

tokens = []
weights = []
for part, w in parts:
    part_tokens = to_tokens(part)
    tokens += part_tokens
    weights += [w] * len(part_tokens)
print(tokenizer.decode(tokens))
print(tokens)
print(weights)