### 0. Imports and auxiliary functions

In [22]:
import sys, os
sys.path.append("TAALED_1_4_1_Py3") 
sys.path.append("TAACO")            

# --- Core libs ---
import numpy as np
import pandas as pd
from pathlib import Path
import glob, tempfile

# --- HuggingFace (BERT) ---
from transformers import BertTokenizerFast, BertModel
import torch

# --- Progress bar ---
from tqdm import tqdm
tqdm.pandas()

# --- spaCy compatibility (older TAACO/TAALED expect this) ---
import spacy
if not hasattr(spacy.util, "set_data_path"):
    spacy.util.set_data_path = lambda *a, **k: None

# --- TAASSC (your local module) ---
import TAASSC_215_dev as tdev
from TAASSC_215_dev import LGR_Analysis, index_list

# --- TAALED setup (needs GUI stubs) ---
import TAALED_1_4_1 as TAALED
class _Root:
    def update_idletasks(self): pass
TAALED.root = _Root()
TAALED.system = "L" 

# --- TAACO setup (import + GUI stubs + resource path) ---
import TAACOnoGUI
from TAACOnoGUI import runTAACO
TAACOnoGUI.root = _Root()
TAACOnoGUI.system = "L"

In [23]:
def read_df(path: str) -> pd.DataFrame:
    """Read DataFrame from CSV or Parquet based on file extension."""
    if path.endswith(".parquet"):
        return pd.read_parquet(path)
    return pd.read_csv(path)

def save_csv(df, path):
    """Save DataFrame as CSV and print confirmation."""
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(path, index=False)
    print(f"Saved: {path}")

---- 

### 0. Define witch data we are preprocessing (input data)
- Original data source (includes cleaning + preprocessing steps)
- Rewritten text/data (only preprocessing steps)

##### Setting up the reading and saving directory
- Input data choice: original/rewritten
- Input directory: where the data is stored (original/rewritten)
- Output directory: where the preprocessed data will be saved

In [5]:
# ==== RUN SETTINGS ====
MODE = "rewrites" # "original" or "rewrites"

# Original data
ORIGINAL_PATH = "../data/persuade/persuade_2.0_human_scores_demo_id_github.csv"
SAVE_LOW_HIGH = True
COLUMNS_KEEP = ['full_text', 'holistic_essay_score', 'race_ethnicity', 'gender', 'grade_level', 'economically_disadvantaged', 'prompt_name']

# cv_fold
MODEL_DIR = "../model/run_01/"
SCORED_HIGH = "data_high_scored.csv"
SCORED_LOW = "data_low_scored.csv"

# Rewrites data
REWRITES_FOLDER = "../data/rewrites/sat/"
REWRITES_PATTERN = "rew_sat_{}.csv"   

# Outputs
EMB_DIR = "../embeddings/sat/"
OUT_DIR = "../data/processed/sat/"
TEXT_COL = "rewritten_text"

RUN_TAALED = True
RUN_TAACO = True
RUN_TAASSC = True


---- 

### 1. Load Data

In [6]:
if MODE == "original":
    print("Loading original dataset")
    df = read_df(ORIGINAL_PATH)
elif MODE == "rewrites":
    print("Loading rewrite datasets")
    dfs = [read_df(REWRITES_FOLDER+REWRITES_PATTERN.format(i)) for i in range(1,7)]
    print(f"Loaded {len(dfs)} rewrite files")
    
else:
    raise ValueError("MODE must be 'original' or 'rewrites'")


Loading rewrite datasets
Loaded 6 rewrite files


In [7]:
if MODE == 'rewrites':
    df_high = pd.read_csv(MODEL_DIR + SCORED_HIGH)
    df_low = pd.read_csv(MODEL_DIR + SCORED_LOW)

    df_concat = pd.concat([df_high, df_low], ignore_index=True)

    for i, df_ in enumerate(dfs):
        dfs[i] = df_.merge(
            df_concat[['text', 'cv_fold']],
            on='text',
            how='left'
        )

    df_concat = df_concat.set_index("text", drop=False).loc[dfs[0]["text"]]
    df_concat.index = dfs[0].index

    df_concat.to_csv(OUT_DIR + "original.csv", index=False)


-----

### 2. Data cleaning (Only for original data source)

