In [None]:
# ======================================================
# CTC per-source balanced subsetter (Colab 1-click)
# - Pulls from Reddit, Stackexchange, arXiv in Zenodo tar
# - 10,000 per source (5k non-cyber + 5k cyber) by default
# - Saves one file per source + one combined file (with "source")
# ======================================================

# -------- CONFIG --------
PER_SOURCE_TOTAL = 20_000      # total per source
BALANCED_PER_SOURCE = True     # keep 50/50 cyber vs non per source
USE_GOOGLE_DRIVE = True
OUT_DIR = "/content/drive/MyDrive/CTC_by_source"  # folder for outputs (Drive or local)
RANDOM_SEED = 1337

ZENODO_URL = "https://zenodo.org/records/10655913/files/CTC_training_data.tar.gz?download=1"

# File names
REDDIT_OUT = "CTC_Reddit_10k.json"
STACK_OUT  = "CTC_Stackexchange_10k.json"
ARXIV_OUT  = "CTC_arXiv_10k.json"
COMBINED_OUT = "CTC_by_source_30k.json"

# -------- Installs --------
import sys, subprocess, importlib.util
def _pip(pkgs): subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + pkgs, check=True)
need = [p for p in ["ijson","requests","tqdm"] if importlib.util.find_spec(p) is None]
if need: _pip(need)

# -------- Imports --------
import json, re, tarfile, random, posixpath
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import ijson, requests
from tqdm import tqdm

# -------- Mount Drive (optional) --------
if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount("/content/drive")

# -------- Helpers --------
def detect_source_and_label(member_name: str) -> Tuple[Optional[str], Optional[int]]:
    """
    From a tar member path like:
      CTC_training_data_text_export/Reddit/not_cybersecurity.json
      CTC_training_data_text_export/Stackexchange/cybersecurity.json
      CTC_training_data_text_export/arXiv/not_cybersecurity.json
    return (source, label) where source in {"Reddit","Stackexchange","arXiv"}
    and label in {0 (non-cyber), 1 (cyber)}.
    """
    n = member_name.strip()
    parts = n.split("/")
    # Expect .../<Source>/<file>.json
    source = None
    if len(parts) >= 3:
        source = parts[-2]  # Reddit | Stackexchange | arXiv (case-sensitive in Zenodo)
    fname = parts[-1].lower()

    # label from filename
    if re.search(r'(?:^|/)(?:not|non)[-_]?cybersecurity\.json$', fname, re.I):
        label = 0
    elif re.search(r'(?:^|/)cybersecurity\.json$', fname, re.I):
        label = 1
    else:
        # fallback
        if "cyber" in fname and ("not_" in fname or "not-" in fname or "non_" in fname or "non-" in fname):
            label = 0
        elif "cyber" in fname:
            label = 1
        else:
            label = None
    return source, label

def reservoir_insert(reservoir: List[dict], k: int, item: dict, seen: int, rng: random.Random):
    if len(reservoir) < k:
        reservoir.append(item)
    else:
        j = rng.randint(1, seen)  # inclusive
        if j <= k:
            reservoir[j - 1] = item

