## Imports

In [1]:
import os
import re
import json
from pathlib import Path
from typing import Iterator, Iterable, Dict, List, Tuple, Optional
from collections import defaultdict, deque

import pandas as pd
import numpy as np
import glob

import spacy
from spacy.lang.en.stop_words import STOP_WORDS
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation as LDA

import pyarrow as pa
import pyarrow.parquet as pq

# display prefs
pd.set_option('display.max_colwidth', 300)


## Configuration

In [2]:
# configurations
data_dir = Path('data')
comments_file = data_dir / 'amitheasshole_comments.ndjson'
submissions_file = data_dir / 'amitheasshole_submissions.ndjson'

remake_datafile = False  # set True to force rebuilds

interim_dir = Path('data/parquet')
artifacts_dir = Path('artifacts')
for p in [interim_dir, artifacts_dir]:
    p.mkdir(parents=True, exist_ok=True)

# parquet outputs
submissions_parquet = interim_dir / 'submissions_minimal.parquet'
joined_dir = interim_dir / 'joined_parquet'

# peek settings
peek_n = 1_000

# modelling sample size
submission_sample_n = 1_000
random_seed = 42


## ndjson loading and displaying

In [3]:
# the comments file is ~22 GB; never load fully into RAM. stream lines lazily.

def iter_ndjson_lines(path: Path) -> Iterator[dict]:
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                # corrupted line guard; skip
                continue

def peek_ndjson(path: Path, n: int) -> pd.DataFrame:
    # small materialisation to learn schema
    rows = []
    for i, obj in enumerate(iter_ndjson_lines(path)):
        rows.append(obj)
        if i + 1 >= n:
            break
    return pd.DataFrame(rows)

def normalize_submission_id_from_link_id(link_id: str) -> Optional[str]:
    # comments store parent submission as 't3_<id>'; we need bare '<id>'
    if not link_id:
        return None
    parts = link_id.split('_', 1)
    return parts[1] if len(parts) == 2 else link_id


In [None]:
df_sub_peek = peek_ndjson(submissions_file, peek_n)
df_com_peek = peek_ndjson(comments_file, peek_n)

print('submissions columns:', sorted(df_sub_peek.columns.tolist()))
print('comments columns:', sorted(df_com_peek.columns.tolist()))

display(df_sub_peek.head(5))
display(df_com_peek.head(5))


## Load, convert and combine submisions and comments

#### submissions to parquet

In [None]:
# shrink submissions to only the fields we need and store them in a fast format
sub_fields = ['id', 'title', 'selftext', 'link_flair_text', 'created_utc']

def stream_submissions_to_parquet(src: Path, dst: Path, fields=sub_fields, batch_size: int = 100_000):
    writer = None
    rows = []
    with src.open('r', encoding='utf-8') as f:
        for i, line in enumerate(f, 1):
            if not line.strip():
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            rows.append({k: obj.get(k) for k in fields})
            if len(rows) >= batch_size:
                df = pd.DataFrame(rows)
                table = pa.Table.from_pandas(df, preserve_index=False)
                if writer is None:
                    writer = pq.ParquetWriter(dst, table.schema)
                writer.write_table(table)
                rows.clear()
        if rows:
            df = pd.DataFrame(rows)
            table = pa.Table.from_pandas(df, preserve_index=False)
            if writer is None:
                writer = pq.ParquetWriter(dst, table.schema)
            writer.write_table(table)
    if writer is not None:
        writer.close()

if remake_datafile or not submissions_parquet.exists():
    submissions_parquet.parent.mkdir(parents=True, exist_ok=True)
    stream_submissions_to_parquet(submissions_file, submissions_parquet)
    print('wrote:', submissions_parquet)
else:
    print('exists:', submissions_parquet)


#### load submissions into ram

In [None]:
sub_df = pd.read_parquet(submissions_parquet, columns=['id', 'title', 'selftext', 'link_flair_text', 'created_utc'])
sub_df = sub_df.dropna(subset=['id']).drop_duplicates(subset=['id'])
sub_df = sub_df.set_index('id', drop=True)
approx_mb = sub_df.memory_usage(deep=True).sum() / 1e6
print('submissions frame:', sub_df.shape, f'~{approx_mb:.1f} MB in RAM')
display(sub_df.head(3))


