In [None]:
# %pip install datasets pandas spacy tqdm

In [None]:
# !python -m spacy download en_core_web_md

In [None]:
from datasets import load_dataset

ds = load_dataset("trl-lib/tldr", split="train")

In [None]:
import re

def split_prompt(example):
    text = example["prompt"]
    # Regex with DOTALL so that POST: can span multiple lines
    m = re.match(
        r"SUBREDDIT:\s*(?P<subreddit>.+?)\s+TITLE:\s*(?P<title>.+?)\s+POST:\s*(?P<post>.+?)\s+TL;DR:",
        text,
        flags=re.DOTALL,
    )
    if not m:
        return {"subreddit": None, "title": None, "post": text}
    return m.groupdict()

ds = ds.map(split_prompt, remove_columns=["prompt"])

In [None]:
# ‘completion’ → ‘tldr’, drop any rows missing required fields
ds = ds.rename_column("completion", "tldr")
ds = ds.filter(lambda x: x["subreddit"] and x["title"] and x["post"] and x["tldr"])

In [None]:
import unicodedata

def clean_text(example):
    for col in ["title", "post", "tldr"]:
        text = example[col].strip()
        text = unicodedata.normalize("NFKC", text)
        example[col] = text
    return example

ds = ds.map(clean_text)

In [None]:
import spacy
from tqdm.auto import tqdm

nlp = spacy.load("en_core_web_md")

def compute_similar(example):
    doc_post = nlp(example["post"])
    doc_tldr = nlp(example["tldr"])
    tldr_text = example["tldr"].lower()
    similar = {}
    for chunk in doc_post.noun_chunks:
        phrase = chunk.text.strip().lower()
        if len(phrase) < 3:
            continue
        # Binary importance if phrase appears in the TL;DR summary
        important = 1 if phrase in tldr_text else 0
        # Similarity score via spaCy vectors
        sim_score = float(doc_tldr.similarity(nlp(phrase)))
        if sim_score >= 0.75: # Threshold for similarity
            important = 1
        similar[phrase] = (important, sim_score)
    return {"similar": similar}

# Apply with a progress bar
records = []
for row in tqdm(ds, total=len(ds)):
    rec = dict(row)
    rec.update(compute_similar(row))
    records.append(rec)

# Convert back into a Dataset
from datasets import Dataset
ds = Dataset.from_pandas(pd.DataFrame(records))


In [None]:
def keep_top_k(example, k: int = 30):
    items = sorted(
        example["similar"].items(),
        key=lambda kv: kv[1][1],
        reverse=True
    )[:k]
    example["similar"] = dict(items)
    return example

ds = ds.map(keep_top_k)

ds = ds.filter(lambda x: len(x["similar"]) > 0)

In [None]:
import json

# Convert the Dataset to a pandas DataFrame
df = ds.to_pandas()

# Serialize the nested `similar` dict into JSON strings
df['similar'] = df['similar'].apply(json.dumps)

# Write out to CSV (no index column)
df.to_csv("tldr_preprocessed.csv", index=False)

# loading the preprocessed data later:
# import pandas as pd
# df = pd.read_csv("tldr_preprocessed.csv")
# df['similar'] = df['similar'].apply(json.loads)