In [3]:
from pathlib import Path
import pandas as pd

# >>> EDIT THIS to your dataset root:
DATASET_PATH = Path("/home/scur1748/ir2-colbert-news/baseline/data")

def load_behaviors(split: str) -> pd.DataFrame:
    """Load behaviors.tsv or behaviors_parsed.tsv and standardize columns."""
    cand = [DATASET_PATH / split / "behaviors.tsv",
            DATASET_PATH / split / "behaviors_parsed.tsv"]
    path = next((p for p in cand if p.exists()), None)
    if path is None:
        raise FileNotFoundError(f"{split}: behaviors file not found (looked for behaviors.tsv / behaviors_parsed.tsv)")
    # Try no-header then header
    try:
        df = pd.read_csv(path, sep="\t", header=None, dtype=str, engine="python", quoting=3, on_bad_lines="skip")
        if df.shape[1] < 5:
            raise ValueError("behaviors has <5 columns")
        df = df.iloc[:, :5]
        df.columns = ["ImpressionID", "UserID", "Time", "History", "Impressions"]
    except Exception:
        df = pd.read_csv(path, sep="\t", header=0, dtype=str, engine="python", quoting=3, on_bad_lines="skip").fillna("")
        # Map likely names to standard
        colmap = {}
        names = [c.lower() for c in df.columns]
        want = ["impressionid","userid","time","history","impressions"]
        # Simple align left-to-right if needed
        if not set(want).issubset(names):
            df = df.iloc[:, :5]
            df.columns = ["ImpressionID", "UserID", "Time", "History", "Impressions"]
        else:
            # normalize casing
            ren = {}
            for c in df.columns:
                lc = c.lower()
                if lc == "impressionid": ren[c] = "ImpressionID"
                elif lc == "userid": ren[c] = "UserID"
                elif lc == "time": ren[c] = "Time"
                elif lc == "history": ren[c] = "History"
                elif lc == "impressions": ren[c] = "Impressions"
            df = df.rename(columns=ren)
    return df.fillna("")

def load_news(split: str) -> pd.DataFrame:
    """Load news.tsv and return df with a 'news_id' column."""
    path = DATASET_PATH / split / "news.tsv"
    if not path.exists():
        raise FileNotFoundError(f"{split}: news.tsv not found")
    # Try header then no-header
    try:
        df = pd.read_csv(path, sep="\t", header=0, dtype=str, engine="python", quoting=3)
    except Exception:
        df = pd.read_csv(path, sep="\t", header=None, dtype=str, engine="python", quoting=3)
    df = df.fillna("")
    # Find news id column robustly
    candidates = [c for c in df.columns if str(c).lower() in {"news_id","newsid","nid","id","news"}]
    if candidates:
        nid_col = candidates[0]
    else:
        # assume first column is news id
        nid_col = df.columns[0]
    # normalize to 'news_id'
    if nid_col != "news_id":
        df = df.rename(columns={nid_col: "news_id"})
    return df

def count_users_and_news(split: str):
    beh = load_behaviors(split)
    news = load_news(split)
    users = beh["UserID"].nunique()
    news_count = news["news_id"].nunique()
    return users, news_count

for split in ["train", "val"]:
    try:
        u, n = count_users_and_news(split)
        print(f"{split}: {u} unique users, {n} unique news")
    except Exception as e:
        print(f"{split}: ERROR -> {e}")


train: 50000 unique users, 51281 unique news
val: 50000 unique users, 42415 unique news


In [6]:
# === Downsample MIND by users & ensure all expected files exist ===
from pathlib import Path
import pandas as pd
import numpy as np
import shutil
import re

# --- CONFIG: update these paths ---
INPUT_ROOT  = Path("/home/scur1748/ir2-colbert-news/baseline/data")       # your current MIND root
OUTPUT_ROOT = Path("/home/scur1748/ir2-colbert-news/baseline/data_downsampled_20k")  # where to write the 20k subset
KEEP_USERS  = 20_000
SEED        = 42
# ----------------------------------

IMPR_SPLIT = re.compile(r"\s+")
PAIR_SPLIT = re.compile(r"[-:]")

