In [40]:
import ast
import os
import glob

from pathlib import Path

import numpy as np
import pandas as pd

In [41]:
def create_results_doc_pipeline(doc_loc, write_excel=True, save_dir=None, phrase_loc=None):
    """Pipeline to manually get the results from a document"""
    
    doc_name = os.path.basename(doc_loc)
    
    print(f"Processing Document: {doc_name}")
    # Read the sheets as dataframes
    docs = pd.read_excel(doc_loc, sheet_name="docs")
    known = pd.read_excel(doc_loc, sheet_name="known")
    unknown = pd.read_excel(doc_loc, sheet_name="unknown")
    no_context = pd.read_excel(doc_loc, sheet_name="no context")
    metadata = pd.read_excel(doc_loc, sheet_name="metadata")
    
    # Get phrases to keep
    if phrase_loc:
        phrase_list = pd.read_excel(phrase_loc)
        phrases_to_keep = phrase_list[phrase_list['keep_phrase'] == 1].copy()

        # Convert the stringified tuples into actual tuples, then into lists
        phrases_to_keep['tokens'] = phrases_to_keep['tokens'].apply(lambda x: list(ast.literal_eval(x)) if isinstance(x, str) else list(x))
        phrases_to_keep = phrases_to_keep[['phrase']]
        
        reference_phrases = no_context[no_context['phrase_type'] == 'reference'].copy()

        # Perform the merge using the tuple-based key
        merged_phrases = pd.merge(reference_phrases, phrases_to_keep, on='phrase', how='inner')
        merged_phrases = merged_phrases[['phrase_num']]

        no_context = pd.merge(no_context, merged_phrases, on='phrase_num', how='inner')
        known = pd.merge(known, merged_phrases, on='phrase_num', how='inner')
        unknown= pd.merge(unknown, merged_phrases, on='phrase_num', how='inner')
        
    # Get the base LLR table
    cols = ['phrase_num', 'phrase_occurence', 'original_phrase']
    llr_base = (
        pd.concat([known[cols], unknown[cols]], ignore_index=True)
        .drop_duplicates()
        .sort_values(cols, ascending=[True, True, True])  # explicit
        .reset_index(drop=True)
    )
    
    # Now get the phrase statistics
    # 1) No context phrase stats
    no_context_phrase_stats = (
        no_context
        .assign(ref_log_prob = no_context['sum_log_probs']
                .where(no_context['phrase_type'].eq('reference')))
        .assign(ref_raw_prob=no_context['raw_prob']
            .where(no_context['phrase_type'].eq('reference')))
        .groupby('phrase_num', dropna=False)
        .agg(
            no_context_log_prob=('ref_log_prob', 'sum'),
            num_phrases=('phrase_num', 'size'),
            sum_raw_prob=('raw_prob', 'sum'),
            reference_prob=('ref_raw_prob', 'max')
        )
        .assign(
            phrases_kept=lambda d: d['num_phrases'],
            pmf_no_context=lambda d: d['reference_prob'].div(d['sum_raw_prob']),
            llr_no_context=lambda d: np.where(d['pmf_no_context'] > 0, -np.log10(d['pmf_no_context']), 0.0)
        )
        .drop(columns=['sum_raw_prob', 'reference_prob'])
    )

    # 2) Known phrase stats
    known_phrase_stats = (
        known
        .assign(ref_log_prob = known['sum_log_probs_phrase']
                .where(known['phrase_type'].eq('reference')))
        .assign(ref_raw_prob=known['raw_prob']
                .where(known['phrase_type'].eq('reference')))
        .groupby(['phrase_num', 'phrase_occurence'], dropna=False)
        .agg(
            known_log_prob=('ref_log_prob', 'sum'),
            sum_raw_prob=('raw_prob', 'sum'),
            reference_prob=('ref_raw_prob', 'max')
        )
        .assign(
            pmf_known=lambda d: d['reference_prob'].div(d['sum_raw_prob']),
            llr_known=lambda d: np.where(d['pmf_known'] > 0, -np.log10(d['pmf_known']), 0.0)
        )
        .drop(columns=['sum_raw_prob', 'reference_prob'])
    )

    # 3) Unknown phrase stats
    unknown_phrase_stats = (
        unknown
        .assign(ref_log_prob = unknown['sum_log_probs_phrase']
                .where(unknown['phrase_type'].eq('reference')))
        .assign(ref_raw_prob=unknown['raw_prob']
                .where(unknown['phrase_type'].eq('reference')))
        .groupby(['phrase_num', 'phrase_occurence'], dropna=False)
        .agg(
            unknown_log_prob=('ref_log_prob', 'sum'),
            sum_raw_prob=('raw_prob', 'sum'),
            reference_prob=('ref_raw_prob', 'max')
        )
        .assign(
            pmf_unknown=lambda d: d['reference_prob'].div(d['sum_raw_prob']),
            llr_unknown=lambda d: np.where(d['pmf_unknown'] > 0, -np.log10(d['pmf_unknown']), 0.0)
        )
        .drop(columns=['sum_raw_prob', 'reference_prob'])
    )
    
    # ---- START OF RANK TEST ---- #
    # Rank all phrases including the reference phrase
    unknown['rank_including_ref'] = (
        unknown
        .groupby(['phrase_num', 'phrase_occurence'])['sum_log_probs_phrase']
        .rank(ascending=False, method='first')   # or method='dense' if you prefer
    )

    # reference_log_probs: from the 'reference' row
    ref = (
        unknown[unknown['phrase_type'] == 'reference']
        .groupby(['phrase_num', 'phrase_occurence'])['sum_log_probs_phrase']
        .first()                 # or .iloc[0] via .agg('first')
        .rename('reference_log_probs')
    )

    # max_log_probs: from the top-ranked row (rank_including_ref == 1)
    top = (
        unknown[unknown['rank_including_ref'] == 1]
        .groupby(['phrase_num', 'phrase_occurence'])['sum_log_probs_phrase']
        .first()
        .rename('max_log_probs')
    )

    # combine into one dataframe
    agg_df = (
        pd.concat([ref, top], axis=1)
        .reset_index()
    )

    # difference
    agg_df['unknown_ref_vs_top_rank'] = agg_df['reference_log_probs'] - agg_df['max_log_probs']

    agg_df = agg_df.drop(columns=['reference_log_probs', 'max_log_probs']).reset_index()
    
    # ---- END OF RANK TEST ---- #

    # ---- START OF REF vs BEST-OF-REST (rank 0 vs rank 1) ---- #

    keys = ['phrase_num', 'phrase_occurence']
    
    ref = (
        unknown[unknown['rank'] == 0]
        .groupby(keys, dropna=False)['sum_log_probs_phrase']
        .first()
        .rename('ref_logp')
    )

    best_rest = (
        unknown[unknown['rank'] == 1]
        .groupby(keys, dropna=False)['sum_log_probs_phrase']
        .first()
        .rename('best_rest_logp')
    )

    agg_df_ref_vs_best = (
        pd.concat([ref, best_rest], axis=1)
        .reset_index()
    )

    # LLR(ref vs best alternative)
    agg_df_ref_vs_best['unknown_ref_vs_best_rest'] = agg_df_ref_vs_best['ref_logp'] - agg_df_ref_vs_best['best_rest_logp']

    agg_df_ref_vs_best = agg_df_ref_vs_best.drop(columns=['ref_logp', 'best_rest_logp'])

    # ---- END OF REF vs BEST-OF-REST ---- #
    
    # Create final LLR table
    LLR = (
        llr_base
        .assign(
            phrase_num=llr_base['phrase_num'].astype('string'),
            phrase_occurence=pd.to_numeric(llr_base['phrase_occurence'], errors='coerce').astype('Int64')
        )
        .merge(no_context_phrase_stats, on='phrase_num', how='left')
        .merge(known_phrase_stats, on=['phrase_num','phrase_occurence'], how='left')
        .merge(unknown_phrase_stats, on=['phrase_num','phrase_occurence'], how='left')
        .merge(agg_df, on=['phrase_num','phrase_occurence'], how='left')
        .merge(agg_df_ref_vs_best, on=['phrase_num','phrase_occurence'], how='left')
    )
    LLR['known_vs_no_context_log_prob'] = LLR['known_log_prob'] - LLR['no_context_log_prob'] 
    LLR['unknown_vs_no_context_log_prob'] = LLR['unknown_log_prob'] - LLR['no_context_log_prob'] 

    LLR = LLR[['phrase_num', 'phrase_occurence', 'original_phrase', 'num_phrases', 'phrases_kept', 'no_context_log_prob',
               'known_log_prob', 'unknown_log_prob', 'known_vs_no_context_log_prob', 'unknown_vs_no_context_log_prob',
               'pmf_no_context', 'pmf_known', 'pmf_unknown', 'llr_no_context', 'llr_known',
               'llr_unknown', 'unknown_ref_vs_top_rank', 'unknown_ref_vs_best_rest']]
    
    # Summarise the LLR table for the metadata
    LLR_summary = pd.DataFrame([{
        'num_phrases': LLR['phrase_num'].nunique(),
        'phrases_kept': LLR.loc[LLR['phrases_kept'] > 0, 'phrase_num'].nunique(),
        'known_log_prob': LLR['known_log_prob'].sum(skipna=True),
        'unknown_log_prob': LLR['unknown_log_prob'].sum(skipna=True),
        'llr_no_context': LLR['llr_no_context'].sum(skipna=True),
        'llr_known': LLR['llr_known'].sum(skipna=True),
        'llr_unknown': LLR['llr_unknown'].sum(skipna=True),
        'llr_unknown_vs_top_rank': LLR['unknown_ref_vs_top_rank'].sum(skipna=True),
        'unknown_ref_vs_best_rest': LLR['unknown_ref_vs_best_rest'].sum(skipna=True)
    }])

    LLR_summary = LLR_summary.assign(
        normalised_llr_no_context=lambda d: d['llr_no_context'] / d['phrases_kept'],
        normalised_llr_known=lambda d: d['llr_known'] / d['phrases_kept'],
        normalised_llr_unknown=lambda d: d['llr_unknown'] / d['phrases_kept']
    )
    
    # Create final metadata table
    # 1. Drop any overlapping columns in metadata
    overlapping_cols = LLR_summary.columns.intersection(metadata.columns)
    metadata_final = metadata.drop(columns=overlapping_cols, errors='ignore')

    # 2. Concatenate new values
    metadata_final = pd.concat([metadata_final, LLR_summary], axis=1)
    
    if write_excel:
        
        print("Writing file")
        path = Path(save_dir + '/' + doc_name)
        
        # Choose writer mode safely
        writer_mode = "a" if path.exists() else "w"
        writer_kwargs = {"engine": "openpyxl", "mode": writer_mode}
        if writer_mode == "a":
            writer_kwargs["if_sheet_exists"] = "replace"  # only valid in append mode
        

        with pd.ExcelWriter(path, **writer_kwargs) as writer:
            # Write sheets
            docs.to_excel(writer, index=False, sheet_name="docs")
            known.to_excel(writer, index=False, sheet_name="known")
            unknown.to_excel(writer, index=False, sheet_name="unknown")
            no_context.to_excel(writer, index=False, sheet_name="no context")
            LLR.to_excel(writer, index=False, sheet_name="LLR")
            metadata_final.to_excel(writer, index=False, sheet_name="metadata")
            
            # wb = writer.book
            # wb._sheets = ["docs", "metadata", "no context", "known", "unknown", "LLR"]
    
    return metadata_final

