In [13]:
import argparse

from utils import (print_yaml_config,init_or_resume_from, read_yaml)
from model_training.training_utils import load_for_inference
from constants import TOKENIZER_SEPECIAL_TOKENS
from training_datasets.dataset_utils import load_sft_dataset, load_rm_dataset
from training_datasets.collators import DialogueDataCollator




def main(conf):
    print(f"\n{'==='*10} Following are the configuration for training{'==='*10}")
    print_yaml_config(conf)
    
    device_map = {"":0}
    assert "llama" in conf.model_name.lower(), "Currently only llama model supported"
    special_tokens = TOKENIZER_SEPECIAL_TOKENS["llama"]
    model, tokenizer = load_for_inference(device_map,conf,special_tokens,conf.model_name,False)
    _ , eval_ds = load_sft_dataset(conf,special_tokens["eos_token"])

    eval_collate_fn = DialogueDataCollator(
        tokenizer,
        max_length=conf.collator["val_max_length"],
        random_offset_probability=conf.collator["random_offset_probability"],
        label_masking=conf.collator["label_masking"],
        samples_mixing=False,
        use_system_prefix=conf.collator["use_system_prefix"],
        system_prefix=conf.collator["system_prefix"],
        )
    return model, eval_ds, eval_collate_fn, tokenizer

config = {}
conf = read_yaml('./config.yaml')
config.update(conf["default"])
config.update(conf["sft_eval"])
config["name_suffix"] = ""
config["debug"] = False
config["subset"] = "sft_eval"

# Create a Namespace object for config
config_ns = argparse.Namespace(**config)


init_or_resume_from(config_ns)
# config_ns.model_name = config_ns.base_model_name
model, eval_ds, collator, tokenizer = main(config_ns)


!!python/object:argparse.Namespace
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-12
adapter_number: final_checkpoint
adpater_name: null
checkpoint_name: null
checkpoint_number: final_checkpoint
collator:
  label_masking: true
  random_offset_probability: 0.5
  samples_mixing: true
  system_prefix: null
  use_system_prefix: false
  val_max_length: 2048
dataset:
  alpaca:
    max_val_set: 200
    val_split: 0.05
  dolly:
    max_val_set: 300
    val_split: 0.05
  math_instruction:
    max_val_set: 200
    val_split: 0.05
  oasst_export:
    lang: en,bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk
    max_val_set: null
    val_split: 0.05
  vicuna:
    max_val_set: 800
    val_split: 0.05
debug: false
debug_set: 100
dtype: bf16
eval_accumulation_steps: null
eval_batch: 4
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
hpt_data_frac: null
init_from_adapter: output/LLama-2-7b_pre_sft_4bit_lr_1e5_bs_64_adam_hf/final_checkpoint
int4_training: true

Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.20s/it]


tokenizer size 32003
Resizing embeddings to 32016
===loading the vicuna dataset===

Size of vicuna training data: 55858
Size of vicuna validation data: 800
===loading the dolly dataset===

Size of dolly training data: 14250
Size of dolly validation data: 300
===loading the alpaca dataset===

Size of alpaca training data: 19019
Size of alpaca validation data: 200
===loading the math_instruction dataset===

    features: ['INSTRUCTION', 'RESPONSE', 'SOURCE'],
    num_rows: 8792
}) were invalid.
Size of math_instruction training data: 8351
Size of math_instruction validation data: 200
===loading the oasst_export dataset===

OASST HF dataset: len(train)=38192, len(eval)=1992


In [77]:
# data = eval_ds["vicuna"][0]
# tokens = collator([data])
input = """<|prompter|>When was google created? Why did it dominate over other search engines?<s><|assistant|>"""
tokens = tokenizer(input,return_tensors='pt').to("cuda:0")

print(len(tokens.input_ids[0]))
# label_masks = tokens.pop("label_masks")
# targets = tokens.pop("targets")

output = model.generate(**tokens,
                        max_new_tokens=400,
                        do_sample=True,
                        top_p=0.9,
                        top_k=0,
                        repetition_penalty=1.2,
                        temperature=0.8)
# import torch

# with torch.no_grad():
#         outputs = model(
#             input_ids=tokens["input_ids"],
#             use_cache=False,
#         )

# logits = outputs.get("logits")
# pred_ids = torch.argmax(logits, dim=-1)
# len(pred_ids[0])

19


In [78]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[0], skip_special_tokens=False))



Output:
----------------------------------------------------------------------------------------------------
<s><|prompter|> When was google created? Why did it dominate over other search engines?<s><|assistant|> Google was founded by Larry Page and Sergey Brin in 1998.
Google's success can be attributed to several factors, including its ability to quickly index the vast amount of information available online, its simple user interface that allows users to easily navigate through the search results, and its emphasis on providing high-quality search results based on a complex algorithm known as PageRank . Additionally, Google has also developed various products and services such as Gmail, YouTube, Android operating system for mobile devices, etc., which have further increased its popularity among internet users</s>


In [33]:
tokenizer.decode(tokens.input_ids[0])

"the game scores the difference between their score and the total of all the other players' scores.\n\nI hope this helps you create a fun and exciting Rummy-based card game! Let me know if you have any further questions.</s><|prompter|> I am going to add a mechanic to the game where players can choose a new rule to add to the game after every round. These new rules are written on cards called Rule Cards. The players can look at the Rule Cards and choose one. The rule written on the card will then be in effect for the rest of the game. If I give you the examples of Rule Cards that I have already thought of, can you help me think of more of them?</s><|assistant|> Certainly! I'd be happy to help you come up with more Rule Cards. What are the examples of Rule Cards that you have already thought of? Knowing those can help me come up with additional ideas that fit with the theme and mechanics of your game.</s><|prompter|> 1. Wild Cards - different cards can be wild\n2. Change the number of c

In [74]:
print("".join(eval_ds["alpaca"][6]))
# 

<|prompter|>Write a function to create a retail price tag given a cost, taxes and discounts.
cost = 5
tax = 0.08
discount = 0.2</s><|assistant|>def price_tag(cost, tax, discount):
  tax_amount = cost * tax 
  discounted_price = cost + tax_amount - (discount * cost)
  return "Retail Price: $" + str(round(discounted_price, 2))

price_tag(cost, tax, discount) # returns "Retail Price: $7.20"</s>


In [64]:
# model.generate
