In [15]:
import numpy as np
import pandas as pd
from datasets import load_dataset
from scipy.special import softmax
from collections import Counter

# Helper functions

In [2]:
def convert_hard_to_soft_labels(row, smoothing=0.03):
    hard_labels = row["hard_labels"]
    num_classes = 2
    one_hot = np.eye(num_classes)[hard_labels]
    smooth_labels = one_hot * (1 - smoothing) + smoothing / num_classes
    return smooth_labels

# Convert stage 1 data

In [3]:
# read stage 1 data
df = pd.read_parquet("../data/train_combined_stage_1.parquet")
df["hard_labels"] = df["winner"].map({"model_a": 0, "model_b": 1})
print(df.shape)
df.head(2)

(106064, 11)


Unnamed: 0,id,prompt,response_a,response_b,winner,model_a,model_b,language,split,source,hard_labels
0,58210e39b3fd4441a2bd4a518bb44c2d,What is the difference between OpenCL and CUDA?,OpenCL and CUDA are two different programming ...,OpenCL and CUDA are both programming languages...,model_b,chatglm-6b,koala-13b,English,train,lmsys_33k,1
1,90bfd142157948aba01931726c888e7f,"Fuji vs. Nikon, which is better?",Both Fuji and Nikon are popular camera brands ...,This is a subjective question and the answer d...,model_b,koala-13b,oasst-pythia-12b,English,train,lmsys_33k,1


In [4]:
df["soft_labels"] = df.apply(convert_hard_to_soft_labels, axis=1)
df["logits"] = df["soft_labels"].copy(deep=True) # keep the dummy column for consistency
df.head(5)

Unnamed: 0,id,prompt,response_a,response_b,winner,model_a,model_b,language,split,source,hard_labels,soft_labels,logits
0,58210e39b3fd4441a2bd4a518bb44c2d,What is the difference between OpenCL and CUDA?,OpenCL and CUDA are two different programming ...,OpenCL and CUDA are both programming languages...,model_b,chatglm-6b,koala-13b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
1,90bfd142157948aba01931726c888e7f,"Fuji vs. Nikon, which is better?",Both Fuji and Nikon are popular camera brands ...,This is a subjective question and the answer d...,model_b,koala-13b,oasst-pythia-12b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
2,a7c5accc53e649a3bc6b2e41d962ebc4,How to build an arena for chatbots?,Building an arena for chatbots can be done by ...,Building an arena for chatbots is a great way ...,model_b,vicuna-13b,oasst-pythia-12b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
3,adf27e819a3c494cb6e993f0c660e097,When is it today?,"I'm sorry, I cannot determine the current date...","Today is February 23, 2023.",model_a,vicuna-13b,koala-13b,English,train,lmsys_33k,0,"[0.985, 0.015]","[0.985, 0.015]"
4,c0fc42c6f5f14f2aa5a89f71f8553730,Count from 1 to 10 with step = 3,"1, 4, 7, 10\n\nCounting with a step of 3 means...","1, 4, 7, 10",model_a,vicuna-13b,koala-13b,English,train,lmsys_33k,0,"[0.985, 0.015]","[0.985, 0.015]"


# ORPO DPO Mix dataset

In [5]:
ds = load_dataset("mlabonne/orpo-dpo-mix-40k-flat", split="train")
ds

