In [1]:
from llama3.utils import get_llama3, get_llama3_tokenizer
from llama3.config import MODELS_DIR, DATA_DIR

In [2]:
import os 
adapter_path = os.path.join(MODELS_DIR, "Meta-Llama-3-8B-Instruct_ft_pop_kar_1_ep_1split")

In [3]:
data_path = os.path.join(DATA_DIR, "raft_qn")

In [4]:
from datasets import load_from_disk, concatenate_datasets

In [5]:
ds_rice = load_from_disk(os.path.join(data_path, "rice"))
ds_wheat = load_from_disk(os.path.join(data_path, "wheat"))
ds = concatenate_datasets([ds_rice, ds_wheat]) # type: ignore

In [7]:
def inst_format(example) :
    return {
        "prompt": "Answer the qustion based on provided context. Context:\n"+example["instruction"],
            "completion": example["cot_answer"]
        }
dataset = ds.map(inst_format)

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 159
})

In [11]:
dataset = dataset.remove_columns(['id', 'type', 'question', 'context', 'oracle_context', 'cot_answer', 'instruction'])

In [8]:
import dotenv
from datasets import load_from_disk, concatenate_datasets
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import transformers
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel     # type: ignore
from datasets import load_from_disk, concatenate_datasets
import numpy as np
from transformers.training_args import TrainingArguments
from llama3.utils import get_llama3
from trl import SFTTrainer

In [9]:
model = "Meta-Llama-3-8B-Instruct"
save_dir = f"./models/{model}_ft_mhop"
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj']
model, tokenizer = get_llama3(adapter_path=adapter_path)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=16, 
    lora_alpha=32,
    target_modules=target_modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset, # type: ignore
    # eval_dataset=eval_data,
    args=TrainingArguments(
        auto_find_batch_size=True,
        # resume_from_checkpoint="chkpt_path",
        # per_device_train_batch_size=1,
        # per_device_eval_batch_size=1,
        # gradient_accumulation_steps=grad_acc_steps,
        # eval_accumulation_steps=grad_acc_steps,
        # warmup_steps=10,
        warmup_ratio=0.1,
        # max_steps=20,
        # learning_rate=2e-4,
        # fp16=True,
        bf16=True,
        # logging_steps=20/grad_acc_steps,
        # evaluation_strategy="steps",
        # eval_steps=20/grad_acc_steps,
        output_dir=save_dir+"_outputs",
        # optim="paged_adamw_8bit",
        # num_train_epochs=n_epochs
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train() # type: ignore
# trainer.save_model(save_dir)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/159 [00:00<?, ? examples/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss


TrainOutput(global_step=60, training_loss=2.213941446940104, metrics={'train_runtime': 584.3203, 'train_samples_per_second': 0.816, 'train_steps_per_second': 0.103, 'total_flos': 2.167924542554112e+16, 'train_loss': 2.213941446940104, 'epoch': 3.0})

In [13]:
trainer.save_model(save_dir)


Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted. You must be authenticated to access it. - silently ignoring the lookup for the file config.json in meta-llama/Meta-Llama-3-8B-Instruct.


In [60]:
ds = ds.remove_columns(['id', 'type', 'question', 'context', 'oracle_context', 'cot_answer', 'instruction'])

In [42]:
def remove_oracle_context(example):
    oracle_context = example['oracle_context']
    example['non_oracle'] = [sentence for sentence in example['context']['sentences'][0] if sentence!=oracle_context]
    return example
ds = ds.map(remove_oracle_context)


In [43]:
ds = ds.remove_columns(['type', 'context', 'instruction', 'oracle_context', 'cot_anwser', 'inst'])

In [44]:
# def extract_ans(example):
#     cot_ans = example['cot_answer']
#     ans_split = cot_ans.split('<ANSWER>')
#     if(len(ans_split) == 2):
#         example['answer'] = ans_split[1]
#     else :
#         example['answer'] = ""
#     return example
# ds = ds.map(extract_ans)

In [35]:
# ds = ds.filter(lambda example: len(example["cot_answer"].split("<ANSWER>")) < 2 )

Filter:   0%|          | 0/159 [00:00<?, ? examples/s]

In [45]:
ds

Dataset({
    features: ['id', 'question', 'oracle_context', 'cot_answer', 'non_oracle'],
    num_rows: 159
})

In [14]:
tokenizer = get_llama3_tokenizer(adapter_path=adapter_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [46]:
# def get_token_lens(exmaple) :
#    tlen = len(tokenizer.encode(exmaple['cot_answer']))
#    tlen += len(tokenizer.encode(exmaple['question']))
#    tlen += len(tokenizer.encode(exmaple['oracle_context']))
#    for sentence in exmaple['non_oracle']:
#       tlen += len(tokenizer.encode(sentence))
#    exmaple['tlen'] = tlen
#    return exmaple
# ds = ds.map(get_token_lens)

In [50]:
def inst_format(example) :
    tlen = len(tokenizer.encode(example['question']))
    tlen += len(tokenizer.encode(example['oracle_context']))
    tlen += len(tokenizer.encode(example['cot_answer']))
    tlen += len(tokenizer.encode("Context:\n Answer the quetion based on the context"))
    prompt = "Context:\n"
    if(tlen > 4000):
        return example
    k = 0
    for sentence in example['non_oracle']:
        tlen += len(tokenizer.encode(sentence))
        if(tlen > 4000):
            break
        k+=1
    for i in range(k):
        prompt += f"<DOCUMENT>{example['non_oracle']}"

    

{'id': ['seed_task_4',
  'seed_task_13',
  'seed_task_0',
  'seed_task_11',
  'seed_task_3'],
 'question': ['What are the characteristics of Javanica variety and where is it mainly found?',
  'What are the recommended rice hybrids and varieties for cultivation in different states of India?',
  'What are some of the states in India where paddy cultivation is under rainfed upland situation?',
  'When should Nitrofen be applied after transplanting rice?',
  'What are the general guidelines for fertilization of high-yielding dwarf varieties of wheat under different agro-climatic conditions according to the All-India Coordinated Wheat Improvement Project?'],
 'oracle_context': ['They may be awned or awnless,  leaves are narrow and dark green in colour. javanica : These varieties are characterized by a sti ff straw, long panicle with awned grains, \nsparse tille ring habit, lo ng durat ion and low sen sitivity to dif ferences in day ligh t. These ar e \nfound m ainly in Indonesia.',
  'Selec