In [42]:
def process_directory(
    read_dir,
    save_dir,
    result_save_loc=None,   # optional now
    phrase_loc="/Volumes/BCross/paraphrase examples slurm/wiki-phrase-list-reviewed.xlsx"
):

    print(f"Reading from: {read_dir}")
    print(f"Saving processed files to: {save_dir}")
    
    if result_save_loc:
        print(f"Final combined output: {result_save_loc}\n")
    else:
        print("No result_save_loc provided → combined file will NOT be saved.\n")

    os.makedirs(save_dir, exist_ok=True)

    # If result_save_loc exists, do not overwrite
    if result_save_loc and os.path.exists(result_save_loc):
        print(f"Output already exists: {result_save_loc}. Exiting.")
        return

    # Read all Excel files
    xlsx_files = glob.glob(os.path.join(read_dir, "*.xlsx"))
    print(f"Found {len(xlsx_files)} files\n")

    all_metadata = []

    for i, file_path in enumerate(xlsx_files, start=1):
        print(f"Processing file {i}/{len(xlsx_files)}: {os.path.basename(file_path)}")

        try:
            metadata = create_results_doc_pipeline(
                file_path,
                write_excel=True,
                save_dir=save_dir,
                phrase_loc=phrase_loc
            )
            all_metadata.append(metadata)

        except Exception as e:
            print(f"❌ Failed to process {file_path}\nError: {e}\n")
            continue

    # Combine metadata
    if all_metadata:
        full_metadata = pd.concat(all_metadata, ignore_index=True)
    else:
        full_metadata = pd.DataFrame()

    # Sort if index exists
    if "index" in full_metadata.columns:
        full_metadata = full_metadata.sort_values(by="index").reset_index(drop=True)

    print("\nAll files complete ✓")
    
    # Save if a path was provided
    if result_save_loc:
        full_metadata.to_excel(result_save_loc, index=False)
        print(f"\nCombined results saved to: {result_save_loc}")
    else:
        print("\nSkipping save because no result_save_loc was provided.")

        return full_metadata  # return result in all cases


