"""
Préparation des données d'évaluation (version alignée sur test.ipynb)

- Calcule hors-ligne les mesures de complexité (lexicales, syntaxiques, discursives)
  * Lexical:  MTLD, LD, LS   (stanza + COW5000 via textcomplexity + WordNet)
  * Syntaxe:  MDD, CS        (stanza: depparse + constituency)
  * Discours: LC, CoH        (spaCy: lemmas + word vectors)

- Écarte toute ligne si une mesure est None/NaN (robustesse exigée)
- Marque et supprime les lignes où Simple domine déjà Complex (>= toutes métriques non-NaN
  + progrès strict dans chaque famille lex/syn/disc), en conservant leurs IDs originaux
- Conserve les IDs d’origine même après filtrage
- Exporte: __augmented.csv/.parquet, __filtered.csv/.parquet, __removed_ids.json, __report.json

Usage:
  python prepare_from_notebook.py --input data_sampled/OSE_adv_ele.csv --sep '\t' --outdir prepared

Tu peux aussi lui donner un autre dataset listé dans ton dict si besoin.
"""

In [1]:
# Cell 1 — Imports & dataset mapping
from __future__ import annotations

# Standard
import json, math
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

# Third-party
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import nltk
from nltk.corpus import wordnet as wn

import stanza
import spacy

# textcomplexity for COW top-5000 (en.json)
import importlib.resources as pkg_resources
import textcomplexity

# sklearn (used indirectly in coherence; keep handy)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# ---- Datasets (.txt TSV)
datasets = {
    'ose_adv_ele': 'data_sampled/OSE_adv_ele.txt',
    'ose_adv_int': 'data_sampled/OSE_adv_int.txt',
    'swipe':       'data_sampled/swipe.txt',
    'vikidia':     'data_sampled/vikidia.txt',
}

def load_data(path: str) -> pd.DataFrame:
    return pd.read_csv(path, sep='\t')

def load_dataset(name: str) -> pd.DataFrame:
    if name not in datasets:
        raise ValueError(f"Dataset {name} not found")
    return load_data(datasets[name])


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Cell 2 — Resource setup
def _ensure_resources():
    try:
        stanza.download("en", processors="tokenize,pos,lemma,depparse,constituency", verbose=False)
    except Exception:
        pass

    try:
        _ = wn.synsets("dog")
    except LookupError:
        nltk.download("wordnet")
        nltk.download("omw-1.4")

    try:
        spacy.load("en_core_web_md")
    except OSError:
        raise RuntimeError(
            "spaCy model 'en_core_web_md' is required. "
            "Install with: python -m spacy download en_core_web_md"
        )

_ensure_resources()

spacy_nlp = spacy.load("en_core_web_md")
spacy_nlp.add_pipe("sentencizer")

_STANZA_PIPELINES: Dict[str, stanza.Pipeline] = {}
def get_stanza_pipeline(lang: str = "en", use_gpu: bool = False) -> stanza.Pipeline:
    if lang not in _STANZA_PIPELINES:
        _STANZA_PIPELINES[lang] = stanza.Pipeline(
            lang=lang,
            processors="tokenize,pos,lemma,depparse,constituency",
            tokenize_no_ssplit=False,
            use_gpu=use_gpu,
            verbose=False,
        )
    return _STANZA_PIPELINES[lang]


In [3]:
# Cell 3 — Constants & COW loader
CONTENT_UPOS = {"NOUN", "PROPN", "VERB", "ADJ", "ADV"}   # stanza POS
CONTENT_POS_SPACY = {"NOUN", "VERB", "ADJ", "ADV"}       # spaCy POS
METRICS = ["MTLD","LD","LS","MDD","CS","LC","CoH"]

@lru_cache(maxsize=1)
def load_cow_top5000_en() -> set:
    with pkg_resources.files(textcomplexity).joinpath("en.json").open("r", encoding="utf-8") as f:
        lang_def = json.load(f)
    most_common = lang_def["most_common"]  # list of [word, xpos]
    return {w.lower() for w, xpos in most_common}


