In [None]:
import os
from datasets import load_dataset

token = os.getenv('HF_TOKEN')
ds_100k = load_dataset("GGmorello/FLAMES_results", "100k", token=token)
#ds = load_dataset('GGmorello/FLAMES', 'infilled', split='train[:10000]', token=token, cache_dir='/Users/mojtabaeshghie/.cache/hf')#, num_proc=8)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df_100k = ds_100k['train'].to_pandas()

In [4]:
import re
import math
import pandas as pd

# ===================== internals (filtering / scoring) ======================

_WS_RE = re.compile(r'\s+')
_PAREN_TRIM_RE = re.compile(r'^\s*\((?P<inner>.*)\)\s*$', re.DOTALL)

def _strip_wrapping_parens(s: str) -> str:
    prev, cur = None, (s or "").strip()
    while prev != cur:
        prev = cur
        m = _PAREN_TRIM_RE.match(cur)
        if not m:
            break
        cur = m.group('inner').strip()
    return cur

def _norm(s: str) -> str:
    x = (s or "").strip()
    if x.endswith(';'):
        x = x[:-1]
    x = _strip_wrapping_parens(x)
    return _WS_RE.sub(' ', x).strip()

def _is_effectively_empty(x) -> bool:
    """Treat None/NaN/empty/placeholder-like as empty; lists/dicts by length."""
    if x is None:
        return True
    # pandas NaN
    try:
        if pd.isna(x):
            return True
    except Exception:
        pass
    # containers
    if isinstance(x, (list, tuple, dict, set)):
        return len(x) == 0
    # strings
    s = str(x).strip()
    if s == "":
        return True
    if s.lower() in {"nan", "none", "null"}:
        return True
    if s in {"[]", "{}"}:
        return True
    return False

# triviality / address checks (operate on normalized text)
_ATOM = r'(?:[A-Za-z_]\w*(?:\.[A-Za-z_]\w*|\[[^\[\]\(\)]*\])*)|(?:0x[0-9A-Fa-f]+)|(?:\d+(?:\s*\*\s*1e\d+)?)'
_SIMPLE_CMP_OPS = r'(?:==|!=|<=|>=|<|>)'
_FN_CALL_ONLY_RE   = re.compile(r'^\s*[A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*\s*\([^()]*\)\s*$', re.DOTALL)
_IDENT_ONLY_RE     = re.compile(r'^\s*[A-Za-z_]\w*(?:\.[A-Za-z_]\w*|\[[^\[\]\(\)]*\])*\s*$', re.DOTALL)
_NEG_SIMPLE_RE     = re.compile(r'^\s*!\s*(?:' + _ATOM + r'|' + r'[A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*\s*\([^()]*\))\s*$', re.DOTALL)
_SIMPLE_CMP_RE     = re.compile(r'^\s*(' + _ATOM + r')\s*' + _SIMPLE_CMP_OPS + r'\s*(' + _ATOM + r')\s*$', re.DOTALL)
_HARDCODED_ADDR_RE = re.compile(r'(?:\baddress\s*\(\s*0x?0\s*\)|\b0x[0-9A-Fa-f]{40}\b)')

def _hardness(s: str) -> float:
    bool_ops = s.count('&&') + s.count('||') + s.count('!') * 0.5
    arith_ops = len(re.findall(r'(?<!<|>|=)[\+\-\*/%]', s))
    cmp_ops = len(re.findall(_SIMPLE_CMP_OPS, s))
    parens = s.count('(') + s.count(')')
    indexing = s.count('[') + s.count(']')
    members = s.count('.')
    length = len(s)
    fn_calls = len(re.findall(r'[A-Za-z_]\w*\s*\(', s))
    return (
        2.0*bool_ops +
        1.5*arith_ops +
        1.5*max(0, cmp_ops-1) +
        0.5*parens +
        0.4*indexing +
        0.3*members +
        0.6*fn_calls +
        0.01*length
    )

