In [None]:
from pathlib import Path
import pandas as pd

# >>> EDIT THIS to your dataset root:
DATASET_PATH = Path("/home/scur1748/ir2-colbert-news/baseline/data")

def load_behaviors(split: str) -> pd.DataFrame:
    """Load behaviors.tsv or behaviors_parsed.tsv and standardize columns."""
    cand = [DATASET_PATH / split / "behaviors.tsv",
            DATASET_PATH / split / "behaviors_parsed.tsv"]
    path = next((p for p in cand if p.exists()), None)
    if path is None:
        raise FileNotFoundError(f"{split}: behaviors file not found (looked for behaviors.tsv / behaviors_parsed.tsv)")
    # Try no-header then header
    try:
        df = pd.read_csv(path, sep="\t", header=None, dtype=str, engine="python", quoting=3, on_bad_lines="skip")
        if df.shape[1] < 5:
            raise ValueError("behaviors has <5 columns")
        df = df.iloc[:, :5]
        df.columns = ["ImpressionID", "UserID", "Time", "History", "Impressions"]
    except Exception:
        df = pd.read_csv(path, sep="\t", header=0, dtype=str, engine="python", quoting=3, on_bad_lines="skip").fillna("")
        # Map likely names to standard
        colmap = {}
        names = [c.lower() for c in df.columns]
        want = ["impressionid","userid","time","history","impressions"]
        # Simple align left-to-right if needed
        if not set(want).issubset(names):
            df = df.iloc[:, :5]
            df.columns = ["ImpressionID", "UserID", "Time", "History", "Impressions"]
        else:
            # normalize casing
            ren = {}
            for c in df.columns:
                lc = c.lower()
                if lc == "impressionid": ren[c] = "ImpressionID"
                elif lc == "userid": ren[c] = "UserID"
                elif lc == "time": ren[c] = "Time"
                elif lc == "history": ren[c] = "History"
                elif lc == "impressions": ren[c] = "Impressions"
            df = df.rename(columns=ren)
    return df.fillna("")

def load_news(split: str) -> pd.DataFrame:
    """Load news.tsv and return df with a 'news_id' column."""
    path = DATASET_PATH / split / "news.tsv"
    if not path.exists():
        raise FileNotFoundError(f"{split}: news.tsv not found")
    # Try header then no-header
    try:
        df = pd.read_csv(path, sep="\t", header=0, dtype=str, engine="python", quoting=3)
    except Exception:
        df = pd.read_csv(path, sep="\t", header=None, dtype=str, engine="python", quoting=3)
    df = df.fillna("")
    # Find news id column robustly
    candidates = [c for c in df.columns if str(c).lower() in {"news_id","newsid","nid","id","news"}]
    if candidates:
        nid_col = candidates[0]
    else:
        # assume first column is news id
        nid_col = df.columns[0]
    # normalize to 'news_id'
    if nid_col != "news_id":
        df = df.rename(columns={nid_col: "news_id"})
    return df

def count_users_and_news(split: str):
    beh = load_behaviors(split)
    news = load_news(split)
    users = beh["UserID"].nunique()
    news_count = news["news_id"].nunique()
    return users, news_count

for split in ["train", "val"]:
    try:
        u, n = count_users_and_news(split)
        print(f"{split}: {u} unique users, {n} unique news")
    except Exception as e:
        print(f"{split}: ERROR -> {e}")


train: 20000 unique users, 17 unique news
val: 50000 unique users, 42415 unique news


In [21]:
# --- ONE CELL: robust downsample that preserves news.tsv JSON exactly ---

import os, re, shutil, random, json
from typing import Set
import pandas as pd

# ------------ edit your paths/params here ------------
in_root  = "/home/scur1748/ir2-colbert-news/baseline/data"
out_root = "/home/scur1748/ir2-colbert-news/baseline/data_downsampled_20k"
seed = 42
train_users = 20000
val_users   = -1     # set >0 to also downsample val
# -----------------------------------------------------

BEH_COLS = ["impression_id", "user_id", "time", "click_history", "impressions"]
SEP = re.compile(r"[,\s|]+")

def read_behaviors(path: str) -> pd.DataFrame:
    # behaviors.tsv in MIND is TSV w/out header (5 cols)
    # Use python engine + QUOTE_NONE to avoid accidental quote parsing
    return pd.read_csv(path, sep="\t", header=None, names=BEH_COLS,
                       dtype=str, engine="python", quoting=3, on_bad_lines="skip")

