In [2]:
import utils
import polars as pl

comments_df = utils.augmented_comments()

In [3]:
df = utils.full_dataset()

stories_with_100_points = df.filter(pl.col("score") >= 100)

# filter comments_df to only include comments whose top_level_parent is in the stories_with_100_points dataframe
comments_df = comments_df.filter(
    pl.col("top_level_parent").is_in(stories_with_100_points["id"])
)

comments_df

id,by,time,text,parent,top_level_parent,kids,siblings_count,sibling_rank,nested_level
i64,str,datetime[μs],str,i64,i64,list[i64],u32,i64,i32
194,"""jmzachary""",2007-02-20 22:33:51,"""Thanks for the rationale. I'm …",189,189,"[205, 422, 199]",20,13,0
195,"""jdroid""",2007-02-20 22:36:52,"""You've filled a hole reddit wa…",189,189,[259],20,3,0
199,"""Zak""",2007-02-20 22:48:33,"""I don't think the fact that th…",194,189,"[1644, 1897]",3,3,1
205,"""ninwa""",2007-02-20 23:30:23,"""Really? I was most interested …",194,189,"[210, 209]",3,1,1
209,"""ninwa""",2007-02-20 23:41:34,"""This comment added through the…",205,189,,2,2,2
…,…,…,…,…,…,…,…,…,…
41813369,"""whall6""",2024-10-11 20:25:08,"""If you truly only have 5 minut…",41812596,41811263,,1,1,3
41813371,"""julianeon""",2024-10-11 20:25:25,"""But consider the tradeoff: it'…",41811539,41811263,,13,1,1
41813375,"""throw0101c""",2024-10-11 20:25:51,"""> <i>There is so much coal... …",41811328,41807681,,1,1,3
41813381,"""marcosdumay""",2024-10-11 20:26:21,"""Add forced sedentarism into th…",41812891,41811263,,1,1,2


In [4]:
# Let's limit to comments for which all of their parents have a sibling rank <=
# 3. If a comment isn't in the top 3 of its parent it's less likely to have been
# read and upvoted by enough people to form a representative sample.

# We'll iteratively add comments with sibling_rank <= 3 up to 6 iterations.

filtered_comments_df = comments_df.filter(
    (pl.col("sibling_rank") <= 3) & (pl.col("parent").eq(pl.col("top_level_parent")))
)

for iteration in range(6):
    print(f"Iteration {iteration + 1}, current length: {len(filtered_comments_df)}")
    new_filtered_comments_df = comments_df.filter(
        (
            (pl.col("sibling_rank") <= 3)
            & (pl.col("parent").is_in(filtered_comments_df["id"]))
        )
        | (pl.col("id").is_in(filtered_comments_df["id"]))
    )

    if len(new_filtered_comments_df) == len(filtered_comments_df):
        break

    filtered_comments_df = new_filtered_comments_df

comments_df = filtered_comments_df

del filtered_comments_df
comments_df

Iteration 1, current length: 533314
Iteration 2, current length: 1614377
Iteration 3, current length: 2828828
Iteration 4, current length: 3874258
Iteration 5, current length: 4636895
Iteration 6, current length: 5139157


id,by,time,text,parent,top_level_parent,kids,siblings_count,sibling_rank,nested_level
i64,str,datetime[μs],str,i64,i64,list[i64],u32,i64,i32
195,"""jdroid""",2007-02-20 22:36:52,"""You've filled a hole reddit wa…",189,189,[259],20,3,0
259,"""whatsreal""",2007-02-21 06:18:50,"""HaHa! Yes thank you Paul, I w…",195,189,"[453, 7198]",1,1,1
287,"""ced""",2007-02-21 09:08:16,"""Since community-building is pa…",189,189,[365],20,2,0
353,"""jhenzie""",2007-02-21 18:47:34,"""Paul, can we play with arc…",189,189,[7199],20,1,0
365,"""jdroid""",2007-02-21 19:25:14,"""If I bought yspace.com would y…",287,189,,1,1,1
…,…,…,…,…,…,…,…,…,…
41813352,"""dullcrisp""",2024-10-11 20:22:59,"""A <i>scrappy</i> marketing tea…",41813266,41811263,,1,1,3
41813354,"""ragnese""",2024-10-11 20:23:06,"""It was only true that Chrome w…",41813112,41809698,,1,1,4
41813362,"""layer8""",2024-10-11 20:23:58,"""It’s not simply a limitation i…",41812523,41812523,,31,3,0
41813363,"""nightski""",2024-10-11 20:24:07,"""Nothing ""needs"" to happen. Peo…",41812696,41811263,,1,1,5


