In [1]:
from oasst_data import ExportMessageNode, read_dataset_message_trees, read_message_trees, visit_threads_depth_first


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hf_dataset_name = "OpenAssistant/oasst1"
tree_iter = read_dataset_message_trees(hf_dataset_name, split="train+validation")

In [3]:
from torch.utils.data import Dataset, random_split
from torch import Generator

manual_seed=90
generator = Generator()
generator.manual_seed(manual_seed)

class ListDataset(Dataset):
    def __init__(self, data: list):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]
    

lang_codes = ["en"]
val_split = 0.2
lang = "en",
top_k = None
        
def get_data(mode="sft"):
    threads_per_tree = []
    for tree in tree_iter:
        if tree.tree_state != "ready_for_export" or not tree.prompt.review_result or tree.prompt.lang not in lang_codes:
            continue

        if mode in ("sft", "rm"):
            if tree.tree_state != "ready_for_export":
                continue
        elif mode == "rl":
            if tree.tree_state not in ("ready_for_export", "prompt_lottery_waiting"):
                continue

        # extract all threads up to last assistant reply
        threads: list[list[ExportMessageNode]] = []
            
        def thread_filter(thread: list[ExportMessageNode]) -> bool:
            if any(m.deleted or m.synthetic for m in thread):
                return False

            if top_k is not None:
                for i, m in enumerate(thread):
                    if m.role == "assistant":
                        if m.rank is None:
                            if i > 0 and len(thread[i - 1].replies) > 1:
                                return False
                        elif m.rank >= top_k:
                            return False
            return True


        def leaf_filter(thread: list[ExportMessageNode]) -> bool:
                if mode == "sft":
                    # in SFT mode `not thread[-1].replies` finds nodes without children (leaves).
                    # We are interested in those which are role='assistant' but some trees don't end on assistant nodes
                    # but have prompter leaves .. we want to use those trees too .. e.g. remove the last prompter message(s)
                    # so that they end with assistant. The `thread[-2].replies[0] == thread[-1]` check makes sure that only
                    # the FIRST prompter reply is added .. e.g. the parent does not appear multiple times and we can use
                    # pop() to remove superfluous prompter leaf node later.
                    return (
                        len(thread) > 1
                        and not thread[-1].replies
                        and (thread[-1].role == "assistant" or thread[-2].replies[0] == thread[-1])
                        and thread_filter(thread)
                    )
                elif mode == "rm":
                    # for reward models we use thread-fragments ending on prompter messages as prefix and
                    # their (ranked) replies as possible continuations.
                    if thread[-1].replies is None:
                        return False
                    return (
                        thread[-1].role == "prompter"
                        and len([r for r in thread[-1].replies if r.rank is not None]) > 1
                        and thread_filter(thread)
                    )
                elif mode == "rl":
                    # during rl we are interested in all possible prefixes ending in prompter messages
                    return thread[-1].role == "prompter" and not any(m.deleted or m.synthetic for m in thread)

                raise RuntimeError()

        def process_thread(thread: list[ExportMessageNode]):
                if mode == "sft":
                    # ensure roles are strictly alternating between prompter and assistant
                    assert all(m.role == "prompter" for m in thread[0::2]) and all(m.role == "assistant" for m in thread[1::2])
                    conversation: list[[]] = [[m.text,m.role,m.get_label_value("quality"),m.get_label_value("humor"),
                                             m.get_label_value("creativity")]
                                            for m in thread
                                           ]
                    return conversation
                elif mode == "rm":
                    prefix = [m.text for m in thread]
                    replies = [r for r in thread[-1].replies if r.role == "assistant" and r.rank is not None]
                    replies = sorted(replies, key=lambda r: r.rank)
                    replies = [r.text for r in replies]
                    return (prefix, replies)
                elif mode == "rl":
                    return ([m.text for m in thread],)

                raise RuntimeError()

        visit_threads_depth_first(tree.prompt, visitor=threads.append, predicate=leaf_filter)
        if mode == "sft":
            for t in threads:
                if t[-1].role == "prompter":
                    t.pop()
        threads_per_tree.append(threads)
        # split on tree basis, messages from same tree must not end up in different splits
        trees = ListDataset(threads_per_tree)
        splits = random_split(trees, lengths=[1.0 - val_split, val_split], generator=generator)

        def flatten(ds: ListDataset) -> ListDataset:
            return ListDataset([process_thread(thread) for tree_threads in ds for thread in tree_threads])

        train = flatten(splits[0])
        val = flatten(splits[1])
        return train,val,threads_per_tree


