In [None]:
import argparse
from typing import Optional, Union

import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from dataclasses import dataclass

import datasets
from datasets import Dataset
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import log_loss

from transformers import (
    AutoTokenizer,
    AutoConfig,
    EarlyStoppingCallback,
    AutoModelForCausalLM,
    AutoModelForMultipleChoice,
    TrainingArguments,
    Trainer,
    RobertaForMultipleChoice,
    AutoModelForSequenceClassification,
    LlamaModel,
    LlamaForSequenceClassification,
    BitsAndBytesConfig,
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    TrainerCallback,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy

from peft import (
    get_peft_config,
    PeftModel,
    PeftConfig,
    get_peft_model,
    LoraConfig,
    TaskType,
)
import os

import random
from random import randint
def seed_everything(seed=None):
    '''
    固定seed
    :param seed: int, 随机种子
    '''
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    if (seed is None) or not (min_seed_value <= seed <= max_seed_value):
        seed = random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max)
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    return seed


seed_everything(42)

from utils import load_split_data, load_json

In [None]:
from torch.utils.data import Dataset
class InstructionDataSet(Dataset):
    def __init__(self, data, tokenizer, max_source_length, max_target_length):
        super(InstructionDataSet, self).__init__()
        #self.data = data.sample(len(data), random_state=0).reset_index(drop=True)
        self.data = data
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        # self.A_token = self.tokenizer.encode(text='A', add_special_tokens=False, truncation=True, )
        # self.B_token = self.tokenizer.encode(text='B', add_special_tokens=False, truncation=True, )
        # self.C_token = self.tokenizer.encode(text='C', add_special_tokens=False, truncation=True, )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        now_data = self.data.loc[index]
        idx = now_data['id']
        templete_part1 = "<start_of_turn>user\nHere are two question-answering dialogues. Compare two model performance on answering question, determine which is better.\n\n"
        templete_part1_input_ids = self.tokenizer(text=templete_part1, add_special_tokens=True, padding=False)['input_ids']
        
        templete_part2 = "\n###options\nA. Model A\nB. Model B\nC. Tie\n<end_of_turn>\n"
        templete_part2_input_ids = self.tokenizer(text=templete_part2, add_special_tokens=True, padding=False)['input_ids'][1:]
        #print(f"templete_part2 is {templete_part2_input_ids}")
        templete_part3 = "<start_of_turn>model\n"
        templete_part3_input_ids = self.tokenizer(text=templete_part3, add_special_tokens=True, padding=False)['input_ids'][1:]
        prompt_response = now_data['prompt_response']
        #print(f"id is {now_data['id']}")
        #print(prompt_response)
        prompt_response_ids = self.tokenizer(text=prompt_response, add_special_tokens=True, truncation=True,
                                          max_length=self.max_source_length, padding=False)['input_ids'][1:]
        
        input_ids = templete_part1_input_ids + prompt_response_ids + templete_part2_input_ids + templete_part3_input_ids
        input_text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
        #print(f"input is {self.tokenizer.decode(input_ids)}")
        return {
            "input_ids": input_text,
            "id": idx
        }

from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

def collate_fn(batch):
    batch = {k: [item[k] for item in batch] for k in ('input_ids','id')}
    #print(batch)
    batch_input = tokenizer(
        batch['input_ids'],
        padding='longest',
        truncation=True,
        return_tensors="pt",
        add_special_tokens=True,
        max_length=MAX_LENGTH + 50
    )
    return batch_input, batch['id']

In [None]:
from utils import load_split_data
data_path = "dataset/1M/35k_in_1M.json"
prompt_type = 3
MAX_INPUT = 1900
if_train = False
split = False
if_drop_duplicate = False
keep = 'last'
df_train , df_valid = load_split_data(data_path, prompt_type, MAX_INPUT, if_train, split, False, if_drop_duplicate, 'last')
test = df_train

In [None]:
tmp = pd.read_json(data_path)

