# Filter and Rescore

This code is used to complete some filtering steps for previously scored data.

In [70]:
import sys
import contextlib
import io

import pandas as pd

from typing import Iterable, Optional, Union
from from_root import from_root

sys.path.insert(0, str(from_root("src")))

from n_gram_filtering import apply_ngram_filtering

In [71]:
def create_summary_by_token_num(
    per_phrase_table: pd.DataFrame,
    *,
    group_cols: Optional[Union[str, Iterable[str]]] = None,
    key_col: str = "num_tokens",
    sum_cols: Optional[Iterable[str]] = None,
) -> pd.DataFrame:
    """
    Group (optionally) by `group_cols`. For each group and each threshold t in
    sorted unique values of `key_col`, sum `sum_cols` over rows where key_col >= t.

    If `sum_cols` is None, it is inferred as all numeric columns excluding
    group_cols and key_col.
    """
    if group_cols is None:
        group_cols_list: list[str] = []
    elif isinstance(group_cols, str):
        group_cols_list = [group_cols]
    else:
        group_cols_list = list(group_cols)

    if sum_cols is None:
        exclude = set(group_cols_list) | {key_col}
        sum_cols_list = [
            c for c in per_phrase_table.columns
            if c not in exclude and pd.api.types.is_numeric_dtype(per_phrase_table[c])
        ]
    else:
        sum_cols_list = list(sum_cols)

    def _build_rows(df: pd.DataFrame, group_vals: dict) -> list[dict]:
        token_thresholds = sorted(df[key_col].dropna().unique())
        out: list[dict] = []
        for t in token_thresholds:
            filt = df[df[key_col] >= t]
            sums = filt[sum_cols_list].sum(numeric_only=True)

            row = {**group_vals, "min_token_size": int(t), "n_rows": int(len(filt))}
            row.update(sums.to_dict())
            out.append(row)
        return out

    rows: list[dict] = []
    if group_cols_list:
        for keys, gdf in per_phrase_table.groupby(group_cols_list, dropna=False, sort=False):
            if len(group_cols_list) == 1:
                keys = (keys,)
            group_vals = dict(zip(group_cols_list, keys))
            rows.extend(_build_rows(gdf, group_vals))
        sort_by = group_cols_list + ["min_token_size"]
    else:
        rows.extend(_build_rows(per_phrase_table, {}))
        sort_by = ["min_token_size"]

    return pd.DataFrame(rows).sort_values(sort_by).reset_index(drop=True)

In [72]:
def apply_filtering_and_rescore(
    df,
    metadata,
    model_loc = "/Volumes/BCross/models/gpt2",
    token_col = 'tokens',
    min_tokens = 2,
    sum_cols = ["no_context_sum_log_probs", "known_sum_log_probs", "unknown_sum_log_probs"]
):
    meta_slim = (
        metadata[["data_type", "corpus", "scoring_model",
                "max_context_tokens", "problem", "problem_completed"]]
    )
    
    filtered_df = apply_ngram_filtering(
        df,
        model_loc = "/Volumes/BCross/models/gpt2",
        token_col = 'tokens',
        min_tokens = 2
    )
    
    scored_df = create_summary_by_token_num(
        per_phrase_table=filtered_df,
        group_cols=['data_type', 'corpus', 'scoring_model', 'max_context_tokens',
                    'problem', 'target'],
        key_col="num_tokens",
        sum_cols=sum_cols
    )
    
    scored_df = scored_df[
        ['data_type', 'corpus', 'scoring_model', 'max_context_tokens',
         'min_token_size', 'problem', 'target'] + sum_cols
    ]
    
    scored_df = scored_df.sort_values(
        by=[
            "data_type", "corpus", "scoring_model", "max_context_tokens",
            "min_token_size", "problem", "target"
        ]
    ).reset_index(drop=True)
    
    final_df = (
        pd.merge(
            scored_df,
            meta_slim,
            how="left",
            on=["data_type", "corpus", "scoring_model", "max_context_tokens", "problem"],
        )
        .loc[lambda d: d["problem_completed"].eq(True)]
        .drop(columns=["problem_completed"])
    )
    
    return final_df

In [73]:
models = ["gpt2"]

corpora = [
    "Wiki", "Enron", "Perverted Justice", "StackExchange", "ACL",
    "TripAdvisor", "The Apricity", "Koppel's Blogs", "The Telegraph",
    "Reddit"
]

data_types = ["training", "test"]

raw_levels = [
    "raw_100", "raw_200", "raw_300", "raw_400", "raw_500", "raw_600",
    "raw_700", "raw_800", "raw_900", "raw_1000", "raw"
]

base_loc = "/Volumes/BCross/av_datasets_experiments/ngram_masking_logrpobs"

for d_type in data_types:
    for corpus in corpora:
        for model in models:
            print(f"Working on: {d_type} | {corpus} | {model}")
            read_files_dir = f"{base_loc}/{d_type}/{corpus}/{model}/raw_results"
            save_file_loc = f"{base_loc}/{d_type}/{corpus}/{model}/filtered_agg_scores.xlsx"
            
            data_list = []
            for level in raw_levels:
                file_name = f"phrase_scores_{level}"
                data_loc = f"{read_files_dir}/phrase_scores_{level}.xlsx"
                metadata_loc = f"{read_files_dir}/problem_completed_metadata_{level}.xlsx"
                
                df = pd.read_excel(data_loc, engine="openpyxl")
                metadata = pd.read_excel(metadata_loc, engine="openpyxl")
                
                # Suppres printed output
                with contextlib.redirect_stdout(io.StringIO()):
                    rescored_df = apply_filtering_and_rescore(
                        df,
                        metadata
                    )
                
                data_list.append(rescored_df)
            
            results = pd.concat(data_list, ignore_index=True)
            
            results.to_excel(save_file_loc, engine="openpyxl", index=None)

Working on: training | Wiki | gpt2
Working on: training | Enron | gpt2
Working on: training | Perverted Justice | gpt2
Working on: training | StackExchange | gpt2
Working on: training | ACL | gpt2
Working on: training | TripAdvisor | gpt2
Working on: training | The Apricity | gpt2
Working on: training | Koppel's Blogs | gpt2
Working on: training | The Telegraph | gpt2
Working on: training | Reddit | gpt2
Working on: test | Wiki | gpt2
Working on: test | Enron | gpt2
Working on: test | Perverted Justice | gpt2
Working on: test | StackExchange | gpt2
Working on: test | ACL | gpt2
Working on: test | TripAdvisor | gpt2
Working on: test | The Apricity | gpt2
Working on: test | Koppel's Blogs | gpt2
Working on: test | The Telegraph | gpt2
Working on: test | Reddit | gpt2