In [4]:
train_sft,_,tsft = get_data(mode="sft")
train_rm,_,trm = get_data(mode="rm")



In [5]:
len(tsft)

1

In [None]:
for t in threads_per_tree[0]:
    print([[m.text[:100],m.role,m.get_label_value("quality"),
      m.get_label_value("humor"),
      m.get_label_value("creativity")] for m in t])
    print('======'*20)

In [None]:
from datasets import load_dataset
hf_dataset_name = "OpenAssistant/oasst1"
dataset = load_dataset(hf_dataset_name,split="train+validation")

In [None]:
len(dataset)
def convert_hf_message(row: dict) -> None:
    emojis = row.get("emojis")
    if emojis:
        row["emojis"] = dict(zip(emojis["name"], emojis["count"]))
    labels = row.get("labels")
    if labels:
        row["labels"] = {
            name: {"value": value, "count": count}
            for name, value, count in zip(labels["name"], labels["value"], labels["count"])
        }

In [None]:
tree_dict: dict = None
parents: list = None
trees = []
for row in dataset:
    convert_hf_message(row)
    if row["parent_id"] is None:
        if tree_dict:
            trees.append([row['role'],row['text']])

        tree_dict = {
            "message_tree_id": row["message_id"],
            "tree_state": row["tree_state"],
            "prompt": row,
        }
        parents = []
    else:
        while parents[-1]["message_id"] != row["parent_id"]:
            parents.pop()
        parent = parents[-1]
        if "replies" not in parent:
            parent["replies"] = []
        parent["replies"].append(row)

    row.pop("message_tree_id", None)
    row.pop("tree_state", None)
    parents.append(row)
    if tree_dict:
        trees.append([row['role'],row['text']])

In [None]:
len(tree_dict)

In [None]:
trees[1]

In [None]:
import yaml
from pathlib import Path

c = {}

for config_file in Path('.').glob("**/*.yaml"):
        no_conf = False
        with config_file.open("r") as f:
            c.update(yaml.safe_load(f)["defaults"])

In [None]:
c["datasets"]

In [None]:
for k in c["datasets"]:
    if type(k)==dict:
        k = list(k.keys())[0]
    if k in SUMMARIZATION_DATASETS:
        print(k)

In [None]:
for k in c["datasets"]:
    if type(k)==dict:
        k = list(k.keys())[0]
    if k in QA_DATASETS:
        print(k)

In [None]:
for k in c["datasets"]:
    if type(k)==dict:
        k = list(k.keys())[0]
    if k in INSTRUCTION_DATASETS:
        print(k)
        

In [None]:
SUMMARIZATION_DATASETS = [
    "xsum",
    "cnn_dailymail",
    "samsum",
    "multi_news",
    "scitldr",
    "billsum",
    "debate_sum",
    "tldr_news",
]