In [12]:
if MODE == "original":
    df = df[COLUMNS_KEEP]
    df = df.dropna()
    df.reset_index(drop=True, inplace=True)
    df = df.rename(columns={'full_text': 'text'})
    df['economically_disadvantaged'] = df['economically_disadvantaged'].map({'Economically disadvantaged': 1, 'Not economically disadvantaged': 0})

-----

### 3. Data Preprocessing Steps (NLP tools)
- TAALED
- TAACCO
- TAASSCC

### TAALED

In [13]:
if not hasattr(spacy.util, "set_data_path"):
    spacy.util.set_data_path = lambda *a, **k: None
class _Root:
    def update_idletasks(self):  
        pass

TAALED.root = _Root()  # satisfy TAALED's references to a Tk root
TAALED.system = "L"    # pretend we're on Linux/Mac ('L' or 'M'); avoids GUI branches

In [14]:
# ==== TAALED runner (always 'taaled_' prefix) ====
TAALED_VAR_DICT = {
    "aw": 1, "cw": 1, "fw": 1,
    "simple_ttr": 1, "root_ttr": 1, "log_ttr": 1, "maas_ttr": 1,
    "mattr": 1, "msttr": 1, "hdd": 1,
    "mltd": 1, "mltd_ma": 1, "mtld_wrap": 1,
    "indout": 0,
}

def _detect_filename_col(res):
    # Find filename column in TAALED's CSV
    cands = {"filename","file","file_name","textname","doc","document","name"}
    for c in res.columns:
        if c.lower() in cands:
            return c
    return res.columns[0]

def _run_taaled_once(df, *, text_col, var_dict):
    # Run TAALED on one DataFrame and merge taaled_* metrics back
    if text_col not in df.columns:
        raise ValueError(f"TEXT_COL='{text_col}' not found in columns: {list(df.columns)}")
    df = df.copy()

    with tempfile.TemporaryDirectory() as tmp_dir:
        for i, txt in df[text_col].items():
            with open(os.path.join(tmp_dir, f"{i}.txt"), "w", encoding="utf-8") as f:
                f.write(txt if isinstance(txt, str) else "")

        out_csv = os.path.join(tmp_dir, "taaled_out.csv")
        import tkinter.messagebox as mb

        def _no_popup(*args, **kwargs):
            # do nothing instead of showing a dialog
            return None

        mb.showinfo = _no_popup
        
        TAALED.main(tmp_dir, out_csv, var_dict)
        res = pd.read_csv(out_csv)

    fn_col = _detect_filename_col(res)
    res["__idx__"] = res[fn_col].astype(str).str.replace(".txt", "", regex=False)
    df["__idx__"] = df.index.astype(str)

    metric_cols = [c for c in res.columns if c not in (fn_col, "__idx__")]
    res = res.rename(columns={c: f"taaled_{c}" for c in metric_cols})

    merged = df.merge(res.drop(columns=[fn_col]), on="__idx__", how="left")
    return merged.drop(columns="__idx__")

def run_taaled(df_or_list, *, mode, text_col, var_dict):
    if mode == "original":
        return _run_taaled_once(df_or_list, text_col=text_col, var_dict=var_dict)
    elif mode == "rewrites":
        return [_run_taaled_once(d, text_col=text_col, var_dict=var_dict) for d in df_or_list]
    else:
        raise ValueError("mode must be 'original' or 'rewrites'")

# ---- EXECUTE ----
if RUN_TAALED:
    if MODE == "original":
        df = run_taaled(df, mode="original", text_col=TEXT_COL, var_dict=TAALED_VAR_DICT)
    elif MODE == "rewrites":
        # Ensure each rewrite df has TEXT_COL; rename here if your files use another name
        for i, d in enumerate(dfs):
            if TEXT_COL not in d.columns and "full_text" in d.columns:
                dfs[i] = d.rename(columns={"full_text": TEXT_COL})
        dfs = run_taaled(dfs, mode="rewrites", text_col=TEXT_COL, var_dict=TAALED_VAR_DICT)
    print("TAALED done.")
else:
    print("RUN_TAALED is False â€” skipping.")

TAALED done.


------

### TAACO

