In [81]:
import argparse
from typing import Optional, Union

import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from dataclasses import dataclass

import datasets
from datasets import Dataset

from sklearn.metrics import log_loss

from transformers import (
    AutoTokenizer,
    AutoConfig,
    EarlyStoppingCallback,
    AutoModelForCausalLM,
    AutoModelForMultipleChoice,
    TrainingArguments,
    Trainer,
    RobertaForMultipleChoice,
    AutoModelForSequenceClassification,
    LlamaModel,
    LlamaForSequenceClassification,
    BitsAndBytesConfig,
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    TrainerCallback,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy

from peft import (
    get_peft_config,
    PeftModel,
    PeftConfig,
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
import os

from utils import *

import random
def seed_everything(seed=None):
    '''
    固定seed
    :param seed: int, 随机种子
    '''
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    if (seed is None) or not (min_seed_value <= seed <= max_seed_value):
        seed = random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max)
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    return seed


seed_everything(42)

42

In [82]:
data_path = "dataset/non_overlap/train_33k.json"
prompt_type = 3
MAX_INPUT = 1900
if_train = True
split = False
if_drop_duplicate = False
keep = 'last'
df_train , df_valid = load_split_data(data_path, prompt_type, MAX_INPUT, if_train, split, False, if_drop_duplicate, keep)
test = df_train


100%|██████████| 95028/95028 [00:50<00:00, 1882.63it/s]


In [None]:
df_train[df_train.length >= 1900 * 0.75]

In [None]:
test

In [None]:
o_data = pd.read_json(data_path)

In [83]:
test['length'] = test['prompt_response'].apply(lambda x: len(x.split(" ")))
# test = test.sort_values(by = ['length'], ascending = False).reset_index(drop = True)
# data = test[:5]
# data

In [None]:
test.loc[test.length >= 1900 * 0.75].reset_index(drop = True)

In [84]:
over = test.loc[test.over_max_length == 1].reset_index(drop = True)
over

Unnamed: 0,id,prompt_response,label,overflow_prompt,over_max_length,overflow_response_a,overflow_response_b,length
0,3254113,#Prompt\nCan you help create Dan's thoughts as...,C,Can you help create Dan's thoughts as Liz cane...,1,As Liz leads Dan to the room where he will rec...,As Liz begins to describe the details of the p...,2302
1,16748032,#Prompt\nwrite a clickable and graphic tic tac...,B,write a clickable and graphic tic tac toe prog...,1,"import graphics from pygame import image,...",Creating a full graphical and clickable Tic Ta...,1634
2,35088788,#Prompt\nplease write a nice casual and friend...,C,please write a nice casual and friendly reply ...,1,"Dear [username],\n\nThank you for sharing your...",Hey there!\n\nI couldn't help but resonate wit...,1767
3,37697874,#Prompt\nRole: You are an Enterprise AI system...,C,Role: You are an Enterprise AI system designed...,1,"AI: {\n""reply"": ""I'm just an AI, I don't have ...",Role: You are an Enterprise AI system designed...,1987
4,43053669,#Prompt\nSummarize this article:\n\nAmabel Hol...,A,"Summarize this article:\n\nAmabel Holland, a f...",1,"Amabel Holland, an autistic board-game designe...",The article highlights the rise in popularity ...,2356
...,...,...,...,...,...,...,...,...
829,8af51cda208b4b62986f62beadccfed8,#Prompt\nHey! Write a tic-tac toy game on pyth...,B,Hey! Write a tic-tac toy game on python please.,1,"Sure, here's an example of a tic-tac-toe game ...",Sure! Here's a simple implementation of a Tic ...,1659
830,0724e6e2022c4ca18bbb97937bd8ee47,#Prompt\nrefactor this code : \npublic class T...,A,refactor this code : \npublic class TennisGame...,1,"Sure, here's a refactored version of the code:...","Sure, here is an updated version of the`Tennis...",1845
831,d8ae2cc49fe4465aaee9706e6220bc38,#Prompt\nimport { Link } from 'preact-router/m...,C,import { Link } from 'preact-router/match';\n\...,1,To make the links work in your full response f...,You're looking to make the rest of the links i...,1508
832,75f89aaf9cec44509ba54780712d025e,"#Prompt\n{\n ""ai_tutor"": {\n ""Author...",C,"{\n ""ai_tutor"": {\n ""Author"": ""JushB...",1,"Your chatbot, Mr. Ranedeer, can assist you wit...",The AI tutor's author has put a lot of thought...,1789


In [None]:
over['over_prompt_response'] = "#Prompt\n" + over['overflow_prompt'] + "\n\n" + "#Response\n" + "##Model A\n" + over['overflow_response_a'] + "\n\n" + "##Model B\n" + over['overflow_response_b']

In [None]:
over.loc[over['over_prompt_response'] == over['prompt_response']]

In [None]:
over = o_data.loc[o_data.id.isin(over.id)].reset_index(drop = True)

In [None]:
print(over.loc[3,'overflow_prompt'])

In [None]:
print(over.loc[4,'prompt_response'])

In [None]:
print(test.loc[test.over_max_length == 0,'prompt_response'].values[0])

In [None]:
test.loc[test.over_max_length == 0]

In [None]:
check = test.loc[test.id == 16748032]
check

In [None]:
print(check.prompt_response.values[-1])

In [None]:
o_data.loc[o_data.id==16748032].prompt.values

In [None]:
o_data.loc[o_data.id==16748032].prompt.values[0][1]

In [None]:
o_data.loc[o_data.id==16748032].response_a.values[0][2]

In [None]:
o_data.loc[o_data.id==16748032].response_b.values[0][2]

In [None]:
o_data.loc[o_data.id=='26dc950ef0'].prompt.values[0][-1]

In [None]:
print(test.loc[test.id == '39036b3a02'].prompt_response.values[0])

In [None]:
o_data.loc[o_data.id == '39036b3a02'].prompt.values[0]

In [None]:
print(check.prompt_response.values[0])

In [None]:
check.prompt_response.values[-2]

In [None]:
test.loc[17,]

In [None]:
o_data.loc[o_data.id == '240a03e332'].prompt.values

In [None]:
device = torch.device("cuda:0")
base_model = 'google/gemma-2-9b-it'
model_path = "output/morning-waterfall-460/checkpoint-5200_888"
MAX_LENGTH = 1900

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side = 'left')
config = AutoConfig.from_pretrained(base_model, trust_remote_code=True)