#### hierarchical comments and submissions parquet

In [None]:
# build a hierarchical table directly from the raw comments + in-memory submissions index (sub_df)

hier_parquet = interim_dir / 'submissions_hierarchical.parquet'

# knobs
include_op_only = False  # set True to only keep OP (is_submitter == True) comments

def read_parquet_nested_safe(path: Path) -> pd.DataFrame:
    # why: safely load parquet with list (nested) columns
    table = pq.read_table(path)
    table = table.combine_chunks()
    return table.to_pandas()

# type-casting helpers to enforce homogeneous list element types
def _to_str(x):
    return None if x is None else str(x)

def _to_int(x):
    if x is None: return None
    try: return int(x)
    except Exception: return None

def _to_float(x):
    if x is None: return None
    try: return float(x)
    except Exception: return None

def _to_bool(x):
    if x is None: return None
    return bool(x)

def build_hierarchical_direct(comments_path: Path,
                              sub_index_df: pd.DataFrame,
                              out_path: Path,
                              include_op_only: bool = False,
                              limit_lines: Optional[int] = None) -> None:
    """
    Stream comments.ndjson and build a hierarchical table with one row per submission
    and *typed* list columns of comments (no cap). Writes with an explicit Arrow schema
    using LARGE string/list types to avoid 2GB offset overflow.
    """
    comments_by_sub: Dict[str, Dict[str, list]] = {}
    sub_meta: Dict[str, Dict[str, object]] = {}

    have_sub = sub_index_df.index
    get_row = sub_index_df.loc

    with comments_path.open('r', encoding='utf-8') as f:
        for i, line in enumerate(f, 1):
            if limit_lines and i > limit_lines:
                break
            if not line.strip():
                continue
            try:
                c = json.loads(line)
            except json.JSONDecodeError:
                continue

            sub_id = normalize_submission_id_from_link_id(c.get('link_id', ''))
            if not sub_id or sub_id not in have_sub:
                continue

            if include_op_only and not c.get('is_submitter', False):
                continue

            if sub_id not in sub_meta:
                s = get_row[sub_id]
                sub_meta[sub_id] = {
                    'submission_id': _to_str(sub_id),
                    'submission_title': _to_str(s.get('title')),
                    'submission_selftext': _to_str(s.get('selftext')),
                    'submission_flair': _to_str(s.get('link_flair_text')),
                    'submission_created_utc': _to_float(s.get('created_utc')),
                }

            if sub_id not in comments_by_sub:
                comments_by_sub[sub_id] = {
                    'comment_ids': [],
                    'comment_bodies': [],
                    'comment_scores': [],
                    'comment_created_utcs': [],
                    'is_submitters': []
                }

            buf = comments_by_sub[sub_id]
            buf['comment_ids'].append(_to_str(c.get('id')))
            buf['comment_bodies'].append(_to_str(c.get('body')))
            buf['comment_scores'].append(_to_int(c.get('score')))
            buf['comment_created_utcs'].append(_to_float(c.get('created_utc')))
            buf['is_submitters'].append(_to_bool(c.get('is_submitter')))

    # materialise to Arrow directly (skip pandas for writing nested columns)
    rows = []
    for sid, meta in sub_meta.items():
        buf = comments_by_sub.get(sid, {
            'comment_ids': [], 'comment_bodies': [], 'comment_scores': [],
            'comment_created_utcs': [], 'is_submitters': []
        })
        rows.append({
            **meta,
            'comment_ids': buf['comment_ids'],
            'comment_bodies': buf['comment_bodies'],
            'comment_scores': buf['comment_scores'],
            'comment_created_utcs': buf['comment_created_utcs'],
            'is_submitters': buf['is_submitters'],
            'n_comments': len(buf['comment_ids'])
        })

    # use LARGE string/list types to avoid offset overflow
    STR = pa.large_string()
    LSTR_LIST = pa.large_list(STR)

    schema = pa.schema([
        pa.field('submission_id', STR),
        pa.field('submission_title', STR),
        pa.field('submission_selftext', STR),
        pa.field('submission_flair', STR),
        pa.field('submission_created_utc', pa.float64()),
        pa.field('comment_ids', LSTR_LIST),
        pa.field('comment_bodies', LSTR_LIST),
        pa.field('comment_scores', pa.list_(pa.int64())),
        pa.field('comment_created_utcs', pa.list_(pa.float64())),
        pa.field('is_submitters', pa.list_(pa.bool_())),
        pa.field('n_comments', pa.int64()),
    ])

    def col(name, typ):
        return pa.array([row.get(name) for row in rows], type=typ)

    table = pa.table([
        col('submission_id', STR),
        col('submission_title', STR),
        col('submission_selftext', STR),
        col('submission_flair', STR),
        col('submission_created_utc', pa.float64()),
        col('comment_ids', LSTR_LIST),
        col('comment_bodies', LSTR_LIST),
        col('comment_scores', pa.list_(pa.int64())),
        col('comment_created_utcs', pa.list_(pa.float64())),
        col('is_submitters', pa.list_(pa.bool_())),
        col('n_comments', pa.int64()),
    ], schema=schema)

    pq.write_table(table, out_path)
    print(f'wrote hierarchical table (direct, typed): {out_path} ({table.num_rows} submissions)')