def write_json_array(path: Path, items: List[dict]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        f.write("[\n")
        for i, obj in enumerate(items):
            if i:
                f.write(",\n")
            json.dump(obj, f, ensure_ascii=False)
        f.write("\n]\n")

def make_per_source_balanced(
    url: str,
    out_dir: Path,
    per_source_total: int = 10_000,
    balanced: bool = True,
    seed: int = 1337
) -> Dict:
    """
    Build per-source reservoirs:
      sources = ["Reddit","Stackexchange","arXiv"]
      If balanced=True, keep per_source_total/2 per class (0 and 1) per source.
      Otherwise, keep per_source_total with reservoir across both classes (still labeled).
    Stop early once all per-source quotas are filled.
    Returns diagnostics.
    """
    rng = random.Random(seed)
    sources = ["Reddit", "Stackexchange", "arXiv"]

    # targets
    if balanced:
        k_per_class = per_source_total // 2
        targets = {s: {0: k_per_class, 1: k_per_class} for s in sources}
    else:
        # not used here, but left for completeness (single reservoir per source)
        targets = {s: {0: per_source_total, 1: 0} for s in sources}

    # reservoirs & counters
    res = {s: {0: [], 1: []} for s in sources}
    seen = {s: {0: 0, 1: 0} for s in sources}

    diags = {
        "members_seen": 0,
        "items_seen_total": 0,
        "per_source_used": {s: {0: 0, 1: 0} for s in sources},
        "per_source_seen": {s: {0: 0, 1: 0} for s in sources},
        "members_skipped": 0
    }

    def all_filled():
        for s in sources:
            for c in (0, 1):
                if len(res[s][c]) < targets[s][c]:
                    return False
        return True

    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        tf = tarfile.open(fileobj=r.raw, mode="r|gz")
        for member in tf:
            if not member.name.endswith(".json"):
                continue
            diags["members_seen"] += 1

            source, label = detect_source_and_label(member.name)
            if source not in sources or label not in (0, 1):
                diags["members_skipped"] += 1
                continue

            fobj = tf.extractfile(member)
            if fobj is None:
                diags["members_skipped"] += 1
                continue

            need = targets[source][label]
            # If this source/class is already full, skip reading this member fast
            if len(res[source][label]) >= need:
                continue

            pbar = tqdm(total=None, unit="items",
                        desc=f"Streaming {member.name} -> {source}, label={label}")
            for item in ijson.items(fobj, "item"):
                diags["items_seen_total"] += 1
                if not isinstance(item, str):
                    continue
                diags["per_source_seen"][source][label] += 1
                seen[source][label] += 1

                if len(res[source][label]) < need:
                    obj = {"text": item, "label": label, "source": source}
                    reservoir_insert(res[source][label], need, obj, seen[source][label], rng)

                pbar.update(1)
                # If this source/class just became full, we can consider breaking early only if all are full.
                if all_filled():
                    pbar.close()
                    break
            else:
                pbar.close()

            if all_filled():
                break

    # write per-source files
    out_dir.mkdir(parents=True, exist_ok=True)
    paths = {}
    for s in sources:
        items = res[s][0] + res[s][1]
        rng.shuffle(items)
        fname = {
            "Reddit": "CTC_Reddit_10k.json",
            "Stackexchange": "CTC_Stackexchange_10k.json",
            "arXiv": "CTC_arXiv_10k.json"
        }[s]
        p = out_dir / fname
        write_json_array(p, items)
        paths[s] = str(p)
        diags["per_source_used"][s][0] = len(res[s][0])
        diags["per_source_used"][s][1] = len(res[s][1])

    # write combined file
    combined = []
    for s in sources:
        combined.extend(res[s][0])
        combined.extend(res[s][1])
    rng.shuffle(combined)
    combined_path = out_dir / "CTC_by_source_30k.json"
    write_json_array(combined_path, combined)
    paths["combined"] = str(combined_path)

    return {"paths": paths, "diags": diags}

# -------- Run it --------
out_dir = Path(OUT_DIR)
result = make_per_source_balanced(
    url=ZENODO_URL,
    out_dir=out_dir,
    per_source_total=PER_SOURCE_TOTAL,
    balanced=BALANCED_PER_SOURCE,
    seed=RANDOM_SEED
)

print("\n✅ Done.")
print("Files written:")
for k, v in result["paths"].items():
    print(f"  {k}: {v}")

print("\nDiagnostics:")
print(json.dumps(result["diags"], indent=2))

# quick sanity
from collections import Counter
for s, p in result["paths"].items():
    if s == "combined": continue
    data = json.loads(Path(p).read_text(encoding="utf-8"))
    print(f"\n{s}: loaded {len(data)}")
    print("label counts:", Counter(d["label"] for d in data))
    print("sample:", (data[0]["text"][:120] + ("…" if len(data[0]["text"])>120 else "")))
# combined
cp = result["paths"]["combined"]
cdata = json.loads(Path(cp).read_text(encoding="utf-8"))
print(f"\ncombined: loaded {len(cdata)}")
print("label counts:", Counter(d["label"] for d in cdata))
print("source counts:", Counter(d["source"] for d in cdata))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Streaming CTC_training_data_text_export/Reddit/not_cybersecurity.json -> Reddit, label=0: 4184184items [00:34, 121534.35items/s]
Streaming CTC_training_data_text_export/Reddit/cybersecurity.json -> Reddit, label=1: 164750items [00:01, 137964.91items/s]
Streaming CTC_training_data_text_export/arXiv/not_cybersecurity.json -> arXiv, label=0: 28996items [00:25, 1117.91items/s]
Streaming CTC_training_data_text_export/arXiv/cybersecurity.json -> arXiv, label=1: 12132items [00:13, 906.16items/s]
Streaming CTC_training_data_text_export/Stackexchange/not_cybersecurity.json -> Stackexchange, label=0: 4842461items [01:57, 41220.27items/s] 
Streaming CTC_training_data_text_export/Stackexchange/cybersecurity.json -> Stackexchange, label=1: 10000items [00:00, 60432.91items/s]



✅ Done.
Files written:
  Reddit: /content/drive/MyDrive/CTC_by_source/CTC_Reddit_10k.json
  Stackexchange: /content/drive/MyDrive/CTC_by_source/CTC_Stackexchange_10k.json
  arXiv: /content/drive/MyDrive/CTC_by_source/CTC_arXiv_10k.json
  combined: /content/drive/MyDrive/CTC_by_source/CTC_by_source_30k.json

Diagnostics:
{
  "members_seen": 6,
  "items_seen_total": 9242523,
  "per_source_used": {
    "Reddit": {
      "0": 10000,
      "1": 10000
    },
    "Stackexchange": {
      "0": 10000,
      "1": 10000
    },
    "arXiv": {
      "0": 10000,
      "1": 10000
    }
  },
  "per_source_seen": {
    "Reddit": {
      "0": 4184184,
      "1": 164750
    },
    "Stackexchange": {
      "0": 4842461,
      "1": 10000
    },
    "arXiv": {
      "0": 28996,
      "1": 12132
    }
  },
  "members_skipped": 0
}

Reddit: loaded 20000
label counts: Counter({0: 10000, 1: 10000})
sample: Look both ways before crossing... ok, at least look one way... no, not that way.... 

Stackexchange: load

In [None]:
def prepare_splits_for_texts(X, y):
    X_train_text, X_test_text, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y
    )
    X_train_text, X_val_text, y_train, y_val = train_test_split(
        X_train_text, y_train, test_size=VAL_SIZE_WITHIN_TRAIN, random_state=RANDOM_STATE, stratify=y_train
    )
    return X_train_text, y_train, X_val_text, y_val, X_test_text, y_test