def _apply_trivial_filters(norm_series: pd.Series,
                           drop_addr=True,
                           drop_ident_only=True,
                           drop_fn_call_only=True,
                           drop_neg_simple=True,
                           drop_simple_cmp=True) -> pd.Series:
    """Return boolean mask KEEP (True if survives)."""
    keep = pd.Series(True, index=norm_series.index)
    if drop_addr:
        keep &= ~norm_series.map(lambda s: bool(_HARDCODED_ADDR_RE.search(s)))
    if drop_ident_only:
        keep &= ~norm_series.map(lambda s: bool(_IDENT_ONLY_RE.match(s)))
    if drop_fn_call_only:
        keep &= ~norm_series.map(lambda s: bool(_FN_CALL_ONLY_RE.match(s)))
    if drop_neg_simple:
        keep &= ~norm_series.map(lambda s: bool(_NEG_SIMPLE_RE.match(s)))
    if drop_simple_cmp:
        def not_simple_cmp(s: str) -> bool:
            if '&&' in s or '||' in s or '?' in s or ':' in s:
                return True
            if re.search(r'(?<!<|>|=)[\+\-\*/%]', s):
                return True
            return not bool(_SIMPLE_CMP_RE.match(s))
        keep &= norm_series.map(not_simple_cmp)
    return keep

def _take_in_band(scored: pd.DataFrame, k: int, q_low: float, q_high: float) -> pd.Index:
    """Return index of picked rows under current band; widen to full range if needed."""
    if scored.empty:
        return scored.index[:0]
    lo = scored['hardness'].quantile(q_low) if 0 < q_low < 1 else scored['hardness'].min()
    hi = scored['hardness'].quantile(q_high) if 0 < q_high < 1 else scored['hardness'].max()
    band = scored[(scored['hardness'] >= lo) & (scored['hardness'] <= hi)]
    if len(band) >= k:
        return band.sort_values(['hardness', 'norm'], ascending=[False, True]).index[:k]
    # widen all the way if short
    band = scored
    return band.sort_values(['hardness', 'norm'], ascending=[False, True]).index[:k]

# ============================= main selector =================================

def select_hard_rows_robust(
    df: pd.DataFrame,
    predicate_col: str = 'predicate',
    result_col: str = 'results',
    hardness_source: str = 'predicate',   # or 'results'
    k: int = 5000,
    q_low: float = 0.40,
    q_high: float = 0.85,
    dedup_mode: str = 'norm',             # 'norm' | 'pair' | 'none'
    verbose: bool = True
) -> pd.DataFrame:
    """
    1) Require BOTH predicate & results to be non-empty (robust emptiness check).
    2) Score hardness on chosen column (predicate/results) over normalized text.
    3) Drop trivial patterns and dedupe (configurable).
    4) Select middle-band; if short, progressively relax until reaching k.

    Returns the ORIGINAL rows (unchanged) + 'hardness' column.
    """

    if hardness_source not in ('predicate', 'results'):
        raise ValueError("hardness_source must be 'predicate' or 'results'")
    if predicate_col not in df.columns or result_col not in df.columns:
        missing = [c for c in (predicate_col, result_col) if c not in df.columns]
        raise KeyError(f"Missing columns: {missing}")

    # -------- Stage 0: both columns non-empty (robust) --------
    mask_both = (~df[predicate_col].map(_is_effectively_empty)) & (~df[result_col].map(_is_effectively_empty))
    base = df.loc[mask_both].copy()
    if verbose:
        print(f"[drop-off] total={len(df)}  both_non_empty={len(base)}  dropped={len(df)-len(base)}")

    if base.empty:
        return base.assign(hardness=pd.Series(dtype=float))

    # Hardness source
    src_col = predicate_col if hardness_source == 'predicate' else result_col

    # -------- Stage 1: normalized text (for scoring/filters only) --------
    norm = base[src_col].astype(str).map(_norm)
    # If normalization empties some rows, drop them
    nonempty_norm = norm.str.len() > 0
    base = base.loc[nonempty_norm].copy()
    norm = norm.loc[nonempty_norm]
    if verbose:
        print(f"[drop-off] after_norm_nonempty={len(base)}  dropped={mask_both.sum()-nonempty_norm.sum()}")

    # -------- Stage 2: trivial filters (strict first) --------
    # Start strict, then progressively relax if short
    drop_flags_list = [
        # (addr, ident_only, fn_call_only, neg_simple, simple_cmp)
        (True,  True,  True,  True,  True),   # strictest
        (True,  True,  True,  True,  False),  # allow simple comparisons
        (True,  True,  False, True,  False),  # also allow pure fn calls
        (True,  False, False, True,  False),  # also allow identifier-only
        (False, False, False, True,  False),  # also allow hardcoded addr
        (False, False, False, False, False),  # allow everything (no trivial filter)
    ]

    # Dedup modes to try if still short
    dedup_modes = [dedup_mode] + [m for m in ('pair', 'none', 'norm') if m != dedup_mode]

    def build_once(drop_flags, dedup_mode_local):
        keep = _apply_trivial_filters(norm,
                                      drop_addr=drop_flags[0],
                                      drop_ident_only=drop_flags[1],
                                      drop_fn_call_only=drop_flags[2],
                                      drop_neg_simple=drop_flags[3],
                                      drop_simple_cmp=drop_flags[4])
        w = base.loc[keep].copy()
        n = norm.loc[keep]

        if verbose:
            print(f"[filters] kept={len(w)}  dropped={len(base)-len(w)}  flags={drop_flags}, dedup={dedup_mode_local}")

        if w.empty:
            return w.assign(hardness=pd.Series(dtype=float))

        # dedup
        if dedup_mode_local == 'norm':
            dedup_idx = n.reset_index().drop_duplicates(n.name)['index']
            w = w.loc[dedup_idx]
            n = n.loc[dedup_idx]
        elif dedup_mode_local == 'pair':
            pair = pd.DataFrame({'p': base.loc[keep, predicate_col].astype(str).values,
                                 'r': base.loc[keep, result_col].astype(str).values},
                                index=w.index)
            dedup_idx = pair.reset_index().drop_duplicates(['p','r'])['index']
            w = w.loc[dedup_idx]
            n = n.loc[dedup_idx]
        elif dedup_mode_local == 'none':
            pass
        else:
            raise ValueError("dedup_mode must be 'norm' | 'pair' | 'none'")

        if verbose:
            print(f"[dedup] after_dedup={len(w)}")

        # score
        hardness_vals = n.map(_hardness)
        scored = pd.DataFrame({'hardness': hardness_vals, 'norm': n}, index=w.index)

        # band select (auto widen to full range if short)
        picked_idx = _take_in_band(scored, k, q_low, q_high)
        selected = w.loc[picked_idx].copy()
        selected['hardness'] = scored.loc[picked_idx, 'hardness'].values
        return selected

    # Try progressively relaxing until we reach k or exhaust
    for flags in drop_flags_list:
        for dm in dedup_modes:
            out = build_once(flags, dm)
            if len(out) >= k:
                if verbose:
                    print(f"[success] got {len(out)} rows (>= {k}) with flags={flags}, dedup={dm}")
                return out.head(k)

    # Last resort: return whatever we could get under the loosest setup
    if verbose:
        print(f"[final] could not reach k={k}; returning {len(out)} rows with last flags/dedup.")
    return out