# build-or-load guard (no need for an intermediate "joined_*" step)
if remake_datafile or not hier_parquet.exists():
    build_hierarchical_direct(
        comments_path=comments_file,
        sub_index_df=sub_df,          # from block 6 (indexed by 'id')
        out_path=hier_parquet,
        include_op_only=include_op_only,
        limit_lines=None              # set e.g. 2_000 for a quick dry-run
    )
    # preview safely by reading only a projection (tiny) to avoid heavy nested ops
    preview = pq.read_table(hier_parquet, columns=['submission_id', 'n_comments']).slice(0, 2).to_pandas()
    display(preview)
else:
    hier_df = read_parquet_nested_safe(hier_parquet)
    print(f'using existing hierarchical file: {hier_parquet} ({len(hier_df)} submissions)')
    display(hier_df.head(2))


## Preproccessing

In [8]:
def ensure_spacy(nlp_name: str = 'en_core_web_sm'):
    try:
        return spacy.load(nlp_name, disable=['parser', 'textcat'])
    except OSError as e:
        raise RuntimeError(
            f"spaCy model {nlp_name!r} not installed. run: python -m spacy download {nlp_name}"
        ) from e

nlp = ensure_spacy()

url_re = re.compile(r'https?://\S+|www\.\S+')
nonword_re = re.compile(r"[^a-zA-Z']+")
multispace_re = re.compile(r'\s+')

def clean_text(text: Optional[str]) -> str:
    if not text:
        return ''
    text = url_re.sub(' ', text)
    text = text.lower()
    text = nonword_re.sub(' ', text)
    text = multispace_re.sub(' ', text).strip()
    return text

def lemmatize(text: str, nlp=nlp, do_ner: bool = True) -> Tuple[str, List[str]]:
    if not text:
        return '', []
    doc = nlp(text)
    lemmas = [t.lemma_ for t in doc if not (t.is_stop or t.is_punct or t.is_space)]
    ents = [f'{ent.label_}:{ent.text}' for ent in doc.ents] if do_ner else []
    return ' '.join(lemmas), ents

def preprocess_submission_row(row: dict, do_ner: bool = True) -> dict:
    raw = ' '.join([str(row.get('title') or ''), str(row.get('selftext') or '')]).strip()
    cleaned = clean_text(raw)
    lemmas, ents = lemmatize(cleaned, do_ner=do_ner)
    return {
        'id': row.get('id'),
        'flair': row.get('link_flair_text'),
        'created_utc': row.get('created_utc'),
        'text_raw': raw,
        'text_clean': cleaned,
        'text_lemmas': lemmas,
        'ents': ents
    }


In [None]:
# why: filter to usable posts once (before any NER) and cache them for fast reloads later.
#      This creates:
#        1) submissions_usable.parquet  (submissions only)
#        2) submissions_with_comments_usable.parquet  (from hierarchical: title/body/comments)

# thresholds
min_body_len = 50       # chars required in submission selftext
min_title_len = 40      # chars required in title
min_comment_len = 150   # total chars across all comments (hierarchical path)

removed_markers = {'[removed]', '[deleted]', None, ''}

def is_removed(txt):
    return (txt is None) or (str(txt).strip() in removed_markers)

def text_len(s):
    return 0 if s is None else len(str(s).strip())