def bench_source(name: str, X: List[str], y: List[int]):
    print(f"\n===== {name}: {len(X)} samples =====")
    X_train_text, y_train, X_val_text, y_val, X_test_text, y_test = prepare_splits_for_texts(X, y)
    results = []
    for method in METHODS.keys():
        print(f"\n[{name}] Method: {method}")
        acc, report, timings = train_eval_method(X_train_text, y_train, X_val_text, y_val, X_test_text, y_test, method)
        results.append({"source": name, "method": method, "accuracy": acc, **timings})
        print(f"  accuracy={acc:.4f}")
    return pd.DataFrame(results)

# Load per-source data
Xr, yr = load_ctc_json(RED_PATH)
Xs, ys = load_ctc_json(STK_PATH)
Xa, ya = load_ctc_json(ARX_PATH)

df_r = bench_source("Reddit",       Xr, yr)
df_s = bench_source("Stackexchange", Xs, ys)
df_a = bench_source("arXiv",        Xa, ya)

df_all = pd.concat([df_r, df_s, df_a], ignore_index=True)
df_all.to_csv(WORKDIR / "vectorizer_bench_results.csv", index=False)
df_all.head()

In [None]:
def plot_accuracy_bars(df, title):
    order = df.groupby("method")["accuracy"].mean().sort_values(ascending=False).index.tolist()
    plt.figure(figsize=(10,5))
    plt.bar(df["method"], df["accuracy"])
    plt.xticks(rotation=60, ha="right")
    plt.ylim(0,1)
    plt.title(title)
    plt.ylabel("Accuracy")
    plt.grid(axis="y", alpha=0.3)
    plt.show()

def plot_accuracy_per_source(df_all):
    for src in ["Reddit","Stackexchange","arXiv"]:
        df = df_all[df_all["source"]==src].copy()
        df = df.sort_values("accuracy", ascending=False)
        plt.figure(figsize=(10,5))
        plt.bar(df["method"], df["accuracy"])
        plt.xticks(rotation=60, ha="right"); plt.ylim(0,1)
        plt.title(f"{src} — Accuracy by method")
        plt.ylabel("Accuracy"); plt.grid(axis="y", alpha=0.3)
        plt.show()

def plot_timing(df_all, timing_key: str, title: str):
    for src in ["Reddit","Stackexchange","arXiv"]:
        df = df_all[df_all["source"]==src].copy()
        if timing_key not in df.columns: continue
        df = df.sort_values(timing_key, ascending=True)
        plt.figure(figsize=(10,5))
        plt.bar(df["method"], df[timing_key])
        plt.xticks(rotation=60, ha="right")
        plt.title(f"{src} — {title}")
        plt.ylabel("Seconds"); plt.grid(axis="y", alpha=0.3)
        plt.show()

plot_accuracy_per_source(df_all)
plot_timing(df_all, "fit_vectorizer_s", "Vectorizer/Embedder Fit Time (s)")
plot_timing(df_all, "transform_s", "Transform/Embed Time (s)")
plot_timing(df_all, "fit_clf_s", "Classifier Fit Time (s)")
plot_timing(df_all, "infer_s", "Inference Time (s)")