In [48]:
base_loc = '/Volumes/BCross/paraphrase examples slurm'
data_loc = f"{base_loc}/Wiki-Test"

# For Qwen
# read_dir = f'{data_loc}/raw'

# For everything else
read_dir = f'{data_loc}/ModernBERT-large/gpt2 results/raw'
save_dir = f'{data_loc}/ModernBERT-large/gpt2 results/filtered_inc_rank_v3'
result_save_loc = f"{data_loc}/ModernBERT-base/results.xlsx"
phrase_loc = f"{base_loc}/wiki-phrase-list-reviewed.xlsx"

process_directory(
    read_dir=read_dir,
    save_dir=save_dir,
    # result_save_loc=result_save_loc,
    phrase_loc=phrase_loc)

Reading from: /Volumes/BCross/paraphrase examples slurm/Wiki-Test/ModernBERT-large/gpt2 results/raw
Saving processed files to: /Volumes/BCross/paraphrase examples slurm/Wiki-Test/ModernBERT-large/gpt2 results/filtered_inc_rank_v3
No result_save_loc provided → combined file will NOT be saved.

Found 616 files

Processing file 1/616: legolas2186_text_2 vs legolas2186_text_3.xlsx
Processing Document: legolas2186_text_2 vs legolas2186_text_3.xlsx
Writing file
Processing file 2/616: livelikemusic_text_5 vs livelikemusic_text_3.xlsx
Processing Document: livelikemusic_text_5 vs livelikemusic_text_3.xlsx
Writing file
Processing file 3/616: rjecina_text_1 vs rjecina_text_11.xlsx
Processing Document: rjecina_text_1 vs rjecina_text_11.xlsx
Writing file
Processing file 4/616: notpietru_text_1 vs obamafan70_text_5.xlsx
Processing Document: notpietru_text_1 vs obamafan70_text_5.xlsx
Writing file
Processing file 5/616: obamafan70_text_3 vs orangemarlin_text_4.xlsx
Processing Document: obamafan70_text