def total_comment_chars(bodies):
    if not isinstance(bodies, list):
        return 0
    return sum(len(str(b or '')) for b in bodies)

# ---------- submissions-only usable ----------
# diagnostics on the full submissions index frame
diag = pd.DataFrame({
    'has_body': ~sub_df['selftext'].apply(is_removed),
    'title_len': sub_df['title'].apply(text_len),
    'body_len': sub_df['selftext'].apply(text_len),
    'flair_none': sub_df['link_flair_text'].isna()
})
print('total submissions:', len(sub_df))
print('with usable body:', int(diag['has_body'].sum()))
print('flair available:', int((~diag["flair_none"]).sum()))
display(diag.describe())

# keep if (body usable & body_len >= min_body_len) OR (title_len >= min_title_len)
usable_mask = (
    (~sub_df['selftext'].apply(is_removed) & (sub_df['selftext'].apply(text_len) >= min_body_len))
    | (sub_df['title'].apply(text_len) >= min_title_len)
)
usable_sub_df = sub_df.loc[usable_mask, ['title','selftext','link_flair_text','created_utc']].copy()
print('usable submissions (submissions-only):', len(usable_sub_df))

# cache (submissions only)
usable_submissions_parquet = interim_dir / 'submissions_usable.parquet'
# uncomment if you want them saved:
# usable_sub_df.to_parquet(usable_submissions_parquet, index=True)
# print('saved:', usable_submissions_parquet)


# ---------- hierarchical usable (submissions + comments) ----------
# build from hierarchical parquet
usable_hier_parquet = interim_dir / 'submissions_with_comments_usable.parquet'

if 'hier_df' in globals() or (('hier_parquet' in globals()) and Path(hier_parquet).exists()):
    # load hier_df if not already in memory
    if 'hier_df' not in globals():
        hier_df = pd.read_parquet(hier_parquet)

    # usable if any of: selftext long enough, title long enough, OR enough total comment text
    usable_hier_mask = (
        ((~hier_df['submission_selftext'].apply(is_removed)) & (hier_df['submission_selftext'].apply(text_len) >= min_body_len))
        | (hier_df['submission_title'].apply(text_len) >= min_title_len)
        | (hier_df['comment_bodies'].apply(total_comment_chars) >= min_comment_len)
    )
    usable_hier_df = hier_df.loc[usable_hier_mask].copy()
    print('usable submissions (with comments):', len(usable_hier_df), 'of', len(hier_df))

    # save a compact projection that’s all we need for later preprocessing-from-hierarchical
    cols_keep = [
        'submission_id','submission_title','submission_selftext',
        'submission_flair','submission_created_utc',
        'comment_bodies','is_submitters','n_comments'
    ]
    # uncomment if you want them saved:
    # usable_hier_df[cols_keep].to_parquet(usable_hier_parquet, index=False)
    # print('saved:', usable_hier_parquet)

    # quick peek
    display(usable_hier_df.head(2)[['submission_id','n_comments']])
else:
    print('hierarchical parquet not found — skipping submissions+comments usable cache. '
          'Build it first with build_hierarchical_direct().')


#### apply preproccessing to sumbitions and save file

In [None]:
# why: turn usable submissions into a clean, lemmatised, NER-tagged dataset and cache it for fast reloads

submissions_preprocessed_parquet = interim_dir / 'submissions_preprocessed.parquet'
build_submissions_preprocessed = True          # set False to skip rebuilding if file exists
overwrite_submissions_preprocessed = False     # set True to force overwrite

def preprocess_submissions_df(df_sub_indexed: pd.DataFrame, do_ner: bool = True) -> pd.DataFrame:
    """
    df_sub_indexed: dataframe indexed by 'id' with columns ['title','selftext','link_flair_text','created_utc']
    returns: dataframe with ['id','flair','created_utc','text_raw','text_clean','text_lemmas','ents']
    """
    records = []
    for sid, row in df_sub_indexed[['title','selftext','link_flair_text','created_utc']].iterrows():
        row_adapted = {
            'id': sid,
            'title': row.get('title'),
            'selftext': row.get('selftext'),
            'link_flair_text': row.get('link_flair_text'),
            'created_utc': row.get('created_utc'),
        }
        records.append(preprocess_submission_row(row_adapted, do_ner=do_ner))
    return pd.DataFrame(records)