In [4]:
# Cell 4 — Lexical metrics (stanza)
def _compute_mtld(tokens: Iterable[str], ttr_threshold: float = 0.72) -> Optional[float]:
    tokens = [t for t in tokens if t]
    if not tokens: return None
    types, factor_count, in_factor = set(), 0.0, 0
    for tok in tokens:
        in_factor += 1
        types.add(tok)
        ttr = len(types)/in_factor
        if ttr < ttr_threshold:
            factor_count += 1.0
            types, in_factor = set(), 0
    if in_factor > 0:
        final_ttr = len(types)/in_factor
        if final_ttr < 1.0:
            frac = (1.0 - final_ttr)/(1.0 - ttr_threshold)
            factor_count += max(0.0, min(1.0, frac))
    if factor_count == 0: return None
    return len(tokens)/factor_count

def _compute_lexical_density(total_tokens: int, content_tokens: int) -> Optional[float]:
    if total_tokens == 0: return None
    return content_tokens/total_tokens

def _compute_lexical_sophistication_cow(content_forms: Iterable[str], cow_top5000: set) -> Optional[float]:
    forms = [f for f in content_forms if f]
    if not forms: return None
    off_list = sum(1 for f in forms if f not in cow_top5000)
    return off_list/len(forms)

def lexical_measures_from_doc(doc: stanza.Document) -> Dict[str, Optional[float]]:
    cow = load_cow_top5000_en()
    mtld_tokens, total_tokens, content_tokens, content_forms = [], 0, 0, []
    for sent in doc.sentences:
        for w in sent.words:
            if w.upos == "PUNCT": 
                continue
            lemma = (w.lemma or w.text or "").lower()
            if not lemma: 
                continue
            mtld_tokens.append(lemma); total_tokens += 1
            if w.upos in CONTENT_UPOS:
                content_tokens += 1
                content_forms.append((w.text or "").lower())
    return {
        "MTLD": _compute_mtld(mtld_tokens) if mtld_tokens else None,
        "LD":   _compute_lexical_density(total_tokens, content_tokens),
        "LS":   _compute_lexical_sophistication_cow(content_forms, cow),
    }

def lexical_measures_from_text(text: str, lang: str = "en") -> Dict[str, Optional[float]]:
    if not isinstance(text, str) or not text.strip():
        return {"MTLD": None, "LD": None, "LS": None}
    doc = get_stanza_pipeline(lang)(text)
    return lexical_measures_from_doc(doc)


In [5]:
# Cell 5 — Syntactic metrics (stanza)
def mdd_from_doc(doc: stanza.Document) -> Optional[float]:
    sentence_mdds = []
    for sent in doc.sentences:
        dists = []
        for w in sent.words:
            if not w.head or w.head == 0: 
                continue
            dists.append(abs(w.id - w.head))
        if dists:
            sentence_mdds.append(sum(dists)/len(dists))
    if not sentence_mdds: return None
    return sum(sentence_mdds)/len(sentence_mdds)

def _count_clauses_in_tree(tree) -> int:
    if tree is None: return 0
    cnt = 1 if getattr(tree, "label", "").startswith("S") else 0
    for ch in getattr(tree, "children", []):
        if hasattr(ch, "label"):
            cnt += _count_clauses_in_tree(ch)
    return cnt

def cs_from_doc(doc: stanza.Document) -> Optional[float]:
    counts = []
    for sent in doc.sentences:
        tree = getattr(sent, "constituency", None)
        if tree is None: 
            continue
        counts.append(_count_clauses_in_tree(tree))
    if not counts: return None
    return sum(counts)/len(counts)

def syntactic_measures_from_doc(doc: stanza.Document) -> Dict[str, Optional[float]]:
    return {"MDD": mdd_from_doc(doc), "CS": cs_from_doc(doc)}

def syntactic_measures_from_text(text: str, lang: str = "en") -> Dict[str, Optional[float]]:
    if not isinstance(text, str) or not text.strip():
        return {"MDD": None, "CS": None}
    doc = get_stanza_pipeline(lang)(text)
    return syntactic_measures_from_doc(doc)


In [6]:
# Cell 6 — Discourse metrics (spaCy + WordNet)
def is_content_token_spacy(tok) -> bool:
    return tok.is_alpha and not tok.is_stop and tok.pos_ in CONTENT_POS_SPACY