In [5]:
# limit all comments to ones that have at least 5 siblings and fewer than 25
# siblings. If there are more than 25 I don't trust the peoples' votes since
# they probably haven't had time to read all of them.
comments_df = comments_df.filter(
    (pl.col("siblings_count") >= 5) & (pl.col("siblings_count") < 25)
)

comments_df

id,by,time,text,parent,top_level_parent,kids,siblings_count,sibling_rank,nested_level
i64,str,datetime[μs],str,i64,i64,list[i64],u32,i64,i32
195,"""jdroid""",2007-02-20 22:36:52,"""You've filled a hole reddit wa…",189,189,[259],20,3,0
287,"""ced""",2007-02-21 09:08:16,"""Since community-building is pa…",189,189,[365],20,2,0
353,"""jhenzie""",2007-02-21 18:47:34,"""Paul, can we play with arc…",189,189,[7199],20,1,0
27552,"""pg""",2007-06-12 19:02:46,"""Don't worry, I'm not going to …",27550,27550,"[28315, 27559, … 28523]",7,1,0
27560,"""danielha""",2007-06-12 19:15:56,"""Prediction: this submission re…",27550,27550,[28098],7,2,0
…,…,…,…,…,…,…,…,…,…
41812977,"""shadowgovt""",2024-10-11 19:50:42,"""Their incentive is really to m…",41810118,41809698,,7,-1,2
41813078,"""itronitron""",2024-10-11 19:59:06,"""This is the type of comment th…",41812493,41811263,"[41813116, 41813266]",9,3,1
41813079,"""ragnese""",2024-10-11 19:59:16,"""The vast, <i>vast</i>, majorit…",41809962,41809698,,12,-1,1
41813239,"""glenstein""",2024-10-11 20:13:25,"""I think this is such a helpful…",41812493,41811263,,9,2,1


In [6]:
# filter comments that have non-null text and text length > 10
comments_df = comments_df.filter(
    (pl.col("text").is_not_null())
    & (pl.col("text").str.len_chars() > 10)
    & (pl.col("sibling_rank") > 0)
)

comments_df

id,by,time,text,parent,top_level_parent,kids,siblings_count,sibling_rank,nested_level
i64,str,datetime[μs],str,i64,i64,list[i64],u32,i64,i32
195,"""jdroid""",2007-02-20 22:36:52,"""You've filled a hole reddit wa…",189,189,[259],20,3,0
287,"""ced""",2007-02-21 09:08:16,"""Since community-building is pa…",189,189,[365],20,2,0
353,"""jhenzie""",2007-02-21 18:47:34,"""Paul, can we play with arc…",189,189,[7199],20,1,0
27552,"""pg""",2007-06-12 19:02:46,"""Don't worry, I'm not going to …",27550,27550,"[28315, 27559, … 28523]",7,1,0
27560,"""danielha""",2007-06-12 19:15:56,"""Prediction: this submission re…",27550,27550,[28098],7,2,0
…,…,…,…,…,…,…,…,…,…
41812565,"""vessenes""",2024-10-11 19:14:29,"""Tirzepatide and Semaglutide ar…",41812493,41811263,[41812815],9,1,1
41812643,"""squidlogic""",2024-10-11 19:23:13,"""In my experience fitness is le…",41812339,41811263,"[41812886, 41812696, … 41813092]",6,1,3
41812717,"""jrflowers""",2024-10-11 19:29:53,"""> People in the 50s weren't sl…",41812339,41811263,"[41813315, 41812727]",6,3,3
41813078,"""itronitron""",2024-10-11 19:59:06,"""This is the type of comment th…",41812493,41811263,"[41813116, 41813266]",9,3,1


In [7]:
from tqdm import tqdm
from typing import List, TypedDict, Dict
from collections import defaultdict
import datetime


progress_bar = tqdm(total=len(comments_df), desc="Processing groups")


class GroupableComment(TypedDict):
    id: int
    parent: int
    time: datetime.datetime
    sibling_rank: int
    top_level_parent: int


class RewardPair(TypedDict):
    chosen: int  # the id of the chosen comment
    rejected: int  # the id of the rejected comment
    chosen_rank: int
    rejected_rank: int
    top_level_parent: int