In [None]:
len(tmp.loc[tmp.id == '4af73ffd64'].prompt)

In [None]:
test['length'] = test['prompt_response'].apply(len)

In [None]:
test = test.sort_values(by = ['length'], ascending = False).reset_index(drop = True)
test

In [None]:
from tqdm import tqdm
def inference(model, test_dataloader):
    test_predictions = []
    for batch in tqdm(test_dataloader):
        batch_input, idx = batch
        for k in batch_input.keys():
            batch_input[k] = batch_input[k].to(device)
        with torch.no_grad():
            response = model.generate(**batch_input, max_new_tokens=1, return_dict_in_generate=True, output_scores=True)
            #batch_input['input_ids'].shape[-1] + 1
            score = response.scores[0]
            A_prob, B_prob, C_prob = score[:,A_TOKEN_IDS], score[:,B_TOKEN_IDS], score[:,C_TOKEN_IDS]
            logits = torch.cat([A_prob, B_prob, C_prob], dim=-1)
            #logits = torch.Tensor([[A_prob,B_prob,C_prob]]) / 1.1
            logits = torch.softmax(logits, dim=-1).cpu().numpy()
            node_result = [[idx[i],logits[i]] for i in range(len(idx))]
        test_predictions.extend(node_result)
    return test_predictions

In [None]:
device = torch.device("cuda:0")
base_model = 'google/gemma-2-9b-it'
model_path = "output/morning-waterfall-460/checkpoint-5200_888"
MAX_LENGTH = 1900

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side = 'left')
config = AutoConfig.from_pretrained(base_model, trust_remote_code=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)
base_model_0 = AutoModelForCausalLM.from_pretrained(base_model,
                                                 config=config,
                                                 quantization_config=bnb_config,
                                                 torch_dtype=torch.float16,
                                                 device_map="auto",
                                                 trust_remote_code=True)
# base_model_0.config.pad_token_id = tokenizer.pad_token_id
# base_model_0.resize_token_embeddings(len(tokenizer))
new_model = model_path
model0 = PeftModel.from_pretrained(base_model_0, new_model).to(device)
#model0 = model0.merge_and_unload()
model0.eval()

In [None]:
A_TOKEN_IDS = tokenizer('A',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
B_TOKEN_IDS = tokenizer('B',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]
C_TOKEN_IDS = tokenizer('C',add_special_tokens=True, truncation=True, max_length=1024)['input_ids'][1:]


In [None]:
batch_size = 4
tokenized_dataset = InstructionDataSet(test, tokenizer, MAX_LENGTH, 1)

test_dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size = batch_size ,collate_fn=collate_fn)

In [None]:
sub_pred = inference(model = model0, test_dataloader = test_dataloader)

In [None]:
sub_pred[0][0]

In [None]:
if batch_size != 1:
    # 提取数据
    processed_data = []
    for item in sub_pred:
        #item = item[0]
        id = item[0]#.item()  # 获取id
        array_values = item[1].tolist()  # 获取array并转换为列表
        processed_data.append([id] + array_values)
    

else:
    # 提取数据
    processed_data = []

    
    for item in sub_pred:
        item = item[0]
        id = item[0].item()  # 获取id
        array_values = item[1].tolist()  # 获取array并转换为列表
        processed_data.append([id] + array_values)

new_columns = ['id', 'winner_model_a', 'winner_model_b', 'winner_tie']
df = pd.DataFrame(processed_data, columns=new_columns)
df = df.groupby('id').mean().reset_index()

prediction = np.array(df[new_columns[1:]])
test = test.drop_duplicates(subset = ['id']).reset_index(drop = True)
test = test.sort_values(by = ['id']).reset_index(drop = True)

In [None]:
data = pd.read_json(data_path)

In [None]:
final = data.merge(df, how = 'left', on = 'id')
assert len(final) == len(data) == len(df)
final

In [None]:
data_path

In [None]:
final.to_json("dataset/persudo_label/35k_in_1M_prediction.json", index=False)