if submissions_preprocessed_parquet.exists() and not overwrite_submissions_preprocessed and not build_submissions_preprocessed:
    submissions_preprocessed_df = pd.read_parquet(submissions_preprocessed_parquet)
    print('loaded:', submissions_preprocessed_parquet, f'({len(submissions_preprocessed_df)} rows)')
else:
    # source is the cached usable submissions from block 9 (indexed by 'id')
    if 'usable_sub_df' not in globals():
        raise RuntimeError('usable_sub_df not found. Run block 9 first to create/load it.')
    submissions_preprocessed_df = preprocess_submissions_df(usable_sub_df, do_ner=True)
    submissions_preprocessed_df.to_parquet(submissions_preprocessed_parquet, index=False)
    print('saved:', submissions_preprocessed_parquet, f'({len(submissions_preprocessed_df)} rows)')

display(submissions_preprocessed_df.head(3))


#### apply preproccessing to sumbitions+comments and save file

In [None]:
# why: turn hierarchical rows (title + selftext + all comments) into cleaned/lemmatised/NER'd text,
#      but do it in chunks to avoid RAM blowups and to allow resuming.

import math
import pyarrow.dataset as ds
import pyarrow.parquet as pq

parts_dir = interim_dir / 'submissions_with_all_comments_preprocessed_parts'
parts_dir.mkdir(parents=True, exist_ok=True)

out_merged_parquet = interim_dir / 'submissions_with_all_comments_preprocessed.parquet'
use_op_only_text = False        # True = only include OP comments in the text
batch_rows = 1_000              # tune based on CPU/RAM
resume = True                   # skip batches that already have a written part file

def ensure_list_of_str(x):
    # why: robustly coerce hierarchical column to a list[str]
    if x is None:
        return []
    if isinstance(x, list):
        return [str(v) if v is not None else '' for v in x]
    if isinstance(x, np.ndarray):
        return [str(v) if v is not None else '' for v in x.tolist()]
    # sometimes a scalar sneaks in; treat as single-comment list
    if isinstance(x, str):
        return [x]
    try:
        return [str(v) for v in list(x)]
    except Exception:
        return [str(x)]

def iter_hier_batches(path: Path, columns=None, batch_size: int = 1000):
    """
    Stream the hierarchical parquet in record batches using ParquetFile.iter_batches,
    yielding small pandas DataFrames. Works across PyArrow versions.
    """
    pf = pq.ParquetFile(str(path))
    # iter_batches returns pyarrow.RecordBatch objects in chunks
    for batch in pf.iter_batches(batch_size=batch_size, columns=columns):
        yield batch.to_pandas()


