In [1]:
import subprocess
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, pipeline
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import logging
import numpy as np
from trl import DPOTrainer, DPOConfig, ModelConfig,get_quantization_config,get_kbit_device_map

# Load environment variables from /etc/network_turbo
result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

# Set the model path of qwen sft AND sentiment model
LM_MODEL = "august66/qwen2-sft-final"
SENTIMENT_MODEL = "siebert/sentiment-roberta-large-english"
N_PREFIX_TOKENS = 5


#load dataset
dataset_test = load_dataset("stanfordnlp/imdb", split="test")
dataset_train = load_dataset("stanfordnlp/imdb", split="train")
def prompt_completion_preprocess(example):
    words = example['text'].split()
    prompt = ' '.join(words[:N_PREFIX_TOKENS])
    completion = ' '.join(words[N_PREFIX_TOKENS:])
    return {'prompt': prompt, 'completion': completion}
dataset_test = dataset_test.map(prompt_completion_preprocess, remove_columns=['text', 'label'])
dataset_train = dataset_train.map(prompt_completion_preprocess, remove_columns=['text', 'label'])






In [3]:
qwen_sft_model = AutoModelForCausalLM.from_pretrained(LM_MODEL)
qwen_sft_tokenizer = AutoTokenizer.from_pretrained(LM_MODEL)
qwen_sft_tokenizer.padding_side = "left"
if qwen_sft_tokenizer.pad_token is None:
    qwen_sft_tokenizer.pad_token = qwen_sft_tokenizer.eos_token
pipe_qwen_sft = pipeline(
    'text-generation',
    model = qwen_sft_model,
    tokenizer = qwen_sft_tokenizer,
    device_map = 'auto'
)
prompts_train = dataset_train['prompt']
generated_completions_train = pipe_qwen_sft(
    prompts_train,
    max_new_tokens = 128,
    do_sample = True,
    top_p = 0.95,
    top_k = 50,
    temperature = 0.8,
    num_return_sequences = 2,
    batch_size = 128,
    repetition_penalty = 1.2,
    eos_token_id = qwen_sft_tokenizer.eos_token_id,
)


Device set to use cuda:0


In [4]:
generated_completions_train_flat = Dataset.from_list(list(np.array(generated_completions_train).ravel()))

pipe_sentiment = pipeline(
    'sentiment-analysis',
    model = SENTIMENT_MODEL,
)

train_sentiment_results = pipe_sentiment(
    generated_completions_train_flat['generated_text'],
    batch_size = 128,
)

Device set to use cuda:0


In [16]:
N = len(dataset_test)
prompt_completion_list_train = []
for i in range(N):

    prompt = dataset_train[i]['prompt']
    completion_1 = generated_completions_train_flat[2*i]['generated_text']
    reward_1 = train_sentiment_results[2*i]['score'] if train_sentiment_results[2*i]['score'] == 'POSITIVE' else 1-train_sentiment_results[2*i]['score']
    completion_2 = generated_completions_train_flat[2*i + 1]['generated_text']
    reward_2 = train_sentiment_results[2*i + 1]['score'] if train_sentiment_results[2*i + 1]['score'] == 'POSITIVE' else 1-train_sentiment_results[2*i + 1]['score']
    preference_prob = F.sigmoid(torch.tensor(reward_1-reward_2))
    bernoulli_indicator = torch.bernoulli(preference_prob).item()
    if bernoulli_indicator == 1:
        chosen, rejected = completion_1, completion_2
        reward_chosen, reward_rejected = reward_1, reward_2
    else:
        chosen, rejected = completion_2, completion_1
        reward_chosen, reward_rejected = reward_2, reward_1
    prompt_completion_list_train.append({
        'prompt': prompt,
        'chosen': " ".join(chosen.split()[N_PREFIX_TOKENS:]),
        'rejected': " ".join(rejected.split()[N_PREFIX_TOKENS:]),
        'reward_chosen': reward_chosen,
        'reward_rejected': reward_rejected
    })
    
prompt_completion_dataset_train = Dataset.from_list(prompt_completion_list_train)
dpo_dataset_train = prompt_completion_dataset_train.select_columns(['prompt', 'chosen', 'rejected'])
    


In [24]:
#why random sample?
#what is gradient checking, gradient acc, learning rate 
model_args = ModelConfig(LM_MODEL)
beta = 0.1 
torch_dtype = (
    model_args.torch_dtype if model_args.torch_dtype in ['auto', None] else torch.float16
)