@lru_cache(maxsize=100000)
def get_related_lemmas(lemma: str) -> set:
    lemma = lemma.lower()
    rel = set()
    for syn in wn.synsets(lemma):
        for l in syn.lemmas():
            rel.add(l.name().lower().replace("_", " "))
            for ant in l.antonyms():
                rel.add(ant.name().lower().replace("_", " "))
        for hyper in syn.hypernyms():
            for l in hyper.lemmas(): rel.add(l.name().lower().replace("_", " "))
        for hypo in syn.hyponyms():
            for l in hypo.lemmas(): rel.add(l.name().lower().replace("_", " "))
        for mer in syn.part_meronyms() + syn.member_meronyms() + syn.substance_meronyms():
            for l in mer.lemmas(): rel.add(l.name().lower().replace("_", " "))
        for hyper in syn.hypernyms():
            for sib in hyper.hyponyms():
                if sib == syn: continue
                for l in sib.lemmas(): rel.add(l.name().lower().replace("_", " "))
    rel.discard(lemma)
    return rel

def lexical_cohesion_single(text: str, nlp) -> float:
    if not isinstance(text, str) or not text.strip(): return 0.0
    doc = nlp(text)
    m = sum(1 for t in doc if t.is_alpha)
    if m == 0: return 0.0
    sents = list(doc.sents)
    if len(sents) < 2: return 0.0
    sent_lemmas: List[set] = []
    for s in sents:
        lemmas = set(t.lemma_.lower() for t in s if is_content_token_spacy(t))
        if lemmas: sent_lemmas.append(lemmas)
    if len(sent_lemmas) < 2: return 0.0
    cohesive = 0
    for i in range(len(sent_lemmas)-1):
        for j in range(i+1, len(sent_lemmas)):
            li, lj = sent_lemmas[i], sent_lemmas[j]
            cohesive += len(li & lj)
            for lemma in li:
                cohesive += len(get_related_lemmas(lemma) & lj)
    return float(cohesive)/float(m)

def sentence_vector(sent, D: int) -> np.ndarray:
    vecs = [t.vector for t in sent if t.has_vector and not t.is_punct and not t.is_space]
    if not vecs: return np.zeros(D, dtype="float32")
    return np.mean(vecs, axis=0)

def coherence_single(text: str, nlp) -> float:
    if not isinstance(text, str) or not text.strip(): return 0.0
    if nlp.vocab.vectors_length == 0:
        raise RuntimeError("spaCy model must have word vectors (use 'en_core_web_md').")
    doc = nlp(text)
    sents = list(doc.sents)
    if len(sents) < 2: return 0.0
    D = nlp.vocab.vectors_length
    vecs = [sentence_vector(s, D) for s in sents]
    sims: List[float] = []
    for i in range(len(vecs)-1):
        v1, v2 = vecs[i], vecs[i+1]
        denom = np.linalg.norm(v1) * np.linalg.norm(v2)
        if denom == 0.0: continue
        sims.append(float(np.dot(v1, v2)/denom))
    return float(np.mean(sims)) if sims else 0.0

def compute_discourse_measures_series(texts: pd.Series, nlp) -> Tuple[np.ndarray, np.ndarray]:
    lc = np.array([lexical_cohesion_single(t, nlp) for t in texts], dtype="float32")
    coh = np.array([coherence_single(t, nlp) for t in texts], dtype="float32")
    return lc, coh


In [7]:
# Cell 7 — All measures DF(column) => 7 columns
def compute_all_measures_df(df: pd.DataFrame, column: str, lang: str = "en") -> pd.DataFrame:
    out = pd.DataFrame(index=df.index, columns=METRICS, dtype="float64")
    pipe = get_stanza_pipeline(lang)
    texts = df[column].fillna("").astype(str)

    # Lexical + syntactic (stanza)
    for idx, text in tqdm(list(texts.items()), desc=f"Stanza {column}", total=len(texts)):
        try:
            if not text.strip():
                m = {"MTLD": None, "LD": None, "LS": None, "MDD": None, "CS": None}
            else:
                doc = pipe(text)
                m = {}
                m.update(lexical_measures_from_doc(doc))
                m.update(syntactic_measures_from_doc(doc))
            out.loc[idx, ["MTLD","LD","LS","MDD","CS"]] = [m["MTLD"], m["LD"], m["LS"], m["MDD"], m["CS"]]
        except Exception:
            out.loc[idx, ["MTLD","LD","LS","MDD","CS"]] = [None, None, None, None, None]

    # Discourse (spaCy)
    try:
        lc_vec, coh_vec = compute_discourse_measures_series(texts, spacy_nlp)
        out.loc[texts.index, "LC"]  = lc_vec
        out.loc[texts.index, "CoH"] = coh_vec
    except Exception:
        out.loc[texts.index, "LC"]  = 0.0
        out.loc[texts.index, "CoH"] = 0.0

    return out


