In [1]:
import glob
import pickle
from collections import defaultdict
from datetime import datetime
from pathlib import Path
import os

import numpy as np
import pandas as pd
from gensim.models import Word2Vec
from gensim.models.phrases import Phraser
from scipy.spatial.distance import cosine

# ------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------
PERIOD_DEFINITIONS = {
    "before_2016": range(2011, 2017),
    "2017_2020": range(2017, 2021),
    "2021_2024": range(2021, 2025),
}
SUBREDDITS = ("democrats", "republican")
BASE_DATA_DIR = Path("processed_comments_2")
BIGRAM_MODEL_PATH = Path("../../models/bigram/political_bigram_1.phr")
NEUTRAL_MODEL_PATH = Path("../../models/neutral/neutral_2.model")
FREQ_FILE_DEM = Path("../../output/word_frequency/word_freq_yearly/democrats_withstopwords_year.csv")
FREQ_FILE_REP = Path("../../output/word_frequency/word_freq_yearly/republican_withstopwords_year.csv")
OUTPUT_DIR = Path("../../output/contextual/neutral")
os.makedirs(OUTPUT_DIR, exist_ok=True)
DATE_FIELDS = ("year", "created_year", "created_utc", "created")
CONTEXT_WINDOW = 5
EPS = 1e-8

# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def ensure_path(path: Path) -> Path:
    if not path.exists():
        raise FileNotFoundError(f"Missing file: {path.resolve()}")
    return path

def load_bigram_model(path: Path) -> Phraser:
    return Phraser.load(str(ensure_path(path)))

def load_neutral_model(path: Path) -> Word2Vec:
    return Word2Vec.load(str(ensure_path(path)))

def period_years(period: str) -> list[str]:
    return [str(y) for y in PERIOD_DEFINITIONS.get(period, [])]

def select_target_words(freq_file_dem: Path, freq_file_rep: Path, period: str) -> set[str]:
    years = period_years(period)
    if not years:
        return set()
    df_dem = pd.read_csv(ensure_path(freq_file_dem), index_col=0).reindex(columns=years).fillna(0)
    df_rep = pd.read_csv(ensure_path(freq_file_rep), index_col=0).reindex(columns=years).fillna(0)
    mask_dem = (df_dem >= 1).all(axis=1) & (df_dem.sum(axis=1) > 10)
    mask_rep = (df_rep >= 1).all(axis=1) & (df_rep.sum(axis=1) > 10)
    return set(df_dem.index[mask_dem]) & set(df_rep.index[mask_rep])

def comment_year(comment: dict) -> int | None:
    for field in DATE_FIELDS:
        value = comment.get(field)
        if value is None:
            continue
        if isinstance(value, (int, float)):
            return datetime.utcfromtimestamp(value).year
        if isinstance(value, str):
            try:
                return datetime.fromisoformat(value[:19]).year
            except ValueError:
                continue
    return None

def year_to_period(year: int | None) -> str | None:
    if year is None:
        return None
    for period, years in PERIOD_DEFINITIONS.items():
        if year in years:
            return period
    return None

def iter_comments(base_dir: Path, subreddit: str):
    pattern = base_dir / subreddit / f"{subreddit}_batch*.pkl"
    for file_path in sorted(glob.glob(str(pattern))):
        with open(file_path, "rb") as fh:
            comments = pickle.load(fh)
        for comment in comments:
            yield comment

def extract_context_vectors(
    base_dir: Path,
    subreddit: str,
    bigram_model: Phraser,
    target_words: set[str],
    period: str,
    embedding_model: Word2Vec,
) -> dict[str, list[np.ndarray]]:
    context_vectors: dict[str, list[np.ndarray]] = defaultdict(list)
    vocab = embedding_model.wv.key_to_index

    for comment in iter_comments(base_dir, subreddit):
        tokens = comment.get("processed_text")
        if not tokens or year_to_period(comment_year(comment)) != period:
            continue
        bigram_tokens = bigram_model[tokens]
        for idx, token in enumerate(bigram_tokens):
            if token not in target_words:
                continue
            left = max(0, idx - CONTEXT_WINDOW)
            right = min(len(bigram_tokens), idx + CONTEXT_WINDOW + 1)
            window = [
                bigram_tokens[j]
                for j in range(left, right)
                if j != idx and bigram_tokens[j] in vocab
            ]
            if not window:
                continue
            centroid = np.mean([embedding_model.wv[w] for w in window], axis=0)
            context_vectors[token].append(centroid)
    return context_vectors

