In [None]:
import Datasets
from peft import LoraConfig
import torch
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import DataCollatorForCompletionOnlyLM

In [2]:
out_file = 'train_output/phi3_qp.txt'
out_predictions_file = 'train_output/phi3_qp_predictions.txt'

with open(out_file,'a') as f:
    f.write('\n \n new training: \n')
    f.write('-----------------------------\n')

with open(out_predictions_file,'a') as f:
    f.write('\n \n new training: \n')
    f.write('-----------------------------\n')

In [3]:
num_epochs = 1

###################
# Hyper-parameters
###################
training_config = {
    "bf16": True,
    "do_eval": False,
    "learning_rate": 5.0e-06,
    "log_level": "info",
    #"logging_steps": 20,
    #"logging_strategy": "epoch", #"steps",
    "lr_scheduler_type": "cosine",
    "num_train_epochs": num_epochs,
    "max_steps": -1,
    "output_dir": "./checkpoint_dir_qp",
    "overwrite_output_dir": True,
    "per_device_eval_batch_size": 4,
    "per_device_train_batch_size": 4,
    "remove_unused_columns": True,
    "save_steps": 100,
    "save_total_limit": 1,
    "seed": 0,
    "gradient_checkpointing": True,
    "gradient_checkpointing_kwargs":{"use_reentrant": False},
    "gradient_accumulation_steps": 1,
    "warmup_ratio": 0.2,
    }

peft_config = {
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "target_modules": "all-linear",
    "modules_to_save": None,
}
train_conf = TrainingArguments(**training_config)
peft_conf = LoraConfig(**peft_config)


In [None]:
checkpoint_path = 'models/phi3_4k' #"microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    #attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support
    torch_dtype=torch.bfloat16,
    device_map=None
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)#, **model_kwargs) #Uncomment for download
checkpoint_path = "tokenizer/phi3_4k" #"microsoft/Phi-3-mini-4k-instruct" #"tokenizer/phi3/"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.model_max_length = 512#2048
tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'

In [None]:
#model.save_pretrained("models/phi3_4k")
#tokenizer.save_pretrained("tokenizer/phi3_4k")

In [None]:
train_dataset = Datasets.gerernate_qqp('data/QP/Numeracy600K_comment_train.json',tokenizer)
dev_dataset = Datasets.gerernate_qqp('data/QP/Numeracy600K_comment_dev.json',tokenizer)
test_dataset = Datasets.gerernate_qqp('data/QP/Numeracy600K_comment_test.json',tokenizer)


In [None]:
print(len(train_dataset["comment"]))

In [None]:
#def eval(epoch,train_data,test_data):
#
#    for mode in ["train","dev"]:
#        true_count = 0
#        total_count = 0
#
#        if mode == "train":
#            inputs = train_data
#        else:
#            inputs = test_data
#
#
#        for i,input in enumerate(inputs["inputs"][:100]):
#            total_count += 1
#            answer = inputs["labels"][i]
#            input = tokenizer.encode(input,return_tensors="pt")
#            input = input.to('mps')
#            outputs = model.generate(input, max_new_tokens=32)
#            text = tokenizer.batch_decode(outputs)[0]
#            if answer in text.split("<|assistant|>")[-1]:
#                true_count += 1
#            #else:
#            #    print(f"worng predicted {text.split("<|assistant|>")[-1]} but was {answer}")
#
#        print(mode)    
#        print(f'{total_count / len(inputs["inputs"][:100]) *100}% : {true_count/total_count *100}%')
#
#        with open('train_output/phi3_qp.txt','a') as f:
#            f.write(f'{epoch} : {mode} \n')
#            f.write(f'{true_count/total_count *100} % \n')
#
#    with open('train_output/phi3_qp.txt','a') as f:
#        f.write('-----------------------------')
#    

In [None]:
print(train_dataset['text'][0])

In [None]:
print(train_dataset['inputs'][0])

In [None]:
print(train_dataset['labels'][0])

In [None]:
model.to('mps')
Datasets.eval(model,tokenizer,0,train_dataset,'train',out_file,out_predictions_file)
Datasets.eval(model,tokenizer,0,dev_dataset,'dev',out_file,out_predictions_file)

In [None]:
instruction_template = "<|user|>"
response_template = "<|assistant|>"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
        model=model,
        args=train_conf,
        train_dataset=train_dataset,  #processed_train_dataset,
        eval_dataset=dev_dataset,
        #max_seq_length=2048,
        #dataset_text_field="text",
        tokenizer=tokenizer,
        #packing=True
        data_collator=collator,
        peft_config=peft_conf
    )


In [None]:
for i in range(6):
    
    train_result = trainer.train()

    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()


    with open(out_file,'a') as f:
        f.write(f'\n train metrics {i}: \n')
        f.write(str(metrics))    

    
    #############
    # Evaluation
    #############
    #tokenizer.padding_side = 'left'
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(dev_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    with open(out_file,'a') as f:
        f.write(f'\n dev metrics {i}: \n')
        f.write(str(metrics))    

    
    Datasets.eval(model,tokenizer,0,train_dataset,'train',out_file,out_predictions_file)
    Datasets.eval(model,tokenizer,0,dev_dataset,'dev',out_file,out_predictions_file)

In [None]:
Datasets.eval(model,tokenizer,0,test_dataset,'test',out_file,out_predictions_file)

In [None]:
model.save_pretrained('models/qp/phi3/')
tokenizer.save_pretrained('tokenizer/qqa/phi3/')