def process_group(group: List[GroupableComment]) -> list[RewardPair]:
    """
    The goal of this function is to create a reward dataset for Hacker News
    comments. It takes in a dataframe containing all the comments within a given
    hour, and then breaks them down by the parent id.

    Within each parent id group, it orders the comments by sibling rank (that
    is, their relative position in the thread). Lower sibling rank means the
    comment is higher in the thread, which we can assume means it had more
    upvotes and is better.

    We take each comment and its directly following sibling and mark those as
    "chosen" and "rejected", creating a pairwise dataset we can use for reward
    modeling.
    """
    if len(group) == 0:
        return []

    sorted_group = sorted(group, key=lambda x: x["sibling_rank"])

    top_comment = sorted_group[0]
    other_comments = sorted_group[1:]

    group_reward_pairs = []

    for other_comment in other_comments:
        # Make sure the comments were posted within 30 minutes of each other. If
        # they were too far apart, they likely aren't comparable since the older
        # one is likely to win.
        if (other_comment["time"] - top_comment["time"]).total_seconds() > 1800:
            break

        group_reward_pairs.append(
            {
                "chosen": top_comment["id"],
                "rejected": other_comment["id"],
                "chosen_rank": top_comment["sibling_rank"],
                "rejected_rank": other_comment["sibling_rank"],
                "top_level_parent": top_comment["top_level_parent"],
            }
        )

    return group_reward_pairs


reward_pairs: list[RewardPair] = []
parent_group: list[GroupableComment] = []
current_parent = None

sorted_by_parent = comments_df.select(
    pl.col("id"),
    pl.col("sibling_rank"),
    pl.col("parent"),
    pl.col("time"),
    pl.col("top_level_parent"),
).sort("parent")

for comment in sorted_by_parent.iter_rows(named=True):
    comment = GroupableComment(**comment)

    # We already sorted by parent, so if the current comment's parent isn't the
    # same as the previous one, we know we've moved on to a new parent and can
    # process the previous set of comments.
    if comment["parent"] != current_parent:
        reward_pairs.extend(process_group(parent_group))

        parent_group = []
        current_parent = comment["parent"]

    parent_group.append(comment)
    progress_bar.update(1)

# Close the progress bar
progress_bar.close()

pairs_df = pl.DataFrame(reward_pairs)

pairs_df

Processing groups:   0%|          | 1044/1104093 [00:00<01:45, 10439.96it/s]

Processing groups: 100%|██████████| 1104093/1104093 [00:01<00:00, 563691.17it/s]


chosen,rejected,chosen_rank,rejected_rank,top_level_parent
i64,i64,i64,i64,i64
353,287,1,2,189
353,195,1,3,189
27552,27560,1,2,27550
77373,77337,1,2,77246
77666,77438,1,2,77246
…,…,…,…,…
41812258,41812168,1,3,41811263
41812339,41812380,1,2,41811263
41812339,41812405,1,3,41811263
41812643,41812457,1,2,41811263


In [8]:
import numpy as np

# Get unique top_level_parent values
unique_parents = pairs_df["top_level_parent"].unique()

# Randomly assign splits
np.random.seed(42)  # for reproducibility
split_assignments = np.random.choice(
    ["train", "test", "val"], size=len(unique_parents), p=[0.8, 0.1, 0.1]
)

# Create a dictionary mapping top_level_parent to split
parent_to_split = dict(zip(unique_parents, split_assignments))

# Add the new 'split' column
pairs_df = pairs_df.with_columns(
    pl.col("top_level_parent").replace_strict(parent_to_split).alias("split")
)

# Verify the split proportions
split_counts = pairs_df["split"].value_counts().sort("count", descending=True)
print("Split counts:")
print(split_counts)

# Calculate percentages
total = split_counts["count"].sum()
percentages = (split_counts["count"] / total * 100).round(2)
print("\nSplit percentages:")
print(percentages)

Split counts:
shape: (3, 2)
┌───────┬────────┐
│ split ┆ count  │
│ ---   ┆ ---    │
│ str   ┆ u32    │
╞═══════╪════════╡
│ train ┆ 297776 │
│ test  ┆ 37610  │
│ val   ┆ 36025  │
└───────┴────────┘

Split percentages:
shape: (3,)
Series: 'count' [f64]
[
	80.17
	10.13
	9.7
]


In [10]:
sample_pairs = pairs_df.sample(n=100000, seed=42)

In [11]:
from utils import build_all_prompts

chosen_prompts_v1 = build_all_prompts(sample_pairs["chosen"], version="v1")
chosen_prompts_v2 = build_all_prompts(sample_pairs["chosen"], version="v2")
rejected_prompts_v1 = build_all_prompts(sample_pairs["rejected"], version="v1")
rejected_prompts_v2 = build_all_prompts(sample_pairs["rejected"], version="v2")