def extract_news_from_behaviors(df: pd.DataFrame) -> Set[str]:
    """Collect news IDs from click_history and impressions, trimming labels -0/-1 and odd separators."""
    news_ids: Set[str] = set()

    # History: tokens like Nxxxx separated by space/comma/pipe; '-' means empty
    for h in df["click_history"].fillna("").astype(str):
        if not h or h == "-":
            continue
        for tok in SEP.split(h.strip()):
            tok = tok.strip()
            if not tok:
                continue
            if "-" in tok:  # safety: drop any stuck-on label
                tok = tok.split("-", 1)[0]
            news_ids.add(tok)

    # Impressions: tokens "Nxxxx-0/1" or "Nxxxx"
    for imp in df["impressions"].fillna("").astype(str):
        if not imp:
            continue
        for tok in SEP.split(imp.strip()):
            tok = tok.strip()
            if not tok:
                continue
            if "-" in tok:
                tok = tok.split("-", 1)[0]
            news_ids.add(tok)

    # Keep typical MIND IDs (start with 'N')
    return {nid for nid in news_ids if nid and nid[0] == "N"}

def sample_users(df: pd.DataFrame, target_users: int, seed: int) -> Set[str]:
    users = sorted(df["user_id"].astype(str).unique().tolist())
    if target_users <= 0 or target_users >= len(users):
        return set(users)
    rnd = random.Random(seed)
    return set(rnd.sample(users, target_users))

def subset_behaviors(df: pd.DataFrame, keep_users: Set[str]) -> pd.DataFrame:
    return df[df["user_id"].astype(str).isin(keep_users)].copy()

