In [73]:
import gspread
from gspread_dataframe import get_as_dataframe
import json
import pandas as pd

In [74]:
def analyze_retention_by_tier(
    target_worksheet_title: str,
    sheet_url: str = "https://docs.google.com/spreadsheets/d/1g52ckNkohDlzH0jUjCksR3EXxxJXAe7_XAhf1TGBjZo/edit?usp=sharing",
    resource_json_path: str = "./language_source_classification.json",
    service_account_path: str = "/scratch/project_462000941/members/zihao/rm4mt/rm4mt_eval/rm4mt-463314-3ce1280ee29c.json"
) -> pd.DataFrame:
    """Analyze retention rates by language tier pairs from a Google Sheet.
    
    Args:
        target_worksheet_title: Title of the worksheet to analyze
        sheet_url: URL of the Google Sheet
        resource_json_path: Path to language resource classification JSON
        service_account_path: Path to Google service account credentials
        
    Returns:
        DataFrame with retention statistics by tier pairs
    """
    gc = gspread.service_account(filename=service_account_path)
    spreadsheet = gc.open_by_url(sheet_url)
    
    # Load worksheet data
    ws_data = {}
    for worksheet in spreadsheet.worksheets():
        df = get_as_dataframe(worksheet)
        df = df.dropna(how='all')
        ws_data[worksheet.title] = df

    df = ws_data[target_worksheet_title]
    df = df[df['lang_pair'] != "ALL"].copy()

    # Load language tiers
    with open(resource_json_path, "r", encoding="utf-8") as f:
        resource_tiers = json.load(f)
    resource_labels = ["low", "low", "mid", "mid", "high", "high"]

    lang2tier = {}
    for idx, d in enumerate(resource_tiers):
        for code in d.keys():
            lang2tier[code] = resource_labels[idx]

    def get_tiers(row):
        src, tgt = row['lang_pair'].split('-')
        tier_src = lang2tier.get(src, "unknown")
        tier_tgt = lang2tier.get(tgt, "unknown")
        return pd.Series([src, tgt, tier_src, tier_tgt])

    df[['src_lang', 'tgt_lang', 'src_tier', 'tgt_tier']] = df.apply(get_tiers, axis=1)
    df['tier_pair'] = df['src_tier'] + '-' + df['tgt_tier']

    # Calculate retention stats
    stats = (
        df.groupby('tier_pair')
        .agg(
            pre_count_sum=('pre_count', 'sum'),
            post_count_sum=('post_count', 'sum')
        )
        .reset_index()
    )

    stats['retention_rate'] = stats.apply(
        lambda r: r['post_count_sum'] / r['pre_count_sum'] if r['pre_count_sum'] > 0 else 0,
        axis=1
    )

    return stats.sort_values('retention_rate')

In [75]:
stats = analyze_retention_by_tier("mala-opus-dedup-2410-ReLID-by-GlotLID-Threshold-0_9")
print(stats)

   tier_pair pre_count_sum post_count_sum  retention_rate
7    mid-low     432905064      180328877        0.416555
5    low-mid     324428479      150201385        0.462972
8    mid-mid    3823600614     1811186350        0.473686
3   low-high     775092270      373482089        0.481855
4    low-low     224224154      115882132        0.516814
2   high-mid   10952463269     5897156431        0.538432
1   high-low    2213934690     1241822666        0.560912
6   mid-high   11244512235     7308911293        0.649998
0  high-high   32681289933    23409707510        0.716303


In [76]:
stats = analyze_retention_by_tier("mala-opus-dedup-2410-ReLID-by-ConLID-Threshold-0_9")
print(stats)

   tier_pair pre_count_sum post_count_sum  retention_rate
7    mid-low     432905064      161737708        0.373610
3   low-high     775092270      307678358        0.396957
8    mid-mid    3823600614     1624400984        0.424835
5    low-mid     324428479      140446087        0.432903
2   high-mid   10952463269     4799343593        0.438198
1   high-low    2213934690      977745025        0.441632
4    low-low     224224154      114697284        0.511530
6   mid-high   11244512235     5812298130        0.516901
0  high-high   32681289933    16927493296        0.517957