sample_pairs_v1 = sample_pairs.with_columns(
    pl.Series("chosen_prompt", chosen_prompts_v1),
    pl.Series("rejected_prompt", rejected_prompts_v1),
)

sample_pairs_v2 = sample_pairs.with_columns(
    pl.Series("chosen_prompt", chosen_prompts_v2),
    pl.Series("rejected_prompt", rejected_prompts_v2),
)

sample_pairs


Building prompts:   0%|          | 0/100000 [00:00<?, ?it/s]

Building prompts: 100%|██████████| 100000/100000 [00:54<00:00, 1840.04it/s]
Building prompts: 100%|██████████| 100000/100000 [00:49<00:00, 2005.72it/s]
Building prompts: 100%|██████████| 100000/100000 [00:53<00:00, 1868.94it/s]
Building prompts: 100%|██████████| 100000/100000 [00:49<00:00, 2026.30it/s]


chosen,rejected,chosen_rank,rejected_rank,top_level_parent,split
i64,i64,i64,i64,i64,str
8883038,8881764,1,3,8878903,"""val"""
30012617,30012962,1,2,30011382,"""train"""
8525987,8524085,1,3,8523421,"""train"""
15055589,15055588,1,2,15054903,"""train"""
19970816,19969988,1,2,19968496,"""train"""
…,…,…,…,…,…
1713593,1713509,1,2,1713276,"""train"""
7564905,7564825,1,2,7564680,"""test"""
27519642,27519251,1,3,27514437,"""train"""
4986067,4985725,1,3,4985517,"""train"""


In [12]:
import numpy as np
from datasets import Dataset, DatasetDict


# Create a DatasetDict to hold all splits
dataset_dict_v1 = DatasetDict(
    {
        "train": Dataset.from_polars(
            sample_pairs_v1.filter(pl.col("split") == "train").sample(n=30000, seed=42)
        ),
        "validation": Dataset.from_polars(
            sample_pairs_v1.filter(pl.col("split") == "val").sample(n=500, seed=42)
        ),
        "test": Dataset.from_polars(
            sample_pairs_v1.filter(pl.col("split") == "test").sample(n=1000, seed=42)
        ),
    }
)

dataset_dict_v1.save_to_disk("./data/sample_pairs_v1")

dataset_dict_v2 = DatasetDict(
    {
        "train": Dataset.from_polars(
            sample_pairs_v2.filter(pl.col("split") == "train").sample(n=30000, seed=42)
        ),
        "validation": Dataset.from_polars(
            sample_pairs_v2.filter(pl.col("split") == "val").sample(n=500, seed=42)
        ),
        "test": Dataset.from_polars(
            sample_pairs_v2.filter(pl.col("split") == "test").sample(n=1000, seed=42)
        ),
    }
)

dataset_dict_v2.save_to_disk("./data/sample_pairs_v2")

Saving the dataset (0/1 shards):   0%|          | 0/30000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/30000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [13]:
from dotenv import load_dotenv

load_dotenv("/workspace/.env")

dataset_dict_v1.push_to_hub("OpenPipe/best-hn-comment-pairs-v1")
dataset_dict_v2.push_to_hub("OpenPipe/best-hn-comment-pairs-v2")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/30 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]



Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/30 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/datasets/OpenPipe/best-hn-comment-pairs-v2/commit/0c19b0ee5a3bb1d1caf8071359897f22338f90b0', commit_message='Upload dataset', commit_description='', oid='0c19b0ee5a3bb1d1caf8071359897f22338f90b0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/OpenPipe/best-hn-comment-pairs-v2', endpoint='https://huggingface.co', repo_type='dataset', repo_id='OpenPipe/best-hn-comment-pairs-v2'), pr_revision=None, pr_num=None)

In [16]:
top_comments = comments_df.filter(pl.col("sibling_rank") == 1).filter(
    (pl.col("id").is_in(dataset_dict_v1["train"]["chosen"]).not_())
    & (pl.col("id").is_in(dataset_dict_v1["train"]["rejected"]).not_())
)

top_comments = top_comments.with_columns(
    pl.Series("prompt", build_all_prompts(top_comments["id"], version="v1"))
)

top_comments.write_parquet("./data/top_comments.parquet")


Building prompts: 100%|██████████| 340702/340702 [02:44<00:00, 2065.44it/s]