Unnamed: 0,index,sample_id,problem,corpus,known_author,unknown_author,unknown_doc_id,known_doc_id,target,num_phrases,...,known_log_prob,unknown_log_prob,llr_no_context,llr_known,llr_unknown,llr_unknown_vs_top_rank,unknown_ref_vs_best_rest,normalised_llr_no_context,normalised_llr_known,normalised_llr_unknown
0,0,1,Hodja_Nasreddin vs Hodja_Nasreddin,Wiki,Hodja_Nasreddin,Hodja_Nasreddin,hodja_nasreddin_text_3,hodja_nasreddin_text_1,True,12,...,-94.802674,-100.632861,19.590962,11.743081,10.633380,-9.775343,-4.916896,1.632580,0.978590,0.886115
1,1,2,Hodja_Nasreddin vs Hodja_Nasreddin,Wiki,Hodja_Nasreddin,Hodja_Nasreddin,hodja_nasreddin_text_3,hodja_nasreddin_text_10,True,11,...,-82.992827,-83.322568,22.546681,11.643380,9.595564,-10.908952,7.442016,2.049698,1.058489,0.872324
2,2,3,Hodja_Nasreddin vs Hodja_Nasreddin,Wiki,Hodja_Nasreddin,Hodja_Nasreddin,hodja_nasreddin_text_3,hodja_nasreddin_text_11,True,8,...,-49.086520,-71.664797,10.791892,5.733034,6.693348,-5.423525,4.634485,1.348987,0.716629,0.836669
3,3,4,Hodja_Nasreddin vs HonestopL,Wiki,Hodja_Nasreddin,HonestopL,honestopl_text_1,hodja_nasreddin_text_1,False,8,...,-66.596993,-71.588018,17.202432,9.351598,6.611619,-6.337841,-3.123308,2.150304,1.168950,0.826452
4,4,5,Hodja_Nasreddin vs HonestopL,Wiki,Hodja_Nasreddin,HonestopL,honestopl_text_1,hodja_nasreddin_text_10,False,7,...,-49.804151,-41.695350,13.543120,5.959475,4.498780,-2.327063,5.433823,1.934731,0.851354,0.642683
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
611,667,668,ZjarriRrethues vs 142.196.88.228,Wiki,ZjarriRrethues,142.196.88.228,142_196_88_228_text_2,zjarrirrethues_text_4,False,5,...,-28.818235,-59.272039,19.189928,5.354576,7.316360,-8.091894,24.915954,3.837986,1.070915,1.463272
612,668,669,ZjarriRrethues vs 142.196.88.228,Wiki,ZjarriRrethues,142.196.88.228,142_196_88_228_text_2,zjarrirrethues_text_5,False,6,...,-53.050993,-77.322432,34.128463,9.099270,10.200630,-14.152155,11.186672,5.688077,1.516545,1.700105
613,669,670,ZjarriRrethues vs ZjarriRrethues,Wiki,ZjarriRrethues,ZjarriRrethues,zjarrirrethues_text_2,zjarrirrethues_text_1,True,9,...,-85.837280,-69.836864,22.351969,9.807849,10.135595,-12.228219,9.965105,2.483552,1.089761,1.126177
614,670,671,ZjarriRrethues vs ZjarriRrethues,Wiki,ZjarriRrethues,ZjarriRrethues,zjarrirrethues_text_2,zjarrirrethues_text_4,True,13,...,-131.761348,-123.233849,38.555566,16.965157,12.172307,-14.319427,11.892420,2.965813,1.305012,0.936331