In [8]:
# Cell 8 — ComplexityScore & vectorized dominance
@dataclass(frozen=True)
class ComplexityScore:
    MTLD: float; LD: float; LS: float
    MDD: float; CS: float
    LC: float;  CoH: float

    @property
    def lex(self): return {"MTLD": self.MTLD, "LD": self.LD, "LS": self.LS}
    @property
    def syn(self): return {"MDD": self.MDD, "CS": self.CS}
    @property
    def dis(self): return {"LC": self.LC, "CoH": self.CoH}

    @staticmethod
    def from_row(row: pd.Series, prefix: str) -> "ComplexityScore":
        def v(name):
            try: return float(row[f"{prefix}_{name}"])
            except Exception: return float("nan")
        return ComplexityScore(
            MTLD=v("MTLD"), LD=v("LD"), LS=v("LS"),
            MDD=v("MDD"), CS=v("CS"),
            LC=v("LC"),  CoH=v("CoH"),
        )

    @staticmethod
    def _isnan(x: Optional[float]) -> bool:
        try: return x is None or math.isnan(float(x))
        except Exception: return True

    def _ge_or_ignore_nan(self, a: float, b: float) -> bool:
        return self._isnan(b) or (not self._isnan(a) and a >= b)

    def _strict_progress(self, new: Dict[str, float], old: Dict[str, float]) -> bool:
        for k in new:
            a, b = new[k], old[k]
            if not self._isnan(a) and not self._isnan(b) and a > b:
                return True
        return False

    def dominates(self, target: "ComplexityScore",
                  previous: "ComplexityScore|None" = None,
                  length_ok: bool = True) -> bool:
        for k in METRICS:
            if not self._ge_or_ignore_nan(getattr(self, k), getattr(target, k)):
                return False
        if not length_ok:
            return False
        if previous is None:
            return True
        return (self._strict_progress(self.lex, previous.lex) and
                self._strict_progress(self.syn, previous.syn) and
                self._strict_progress(self.dis, previous.dis))

def simple_dominates_complex_flags(df: pd.DataFrame) -> pd.Series:
    """True if Simple ≥ Complex (ignore NaN targets) + strict progress per family."""
    ge_all = []
    for m in METRICS:
        s = df[f"simple_{m}"].astype(float)
        c = df[f"complex_{m}"].astype(float)
        ge_m = (s >= c) | c.isna()
        ge_all.append(ge_m)
    ge_ok = np.logical_and.reduce(ge_all)

    def strict_progress(cols_s, cols_c):
        s = df[cols_s].astype(float).to_numpy()
        c = df[cols_c].astype(float).to_numpy()
        valid = ~np.isnan(s) & ~np.isnan(c)
        return ((s > c) & valid).any(axis=1)

    prog_lex = strict_progress(
        ["simple_MTLD","simple_LD","simple_LS"],
        ["complex_MTLD","complex_LD","complex_LS"],
    )
    prog_syn = strict_progress(
        ["simple_MDD","simple_CS"],
        ["complex_MDD","complex_CS"],
    )
    prog_dis = strict_progress(
        ["simple_LC","simple_CoH"],
        ["complex_LC","complex_CoH"],
    )

    return ge_ok & prog_lex & prog_syn & prog_dis