In [15]:
opts = {
    "sourceKeyOverlap": False, "sourceLSA": False, "sourceLDA": False, "sourceWord2vec": False,
    "wordsAll": True, "wordsContent": True, "wordsFunction": True,
    "wordsNoun": True, "wordsPronoun": True, "wordsArgument": True,
    "wordsVerb": True, "wordsAdjective": True, "wordsAdverb": True,
    "overlapSentence": True, "overlapParagraph": True,
    "overlapAdjacent": True, "overlapAdjacent2": True,
    "otherTTR": True, "otherConnectives": True, "otherGivenness": True,
    "overlapLSA": True, "overlapLDA": True, "overlapWord2vec": True,
    "overlapSynonym": True, "overlapNgrams": True,
    "outputTagged": False, "outputDiagnostic": False,
}

# ==== TAACO runner (simplified, always 'taaco_' prefix) ====

def _detect_taaco_filename_col(res):
    """Find which column in TAACO's CSV contains the filenames."""
    candidates = {"filename","file","file_name","textname","doc","document","name"}
    for c in res.columns:
        if c.lower() in candidates:
            return c
    for c in res.columns:  # fallback: looks like '*.txt'
        try:
            if res[c].astype(str).str.endswith(".txt").any():
                return c
        except Exception:
            pass
    return res.columns[0]

def _run_taaco_once(df, *, text_col, opts):
    """Run TAACO on a single DataFrame and merge taaco_ metrics back."""
    if text_col not in df.columns:
        raise ValueError(f"TEXT_COL='{text_col}' not found")
    df = df.copy()

    with tempfile.TemporaryDirectory() as tmp_dir:
        for i, txt in df[text_col].items():
            with open(os.path.join(tmp_dir, f"{i}.txt"), "w", encoding="utf-8") as f:
                f.write(txt if isinstance(txt, str) else "")

        out_csv = os.path.join(tmp_dir, "taaco_out.csv")
        runTAACO(tmp_dir, out_csv, opts)
        res = pd.read_csv(out_csv)

    fn_col = _detect_taaco_filename_col(res)
    res["__idx__"] = res[fn_col].astype(str).str.replace(".txt", "", regex=False)
    df["__idx__"] = df.index.astype(str)

    metric_cols = [c for c in res.columns if c not in (fn_col, "__idx__")]
    res = res.rename(columns={c: f"taaco_{c}" for c in metric_cols})

    merged = df.merge(res.drop(columns=[fn_col]), on="__idx__", how="left")
    return merged.drop(columns="__idx__")

def run_taaco(df_or_list, *, mode, text_col, opts):
    """Handles both modes: DataFrame (original) or list[DataFrame] (rewrites)."""
    if mode == "original":
        return _run_taaco_once(df_or_list, text_col=text_col, opts=opts)
    elif mode == "rewrites":
        return [_run_taaco_once(d, text_col=text_col, opts=opts) for d in df_or_list]
    else:
        raise ValueError("mode must be 'original' or 'rewrites'")

# ---- EXECUTE ----
if RUN_TAACO:
    # Save current working directory and switch into TAACO folder
    orig_cwd = os.getcwd()
    taaco_dir = "TAACO"  
    os.chdir(taaco_dir)

    try:
        if MODE == "original":
            df = run_taaco(df, mode="original", text_col=TEXT_COL, opts=opts)

        elif MODE == "rewrites":
            for i, d in enumerate(dfs):
                if TEXT_COL not in d.columns and "full_text" in d.columns:
                    dfs[i] = d.rename(columns={"full_text": TEXT_COL})
            dfs = run_taaco(dfs, mode="rewrites", text_col=TEXT_COL, opts=opts)

        print("TAACO done.")

    finally:
        # Always restore the original directory even if TAACO crashes
        os.chdir(orig_cwd)
        print(f"â†’ reverted working dir to: {os.getcwd()}")




Loading Spacy
Loading Spacy Model
Starting TAACO...
Loading LSA vector space...
Loading LDA vector space...
Loading word2vec vector space...
outdir: /var/folders/lm/psfj50g95wgczw9z2l6zf32m0000gn/T/tmpd903j3w4/taaco_out.csv
key_out_dir: /var/folders/lm/psfj50g95wgczw9z2l6zf32m0000gn/T/tmpd903j3w4
TAACO is processing 1 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))
  return(np.sum([v for v in A * np.log2(A/B) if not np.isnan(v)]))