In [None]:
t = check.prompt_response.values[0]
len(tokenizer(t)['input_ids'])

In [102]:
def adjust_values(A, B, a_space, b_space, ex_space):
    # 计算A和a_space的差值
    a_diff = a_space - A
    b_diff = b_space - B
    
    # 第一种情况：A小于a_space，B小于b_space
    if A < a_space and B < b_space:
        ex_space += a_diff + b_diff
        return A, B, ex_space

    # 第二种情况：如果A和B都各自大于自己的space
    elif A > a_space and B > b_space:
        total_extra_needed = (A - a_space) + (B - b_space)
        if total_extra_needed > ex_space:
            A = int(a_space + ex_space / 2)
            B = int(b_space + ex_space / 2)
            ex_space = 0
        else:
            a_space = A
            b_space = B
            ex_space -= total_extra_needed
            
        return A, B, ex_space
        
    # 第三种情况：A或者B其中有一个大于a_space, b_space
    elif A > a_space or B > b_space:
        # 如果A大于a_space但是B小于b_space
        if A > a_space and B < b_space:
            extra_needed = A - a_space
            ex_space += b_space - B
            #够用
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                #不够用
                #b_space = B + available_space
                A = a_space + ex_space
                ex_space = 0

        # 如果B大于b_space但是A小于a_space
        elif B > b_space and A < a_space:
            extra_needed = B - b_space
            ex_space += a_space - A
            
            if ex_space >= extra_needed:
                ex_space -= extra_needed
                
            else:
                B = b_space + ex_space
                ex_space = 0

        return A, B, ex_space
    

def adjust(current_lengths, prompt_length_space=300, response_length_space=800):
    prompt_length = current_lengths[0]
    response_a_length = current_lengths[1]
    response_b_length = current_lengths[2]
    #先看prompt的额度
    ex_space = max(0, prompt_length_space - prompt_length)
    response_a_length, response_b_length, ex_space = adjust_values(response_a_length, response_b_length, response_length_space, response_length_space, ex_space)
    prompt_length = min(prompt_length, prompt_length_space)
    prompt_length += ex_space

    return prompt_length, response_a_length, response_b_length

