In [12]:
# CELL 1: Imports & CONFIG
import os
import json
import random
import re
import string
from collections import Counter, defaultdict
from tqdm import tqdm
import math

import pandas as pd
import numpy as np

# ---------- CONFIG ----------
RAW_JSON = "../data/raw/data.json"                  # input (adjust if needed)
OUTPUT_CSV = "../data/processed/balanced_dataset.csv"
os.makedirs(os.path.dirname(OUTPUT_CSV), exist_ok=True)

RANDOM_STATE = 42
random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

# target range (11k - 11.5k)
TARGET_MIN = 11000
TARGET_MAX = 11500

# text-length buckets (words)
SHORT_MAX = 19      # <20 words -> short
MEDIUM_MAX = 100    # 20-100 words -> medium
# long -> >100 words

# minimal combined text length in characters to keep (very short posts removed)
MIN_CHARS_KEEP = 20

print("Config:")
print(" RAW_JSON:", RAW_JSON)
print(" OUTPUT_CSV:", OUTPUT_CSV)
print(" Random seed:", RANDOM_STATE)
print(" Target range per label:", TARGET_MIN, "-", TARGET_MAX)


Config:
 RAW_JSON: ../data/raw/data.json
 OUTPUT_CSV: ../data/processed/balanced_dataset.csv
 Random seed: 42
 Target range per label: 11000 - 11500


In [13]:
# CELL 2: Load raw JSON and show counts
with open(RAW_JSON, "r", encoding="utf-8") as f:
    raw = json.load(f)

print("Total raw entries:", len(raw))

df = pd.DataFrame(raw)
# Ensure columns exist
for col in ["id","title","selftext","label"]:
    if col not in df.columns:
        df[col] = ""

print("Columns available:", df.columns.tolist())
print("Initial label distribution:")
print(Counter(df["label"].fillna("missing").values))


Total raw entries: 122534
Columns available: ['id', 'subreddit', 'title', 'selftext', 'label', 'created_utc', 'url']
Initial label distribution:
Counter({'depression': 36884, 'suicidal': 33736, 'anxiety': 30725, 'normal': 21189})


In [14]:
# CELL 3: Cleaning helpers and execution

def normalize_whitespace(s):
    return re.sub(r"\s+", " ", s).strip()

_trans_table = str.maketrans("", "", string.punctuation)

def normalize_text_for_dup(s):
    """Normalize text for duplicate detection: lowercase, remove punctuation, collapse whitespace."""
    s = s.lower()
    s = s.translate(_trans_table)   # remove punctuation
    s = re.sub(r"\s+", " ", s).strip()
    return s

# 1) Ensure title & selftext are strings
df["title"] = df["title"].fillna("").astype(str).map(normalize_whitespace)
df["selftext"] = df["selftext"].fillna("").astype(str).map(normalize_whitespace)

# 2) Drop rows where EITHER title OR selftext is empty (user requested strict rule)
mask_keep = (df["title"].str.strip() != "") & (df["selftext"].str.strip() != "")
df = df[mask_keep].reset_index(drop=True)
print("After dropping rows where title OR selftext is empty:", len(df))

# 3) Create combined text and drop extremely short content
df["combined"] = (df["title"] + " " + df["selftext"]).str.strip()
df = df[df["combined"].str.len() >= MIN_CHARS_KEEP].reset_index(drop=True)
print(f"After removing combined length < {MIN_CHARS_KEEP} chars:", len(df))

# 4) Drop duplicate IDs (keep first)
df = df.drop_duplicates(subset=["id"]).reset_index(drop=True)
print("After dropping duplicate ids:", len(df))

# 5) Drop duplicates by normalized combined text
df["norm_combined"] = df["combined"].map(normalize_text_for_dup)
df = df.drop_duplicates(subset=["norm_combined"]).reset_index(drop=True)
print("After dropping duplicates by normalized text:", len(df))

# Final counts
print("Label counts after cleaning & de-dup:")
label_counts = Counter(df["label"].values)
for k,v in label_counts.items():
    print(f"  {k}: {v}")


After dropping rows where title OR selftext is empty: 114872
After removing combined length < 20 chars: 114860
After dropping duplicate ids: 114860
After dropping duplicates by normalized text: 114077
Label counts after cleaning & de-dup:
  suicidal: 32795
  depression: 36048
  anxiety: 29770
  normal: 15464


In [15]:
# CELL 4: Assign text-length buckets and define sampling function

def word_count_bucket(s):
    wc = len(s.split())
    if wc <= SHORT_MAX:
        return "short"
    elif wc <= MEDIUM_MAX:
        return "medium"
    else:
        return "long"

df["wc"] = df["combined"].map(lambda s: len(s.split()))
df["bucket"] = df["combined"].map(word_count_bucket)

print("Bucket distribution (overall):")
print(df["bucket"].value_counts())