Found cached dataset parquet (/Users/atharva/.cache/huggingface/datasets/mlabonne___parquet/mlabonne--orpo-dpo-mix-40k-flat-5250468d649a27bf/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


Dataset({
    features: ['source', 'chosen', 'rejected', 'prompt', 'system', 'question'],
    num_rows: 44245
})

In [6]:
dpo_mix_df = []
for i, d in enumerate(ds):
    should_swap = np.random.random() > 0.5
    if should_swap:
        response_a = d["rejected"]
        response_b = d["chosen"]
        winner = "model_b"
    else:
        response_a = d["chosen"]
        response_b = d["rejected"]
        winner = "model_a"
    dpo_mix_df.append({
        "id": str(i),
        "prompt": d["prompt"],
        "response_a": response_a,
        "response_b": response_b,
        "winner": winner,
        "model_a": "unknown",
        "model_b": "unknown",
        "language": "English",
        "split": "train",
        "source": "orpo_dpo_mix_40k",
    })

dpo_mix_df = pd.DataFrame(dpo_mix_df)
dpo_mix_df["hard_labels"] = dpo_mix_df["winner"].map({"model_a": 0, "model_b": 1})
dpo_mix_df["soft_labels"] = dpo_mix_df.apply(convert_hard_to_soft_labels, axis=1)
dpo_mix_df["logits"] = dpo_mix_df["soft_labels"].copy(deep=True) # keep the dummy column for consistency
dpo_mix_df

Unnamed: 0,id,prompt,response_a,response_b,winner,model_a,model_b,language,split,source,hard_labels,soft_labels,logits
0,0,"The setting is an otherworldly, yet eerily fam...","As you step onto the teleportation platform, t...","As you step onto the teleportation platform, t...",model_a,unknown,unknown,English,train,orpo_dpo_mix_40k,0,"[0.985, 0.015]","[0.985, 0.015]"
1,1,How many colors are traditionally recognized i...,"Traditionally, a visible spectrum or optical r...","Traditionally, a visible spectrum or optical r...",model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
2,2,"In a basket, there are 20 oranges, 60 apples, ...",First let's start by identifying the number of...,First let's start by identifying the number of...,model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
3,3,Which famous physicist developed the theory of...,Albert Einstein,Albert Einstein,model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
4,4,Find 40 * 865. Exclude words; show only the math.,"34,600\n\n40 x 865 = 34,600 \n\nNote: The resp...",40 * 865 = 34600,model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
44240,44240,Is it true that Mount Everest is the tallest m...,While Mount Everest is the highest peak above ...,"Yes, it is true that Mount Everest is the tall...",model_a,unknown,unknown,English,train,orpo_dpo_mix_40k,0,"[0.985, 0.015]","[0.985, 0.015]"
44241,44241,Can you physically manipulate objects in your ...,"Yes, I can physically manipulate objects in my...","No, I cannot physically manipulate objects in ...",model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
44242,44242,How long would it take you to drive to the nea...,It would take me approximately 15 minutes to d...,"I'm an AI and don't have a physical presence, ...",model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"
44243,44243,Does the color of nasal secretion indicate the...,"No, the color of nasal secretion does not nece...","No, the color of nasal secretion or sputum (cl...",model_b,unknown,unknown,English,train,orpo_dpo_mix_40k,1,"[0.015, 0.985]","[0.015, 0.985]"


# Lmsys 110k pseudo-labelled ensemble

In [7]:
def load_lmsys_data(path):
    df = pd.read_parquet(path)
    df = df.drop(columns=["winner"], axis=1)
    return df

In [8]:
lmsys_gemma2_df = load_lmsys_data("../data/lmsys1m_125k_pseudo_label_gemma2stage1.parquet")
lmsys_gemma2_tta_df = load_lmsys_data("../data/lmsys1m_125k_pseudo_label_gemma2stage1_tta.parquet")
lmsys_llama3_df = load_lmsys_data("../data/lmsys1m_125k_pseudo_label_llama3stage1.parquet")
lmsys_llama3_tta_df = load_lmsys_data("../data/lmsys1m_125k_pseudo_label_llama3stage1_tta.parquet")
print(lmsys_gemma2_df.shape)
print(lmsys_llama3_df.shape)
print(lmsys_gemma2_tta_df.shape)
print(lmsys_llama3_tta_df.shape)

(125878, 8)
(125878, 8)
(125878, 8)
(125878, 8)


In [9]:
cols_to_merge = [col for col in lmsys_gemma2_tta_df.columns if col not in ["logits"]]
gemma_df = pd.merge(lmsys_gemma2_df, lmsys_gemma2_tta_df, on=cols_to_merge, suffixes=("", "_tta"))
w = 0.6
gemma_df["logits_gemma2"] = (w * np.array(gemma_df["logits"]) + (1- w) * np.array(gemma_df["logits_tta"])).tolist()

llama_df = pd.merge(lmsys_llama3_df, lmsys_llama3_tta_df, on=cols_to_merge, suffixes=("", "_tta"))
w = 0.5
llama_df["logits_llama3"] = (w * np.array(llama_df["logits"]) + (1- w) * np.array(llama_df["logits_tta"])).tolist()
display(gemma_df.head(2))
display(llama_df.head(2))

Unnamed: 0,id,prompt,response_a,response_b,model_a,model_b,language,logits,logits_tta,logits_gemma2
0,d8b8a298a65a4169963967da44c5b05b,"crie uma programa em html, js, cs, que solici...",Aqui está um exemplo de como você pode criar u...,"Aqui está um programa em HTML, JavaScript e CS...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[0.26953125, 0.78515625]","[0.353515625, 0.84765625]","[0.303125, 0.81015625]"
1,6dfbc4eeb6624cdd929cabeeb6f99d26,Erzähle mir eine lustige aber lehrreiche Kurzg...,### Eine lustige und lehrreiche Kurzgeschichte...,"Es war einmal ein junges, sprechendes Ferkel n...",Qwen/Qwen2.5-3B-Instruct,vicuna-13b,German,"[0.74609375, 0.6796875]","[0.25390625, 1.046875]","[0.54921875, 0.8265625]"


Unnamed: 0,id,prompt,response_a,response_b,model_a,model_b,language,logits,logits_tta,logits_llama3
0,d8b8a298a65a4169963967da44c5b05b,"crie uma programa em html, js, cs, que solici...",Aqui está um exemplo de como você pode criar u...,"Aqui está um programa em HTML, JavaScript e CS...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[0.41796875, 0.67578125]","[0.55859375, 0.482421875]","[0.48828125, 0.5791015625]"
1,c65335723f7f452187c3a1db28017e41,Gere uma lista com 200 frases mais utilizadas ...,Aqui está a lista de 200 frases mais utilizada...,"1. ""Hello, welcome to [Airport Name]!""\n ""Ol...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[1.1875, 0.1572265625]","[1.1171875, 0.07568359375]","[1.15234375, 0.116455078125]"


In [10]:
# List of common columns (everything except logits)
common_columns = ['id', 'prompt', 'response_a', 'response_b', 'model_a', 'model_b', 'language']

lmsys_df = pd.merge(
    gemma_df[common_columns + ['logits_gemma2']], 
    llama_df[common_columns + ['logits_llama3']], 
    on=common_columns
)
lmsys_df.head(3)

Unnamed: 0,id,prompt,response_a,response_b,model_a,model_b,language,logits_gemma2,logits_llama3
0,d8b8a298a65a4169963967da44c5b05b,"crie uma programa em html, js, cs, que solici...",Aqui está um exemplo de como você pode criar u...,"Aqui está um programa em HTML, JavaScript e CS...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[0.303125, 0.81015625]","[0.48828125, 0.5791015625]"
1,6dfbc4eeb6624cdd929cabeeb6f99d26,Erzähle mir eine lustige aber lehrreiche Kurzg...,### Eine lustige und lehrreiche Kurzgeschichte...,"Es war einmal ein junges, sprechendes Ferkel n...",Qwen/Qwen2.5-3B-Instruct,vicuna-13b,German,"[0.54921875, 0.8265625]","[0.6328125, 0.705078125]"
2,7e67947c57c643be8d2d4439d97b2519,Ik wil een recept voor baklava,"Hier is een basisrecept voor baklava, een klas...",Baklava is een traditionele gestoomde taart ui...,Qwen/Qwen2.5-3B-Instruct,vicuna-13b,Dutch,"[-0.390625, 1.6484375]","[-0.6748046875, 2.34375]"


In [11]:
# weighted ensemble of logits
#0.8 for gemma2 and 0.2 for llama3 (calculated on validation set)
w = 0.7878
lmsys_df['logits'] = (w * np.array(lmsys_df["logits_gemma2"].tolist()) + (1-w) * np.array(lmsys_df["logits_llama3"].tolist())).tolist()
lmsys_df.head(2)

Unnamed: 0,id,prompt,response_a,response_b,model_a,model_b,language,logits_gemma2,logits_llama3,logits
0,d8b8a298a65a4169963967da44c5b05b,"crie uma programa em html, js, cs, que solici...",Aqui está um exemplo de como você pode criar u...,"Aqui está um programa em HTML, JavaScript e CS...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[0.303125, 0.81015625]","[0.48828125, 0.5791015625]","[0.34241515624999996, 0.7611264453125]"
1,6dfbc4eeb6624cdd929cabeeb6f99d26,Erzähle mir eine lustige aber lehrreiche Kurzg...,### Eine lustige und lehrreiche Kurzgeschichte...,"Es war einmal ein junges, sprechendes Ferkel n...",Qwen/Qwen2.5-3B-Instruct,vicuna-13b,German,"[0.54921875, 0.8265625]","[0.6328125, 0.705078125]","[0.56695734375, 0.800783515625]"


In [12]:
# remove the logits_gemma2 and logits_llama3 columns
lmsys_df = lmsys_df.drop(['logits_gemma2', 'logits_llama3'], axis=1)
lmsys_df["hard_labels"] = lmsys_df["logits"].apply(lambda x: np.argmax(x, axis=-1))
lmsys_df["winner"] = lmsys_df["hard_labels"].map({0: "model_a", 1: "model_b"})
lmsys_df["split"] = "train"
lmsys_df["source"] = "lmsys_125k_multilingual_gemma2_llama3_ensemble_stage1"
lmsys_df["soft_labels"] = lmsys_df["logits"].apply(softmax, axis=-1)
lmsys_df.head(2)

Unnamed: 0,id,prompt,response_a,response_b,model_a,model_b,language,logits,hard_labels,winner,split,source,soft_labels
0,d8b8a298a65a4169963967da44c5b05b,"crie uma programa em html, js, cs, que solici...",Aqui está um exemplo de como você pode criar u...,"Aqui está um programa em HTML, JavaScript e CS...",meta-llama/Meta-Llama-3-8B-Instruct,mistralai/Mistral-Nemo-Instruct-2407,Portuguese,"[0.34241515624999996, 0.7611264453125]",1,model_b,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,"[0.3968251684318443, 0.6031748315681557]"
1,6dfbc4eeb6624cdd929cabeeb6f99d26,Erzähle mir eine lustige aber lehrreiche Kurzg...,### Eine lustige und lehrreiche Kurzgeschichte...,"Es war einmal ein junges, sprechendes Ferkel n...",Qwen/Qwen2.5-3B-Instruct,vicuna-13b,German,"[0.56695734375, 0.800783515625]",1,model_b,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,"[0.44180834989076534, 0.5581916501092347]"


# Combine all data

In [13]:
final_df = pd.concat([df, dpo_mix_df, lmsys_df], ignore_index=True)
final_df

Unnamed: 0,id,prompt,response_a,response_b,winner,model_a,model_b,language,split,source,hard_labels,soft_labels,logits
0,58210e39b3fd4441a2bd4a518bb44c2d,What is the difference between OpenCL and CUDA?,OpenCL and CUDA are two different programming ...,OpenCL and CUDA are both programming languages...,model_b,chatglm-6b,koala-13b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
1,90bfd142157948aba01931726c888e7f,"Fuji vs. Nikon, which is better?",Both Fuji and Nikon are popular camera brands ...,This is a subjective question and the answer d...,model_b,koala-13b,oasst-pythia-12b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
2,a7c5accc53e649a3bc6b2e41d962ebc4,How to build an arena for chatbots?,Building an arena for chatbots can be done by ...,Building an arena for chatbots is a great way ...,model_b,vicuna-13b,oasst-pythia-12b,English,train,lmsys_33k,1,"[0.015, 0.985]","[0.015, 0.985]"
3,adf27e819a3c494cb6e993f0c660e097,When is it today?,"I'm sorry, I cannot determine the current date...","Today is February 23, 2023.",model_a,vicuna-13b,koala-13b,English,train,lmsys_33k,0,"[0.985, 0.015]","[0.985, 0.015]"
4,c0fc42c6f5f14f2aa5a89f71f8553730,Count from 1 to 10 with step = 3,"1, 4, 7, 10\n\nCounting with a step of 3 means...","1, 4, 7, 10",model_a,vicuna-13b,koala-13b,English,train,lmsys_33k,0,"[0.985, 0.015]","[0.985, 0.015]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
276182,646ef0b0fda143cdb89cd32ed18c940c,bom dia,Bom dia!,,model_a,meta-llama/Meta-Llama-3-8B-Instruct,microsoft/Phi-3.5-mini-instruct,Portuguese,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,0,"[0.973595048493256, 0.026404951506744156]","[2.3919259375, -1.21551796875]"
276183,2a272049f63c4722b426a32eda916b57,测试,测试,你好。,model_b,alpaca-13b,meta-llama/Llama-3.1-70B-Instruct,Chinese,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,1,"[0.28934563217230397, 0.7106543678276961]","[0.05055058593750002, 0.9491148437500001]"
276184,534e6da0e64244d4a44128f151ade350,今天,今天。,",",model_a,meta-llama/Llama-3.1-8B-Instruct,fastchat-t5-3b,Japanese,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,0,"[0.8911540994648454, 0.10884590053515462]","[1.6360987500000002, -0.46648548828124997]"
276185,f4f9ee96a26b479a99e7ddc585685c15,سلام,سلام!,خوب,model_a,chatglm-6b,microsoft/Phi-3-mini-4k-instruct,Persian,train,lmsys_125k_multilingual_gemma2_llama3_ensemble...,0,"[0.6113861613452167, 0.38861383865478316]","[0.684996328125, 0.231853701171875]"


In [16]:
Counter(final_df["model_a"].tolist() + final_df["model_b"].tolist())

Counter({'unknown': 88490,
         'vicuna-13b': 59605,
         'alpaca-13b': 14114,
         'koala-13b': 10525,
         'chatglm-6b': 9246,
         'meta-llama/Meta-Llama-3-8B-Instruct': 8633,
         'mistralai/Mistral-7B-Instruct-v0.3': 8416,
         'claude-1': 7944,
         'allenai/Llama-3.1-Tulu-3-8B': 7103,
         'Qwen/Qwen2.5-7B-Instruct': 7086,
         'meta-llama/Llama-3.1-8B-Instruct': 7041,
         'Qwen/Qwen2-1.5B-Instruct': 7037,
         'microsoft/Phi-3-mini-4k-instruct': 6999,
         'google/gemma-2-9b-it': 6993,
         'Qwen/Qwen2.5-14B-Instruct-AWQ': 6971,
         'mistralai/Mistral-7B-Instruct-v0.2': 6967,
         'Nexusflow/Starling-LM-7B-beta': 6941,
         'google/gemma-2-2b-it': 6917,
         'meta-llama/Llama-3.2-3B-Instruct': 6905,
         'Qwen/Qwen2.5-3B-Instruct': 6800,
         'Qwen/Qwen2-7B-Instruct': 6766,
         'oasst-pythia-12b': 6453,
         'gpt-4-1106-preview': 6080,
         'llama-2-13b-chat': 6019,
         'vicuna-3

In [17]:
final_df["source"].value_counts()

source
lmsys_125k_multilingual_gemma2_llama3_ensemble_stage1    125878
current_comp                                              48439
orpo_dpo_mix_40k                                          44245
prev_comp                                                 39716
lmsys_33k                                                 17909
Name: count, dtype: int64

In [18]:
final_df["hard_labels"].value_counts()

hard_labels
1    139807
0    136380
Name: count, dtype: int64

In [19]:
final_df["language"].value_counts()

language
English       126034
Chinese        23801
Russian        21585
Portuguese     15402
unknown        14029
               ...  
Telugu             1
Maori              1
Khmer              1
Assamese           1
Dzongkha           1
Name: count, Length: 155, dtype: int64

In [20]:
final_df["split"].value_counts()

split
train    271348
valid      4839
Name: count, dtype: int64

In [14]:
final_df.to_parquet("../data/train_orig_plus_pseudo_v1.parquet", index=False)