In [1]:
import yaml
import logging

from utils import read_jsonl, write_json
from evaluator import DPOModelEvaluator, repository_check

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("mnlp-2024-auto-evaluator")

# Basic repository check to ensure the submission is correct
repository_check()

# Load the main configuration file
main_config = {}
with open("main_config.yaml") as f:
    try:
        main_config = yaml.safe_load(f)
    except Exception as e:
        logger.error(f"Error loading main_config.yaml: {e}! Please check the file format.")

# Load the task type to identify the model class
task_type = main_config.get("task_type", "causal_lm")

# Load the evaluation methods and the required paths
eval_method = main_config.get("eval_method", ["mcqa"])
policy_model_path = main_config["policy_model_path"]
reference_model_path = main_config["reference_model_path"]
test_data_path = main_config["test_data_path"]

# Load the test data
test_data = read_jsonl(test_data_path)

# Load the model arguments
dpo_model_args = main_config.get("dpo_model_args", {})
rag_model_args = main_config.get("rag_model_args", {})
quantized_model_args = main_config.get("quantized_model_args", {})

# Initialize the metrics dictionary
metrics = {
    "team_name": main_config.get("team_name", "Team Name"),
    "task_type": task_type,
}

# Ensure that the evaluation methods are not conflicting
assert not ("reward" in eval_method and "mcqa" in eval_method), "You cannot evaluate both reward and mcqa at the same time!"


In [2]:
len(test_data)

356

In [3]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(test_data[:50], batch_size=2)
evaluator = DPOModelEvaluator(
    task_type=task_type,
    policy_model_path=policy_model_path,
    dpo_model_args=dpo_model_args
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
policy_acc= evaluator.scoring_mcqa(test_dataloader)
eval_method.remove("mcqa")
metrics["policy_acc"] = policy_acc

2024-05-27 17:11:46,751 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-27 17:12:05,077 - INFO - Trained peft adapter loaded
  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:30<00:30, 30.67s/it]

['B', 'C', 'B', 'B', 'B']
True answer: C


100%|██████████| 2/2 [00:56<00:00, 28.46s/it]


['B', 'B', 'B', 'C', 'C', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:18<00:18, 18.85s/it]

['C', 'C', 'C', 'C', 'B', 'C', 'C', 'C']
True answer: D


100%|██████████| 2/2 [00:46<00:00, 23.39s/it]


['C', 'B', 'C', 'C', 'C', 'B', 'B', 'C', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:31<00:31, 31.08s/it]