def summarize_context_vectors(context_vectors: dict[str, list[np.ndarray]]) -> dict[str, dict]:
    stats = {}
    for word, vectors in context_vectors.items():
        if not vectors:
            continue
        arr = np.vstack(vectors)
        stats[word] = {
            "centroid": arr.mean(axis=0),
            "variance": float(arr.var(axis=0).mean()),
            "count": len(vectors),
        }
    return stats

def polarization_measure(stats_left: dict, stats_right: dict) -> dict[str, dict]:
    results = {}
    for word in set(stats_left).intersection(stats_right):
        c1, c2 = stats_left[word]["centroid"], stats_right[word]["centroid"]
        v1, v2 = stats_left[word]["variance"], stats_right[word]["variance"]
        cos_dist = cosine(c1, c2) if np.any(c1) and np.any(c2) else 0.0
        kl_div = np.log((v2 + EPS) / (v1 + EPS)) + (v1 + np.sum((c1 - c2) ** 2)) / (v2 + EPS) - 1
        results[word] = {
            "word": word,
            "cosine": float(cos_dist),
            "kl": float(kl_div),
            "count_left": stats_left[word]["count"],
            "count_right": stats_right[word]["count"],
        }
    return results

def save_period_results(period: str, polarization: dict[str, dict]) -> pd.DataFrame:
    df = pd.DataFrame(polarization.values()).sort_values("cosine", ascending=False)
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    df.to_csv(OUTPUT_DIR / f"polarization_{period}.csv", index=False)
    return df

# ------------------------------------------------------------------
# Pipeline
# ------------------------------------------------------------------
bigram_model = load_bigram_model(BIGRAM_MODEL_PATH)
neutral_model = load_neutral_model(NEUTRAL_MODEL_PATH)
targets_per_period = {
    period: select_target_words(FREQ_FILE_DEM, FREQ_FILE_REP, period)
    for period in PERIOD_DEFINITIONS
}

all_stats = {subreddit: {} for subreddit in SUBREDDITS}
for subreddit in SUBREDDITS:
    for period, targets in targets_per_period.items():
        if not targets:
            print(f"No shared targets for {period}; skipping {subreddit}.")
            all_stats[subreddit][period] = {}
            continue
        print(f"[{subreddit}] {period}: {len(targets)} targets")
        vectors = extract_context_vectors(
            base_dir=BASE_DATA_DIR,
            subreddit=subreddit,
            bigram_model=bigram_model,
            target_words=targets,
            period=period,
            embedding_model=neutral_model,
        )
        all_stats[subreddit][period] = summarize_context_vectors(vectors)

for period in PERIOD_DEFINITIONS:
    stats_dem = all_stats["democrats"].get(period, {})
    stats_rep = all_stats["republican"].get(period, {})
    polarization = polarization_measure(stats_dem, stats_rep)
    if not polarization:
        print(f"No overlapping words for {period}; CSV skipped.")
        continue
    save_period_results(period, polarization)
    print(f"Saved {period} results to {OUTPUT_DIR / f'polarization_{period}.csv'}")

[democrats] before_2016: 4889 targets
[democrats] 2017_2020: 11138 targets
[democrats] 2021_2024: 13324 targets
[republican] before_2016: 4889 targets
[republican] 2017_2020: 11138 targets
[republican] 2021_2024: 13324 targets
No overlapping words for before_2016; CSV skipped.
No overlapping words for 2017_2020; CSV skipped.
No overlapping words for 2021_2024; CSV skipped.