In [None]:
DATASET_FORMAT_MAPPING = {
    "squad_v2": {"index_fn": 'index_squad_v2'},
    "ua_squad": {
        "index_fn": 'index_uasquad',
        "name": "FIdo-AI/ua-squad",
        "params": {"field": "data"},
        "no_val": True,
    },
    "trivia_qa_nocontext": {
        "index_fn": 'index_trivia_qa_nocontext',
        "name": "trivia_qa",
        "params": {"name": "rc.nocontext"},
    },
    "trivia_qa_context": {"index_fn": 'index_trivia_qa_context', "name": "trivia_qa", "params": {"name": "rc"}},
    "adversarial_qa": {
        "index_fn": 'index_adversarial_qa',
        "params": {"name": "adversarialQA"},
    },
    "gsm8k_hard": {"index_fn": 'index_gsm_hard', "name": "reasoning-machines/gsm-hard", "no_val": True},
    "gsm8k": {"index_fn": 'index_gsm8k', "params": {"name": "main"}, "validation": "test"},
    "wikihow": {"name": "b-mc2/wikihow_lists", "index_fn": 'index_wikihow', "no_val": True},
    "essay_instruction": {
        "name": "ChristophSchuhmann/essays-with-instructions",
        "index_fn": 'index_essay_instruction',
        "no_val": True,
    },
    "math_qa": {
        "index_fn": 'index_math_qa',
    },
    "reddit_eli5": {"name": "eli5", "index_fn": 'index_eli5', "split_postfix": "_eli5"},
    "reddit_askh": {"name": "eli5", "index_fn": 'index_eli5', "split_postfix": "_askh"},
    "reddit_asks": {"name": "eli5", "index_fn": 'index_eli5', "split_postfix": "_asks"},
}

INSTRUCTION_DATASETS = {
    # Note humaneval_mbpp_codegen_qa returns a code string that we would want to at least wrap in ``` marks`
    "humaneval_mbpp_codegen_qa": "OllieStanley/humaneval-mbpp-codegen-qa",
    # Write unit tests to do task X
    "humaneval_mbpp_testgen_qa": "OllieStanley/humaneval-mbpp-testgen-qa",
    "grade_school_math_instructions": "qwedsacf/grade-school-math-instructions",
    "recipes": "dctanner/oa_recipes",
    "ubuntu_dialogue_qa": "sedthh/ubuntu_dialogue_qa",
    "cmu_wiki_qa": "sedthh/cmu_wiki_qa",
    "youtube_subs_howto100m": "totuta/youtube_subs_howto100M",
    "iapp_wiki_qa_squad": "wannaphong/iapp_wiki_qa_squad_oa",
    "zhihu-kol": "wangrui6/zhihu-kol",
    "minimath": "kentsui/minimath",
    "oa_wiki_qa_bart_10000row": "michaelthwan/oa_wiki_qa_bart_10000row",
    "oa_leet10k": "ehartford/oa_leet10k",
    "poem_instructions": "checkai/instruction-poems",
    "oa_stackexchange": "donfu/oa-stackexchange",
    "tell_a_joke": "mikegarts/oa_tell_a_joke_20000",
    "wizardlm_70k": "ehartford/WizardLM_alpaca_evol_instruct_70k_unfiltered",
    "megacode": "rombodawg/MegaCodeTraining112k",
    "evol_instruct_code": "nickrosh/Evol-Instruct-Code-80k-v1",
}


QA_DATASETS = list(DATASET_FORMAT_MAPPING.keys())


In [6]:
from training_datasets.dataset_utils import load_sft_dataset
from model.training_utils import get_sft_tokenizer
from config import SFT_TRAINING_CONFIG,TOKENIZER_CONFIG

In [7]:
tokenizer, eos_token= get_sft_tokenizer(SFT_TRAINING_CONFIG,TOKENIZER_CONFIG)
train_ds , eval_ds = load_sft_dataset(eos_token)


