In [1]:
#Code origin
#Author: Alexander Valentini

from datasets import load_dataset,DatasetDict,Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 
from peft import LoraConfig
import os
import torch
import numpy as np
import wandb

In [2]:
project_dir = os.path.dirname(os.path.abspath(os.getcwd()))
#train_data_path = os.path.join(project_dir, 'data/SFT_data/sft_train_dataset_below_1024_chat_format_no_length.json')
#vali_data_path = os.path.join(project_dir, 'data/SFT_data/sft_validation_dataset_below_1024_chat_format_no_length.json')
previous_checkpoint_path = None

train_data_path = 'datasets/mcqa/mcqa_train_dataset_chattemplate_mcqa.jsonl'
vali_data_path = 'datasets/mcqa/mcqa_validation_dataset_chattemplate_mcqa_halved.jsonl'
#check if path is correct using os.path.exists()
#print(os.path.exists(train_data_path))
#print(os.path.exists(vali_data_path))

dataset=load_dataset('json', data_files={"train":train_data_path, "validation":vali_data_path})


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name='AlexVal/dpo_model'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="sdpa"
).to(device)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
print(tokenizer.pad_token)
print(tokenizer.pad_token_id)

<|pad|>
50277


In [5]:
response_template = "<|assistant|>\n" 
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)

In [6]:
peft_config = LoraConfig(
        lora_alpha=64,
        lora_dropout=0.05,
        r=128,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)


In [7]:
training_args = TrainingArguments(
    output_dir="mcqa_model", # directory to save and repository id
    num_train_epochs=1,                     
    per_device_train_batch_size=8,          
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            #Alexander: use gradient checkpointing to save memory - from tutorial
    optim="adamw_torch_fused",              # Alexander: use fused adamw optimizer for faster training
    logging_steps=10,                       # Alexander: log every 10 steps for better debugging
    save_strategy="steps",
    save_steps= 1000,                  
    eval_strategy="steps",
    eval_steps= 500,            
    save_total_limit=3,                     
    load_best_model_at_end=True,            #Load best model at the end of training
    metric_for_best_model="eval_loss",      # metric to use for best model
    learning_rate=5e-7,                     # Lower learning rate to avoid 
#    fp16=True,                             
    #bf16=True,                              
    #tf32=True,                              
    warmup_ratio=0.1,                      
    lr_scheduler_type="linear",           # use constant learning rate scheduler
    report_to="tensorboard",                # report metrics to tensorboard
)

In [8]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    dataset_text_field="text",
    max_seq_length=1024,
    tokenizer=tokenizer,
    packing=False,
    data_collator=collator,
#    peft_config=peft_config,
)

# start training, the model will be automatically saved to the output directory
if previous_checkpoint_path is not None:
    trainer.train(resume_from_checkpoint=previous_checkpoint_path)
else:
    trainer.train()

trainer.save_model()


Map:   0%|          | 0/1458 [00:00<?, ? examples/s]

Map:   0%|          | 0/182 [00:00<?, ? examples/s]

  0%|          | 0/91 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