def preprocess_hierarchical_in_chunks(hier_path: Path,
                                      out_dir: Path,
                                      batch_rows: int = 1000,
                                      op_only: bool = False,
                                      resume: bool = True):
    cols = [
        'submission_id','submission_title','submission_selftext',
        'submission_flair','submission_created_utc',
        'comment_bodies','is_submitters','n_comments'
    ]
    part_idx = 0
    for pdf in iter_hier_batches(hier_path, columns=cols, batch_size=batch_rows):
        part_path = out_dir / f'part-{part_idx:05d}.parquet'
        if resume and part_path.exists():
            part_idx += 1
            continue

        # build raw text (title + selftext + comments)
        all_raw = []
        n_used = []
        for r in pdf.itertuples(index=False):
            bodies = ensure_list_of_str(r.comment_bodies)
            if op_only:
                flags = ensure_list_of_str(r.is_submitters)  # may still be list/ndarray of bools; ensure_list gives strings
                # convert flags back to bools where possible
                flags_bool = []
                for v in (r.is_submitters if isinstance(r.is_submitters, list) else (r.is_submitters.tolist() if isinstance(r.is_submitters, np.ndarray) else [])):
                    flags_bool.append(bool(v))
                if flags_bool:
                    bodies = [b for b, f in zip(bodies, flags_bool) if f]
            n_used.append(len(bodies))
            raw = ' '.join([
                str(r.submission_title or ''),
                str(r.submission_selftext or ''),
                ' '.join(bodies)
            ]).strip()
            all_raw.append(raw)

        # clean first (cheap), then spaCy once via pipe
        cleaned = [clean_text(t) for t in all_raw]
        docs = list(nlp.pipe(cleaned, batch_size=64, n_process=1))
        lemmas = [' '.join(t.lemma_ for t in doc if not (t.is_stop or t.is_punct or t.is_space)) for doc in docs]
        ents = [[f'{e.label_}:{e.text}' for e in doc.ents] for doc in docs]

        out_df = pd.DataFrame({
            'id': pdf['submission_id'].astype(str).values,
            'flair': pdf['submission_flair'].astype('string').where(pdf['submission_flair'].notna(), None),
            'created_utc': pdf['submission_created_utc'].astype(float).values,
            'text_raw': all_raw,
            'text_clean': cleaned,
            'text_lemmas': lemmas,
            'ents': ents,
            'n_comments': pdf['n_comments'].astype('Int64').values,
            'n_comment_bodies_used': n_used
        })

        # write this chunk (use Arrow schema with large strings to be safe)
        schema = pa.schema([
            pa.field('id', pa.string()),
            pa.field('flair', pa.string()),
            pa.field('created_utc', pa.float64()),
            pa.field('text_raw', pa.large_string()),
            pa.field('text_clean', pa.large_string()),
            pa.field('text_lemmas', pa.large_string()),
            pa.field('ents', pa.list_(pa.string())),
            pa.field('n_comments', pa.int64()),
            pa.field('n_comment_bodies_used', pa.int64()),
        ])
        pq.write_table(pa.Table.from_pandas(out_df, schema=schema, preserve_index=False),
                       part_path)
        print(f'wrote {part_path} ({len(out_df)} rows)')
        part_idx += 1

    print('done preprocessing hierarchical in chunks.')

def merge_preprocessed_parts(parts_dir: Path, out_path: Path):
    # why: merge part files into a single parquet (optional; you can also keep the dir as a dataset)
    parts = sorted(parts_dir.glob('part-*.parquet'))
    if not parts:
        print('no parts to merge.')
        return
    tables = [pq.read_table(p) for p in parts]
    pq.write_table(pa.concat_tables(tables, promote=True), out_path)
    print('merged', len(parts), 'parts into', out_path)

# run it
if 'usable_hier_df' in globals():
    # if you built a usable subset parquet in block 9, use that file to stream
    hier_source = interim_dir / 'submissions_with_comments_usable.parquet'
    if not hier_source.exists():
        # fallback to full hierarchical if usable subset file not written
        hier_source = hier_parquet
else:
    hier_source = hier_parquet

preprocess_hierarchical_in_chunks(
    hier_path=hier_source,
    out_dir=parts_dir,
    batch_rows=batch_rows,
    op_only=use_op_only_text,
    resume=resume
)

# optional: merge all parts into a single file (can skip if you like dataset-of-parts)
# merge_preprocessed_parts(parts_dir, out_merged_parquet)


## Sample submissions

In [None]:
# def fetch_random_submissions_df_from_usable(usable_index_df: pd.DataFrame, n: int, seed: int = 42) -> pd.DataFrame:
#     rng = np.random.default_rng(seed)
#     ids = usable_index_df.index.values
#     pick = ids if n >= len(ids) else rng.choice(ids, size=n, replace=False)
#     df = usable_index_df.loc[pick, ['title', 'selftext', 'link_flair_text', 'created_utc']].reset_index()
#     df = df.rename(columns={'index': 'id'})
#     return df

# def build_corpus(df_sub: pd.DataFrame, do_ner: bool = True) -> pd.DataFrame:
#     processed = [preprocess_submission_row(row, do_ner=do_ner) for row in df_sub.to_dict('records')]
#     return pd.DataFrame(processed)

# df_sub_sample = fetch_random_submissions_df_from_usable(usable_sub_df, submission_sample_n, seed=random_seed)
# print('submissions sampled for modelling (usable only):', len(df_sub_sample))
# display(df_sub_sample.head(3))

# df_corpus = build_corpus(df_sub_sample, do_ner=True)
# display(df_corpus.head(3)[['id', 'flair', 'text_raw']])

df_corpus = submissions_preprocessed_df

## LDA topic extraction

