In [None]:
from modelscope import AutoModelForCausalLM, AutoTokenizer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

def create_qwen_model():
    model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', 
                                                torch_dtype='auto',
                                                device_map='auto')
    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
    return model, tokenizer

Using device: cuda


In [None]:
# DPO training model
model_PI, tokenizer = create_qwen_model()
# Base model
model_base, _ = create_qwen_model()

In [3]:
def chat(model, tokenizer, prompt: str):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(messages, 
                                         tokenize=False,
                                         add_generation_prompt=True,)
    # print(text)

    model_inputs = tokenizer([text], return_tensors='pt').to(device)
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512,)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

In [4]:
# A simple training data
train_data = [
    {'prompt': '你是谁？', 'favour': '我是通义千问', 'rejected': '我是阿里云开发的大语言模型，我叫通义千问'},
    {'prompt': '谁发明了你？', 'favour': '阿里巴巴', 'rejected': 'Ian Goodfellow'},
]

def dpo_to_message(dpo_pairs):
    favour_msg, reject_msg = [], []
    for pair in dpo_pairs:
        favour_msg.append([{'role': 'system', 'content': 'You are a helpful assistant.'},
                           {'role': 'user', 'content': pair['prompt']},
                           {'role': 'assistant', 'content': pair['favour']}])
        reject_msg.append([{'role': 'system', 'content': 'You are a helpful assistant.'},
                           {'role': 'user', 'content': pair['prompt']},
                           {'role': 'assistant', 'content': pair['rejected']}])
    return favour_msg, reject_msg

In [5]:
# Preprocess training data
def preprocess(tokenizer, batch_messages):
    input_ls, target_ls = list(), list()

    im_start_id = tokenizer('<|im_start|>').input_ids
    im_end_id = tokenizer('<|im_end|>').input_ids
    newline_id = tokenizer('\n').input_ids
    padding_id = tokenizer('<|endoftext|>').input_ids
    ignore = [-100]

    for messages in batch_messages:
        input_ids, target_ids = list(), list()
        for msg in messages:
            role_id = tokenizer(msg['role']).input_ids
            content_id = tokenizer(msg['content']).input_ids
            if msg['role'] in ['system', 'user']:
                ignore_parts = role_id + content_id + newline_id
                input_ids.extend(im_start_id + ignore_parts + im_end_id + newline_id)
                target_ids.extend(im_start_id + ignore * len(ignore_parts) + im_end_id + newline_id)
            else:
                ignore_parts = role_id + newline_id
                input_ids.extend(im_start_id + ignore_parts + content_id + im_end_id + newline_id)
                target_ids.extend(im_start_id + ignore * len(ignore_parts) + content_id + im_end_id + newline_id)
        input_ls.append(input_ids)
        target_ls.append(target_ids)
    
    # padding
    max_len = max(len(input_ids) for input_ids in input_ls)
    for input_ids, target_ids in zip(input_ls, target_ls):
        input_ids.extend(padding_id * (max_len - len(input_ids)))
        target_ids.extend(ignore * (max_len - len(target_ids)))
    batch_input_idx = torch.tensor(input_ls, dtype=torch.long)
    batch_target_idx = torch.tensor(target_ls, dtype=torch.long)
    batch_mask = batch_input_idx.ne(padding_id[0]).type(torch.long)  # padding mask
    return batch_input_idx, batch_target_idx, batch_mask


In [6]:
# Training mode
model_PI.train()
model_base.train()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-35): 36 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=256, bias=True)
          (v_proj): Linear(in_features=2048, out_features=256, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((2048,), eps=1e-06)
    (rotary_emb):

In [7]:
def dpo_prob_calc(target_ids, pi_logits, base_logits):
    pi_prob = torch.log_softmax(pi_logits, dim=-1)
    base_prob = torch.log_softmax(base_logits, dim=-1)
    
    # torch.gather()
    ignore_mask = target_ids != -100
    indexes = target_ids * ignore_mask

    # Gather the probabilities for the target tokens, ignoring the -100 tokens
    pi_prob_wrt_target = torch.gather(pi_prob, dim=-1, index=indexes.unsqueeze(-1)).squeeze(-1) * ignore_mask
    base_prob_wrt_target = torch.gather(base_prob, dim=-1, index=indexes.unsqueeze(-1)).squeeze(-1) * ignore_mask

    # Calculate the final probabilities by averaging over the valid tokens
    pi_prob_final = pi_prob_wrt_target.sum(-1) / ignore_mask.sum(-1)
    base_prob_final = base_prob_wrt_target.sum(-1) / ignore_mask.sum(-1)
    return pi_prob_final, base_prob_final

def dpo_loss_calc(params):
    chosen_target_ids = params['chosen_target_ids'][:, 1:]
    pi_chosen_logits = params['pi_chosen_logits'][:, :-1, :]
    base_chosen_logits = params['base_chosen_logits'][:, :-1, :]
    pi_chosen_prob, base_chosen_prob = dpo_prob_calc(chosen_target_ids, pi_chosen_logits, base_chosen_logits)

    reject_target_ids = params['reject_target_ids'][:, 1:]
    pi_reject_logits = params['pi_reject_logits'][:, :-1, :]
    base_reject_logits = params['base_reject_logits'][:, :-1, :]
    pi_reject_prob, base_reject_prob = dpo_prob_calc(reject_target_ids, pi_reject_logits, base_reject_logits)

    # Calculate the DPO loss
    pi_prob_diff = pi_chosen_prob - pi_reject_prob
    base_prob_diff = base_chosen_prob - base_reject_prob
    beta = .1
    loss = torch.nn.functional.logsigmoid(beta * (pi_prob_diff - base_prob_diff)).mean()
    return loss

In [None]:
optimizer = torch.optim.SGD(model_PI.parameters(), lr=1e-5)

epochs = 10
vocab = tokenizer.get_vocab()
for epoch in range(epochs):
    chosen_msg, reject_msg = dpo_to_message(train_data)
    # model input and target
    chosen_input_ids, chosen_target_ids, chosen_mask_ids = preprocess(tokenizer, chosen_msg)
    reject_input_ids, reject_target_ids, reject_mask_ids = preprocess(tokenizer, reject_msg)
    # predict of model_PI
    pi_chosen_logits = model_PI(input_ids=chosen_input_ids.to(device), attention_mask=chosen_mask_ids.to(device)).logits
    pi_reject_logits = model_PI(input_ids=reject_input_ids.to(device), attention_mask=reject_mask_ids.to(device)).logits
    # predict of model_base
    base_chosen_logits = model_base(chosen_input_ids.to(device), attention_mask=chosen_mask_ids.to(device)).logits
    base_reject_logits = model_base(reject_input_ids.to(device), attention_mask=reject_mask_ids.to(device)).logits
    # calculate loss
    loss = dpo_loss_calc({
        'chosen_target_ids': chosen_target_ids.to(device),
        'reject_target_ids': reject_target_ids.to(device),
        'pi_chosen_logits': pi_chosen_logits.to(device),
        'pi_reject_logits': pi_reject_logits.to(device),
        'base_chosen_logits': base_chosen_logits.to(device),
        'base_reject_logits': base_reject_logits.to(device),
    })
    print(f'Epoch: {epoch + 1}/{epochs}, Loss: {loss:.3f}')
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
model_PI.eval()
# Test the trained model
prompts = ['你是谁？', '谁发明了你？', '讲讲Transformer模型']
for prompt in prompts:
    response = chat(model_PI, tokenizer, prompt)
    print(response)