def _read_behaviors_any(path: Path) -> pd.DataFrame:
    """
    Standardize behaviors to 5 columns: ImpressionID, UserID, Time, History, Impressions.
    Works for both behaviors.tsv and behaviors_parsed.tsv variants when they contain those fields.
    Falls back to best-effort renaming.
    """
    # try no header
    df = pd.read_csv(path, sep="\t", header=None, dtype=str, engine="python", quoting=3, on_bad_lines="skip")
    if df.shape[1] >= 5:
        df = df.iloc[:, :5]
        df.columns = ["ImpressionID","UserID","Time","History","Impressions"]
    else:
        # try header-based
        df = pd.read_csv(path, sep="\t", header=0, dtype=str, engine="python", quoting=3, on_bad_lines="skip")
        expected = ["ImpressionID","UserID","Time","History","Impressions"]
        for i, col in enumerate(df.columns[:5]):
            if i < 5 and col not in expected:
                df = df.rename(columns={col: expected[i]})
        for col in expected:
            if col not in df.columns:
                df[col] = ""
        df = df[expected]
    return df.fillna("")

def _extract_news_from_row(history: str, impressions: str):
    ids = set()
    if history:
        ids |= {n for n in IMPR_SPLIT.split(history.strip()) if n}
    if impressions:
        for tok in IMPR_SPLIT.split(impressions.strip()):
            tok = tok.strip()
            if not tok:
                continue
            parts = PAIR_SPLIT.split(tok)
            if parts:
                ids.add(parts[0])
    return ids

def _collect_news_ids(behaviors_df: pd.DataFrame):
    keep = set()
    for _, r in behaviors_df.iterrows():
        keep |= _extract_news_from_row(r.get("History",""), r.get("Impressions",""))
    return keep

def _read_news_any(news_path: Path) -> pd.DataFrame:
    """
    Read news.tsv even if there is no header with 'NewsID'.
    Assumes first column is the news id if header missing.
    """
    try:
        df = pd.read_csv(news_path, sep="\t", dtype=str, engine="python", quoting=3)
        if "NewsID" not in df.columns:
            # rename first column to NewsID
            first = df.columns[0]
            df = df.rename(columns={first: "NewsID"})
    except Exception:
        df = pd.read_csv(news_path, sep="\t", header=None, dtype=str, engine="python", quoting=3)
        # pad to at least 8 columns for MIND-like format
        while df.shape[1] < 8:
            df[df.shape[1]] = ""
        df = df.iloc[:, :8]
        df.columns = ["NewsID","Category","SubCategory","Title","Abstract","URL","TitleEntities","AbstractEntities"]
    return df.fillna("")

def _filter_news_tables(split_dir: Path, out_dir: Path, keep_news: set):
    news_path = split_dir / "news.tsv"
    if news_path.exists():
        news_df = _read_news_any(news_path)
        f_news = news_df[news_df["NewsID"].astype(str).isin(keep_news)].copy()
        out_dir.mkdir(parents=True, exist_ok=True)
        f_news.to_csv(out_dir / "news.tsv", sep="\t", index=False, header=True)
    # optional parsed
    np_path = split_dir / "news_parsed.tsv"
    if np_path.exists():
        np_df = pd.read_csv(np_path, sep="\t", dtype=str, engine="python", quoting=3).fillna("")
        # find the news id column
        cand = [c for c in np_df.columns if c.lower() in ("newsid","news_id","nid","id","news")]
        if cand:
            col = cand[0]
        else:
            col = np_df.columns[0]
        np_f = np_df[np_df[col].astype(str).isin(keep_news)].copy()
        np_f.to_csv(out_dir / "news_parsed.tsv", sep="\t", index=False, header=True)

def _subset_user2int(split_dir: Path, out_dir: Path, keep_users: set, remap=False):
    u2i = split_dir / "user2int.tsv"
    if not u2i.exists():
        return {}, set()
    try:
        df = pd.read_csv(u2i, sep="\t", dtype=str, engine="python", quoting=3)
        cols = list(df.columns)
        if len(cols) < 2:
            raise ValueError
        user_col = cols[0]
        int_col  = cols[1]
    except Exception:
        df = pd.read_csv(u2i, sep="\t", header=None, dtype=str, engine="python", quoting=3)
        if df.shape[1] < 2:
            raise ValueError("user2int.tsv must have at least two columns")
        df = df.iloc[:, :2]
        df.columns = ["user","int"]
        user_col, int_col = "user", "int"

    df = df.fillna("")
    f = df[df[user_col].astype(str).isin(keep_users)].copy()
    if remap:
        f = f.sort_values(by=user_col).reset_index(drop=True)
        f[int_col] = np.arange(len(f)).astype(str)
    out_dir.mkdir(parents=True, exist_ok=True)
    f.to_csv(out_dir / "user2int.tsv", sep="\t", index=False, header=True)

    user_to_int = dict(zip(f[user_col].astype(str), f[int_col].astype(str)))
    int_set = set(f[int_col].astype(str))
    return user_to_int, int_set