In [None]:
# fit LDA on lemmatised bag-of-words
def fit_lda(texts: Iterable[str], max_features: int = 50_000, n_topics: int = 15, max_df: float = 0.5, min_df: int = 10, random_state: int = 42):
    vectorizer = CountVectorizer(
        max_features=max_features,
        stop_words='english',
        max_df=max_df,
        min_df=min_df
    )
    X = vectorizer.fit_transform(texts)
    lda = LDA(n_components=n_topics, random_state=random_state, learning_method='batch')
    W = lda.fit_transform(X)  # doc-topic matrix
    H = lda.components_        # topic-term matrix
    return lda, vectorizer, W, H

lda, vect, W, H = fit_lda(df_corpus['text_lemmas'].tolist(), n_topics=15, random_state=random_seed)

def top_words_per_topic(vect, H, topn: int = 15) -> List[List[str]]:
    feature_names = np.array(vect.get_feature_names_out())
    topics = []
    for k in range(H.shape[0]):
        idx = np.argsort(H[k])[::-1][:topn]
        topics.append(feature_names[idx].tolist())
    return topics

topics_top_words = top_words_per_topic(vect, H, topn=15)
for i, words in enumerate(topics_top_words):
    print(f'topic {i:02d}:', ', '.join(words))

# map topics to 5 categories using keyword overlap
category_keywords = {
    'finances': {
        'money','pay','paid','rent','bill','bills','loan','debt','card','credit','cash','salary','bonus','split','cost','expensive','cheap','wedding','gift','refund','share','finance'
    },
    'relationship': {
        'relationship','boyfriend','girlfriend','partner','date','dating','romantic','love','cheat','ex','fiancé','fiance','fiancee','breakup','trust','jealous'
    },
    'family_conflict': {
        'mom','dad','mother','father','sister','brother','siblings','family','cousin','aunt','uncle','inlaws','in','law','grandma','grandpa','child','baby','pregnant','wedding','name'
    },
    'work': {
        'work','job','boss','coworker','manager','shift','hours','office','remote','payroll','promotion','hr','fire','fired','leave','paternity','maternity'
    },
    'societal_norms': {
        'culture','religion','religious','tradition','gender','pronoun','politics','law','legal','illegal','discrimination','racist','ableist','ethics','value','norm','boundary','consent'
    }
}

def score_topic_to_category(words: List[str]) -> Tuple[str, Dict[str, int]]:
    scores = {cat: 0 for cat in category_keywords}
    wordset = set(words)
    for cat, kw in category_keywords.items():
        scores[cat] = len(wordset & kw)
    best_cat = max(scores, key=scores.get)
    return best_cat, scores

topic_category = []
for i, words in enumerate(topics_top_words):
    best, scores = score_topic_to_category(words)
    topic_category.append({'topic': i, 'category': best, **scores})

df_topic_map = pd.DataFrame(topic_category).sort_values(['category', 'topic'])
display(df_topic_map)


## Assign topic with NER bias

In [None]:
topic_labels = np.argmax(W, axis=1)
df_corpus['topic'] = topic_labels

topic_to_cat = {row['topic']: row['category'] for _, row in df_topic_map.iterrows()}
df_corpus['category_initial'] = df_corpus['topic'].map(topic_to_cat).fillna('societal_norms')

def ner_bias_category(ents: List[str], current: str) -> str:
    labels = [e.split(':', 1)[0] for e in ents]
    if any(lbl in ('NORP', 'LAW') for lbl in labels) and current in ('relationship', 'work', 'finances'):
        return 'societal_norms'
    if any(lbl in ('PERSON',) for lbl in labels) and current == 'societal_norms':
        return 'family_conflict'
    return current

df_corpus['category'] = [
    ner_bias_category(ents, cat) for ents, cat in zip(df_corpus['ents'], df_corpus['category_initial'])
]

category_counts = df_corpus['category'].value_counts().rename_axis('category').reset_index(name='count')
display(category_counts)


## Examples

In [None]:
def examples_by_category(df: pd.DataFrame, cat: str, k: int = 5) -> pd.DataFrame:
    ex = df.loc[df['category'] == cat, ['id', 'flair', 'text_raw']].head(k).copy()
    return ex

for cat in ['finances', 'relationship', 'family_conflict', 'work', 'societal_norms']:
    print(f'\n=== {cat.upper()} EXAMPLES ===')
    display(examples_by_category(df_corpus, cat, k=5))