# sampling helper: proportional by bucket within a class
def stratified_sample_by_buckets(class_df, target, seed=RANDOM_STATE):
    """
    class_df: DataFrame for one label with 'bucket' column
    target: desired number of samples (<= len(class_df))
    Returns sampled DataFrame of size 'target'.
    """
    if len(class_df) <= target:
        # nothing to do
        return class_df.sample(frac=1, random_state=seed).reset_index(drop=True)

    # compute bucket proportions
    bucket_counts = class_df["bucket"].value_counts().to_dict()
    total = sum(bucket_counts.values())
    # initial bucket targets (rounded)
    bucket_targets = {b: max(0, int(round((cnt/total)*target))) for b,cnt in bucket_counts.items()}

    # adjust rounding to sum to target
    assigned = sum(bucket_targets.values())
    # if assigned != target, adjust by adding/subtracting 1 to buckets by largest fractional remainders
    if assigned != target:
        # compute fractional parts exactly
        fractions = {b: (bucket_counts[b]/total)*target - bucket_targets[b] for b in bucket_counts}
        # sort buckets by descending fractional part if we need to add, ascending if subtract
        diff = target - assigned
        if diff > 0:
            order = sorted(fractions.items(), key=lambda x: -x[1])
            i = 0
            while diff > 0:
                bucket_targets[order[i % len(order)][0]] += 1
                diff -= 1
                i += 1
        elif diff < 0:
            order = sorted(fractions.items(), key=lambda x: x[1])
            i = 0
            while diff < 0:
                b = order[i % len(order)][0]
                if bucket_targets[b] > 0:
                    bucket_targets[b] -= 1
                    diff += 1
                i += 1

    # Now sample per bucket; if a bucket doesn't have enough, take all and reassign remainder
    sampled_parts = []
    remaining_to_fill = 0
    for b, btarget in bucket_targets.items():
        bucket_df = class_df[class_df["bucket"] == b]
        if len(bucket_df) >= btarget:
            sampled_parts.append(bucket_df.sample(n=btarget, random_state=seed))
        else:
            # take all and accumulate deficit
            sampled_parts.append(bucket_df)
            remaining_to_fill += (btarget - len(bucket_df))

    if remaining_to_fill > 0:
        # collect pool of rows not already taken
        taken_idx = pd.concat(sampled_parts).index
        pool = class_df.drop(index=taken_idx)
        if len(pool) >= remaining_to_fill:
            sampled_parts.append(pool.sample(n=remaining_to_fill, random_state=seed))
        else:
            # fallback: sample with replacement from class_df to meet exact size
            extra = class_df.sample(n=remaining_to_fill, replace=True, random_state=seed)
            sampled_parts.append(extra)

    result = pd.concat(sampled_parts, axis=0).sample(frac=1, random_state=seed).reset_index(drop=True)
    # sanity: trim/extend to target if necessary
    if len(result) > target:
        result = result.sample(n=target, random_state=seed).reset_index(drop=True)
    elif len(result) < target:
        # append random rows with replacement
        extra = class_df.sample(n=(target - len(result)), replace=True, random_state=seed)
        result = pd.concat([result, extra], axis=0).reset_index(drop=True)
    return result


Bucket distribution (overall):
bucket
long      82173
medium    27858
short      4046
Name: count, dtype: int64


In [16]:
# CELL 5: Compute a target per label (random in [TARGET_MIN, TARGET_MAX]) and perform stratified sampling
labels = sorted(df["label"].unique())
print("Labels present:", labels)

label_targets = {}
for lbl in labels:
    # choose target in range
    t = random.randint(TARGET_MIN, TARGET_MAX)
    # safety: if we have fewer than t in raw, set t to available (we only downsample here)
    available = int((df["label"] == lbl).sum())
    if available < t:
        t = available
    label_targets[lbl] = t

print("Targets per label (will not upsample):")
print(label_targets)

# perform sampling
sampled_parts = []
for lbl in tqdm(labels, desc="Sampling labels"):
    class_df = df[df["label"] == lbl].reset_index(drop=True)
    target = label_targets[lbl]
    if len(class_df) <= target:
        print(f"  {lbl}: has {len(class_df)} <= target {target}, keeping all")
        sampled = class_df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)
    else:
        print(f"  {lbl}: downsampling {len(class_df)} -> {target}")
        sampled = stratified_sample_by_buckets(class_df, target, seed=RANDOM_STATE)
    sampled_parts.append(sampled)

balanced_df = pd.concat(sampled_parts, axis=0).sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)
print("Combined balanced size:", len(balanced_df))
print("Label distribution after sampling:")
print(Counter(balanced_df["label"].values))


Labels present: ['anxiety', 'depression', 'normal', 'suicidal']
Targets per label (will not upsample):
{'anxiety': 11327, 'depression': 11057, 'normal': 11012, 'suicidal': 11379}


Sampling labels:  25%|██▌       | 1/4 [00:00<00:00,  7.35it/s]

  anxiety: downsampling 29770 -> 11327


Sampling labels:  50%|█████     | 2/4 [00:00<00:00,  8.09it/s]

  depression: downsampling 36048 -> 11057
  normal: downsampling 15464 -> 11012
  suicidal: downsampling 32795 -> 11379


Sampling labels: 100%|██████████| 4/4 [00:00<00:00,  8.83it/s]

Combined balanced size: 44775
Label distribution after sampling:
Counter({'suicidal': 11379, 'anxiety': 11327, 'depression': 11057, 'normal': 11012})





In [17]:
# CELL 6: Finalize columns and save CSV with only id, title, selftext, label
final_df = balanced_df[["id", "title", "selftext", "label"]].copy()

# Optional: ensure ids unique (safety)
final_df = final_df.drop_duplicates(subset=["id"]).reset_index(drop=True)
print("Final dataset size after ensuring unique ids:", len(final_df))

# Save to CSV
final_df.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")
print("Saved balanced dataset to:", OUTPUT_CSV)

# Print final counts
print("Final counts per label:")
print(Counter(final_df["label"].values))


Final dataset size after ensuring unique ids: 44775
Saved balanced dataset to: ../data/processed/balanced_dataset.csv
Final counts per label:
Counter({'suicidal': 11379, 'anxiety': 11327, 'depression': 11057, 'normal': 11012})