# ================================ usage ======================================
# Example (as you had, but robust):
hard_5k = select_hard_rows_robust(
    df_100k,
    predicate_col='predicate',
    result_col='results',
    hardness_source='predicate',  # or 'results'
    k=5000, q_low=0.40, q_high=0.85,
    dedup_mode='norm',            # try 'pair' if norm is too aggressive
    verbose=True
)
# sanity: both columns are non-empty
assert (~hard_5k['predicate'].map(_is_effectively_empty)).all()
assert (~hard_5k['results'].map(_is_effectively_empty)).all()


[drop-off] total=31072  both_non_empty=5000  dropped=26072
[drop-off] after_norm_nonempty=5000  dropped=0
[filters] kept=2964  dropped=2036  flags=(True, True, True, True, True), dedup=norm
[dedup] after_dedup=2964
[filters] kept=2964  dropped=2036  flags=(True, True, True, True, True), dedup=pair
[dedup] after_dedup=2964
[filters] kept=2964  dropped=2036  flags=(True, True, True, True, True), dedup=none
[dedup] after_dedup=2964
[filters] kept=4389  dropped=611  flags=(True, True, True, True, False), dedup=norm
[dedup] after_dedup=4389
[filters] kept=4389  dropped=611  flags=(True, True, True, True, False), dedup=pair
[dedup] after_dedup=4389
[filters] kept=4389  dropped=611  flags=(True, True, True, True, False), dedup=none
[dedup] after_dedup=4389
[filters] kept=4808  dropped=192  flags=(True, True, False, True, False), dedup=norm
[dedup] after_dedup=4808
[filters] kept=4808  dropped=192  flags=(True, True, False, True, False), dedup=pair
[dedup] after_dedup=4808
[filters] kept=4808 

In [5]:
hard_5k.to_csv("disl-hardinv.csv", index=True)