In [9]:
# Cell 9 — Prepare & save helpers
def prepare_dataset(
    df_in: pd.DataFrame,
    col_simple: str = "Simple",
    col_complex: str = "Complex",
    lang: str = "en",
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    df = df_in.copy()

    if "id" not in df.columns:
        df = df.reset_index().rename(columns={"index": "id"})
    df["id"] = df["id"].astype(int)

    assert col_simple in df.columns and col_complex in df.columns, \
        f"Missing required columns: {col_simple}, {col_complex}"

    m_simple  = compute_all_measures_df(df, column=col_simple,  lang=lang).add_prefix("simple_")
    m_complex = compute_all_measures_df(df, column=col_complex, lang=lang).add_prefix("complex_")

    df_aug = pd.concat([df[["id", col_simple, col_complex]], m_simple, m_complex], axis=1)

    measure_cols = [*(m_simple.columns.tolist()), *(m_complex.columns.tolist())]
    df_aug_clean = df_aug.dropna(subset=measure_cols).copy()

    df_aug_clean["simple_dominates_complex"] = simple_dominates_complex_flags(df_aug_clean)

    df_removed = df_aug_clean[df_aug_clean["simple_dominates_complex"]].copy()
    df_kept    = df_aug_clean[~df_aug_clean["simple_dominates_complex"]].copy()

    ordered_cols = [
        "id", col_simple, col_complex,
        *[f"simple_{k}" for k in METRICS],
        *[f"complex_{k}" for k in METRICS],
        "simple_dominates_complex",
    ]
    df_aug_clean = df_aug_clean[ordered_cols]
    df_kept      = df_kept[[c for c in ordered_cols if c != "simple_dominates_complex"]]

    return df_aug_clean, df_kept, df_removed

def save_artifacts(
    df_aug: pd.DataFrame,
    df_kept: pd.DataFrame,
    df_removed: pd.DataFrame,
    out_dir: Path,
    base_name: str,
):
    out_dir.mkdir(parents=True, exist_ok=True)

    df_aug.to_csv(out_dir / f"{base_name}__augmented.csv", index=False)
    df_kept.to_csv(out_dir / f"{base_name}__filtered.csv", index=False)
    try:
        df_aug.to_parquet(out_dir / f"{base_name}__augmented.parquet", index=False)
        df_kept.to_parquet(out_dir / f"{base_name}__filtered.parquet", index=False)
    except Exception:
        pass

    removed_ids = df_removed["id"].astype(int).tolist()
    with (out_dir / f"{base_name}__removed_ids.json").open("w", encoding="utf-8") as f:
        json.dump({"removed_ids": removed_ids, "count": len(removed_ids)}, f, ensure_ascii=False, indent=2)

    report = {
        "total_rows_after_nan_filter": int(len(df_aug)),
        "kept_rows": int(len(df_kept)),
        "removed_rows": int(len(df_removed)),
        "removed_ratio": float(len(df_removed) / max(1, len(df_aug))),
        "metrics": METRICS,
        "columns_final_expected": 16,
        "note": "16 = Simple, Complex + 7 mesures Simple + 7 mesures Complex",
    }
    with (out_dir / f"{base_name}__report.json").open("w", encoding="utf-8") as f:
        json.dump(report, f, ensure_ascii=False, indent=2)


In [10]:
# Cell 10 — Example run
DATASET_NAME = "ose_adv_ele"   # choose one: ose_adv_ele | ose_adv_int | swipe | vikidia
OUT_DIR = Path("prepared")

df_raw = load_dataset(DATASET_NAME)
print("Rows loaded:", len(df_raw))
display(df_raw.head(3))

df_aug, df_kept, df_removed = prepare_dataset(df_raw, col_simple="Simple", col_complex="Complex", lang="en")

print("After NaN filter:", len(df_aug))
print("Kept:", len(df_kept), "| Removed (Simple dominates):", len(df_removed))

save_artifacts(df_aug, df_kept, df_removed, OUT_DIR, base_name=DATASET_NAME)

print("\nSaved to:", OUT_DIR.resolve())


Rows loaded: 189


Unnamed: 0,Simple,Complex
0,"﻿When you see the word Amazon, what’s the firs...","﻿When you see the word Amazon, what’s the firs..."
1,"﻿To tourists, Amsterdam still seems very liber...","﻿Amsterdam still looks liberal to tourists, wh..."
2,"﻿Anitta, a music star from Brazil, has million...","﻿Brazil’s latest funk sensation, Anitta, has w..."


Stanza Simple: 100%|██████████| 189/189 [42:14<00:00, 13.41s/it]
Stanza Complex: 100%|██████████| 189/189 [1:05:30<00:00, 20.79s/it]


After NaN filter: 189
Kept: 189 | Removed (Simple dominates): 0

Saved to: C:\Users\rroll\Documents\GitHub\ISH_projet_agno_agent_intelligent\prepared