TAACO is processing 2 of 958 files
TAACO is processing 3 of 958 files
TAACO is processing 4 of 958 files
TAACO is processing 5 of 958 files
TAACO is processing 6 of 958 files
TAACO is processing 7 of 958 files
TAACO is processing 8 of 958 files
TAACO is processing 9 of 958 files
TAACO is processing 10 of 958 files
TAACO is processing 11 of 958 files
TAACO is processing 12 of 958 files
TAACO is processing 13 of 958 files
TAACO is processing 14 of 958 files
TAACO is processing 15 of 958 files
TAACO is processing 16 of 958 files
TAACO is processing 17 of 958 files
TAACO is processing 18 of 958 files
TAACO is processing 19 of 958 files
TAACO is processing 20 of 958 files
TAACO is processing 21 of 958 files
TAACO is processing 22 of 958 files
TAACO is processing 23 of 958 files
TAACO is processing 24 of 958 files
TAACO is processing 25 of 958 files
TAACO is processing 26 of 958 files
TAACO is processing 27 of 958 files
TAACO is processing 28 of 958 files
TAACO is processing 29 of 958 files


----

### TAASSC

In [16]:
# ==== TAASSC runner (simplified, always 'taassc_' prefix) ====

def _run_taassc_once(df, *, text_col, index_list):
    """Run TAASSC (LGR_Analysis) on one DataFrame and merge 'taassc_' metrics back."""
    if text_col not in df.columns:
        raise ValueError(f"TEXT_COL='{text_col}' not found in columns: {list(df.columns)[:12]}...")
    df = df.copy()

    records = []
    for txt in df[text_col].fillna(""):
        try:
            res = LGR_Analysis(txt)  # dict of TAASSC metrics
            row = {m: res.get(m, float('nan')) for m in index_list}
        except Exception:
            row = {m: float('nan') for m in index_list}
        records.append(row)

    metrics_df = pd.DataFrame.from_records(records, index=df.index)
    metrics_df.columns = [f"taassc_{m}" for m in index_list]

    return pd.concat([df, metrics_df], axis=1)

def run_taassc(df_or_list, *, mode, text_col, index_list):
    """Handles both modes: DataFrame (original) or list[DataFrame] (rewrites)."""
    if mode == "original":
        return _run_taassc_once(df_or_list, text_col=text_col, index_list=index_list)
    elif mode == "rewrites":
        return [_run_taassc_once(d, text_col=text_col, index_list=index_list) for d in df_or_list]
    else:
        raise ValueError("mode must be 'original' or 'rewrites'")

# ---- EXECUTE ----
if RUN_TAASSC:
    if MODE == "original":
        df = run_taassc(df, mode="original", text_col=TEXT_COL, index_list=index_list)
    elif MODE == "rewrites":
        # Ensure each rewrite df has TEXT_COL; rename here if needed
        for i, d in enumerate(dfs):
            if TEXT_COL not in d.columns and "full_text" in d.columns:
                dfs[i] = d.rename(columns={"full_text": TEXT_COL})
        dfs = run_taassc(dfs, mode="rewrites", text_col=TEXT_COL, index_list=index_list)
    print("TAASSC done.")
else:
    print("RUN_TAASSC is False â€” skipping.")

TAASSC done.


-----

### 4. Save Preprocessed Data

In [17]:
# ==== SAVE CLEANED DATASETS (CSV only) ====

if MODE == "original":
    print("Saving original dataset and SES splits...")
    base = Path(OUT_DIR)
    base.mkdir(parents=True, exist_ok=True)

    # Full version
    save_csv(df, base / "original_full.csv")

    # Low / High SES splits
    if SAVE_LOW_HIGH:
        low = df[df["economically_disadvantaged"] == 1].reset_index(drop=True)
        high = df[df["economically_disadvantaged"] == 0].reset_index(drop=True)
        save_csv(low, base / "original_low_SES.csv")
        save_csv(high, base / "original_high_SES.csv")
        
elif MODE == "rewrites":
    print("Saving all rewrite datasets...")
    base = Path(OUT_DIR)
    base.mkdir(parents=True, exist_ok=True)

    for i, d in enumerate(dfs):
        name = f"rewrite_{i+1}.csv"
        save_csv(d, base / name)

    print(f"Saved {len(dfs)} rewritten datasets to {base}")

print("All saves complete.")