def filter_news_stream_preserve_json(news_path: str, keep_news_ids: Set[str], out_path: str):
    """
    Stream-filter news.tsv by FIRST FIELD (news_id) and preserve the original line bytes:
    - Writes header line through unchanged if present (starts with 'news_id\\t')
    - For data lines, keeps the line if the first tab-delimited field is in keep_news_ids
    Result: title/abstract and *_entities JSON are preserved exactly (no re-quoting).
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    kept = 0
    total = 0
    header_written = False
    with open(news_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout:
        first = fin.readline()
        if not first:
            return
        total += 1
        if first.lower().startswith("news_id\t"):
            fout.write(first)  # write header as-is
            header_written = True
        else:
            # Not a header; process the first line as data
            parts = first.rstrip("\n").split("\t", 1)
            nid = parts[0]
            if nid in keep_news_ids:
                fout.write(first)
                kept += 1

        for line in fin:
            total += 1
            parts = line.rstrip("\n").split("\t", 1)
            if not parts:
                continue
            nid = parts[0]
            if header_written and nid.lower() == "news_id":
                # rare duplicated header inside file; pass through
                fout.write(line)
                continue
            if nid in keep_news_ids:
                fout.write(line)
                kept += 1

    print(f"[news] original lines read: {total:,} | kept lines: {kept:,} | header={'yes' if header_written else 'no'}")

def sanity_check_news_entities(news_path: str, max_issues: int = 10):
    """
    Scan the filtered news.tsv and try to validate JSON in title_entities / abstract_entities columns
    WITHOUT rewriting the file. This only reads/validates.
    """
    # Try to read header to find columns; if no header, skip this deep check
    with open(news_path, "r", encoding="utf-8") as f:
        head = f.readline()
        if not head:
            print("[sanity] news.tsv empty.")
            return
        has_header = head.lower().startswith("news_id\t")
    if not has_header:
        print("[sanity] No header detected in news.tsv; skipping JSON validation (downstream parser likely expects header).")
        return

    # With header present, use pandas to read but DO NOT write back
    df = pd.read_csv(news_path, sep="\t", header=0, dtype=str, engine="python", quoting=3, on_bad_lines="skip").fillna("")
    problems = []
    for col in [c for c in df.columns if c.lower() in ("title_entities", "abstract_entities")]:
        for idx, s in df[col].items():
            v = ("" if pd.isna(s) else str(s)).strip()
            if v == "" or v == "None":
                continue
            try:
                json.loads(v)
            except Exception as e:
                problems.append((col, idx, str(e), v[:120]))
                if len(problems) >= max_issues:
                    break
        if len(problems) >= max_issues:
            break

    if problems:
        print("[sanity] JSON problems detected in news entity columns (first few):")
        for col, idx, err, snippet in problems[:max_issues]:
            print(f"  Column {col} | Row {idx}: {err} --> {snippet}")
        print("  (These indicate the SOURCE file had invalid JSON in *_entities; our filtering preserved lines as-is.)")
    else:
        print("[sanity] Entity JSON looks OK in sampled rows.")

def process_split(split_name: str, in_root: str, out_root: str, seed: int, target_users: int):
    in_split = os.path.join(in_root, split_name)
    out_split = os.path.join(out_root, split_name)
    os.makedirs(out_split, exist_ok=True)

    beh_in  = os.path.join(in_split, "behaviors.tsv")
    news_in = os.path.join(in_split, "news.tsv")
    ent_in  = os.path.join(in_split, "entity_embedding.vec")
    rel_in  = os.path.join(in_split, "relation_embedding.vec")

    beh_out  = os.path.join(out_split, "behaviors.tsv")
    news_out = os.path.join(out_split, "news.tsv")
    ent_out  = os.path.join(out_split, "entity_embedding.vec")
    rel_out  = os.path.join(out_split, "relation_embedding.vec")

    print(f"[{split_name}] Loading behaviors from {beh_in}")
    beh_df = read_behaviors(beh_in)
    print(f"[{split_name}] behavior rows: {len(beh_df):,} | unique users: {beh_df['user_id'].nunique():,}")

    keep_users = sample_users(beh_df, target_users, seed)
    print(f"[{split_name}] users kept: {len(keep_users):,}")
    beh_sub = subset_behaviors(beh_df, keep_users)

    beh_sub.to_csv(beh_out, sep="\t", header=False, index=False)
    print(f"[{split_name}] behaviors written -> {beh_out} (rows: {len(beh_sub):,})")

    keep_news_ids = extract_news_from_behaviors(beh_sub)
    print(f"[{split_name}] referenced news ids (from behaviors): {len(keep_news_ids):,}")

    print(f"[{split_name}] Filtering news.tsv (stream, preserve JSON) ...")
    filter_news_stream_preserve_json(news_in, keep_news_ids, news_out)

    # quick sanity: count lines & header in output
    with open(news_out, "r", encoding="utf-8") as f:
        first = f.readline()
        has_header = first.lower().startswith("news_id\t")
        nlines = 1 + sum(1 for _ in f) if first else 0
    print(f"[{split_name}] news.tsv written -> {news_out} | lines={nlines:,} | header={'yes' if has_header else 'no'}")

    # Validate entity JSON if header present
    if has_header:
        sanity_check_news_entities(news_out)

    # Copy embeddings unchanged
    if os.path.exists(ent_in):
        shutil.copyfile(ent_in, ent_out)
    if os.path.exists(rel_in):
        shutil.copyfile(rel_in, rel_out)

    print(f"[{split_name}] âœ… Done!\n")

# -------------------- run --------------------
process_split("train", in_root, out_root, seed, train_users)

if val_users == -1:
    print("[val] Copying validation split unchanged...")
    in_val = os.path.join(in_root, "val")
    out_val = os.path.join(out_root, "val")
    os.makedirs(out_val, exist_ok=True)
    for fname in ["behaviors.tsv", "news.tsv", "entity_embedding.vec", "relation_embedding.vec"]:
        src = os.path.join(in_val, fname)
        dst = os.path.join(out_val, fname)
        if os.path.exists(src):
            shutil.copyfile(src, dst)
    print("[val] âœ… Done (unchanged).")
else:
    process_split("val", in_root, out_root, seed, val_users)

print("ðŸŽ‰ All done!")


[train] Loading behaviors from /home/scur1748/ir2-colbert-news/baseline/data/train/behaviors.tsv
[train] behavior rows: 156,965 | unique users: 50,000
[train] users kept: 20,000
[train] behaviors written -> /home/scur1748/ir2-colbert-news/baseline/data_downsampled_20k/train/behaviors.tsv (rows: 62,668)
[train] referenced news ids (from behaviors): 38,411
[train] Filtering news.tsv (stream, preserve JSON) ...
[news] original lines read: 51,282 | kept lines: 38,411 | header=no
[train] news.tsv written -> /home/scur1748/ir2-colbert-news/baseline/data_downsampled_20k/train/news.tsv | lines=38,411 | header=no
[train] âœ… Done!

[val] Copying validation split unchanged...
[val] âœ… Done (unchanged).
ðŸŽ‰ All done!