model_kwargs = dict(
    revision = model_args.model_revision,
    torch_dtype = torch_dtype,
    attn_implementation = model_args.attn_implementation,
    trust_remote_code = model_args.trust_remote_code,
)

model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    **model_kwargs
) 
ref_model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    **model_kwargs,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    padding_side = "left",
    trust_remote_code = model_args.trust_remote_code,
)

training_args = DPOConfig(

        gradient_checkpointing=True,
        per_device_train_batch_size=32,
        gradient_accumulation_steps=4,
        learning_rate=5.0e-7,
        logging_steps=50,
        bf16=True,
        num_train_epochs=1,
        push_to_hub=True,  
        output_dir = "/root/autodl-tmp/.autodl/DPO_tldr",
        report_to = 'none',
        beta = beta,
        hub_model_id = f'august66/qwen2-sft-dpo-imdb-beta-{beta}',
    )

trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=training_args,
    train_dataset=dpo_dataset_train,
    processing_class = tokenizer
)

trainer.train()

Extracting prompt in train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/25000 [00:00<?, ? examples/s]

Step,Training Loss
50,0.6938
100,0.6933
150,0.694


TrainOutput(global_step=196, training_loss=0.6921567722242705, metrics={'train_runtime': 705.5743, 'train_samples_per_second': 35.432, 'train_steps_per_second': 0.278, 'total_flos': 0.0, 'train_loss': 0.6921567722242705, 'epoch': 1.0})

In [61]:
prompts_test = dataset_test['prompt']
dpo_model = trainer.model
dpo_tokenizer = trainer.processing_class

dpo_pipe = pipeline(
    'text-generation',
    model = dpo_model,
    tokenizer = dpo_tokenizer,
)
dpo_completions_test = dpo_pipe(
    prompts_test,
    max_new_tokens = 100,
    eos_token_id = dpo_tokenizer.eos_token_id,
    return_full_text = False,
    batch_size = 128,
    temperature = 0.5
)


Device set to use cuda:0


In [62]:
dpo_completion_test_flat= Dataset.from_list(list(np.array(dpo_completions_test).ravel()))
pipe = pipeline(
    'sentiment-analysis',
    model = SENTIMENT_MODEL,
)

dpo_sentiment_analysis_test = pipe(
    dpo_completion_test_flat['generated_text'],
    batch_size = 128,
    truncation = True,
    padding = True,
)

Device set to use cuda:0


In [63]:
total_score = 0
for i in range(len(dpo_sentiment_analysis_test)):
    score = dpo_sentiment_analysis_test[i]['score']
    if dpo_sentiment_analysis_test[i]['label'] == 'NEGATIVE':
        score = 1 - score
    total_score += score
average_score = total_score / len(dpo_sentiment_analysis_test)

In [77]:
dpo_sentiment_analysis_test

[{'label': 'POSITIVE', 'score': 0.9986749291419983},
 {'label': 'POSITIVE', 'score': 0.998848557472229},
 {'label': 'NEGATIVE', 'score': 0.9990302324295044},
 {'label': 'POSITIVE', 'score': 0.9988120794296265},
 {'label': 'NEGATIVE', 'score': 0.9995121955871582},
 {'label': 'NEGATIVE', 'score': 0.9995121955871582},
 {'label': 'NEGATIVE', 'score': 0.9994966983795166},
 {'label': 'NEGATIVE', 'score': 0.9994966983795166},
 {'label': 'NEGATIVE', 'score': 0.9995044469833374},
 {'label': 'NEGATIVE', 'score': 0.9995121955871582},
 {'label': 'NEGATIVE', 'score': 0.9995121955871582},
 {'label': 'POSITIVE', 'score': 0.998848557472229},
 {'label': 'POSITIVE', 'score': 0.9989350438117981},
 {'label': 'POSITIVE', 'score': 0.9985895752906799},
 {'label': 'POSITIVE', 'score': 0.9989012479782104},
 {'label': 'POSITIVE', 'score': 0.9956005811691284},
 {'label': 'NEGATIVE', 'score': 0.9995044469833374},
 {'label': 'NEGATIVE', 'score': 0.9995121955871582},
 {'label': 'POSITIVE', 'score': 0.99778300523757

In [76]:
len(dpo_sentiment_analysis_test)

25000