Downloading (…)lve/main/config.json: 100%|██████████| 507/507 [00:00<00:00, 4.98MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 593/593 [00:00<00:00, 4.33MB/s]
Downloading tokenizer.model: 100%|██████████| 534k/534k [00:01<00:00, 337kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 330/330 [00:00<00:00, 1.10MB/s]


[]


Downloading data: 100%|██████████| 673M/673M [02:24<00:00, 4.67MB/s]
Downloading data files: 100%|██████████| 1/1 [02:24<00:00, 144.11s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 192.60it/s]
Generating train split: 94145 examples [00:02, 31616.62 examples/s]


Size of vicuna training data: 47038
Size of vicuna training data: 11760


Downloading data: 100%|██████████| 7.87M/7.87M [00:01<00:00, 4.20MB/s]
Downloading data files: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 320.81it/s]
Generating train split: 100%|██████████| 15015/15015 [00:00<00:00, 613331.33 examples/s]


Size of dolly training data: 12000
Size of dolly training data: 3001


Downloading data: 100%|██████████| 8.06M/8.06M [00:02<00:00, 3.95MB/s]
Downloading data files: 100%|██████████| 1/1 [00:02<00:00,  2.06s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 463.82it/s]
Generating train split: 20022 examples [00:00, 360937.97 examples/s]


Size of alpaca training data: 16016
Size of alpaca training data: 4005


Downloading data: 100%|██████████| 2.55M/2.55M [00:00<00:00, 2.60MB/s]
Downloading data files: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 285.46it/s]
Generating train split: 100%|██████████| 8792/8792 [00:00<00:00, 612034.80 examples/s]


    features: ['INSTRUCTION', 'RESPONSE', 'SOURCE'],
    num_rows: 8792
}) were invalid.
Size of math_instruction training data: 7032
Size of math_instruction training data: 1759


In [8]:
len(train_ds)

82086

In [10]:
tt = train_ds[100]
print(f'len {len(tt)}')
print('\n'.join([t for t in tt]))

len 2
<|prompter|>How to code based on a Figma design.</s>
<|assistant|>Figma is a popular design tool that allows designers to create and share user interface (UI) designs for websites and applications. As a developer, you can use these designs as a blueprint to build the UI for your project. Here's a step-by-step guide on how to code based on a Figma design:

1. Access the Figma design: Obtain the Figma design file or URL from the designer. You may need to sign up for a free account if you don't have one already.
2. Familiarize yourself with the design: Review the design and understand the layout, components, and color schemes used. Pay close attention to typography, spacing, and any custom components.
3. Export assets: Figma allows you to export images, icons, and other assets in various formats (e.g., PNG, SVG, JPEG). Export any necessary assets, and optimize them for the web to ensure fast loading times.
4. Choose a development environment: Set up a development environment suitabl

In [18]:
import torch
isinstance(t, torch.utils.data.IterableDataset)

False

In [9]:
from config import SFT_TRAINING_CONFIG, DIALOGUE_COLLATOR_CONFIG,SFT_DATASET_CONFIG

dict(**DIALOGUE_COLLATOR_CONFIG,**SFT_DATASET_CONFIG,**SFT_TRAINING_CONFIG)

{'max_length': 1024,
 'random_offset_probability': None,
 'label_masking': True,
 'samples_mixing': True,
 'use_system_prefix': False,
 'system_prefix': None,
 'vicuna': {'class': functools.partial(<class 'training_datasets.sft_dataset.Vicuna'>, input_max_length=1024),
  'val_split': 0.2},
 'dolly': {'class': training_datasets.sft_dataset.DatabrickDolly15k,
  'val_split': 0.2},
 'alpaca': {'class': training_datasets.sft_dataset.AlpacaBaseDataset,
  'val_split': 0.2},
 'math_instruction': {'class': training_datasets.sft_dataset.MathInstruction,
  'val_split': 0.2},
 'cache_dir': './cache',
 'model_name': '',
 'train_batch': 1,
 'eval_batch': 1,
 'lr': 1e-05,
 'num_train_epochs': 3,
 'gradient_accumulation_steps': 1,
 'eval_accumulation_steps': 1,
 'log_steps': 500,
 'eval_steps': 1000,
 'save_steps': 5000,
 'warmup_steps': 20,
 'weight_decay': 0.0,
 'dtype': 'fp16',
 'gradient_checkpointing': True,
 'adam_beta1': '',
 'adam_beta2': '',
 'adam_epsilon': '',
 'resume_from_checkpoint': Non