In [1]:
import pandas as pd
import json
import random
from tqdm import tqdm

In [6]:
OASST_PATH_TRAIN = "dataset/OASST1/train.parquet"
OASST_PATH_VALID = "dataset/OASST1/valid.parquet"
SHAREGPT_PATH = "dataset/ShareGPT/train.parquet"
OUTPUT_JSONL = "mixed_dataset.jsonl"

In [None]:
MAX_TOKENS = 768  # approx limit before truncation
OASST_RATIO = 0.80
SHAREGPT_RATIO = 0.20

In [7]:
def estimate_tokens(text):
    return len(text.split())

In [8]:
print("Loading OASST1...")
oasst_train = pd.read_parquet(OASST_PATH_TRAIN)
oasst_valid = pd.read_parquet(OASST_PATH_VALID)
oasst_df = pd.concat([oasst_train, oasst_valid], ignore_index=True)

print("Loading ShareGPT-clean...")
share_df = pd.read_parquet(SHAREGPT_PATH)

Loading OASST1...
Loading ShareGPT-clean...


In [23]:
print(share_df.columns)
print(share_df.iloc[0])

Index(['turns', 'source', 'good_natured', 'conversations'], dtype='object')
turns                                                            1
source                                                    sharegpt
good_natured                                                  True
conversations    [{'from': 'human', 'value': '0.002 = 1000 
1 =...
Name: 0, dtype: object


In [24]:
print("\n=== TYPE OF conversations ===")
print(type(share_df.iloc[0]["conversations"]))

print("\n=== FIRST conversations ENTRY ===")
print(share_df.iloc[0]["conversations"])


=== TYPE OF conversations ===
<class 'numpy.ndarray'>

=== FIRST conversations ENTRY ===
[{'from': 'human', 'value': '0.002 = 1000 \n1 = x?'}
 {'from': 'gpt', 'value': 'To find the value of x, we can set up a proportion using the given information:\n\n0.002/1000 = 1/x\n\nTo solve for x, we can cross multiply:\n\n0.002 * x = 1000 * 1\n\n0.002x = 1000\n\nDividing both sides by 0.002:\n\nx = 1000 / 0.002\n\nx = 500,000\n\nTherefore, 1 is equal to 500,000 in this proportion.'}]


In [17]:
def extract_oasst_pairs(df):
    # Build a lookup table message_id -> row
    msg_map = {row.message_id: row for _, row in df.iterrows()}

    pairs = []

    for _, row in df.iterrows():
        # Only assistant messages can form a pair
        if row.role != "assistant":
            continue
        
        parent_id = row.parent_id
        if pd.isna(parent_id):
            continue
        if parent_id not in msg_map:
            continue
        
        parent = msg_map[parent_id]

        # Parent must be a user
        if parent.role != "prompter":
            continue

        user_text = str(parent.text).strip()
        assistant_text = str(row.text).strip()

        if len(user_text) == 0 or len(assistant_text) == 0:
            continue

        sample = f"<user>: {user_text}\n<assistant>: {assistant_text}"
        pairs.append(sample)

    return pairs


In [18]:
print("Extracting OASST pairs...")
oasst_samples = extract_oasst_pairs(oasst_df)
print(f"OASST extracted pairs: {len(oasst_samples)}")

Extracting OASST pairs...
OASST extracted pairs: 55668


In [25]:
def extract_sharegpt_pairs(df):
    pairs = []

    for _, row in df.iterrows():
        conv = row["conversations"]

        # Accept numpy arrays OR lists
        if not hasattr(conv, "__len__"):
            continue

        # Convert numpy array -> Python list
        conv = list(conv)

        # Walk through consecutive pairs
        for i in range(len(conv) - 1):
            cur = conv[i]
            nxt = conv[i + 1]

            # keys might be "from" and "value"
            role_cur = cur.get("from")
            role_nxt = nxt.get("from")

            if role_cur == "human" and role_nxt == "gpt":
                user_text = str(cur.get("value", "")).strip()
                assistant_text = str(nxt.get("value", "")).strip()

                if user_text and assistant_text:
                    sample = f"<user>: {user_text}\n<assistant>: {assistant_text}"
                    pairs.append(sample)

    return pairs


In [26]:
print("Extracting ShareGPT pairs...")
sharegpt_samples = extract_sharegpt_pairs(share_df)
print(f"ShareGPT extracted pairs: {len(sharegpt_samples)}")

Extracting ShareGPT pairs...
ShareGPT extracted pairs: 282130


In [27]:
def trim_samples(samples, max_tokens):
    filtered = []
    for s in samples:
        if estimate_tokens(s) <= max_tokens:
            filtered.append(s)
    return filtered

In [28]:

print("Filtering lengths...")
oasst_samples = trim_samples(oasst_samples, MAX_TOKENS)
sharegpt_samples = trim_samples(sharegpt_samples, MAX_TOKENS)

Filtering lengths...


In [29]:
def mix_datasets(a_list, b_list, ratio_a, ratio_b):
    total = min(len(a_list), len(b_list))
    target_a = int(total * ratio_a)
    target_b = int(total * ratio_b)

    random.shuffle(a_list)
    random.shuffle(b_list)

    mixed = a_list[:target_a] + b_list[:target_b]
    random.shuffle(mixed)
    return mixed


In [30]:
print("Mixing datasets...")
mixed_samples = mix_datasets(oasst_samples, sharegpt_samples, OASST_RATIO, SHAREGPT_RATIO)
print(f"Final mixed dataset size: {len(mixed_samples)}")


Mixing datasets...
Final mixed dataset size: 55474


In [31]:
print(f"Saving to {OUTPUT_JSONL}...")

with open(OUTPUT_JSONL, "w", encoding="utf-8") as f:
    for sample in mixed_samples:
        json.dump({"text": sample}, f, ensure_ascii=False)
        f.write("\n")

print("Done. Dataset ready.")

Saving to mixed_dataset.jsonl...
Done. Dataset ready.