In [107]:
from torch.utils.data import Dataset
class InstructionDataSet(Dataset):
    def __init__(self, data, tokenizer, max_source_length, max_target_length):
        super(InstructionDataSet, self).__init__()
        #self.data = data.sample(len(data), random_state=0).reset_index(drop=True)
        self.data = data
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        # self.A_token = self.tokenizer.encode(text='A', add_special_tokens=False, truncation=True, )
        # self.B_token = self.tokenizer.encode(text='B', add_special_tokens=False, truncation=True, )
        # self.C_token = self.tokenizer.encode(text='C', add_special_tokens=False, truncation=True, )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        now_data = self.data.loc[index]
        over_max_length = now_data['over_max_length']
        
        templete_part1 = "<start_of_turn>user\nHere are two question-answering dialogues. Compare two model performance on answering question, determine which is better.\n\n"
        templete_part1_input_ids = self.tokenizer(text=templete_part1, add_special_tokens=True, padding=False)['input_ids']
        
        templete_part2 = "\n###options\nA. Model A\nB. Model B\nC. Tie\n<end_of_turn>\n"
        templete_part2_input_ids = self.tokenizer(text=templete_part2, add_special_tokens=True, padding=False)['input_ids'][1:]
        #print(f"templete_part2 is {templete_part2_input_ids}")
        templete_part3 = "<start_of_turn>model\n"
        templete_part3_input_ids = self.tokenizer(text=templete_part3, add_special_tokens=True, padding=False)['input_ids'][1:]
        
        templete_part4_input_ids = self.tokenizer(text="\n\n", add_special_tokens=False, padding=False)['input_ids']
        
        if over_max_length:
            prompt = "#Prompt\n" + now_data['overflow_prompt']
            r_a = "#Response\n" + "##Model A\n" + now_data['overflow_response_a']
            r_b = "##Model B\n" + now_data['overflow_response_b']
            
            prompt_ids = self.tokenizer(text=prompt, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_a_input_ids = self.tokenizer(text=r_a, add_special_tokens=False, truncation=False, padding=False)['input_ids']
            model_b_input_ids = self.tokenizer(text=r_b, add_special_tokens=False, truncation=False, padding=False)['input_ids']

            if len(prompt_ids) + len(model_a_input_ids) + len(model_b_input_ids) <= self.max_source_length:
                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids
            
            else:
                '''
                prompt 和 response 按照 300， 800， 800
                response 优先
                多的再给prompt
                '''
                length = [len(prompt_ids), len(model_a_input_ids), len(model_b_input_ids)]
                print(f"before {length}")
                prompt_max_length, a_max_length, b_max_length = adjust(length)
                prompt_ids = prompt_ids[:prompt_max_length] + templete_part4_input_ids
                model_a_input_ids = model_a_input_ids[:a_max_length] + templete_part4_input_ids
                model_b_input_ids = model_a_input_ids[:b_max_length] + templete_part4_input_ids
                print(f"after {[prompt_max_length, a_max_length, b_max_length]}")
                prompt_response_ids = prompt_ids + model_a_input_ids + model_b_input_ids
        
        else:
            prompt_response = now_data['prompt_response']
            #print(f"id is {now_data['id']}")
            #print(prompt_response)
            prompt_response_ids = self.tokenizer(text=prompt_response, add_special_tokens=True, truncation=True,
                                              max_length=self.max_source_length, padding=False)['input_ids'][1:]
            #print(prompt_response_ids)        
            
            
        label = now_data['label']
        label_ids = self.tokenizer.encode(text=label, add_special_tokens=False)
        input_ids = templete_part1_input_ids + prompt_response_ids + templete_part2_input_ids + templete_part3_input_ids + label_ids + [self.tokenizer.eos_token_id]
        labels = [-100] * (len(input_ids) - 2) + label_ids + [self.tokenizer.eos_token_id]
        #print(f"input is {self.tokenizer.decode(input_ids)}")
        return {
            "input_ids": input_ids,
            "labels": labels
        }

In [108]:
length = [500,600,600]
adjust(length)

(700, 600, 600)

In [131]:
tokenized_dataset = InstructionDataSet(over,tokenizer, 1900, 1)

In [129]:
test

Unnamed: 0,id,prompt_response,label,overflow_prompt,over_max_length,overflow_response_a,overflow_response_b,length
0,30192,#Prompt\nIs it morally right to try to have a ...,A,,0,,,890
1,53567,#Prompt\nWhat is the difference between marria...,B,,0,,,1167
2,65089,#Prompt\nexplain function calling. how would y...,C,,0,,,432
3,96401,#Prompt\nHow can I create a test set for a ver...,A,,0,,,819
4,370945,"#Prompt\n""Bacteria is life on Mars but a heart...",B,,0,,,144
...,...,...,...,...,...,...,...,...
79697,8777c4945d85469d96cd26fc2ea6f64a,#Prompt\nwho is the president of the U.S.A?\n\...,C,,0,,,39
79698,86063a921be548989c55b85497ab009a,#Prompt\nhow to train lora for stable diffusio...,A,,0,,,698
79699,6685a3b3863f4554887e432f7dbbe8a5,#Prompt\n남녀 섹스 체위 자세 10가지를 적어줘\n\n#Response\n#...,B,,0,,,79
79700,f72930b382e949ea879e7abf3cb1e587,#Prompt\nhow to evaluate a language model outp...,A,,0,,,545


In [132]:
print(len(tokenized_dataset[0]['input_ids']))
print(tokenizer.decode(tokenized_dataset[0]['input_ids']))

before [2204, 373, 282]
after [1245, 373, 282]
1955
before [2204, 373, 282]
after [1245, 373, 282]
<bos><start_of_turn>user
Here are two question-answering dialogues. Compare two model performance on answering question, determine which is better.

#Prompt
Can you help create Dan's thoughts as Liz canes him? 
Although she has warned him that her role as Justice of the Peace required her to deliver the cane strokes iof his sentence with severity (and that someone at the Ministry would review the video of her administration of the caning to ensure she did, so she could not go easy on him),  Dan is surprised at the harshness with which Liz delivers his caning. She  is compelled by her professional duties to adhere strictly to the protocol, ensuring that the punishment is carried out to the letter to avoid reprimand or disciplinary action from her own supervisors, Dan somehow did not expect it to be as harsh as it was.   The fact that the punishment is being recorded and will be reviewed ad