Saving all rewrite datasets...
Saved: ../data/processed/sat/rewrite_1.csv
Saved: ../data/processed/sat/rewrite_2.csv
Saved: ../data/processed/sat/rewrite_3.csv
Saved: ../data/processed/sat/rewrite_4.csv
Saved: ../data/processed/sat/rewrite_5.csv
Saved: ../data/processed/sat/rewrite_6.csv
Saved 6 rewritten datasets to ../data/processed/sat
All saves complete.


----

### Embeddings

In [18]:
device = torch.device('mps') if (torch.backends.mps.is_available()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').eval().to(device)

def get_embeddings(text: str) -> np.ndarray:
    """
    Run one text through BERT, return the [CLS] embedding as a numpy vector.
    """
    inputs = tokenizer(text,
                       return_tensors='pt',
                       padding=True,
                       truncation=True,
                       max_length=512)
    # move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    # [batch=1, seq, dim] â†’ pick CLS token embedding
    cls_emb = outputs.last_hidden_state[0, 0, :]
    return cls_emb.cpu().numpy()

# ==== EMBEDDINGS SAVE (BERT [CLS]) ====
def _embed_df(df, text_col):
    texts = df[text_col].fillna("").tolist()
    it = texts
    if "tqdm" in globals():  # optional progress bar if you already imported tqdm
        it = tqdm(texts, desc="Embedding", total=len(texts))
    embs = [get_embeddings(t) for t in it]
    return np.vstack(embs) if len(embs) else np.zeros((0, model.config.hidden_size), dtype=float)

In [20]:
EMB_DIR = Path(EMB_DIR)

if MODE == "original":
    X_full = _embed_df(df, TEXT_COL)
    np.save(EMB_DIR / "embeddings_original_full.npy", X_full)

    if SAVE_LOW_HIGH:
        low = df[df["economically_disadvantaged"] == 1].reset_index(drop=True)
        high = df[df["economically_disadvantaged"] == 0].reset_index(drop=True)

        X_low = _embed_df(low, TEXT_COL) if len(low) else np.zeros((0, X_full.shape[1]), dtype=float)
        X_high = _embed_df(high, TEXT_COL) if len(high) else np.zeros((0, X_full.shape[1]), dtype=float)

        np.save(EMB_DIR / "embeddings_original_low.npy", X_low)
        np.save(EMB_DIR / "embeddings_original_high.npy", X_high)

        print(f"Saved embeddings to {EMB_DIR}")

elif MODE == "rewrites":
    
    items = [(f"rewrite_{i+1}", d) for i, d in enumerate(dfs)]

    for name, d in items:
        # Ensure expected text column
        if TEXT_COL not in d.columns and "full_text" in d.columns:
            d = d.rename(columns={"full_text": TEXT_COL})
        if TEXT_COL not in d.columns:
            raise KeyError(f"{name}: expected '{TEXT_COL}' column. Got: {list(d.columns)[:12]}...")

        X = _embed_df(d, TEXT_COL)
        np.save(EMB_DIR / f"embeddings_{name}.npy", X)
        print(f"{name}: saved {X.shape} to {EMB_DIR / f'embeddings_{name}.npy'}")

    print(f"Saved {len(items)} rewrite embeddings to {EMB_DIR}")

Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:26<00:00, 36.20it/s]


rewrite_1: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_1.npy


Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:29<00:00, 32.82it/s]


rewrite_2: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_2.npy


Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:28<00:00, 33.32it/s]


rewrite_3: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_3.npy


Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:28<00:00, 33.13it/s]


rewrite_4: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_4.npy


Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:30<00:00, 31.75it/s]


rewrite_5: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_5.npy


Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:31<00:00, 30.64it/s]

rewrite_6: saved (958, 768) to ../embeddings/sat/embeddings_rewrite_6.npy
Saved 6 rewrite embeddings to ../embeddings/sat





### Embedding for the original df matching the rewrites

In [21]:
# Ensure TEXT_COL is defined
TEXT_COL = "text"   # change if needed

# Compute embeddings
X = _embed_df(df_concat, TEXT_COL)

# Save as .npy
out_path = "embeddings_original.npy"
np.save(EMB_DIR / out_path, X)

print(f"Saved embeddings to {out_path} with shape {X.shape}")

Embedding: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 958/958 [00:28<00:00, 33.36it/s]

Saved embeddings to embeddings_original.npy with shape (958, 768)