def _filter_behaviors_parsed_if_present(split_dir: Path, out_dir: Path, keep_users: set, keep_user_ints: set, user_to_int: dict):
    bp = split_dir / "behaviors_parsed.tsv"
    if not bp.exists():
        return
    df = pd.read_csv(bp, sep="\t", dtype=str, engine="python", quoting=3).fillna("")
    cols_lower = {c.lower(): c for c in df.columns}

    # 1) If it has a 'userid' (any case), filter on it
    for key in ("userid","user_id","user"):
        if key in cols_lower:
            c = cols_lower[key]
            df_f = df[df[c].astype(str).isin(keep_users)].copy()
            df_f.to_csv(out_dir / "behaviors_parsed.tsv", sep="\t", index=False, header=True)
            return

    # 2) Else try to filter using user-int column if present
    #    Find the column with the highest overlap with kept user ints
    best_col, best_overlap = None, -1
    for c in df.columns:
        col_vals = set(df[c].astype(str).unique())
        overlap = len(col_vals & keep_user_ints)
        if overlap > best_overlap:
            best_overlap, best_col = overlap, c

    if best_col is not None and best_overlap > 0:
        df_f = df[df[best_col].astype(str).isin(keep_user_ints)].copy()
        df_f.to_csv(out_dir / "behaviors_parsed.tsv", sep="\t", index=False, header=True)
    else:
        # If we cannot identify a user column, just copy the original so the file exists.
        shutil.copy2(bp, out_dir / "behaviors_parsed.tsv")

def _copy_if_exists(src_dir: Path, dst_dir: Path, name: str):
    p = src_dir / name
    if p.exists():
        dst_dir.mkdir(parents=True, exist_ok=True)
        shutil.copy2(p, dst_dir / name)

def _downsample_split(split: str):
    split_in  = INPUT_ROOT / split
    split_out = OUTPUT_ROOT / split
    split_out.mkdir(parents=True, exist_ok=True)

    # 1) choose behaviors file (prefer behaviors.tsv)
    behaviors_path = None
    for fname in ["behaviors.tsv", "behaviors_parsed.tsv"]:
        p = split_in / fname
        if p.exists():
            behaviors_path = p
            break
    if behaviors_path is None:
        raise FileNotFoundError(f"No behaviors file found in {split_in}")

    # 2) read & sample users
    behaviors = _read_behaviors_any(behaviors_path)
    users = behaviors["UserID"].astype(str).unique().tolist()
    rng = np.random.default_rng(SEED)
    k = min(KEEP_USERS, len(users))
    keep_users = set(rng.choice(users, size=k, replace=False).tolist())

    # 3) filter behaviors and write the same filename back
    behaviors_f = behaviors[behaviors["UserID"].astype(str).isin(keep_users)].copy()
    # Write both variants for convenience:
    # - If original was behaviors.tsv, write behaviors.tsv
    # - Additionally, also write a filtered behaviors_parsed.tsv so you don't "miss" it
    behaviors_f.to_csv(split_out / "behaviors.tsv", sep="\t", index=False, header=True)

    # 4) collect referenced news and filter news files
    keep_news = _collect_news_ids(behaviors_f)
    _filter_news_tables(split_in, split_out, keep_news)

    # 5) user2int: subset (no remap to avoid breaking int references)
    user_to_int, keep_user_ints = _subset_user2int(split_in, split_out, keep_users, remap=False)

    # 6) behaviors_parsed.tsv: filter if present (by UserID or by int overlap). If not identifiable, copy.
    _filter_behaviors_parsed_if_present(split_in, split_out, keep_users, keep_user_ints, user_to_int)

    # 7) Ensure mapping files exist in output (copy-through)
    for name in ["category2int.tsv", "entity2int.tsv", "word2int.tsv",
                 "entity_embedding.vec", "relation_embedding.vec"]:
        _copy_if_exists(split_in, split_out, name)

    print(f"[{split}] kept users: {len(keep_users):,} | kept news: {len(keep_news):,}")

# --- run for train & val ---
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
for split in ["train", "val"]:
    if (INPUT_ROOT / split).exists():
        _downsample_split(split)
    else:
        print(f"Skipping missing split: {split}")

print("Done.")


[train] kept users: 20,000 | kept news: 38,341
[val] kept users: 20,000 | kept news: 30,105
Done.