# 合并

In [None]:
p = pd.read_csv("dataset/prediction.csv")
ex = pd.read_json("dataset/ex70k.json")

In [None]:
p = p.rename(columns = {'winner_model_a':"p_winner_model_a", 'winner_model_b':"p_winner_model_b",  'winner_tie':"p_winner_tie"})
final = pd.concat([ex, p], axis = 1)
final = final.drop(columns= ['id'])

In [None]:
def get_p_label(row):
    a = row.p_winner_model_a
    b = row.p_winner_model_b
    c = row.p_winner_tie

    l = [a ,b, c]
    label = l.index(max(l))
    return label

final['p_label'] = final.apply(get_p_label, axis = 1)

def get_label(row):
    label = [idx for idx, option in enumerate(['winner_model_a','winner_model_b','winner_tie']) if row[option] == 1]
    return label[-1]

final['label'] = final.apply(get_label, axis = 1)

In [None]:
threshold1 = 0.9
filter_same = final.loc[final.p_label == final.label,:].reset_index(drop = True)
filter_list = (filter_same.p_winner_model_a >= threshold1) | (filter_same.p_winner_model_b >= threshold1) | (filter_same.p_winner_tie >= threshold1)
filter_same = filter_same.loc[filter_list,:].reset_index(drop = True)

In [None]:
filter_list = (filter_same.difference >= 1) | (filter_same.winner_tie == 1)
filter_same = filter_same.loc[filter_list,:].reset_index(drop = True)

In [None]:
threshold2 = 0.6
filter_dif = final.loc[final.p_label != final.label,:].reset_index(drop = True)
filter_list = (filter_dif.p_winner_model_a >= threshold2) | (filter_dif.p_winner_model_b >= threshold2) | (filter_dif.p_winner_tie >= threshold2)
filter_dif = filter_dif.loc[filter_list,:].reset_index(drop = True)

In [None]:
filter_list = (filter_dif.difference >= 1) | (filter_dif.winner_tie == 1)
filter_dif = filter_dif.loc[filter_list,:].reset_index(drop = True)


In [None]:
filter_dif

In [None]:
filter_same

In [None]:
filter_dif['id'] = [randint(100,999999) + i for i in range(len(filter_dif))]
filter_same['id'] = [randint(100,999999) + i for i in range(len(filter_same))]

In [None]:
save_columns = ['prompt', 'model_a', 'model_b', 'winner_model_a', 'winner_model_b', 'winner_tie', 'response_a', 'response_b', 'id']
filter_dif[save_columns].to_json(f'dataset/70k_dif_thr{int(threshold2 * 100)}.json', index = False)
filter_same[save_columns].to_json(f'dataset/70k_same_thr{int(threshold1 * 100)}.json', index = False)

# 检查是否与valid重复

In [None]:
# 检查

valid = pd.read_json("dataset/non_overlap/valid.json")
filter_dif = pd.read_json("dataset/70k_dif_thr60.json")
filter_same = pd.read_json("dataset/70k_same_thr90.json")

In [None]:
def get_set_prompt_response(data):
    set_prompt_response = []
    for i in data.itertuples():
        prompt_response = i.prompt + i.response_a + i.response_b
        set_prompt_response.append(set(prompt_response))
    data['set_prompt_response'] = set_prompt_response  
    return data

In [None]:
valid = get_set_prompt_response(valid)
filter_dif = get_set_prompt_response(filter_dif)
filter_same = get_set_prompt_response(filter_same)

In [None]:
#valid和任何都不重合
assert len([idx for idx, i in enumerate(valid.set_prompt_response.values) if i in filter_dif.set_prompt_response.values]) == 0
assert len([idx for idx, i in enumerate(valid.set_prompt_response.values) if i in filter_same.set_prompt_response.values]) == 0

In [None]:
len([idx for idx, i in enumerate(valid.set_prompt_response.values) if i in filter_dif.set_prompt_response.values])