['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
True answer: B


100%|██████████| 2/2 [00:57<00:00, 28.90s/it]


['B', 'C', 'C', 'C', 'B', 'C', 'C', 'B', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:21<00:21, 21.51s/it]

['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
True answer: B


100%|██████████| 2/2 [00:44<00:00, 22.33s/it]


['C', 'B', 'C', 'C', 'C', 'B', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: C


 50%|█████     | 1/2 [01:05<01:05, 65.93s/it]

['D', 'D', 'D', 'D', 'D', 'D', 'D']
True answer: A


100%|██████████| 2/2 [03:15<00:00, 97.55s/it] 


['A', 'A', 'A', 'A', 'B', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:44<00:44, 44.74s/it]

['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
True answer: C


100%|██████████| 2/2 [01:23<00:00, 41.93s/it]


['C', 'C', 'C', 'C', 'C', 'C', 'B', 'C', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:23<00:23, 23.38s/it]

['A', 'A', 'A', 'A', 'A', 'B', 'A']
True answer: B


100%|██████████| 2/2 [00:42<00:00, 21.07s/it]


['B', 'B', 'B', 'B', 'B', 'B', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:23<00:23, 23.10s/it]

['B', 'B', 'B', 'B', 'B', 'B', 'B', 'B']
True answer: B


100%|██████████| 2/2 [00:52<00:00, 26.43s/it]


['C', 'B', 'C', 'B', 'B', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: C


 50%|█████     | 1/2 [00:26<00:26, 26.96s/it]

['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
True answer: A


100%|██████████| 2/2 [01:00<00:00, 30.21s/it]


['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:22<00:22, 22.31s/it]

['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
True answer: A


100%|██████████| 2/2 [00:50<00:00, 25.27s/it]


['A', 'C', 'C', 'A', 'A', 'C', 'C', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: C


 50%|█████     | 1/2 [00:27<00:27, 27.24s/it]

['C', 'C', 'C', 'C', 'C', 'C']
True answer: A


100%|██████████| 2/2 [00:56<00:00, 28.42s/it]


['A', 'A', 'A', 'B', 'A', 'A', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:23<00:23, 23.76s/it]

['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
True answer: C


100%|██████████| 2/2 [00:55<00:00, 27.82s/it]


['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:27<00:27, 27.10s/it]

['B', 'B', 'B', 'B', 'B', 'B', 'B', 'B', 'B', 'B']
True answer: A


100%|██████████| 2/2 [00:45<00:00, 22.55s/it]


['C', 'C', 'C', 'C', 'C', 'C', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:29<00:29, 29.42s/it]

['B', 'B', 'B', 'B', 'B', 'B', 'B']
True answer: C


100%|██████████| 2/2 [00:50<00:00, 25.04s/it]


['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:33<00:33, 33.04s/it]

['B', 'B', 'B', 'B', 'B', 'B', 'B']
True answer: C


100%|██████████| 2/2 [01:03<00:00, 31.85s/it]


['C', 'C', 'B', 'C', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: C


 50%|█████     | 1/2 [00:27<00:27, 27.07s/it]

['C', 'C', 'C', 'C', 'C', 'C', 'B', 'C', 'B', 'C']
True answer: C


100%|██████████| 2/2 [00:50<00:00, 25.29s/it]


['C', 'D', 'D', 'B', 'C', 'C', 'B', 'C', 'B', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:36<00:36, 36.05s/it]

['A', 'A', 'B', 'A']
True answer: D


100%|██████████| 2/2 [01:04<00:00, 32.03s/it]


['B', 'C', 'B', 'C', 'C', 'C', 'C', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [01:31<01:31, 91.95s/it]

['A', 'A', 'A', 'B', 'A', 'A', 'A', 'A', 'A']
True answer: D


100%|██████████| 2/2 [01:53<00:00, 56.53s/it]


['D', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:23<00:23, 23.91s/it]

['D', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D']
True answer: B


100%|██████████| 2/2 [00:41<00:00, 20.94s/it]


['B', 'B', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:25<00:25, 25.04s/it]

['D', 'D', 'D', 'C', 'D', 'D', 'D', 'D', 'D', 'D']
True answer: A


100%|██████████| 2/2 [00:52<00:00, 26.28s/it]


['B', 'A', 'A', 'A', 'A', 'A', 'B', 'B']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: A


 50%|█████     | 1/2 [00:32<00:32, 32.47s/it]

['A', 'B', 'B', 'B', 'B', 'B']
True answer: C


100%|██████████| 2/2 [00:55<00:00, 27.77s/it]


['B', 'C', 'B', 'C', 'C', 'C', 'C', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: B


 50%|█████     | 1/2 [00:40<00:40, 40.33s/it]

['B', 'C', 'C', 'C', 'C', 'B', 'C', 'C', 'C', 'C']
True answer: B


100%|██████████| 2/2 [00:58<00:00, 29.37s/it]


['B', 'D', 'D', 'D', 'D', 'D']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:34<00:34, 34.84s/it]

['D', 'B', 'D', 'B', 'D', 'D', 'D', 'B']
True answer: A


100%|██████████| 2/2 [00:51<00:00, 25.57s/it]


['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:21<00:21, 21.12s/it]

['D', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D']
True answer: A


100%|██████████| 2/2 [00:51<00:00, 25.53s/it]


['A', 'B', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']


  0%|          | 0/2 [00:00<?, ?it/s]

True answer: D


 50%|█████     | 1/2 [00:24<00:24, 24.95s/it]

['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
True answer: D


100%|██████████| 2/2 [00:42<00:00, 21.25s/it]

['B', 'C', 'B', 'C', 'B', 'C', 'C', 'C']





In [5]:
policy_acc

0.5

In [3]:
policy_model = evaluator.model_class.from_pretrained(policy_model_path)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2024-05-27 14:41:26,903 - INFO - Trained peft adapter loaded


In [7]:
from transformers import TrainingArguments
 
args = TrainingArguments(
    output_dir="llama3_new",#"doplhin-dpo",               # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=1,         # batch size per device during training
    per_device_eval_batch_size=1,           # batch size for evaluation
    gradient_accumulation_steps=1,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    learning_rate=5e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=25,                       # log every 25 steps
    save_steps=500,                         # when to save checkpoint
    save_total_limit=2,                     # limit the total amount of checkpoints
    evaluation_strategy="steps",            # evaluate every 1000 steps
    eval_steps=700,                         # when to evaluate
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)
 
dpo_args = {
    "beta": 0.1,                            # The beta factor in DPO loss. Higher beta means less divergence
    "loss_type": "sigmoid"                  # The loss type for DPO.
}

prompt_length = 402#1024
max_seq_length = 912#1512

In [9]:
from trl import DPOTrainer
from datasets import Dataset

ds = Dataset.from_json("datasets/dpo_hf_dataset.json")
ds = ds.train_test_split(test_size=0.1)

policy_model.dpo_trainer = trainer_for_eval = DPOTrainer(
    policy_model.pretrained_model,
    ref_model=None, # set to none since we use peft
    # peft_config=peft_config,
    args=args,
    train_dataset=ds["test"],
    eval_dataset=ds["test"],
    tokenizer=evaluator.policy_tokenizer,
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    beta=dpo_args["beta"],
    loss_type=dpo_args["loss_type"],
)

policy_model.dpo_trainer

Map:   0%|          | 0/2674 [00:00<?, ? examples/s]

<trl.trainer.dpo_trainer.DPOTrainer at 0x7f2ae572fca0>