### Credit goes to this author and notebook

https://www.kaggle.com/code/analyticaobscura/cafa-6-decoding-protein-mysteries 

-  here trying weighted average models
-  meta ensemble model is upcoming 

In [None]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import gc
from multiprocessing import Pool, cpu_count

############# Stacking Ensemble #############

def process_single_chunk(args):
    start, end, all_keys, file_paths, weights, method, num_models = args
    key_chunk = all_keys[start:end]
    result = pd.DataFrame({'key': key_chunk})

    for idx, path in enumerate(file_paths):
        model_scores = []
        for chunk in pd.read_csv(path, sep='\t', header=None,
                                 names=['protein', 'go_term', 'score'],
                                 dtype={'protein': str, 'go_term': str, 'score': float},
                                 chunksize=1_000_000):
            chunk['key'] = chunk['protein'] + '_' + chunk['go_term']
            filtered = chunk[chunk['key'].isin(key_chunk)][['key', 'score']]
            filtered = filtered.rename(columns={'score': f'score_{idx}'})
            model_scores.append(filtered)
        if model_scores:
            model_df = pd.concat(model_scores, ignore_index=True)
            model_df = model_df.groupby('key', as_index=False).mean()
            result = result.merge(model_df, on='key', how='left')

    score_cols = [col for col in result.columns if col.startswith('score_')]
    result[score_cols] = result[score_cols].fillna(0)
    if method == 'median':
        result['final_score'] = result[score_cols].median(axis=1)
    elif method == 'weighted_average':
        result['final_score'] = sum(result[f'score_{i}'] * weights[i] for i in range(num_models))
    elif method == 'rank_average':
        for i in range(num_models):
            result[f'rank_{i}'] = result[f'score_{i}'].rank(pct=True)
        result['final_score'] = sum(result[f'rank_{i}'] * weights[i] for i in range(num_models))
    elif method == 'all':
        result['median_score'] = result[score_cols].median(axis=1)
        result['weighted_avg'] = sum(result[f'score_{i}'] * weights[i] for i in range(num_models))
        for i in range(num_models):
            result[f'rank_{i}'] = result[f'score_{i}'].rank(pct=True)
        result['rank_avg'] = sum(result[f'rank_{i}'] * weights[i] for i in range(num_models))
        result['final_score'] = (result['median_score'] * 0.25 + 
                                result['weighted_avg'] * 0.40 + 
                                result['rank_avg'] * 0.35)
    result['protein'], result['go_term'] = zip(*result['key'].str.rsplit('_', n=1))
    return result[['protein', 'go_term', 'final_score']]

def stacking_ensemble_fast(file_paths, weights=None, method='all', output_path='submission.tsv',
                           chunksize=5_000_000, n_jobs=-1):
    if weights is None:
        weights = [1.0 / len(file_paths)] * len(file_paths)
    else:
        weights = np.array(weights) / np.array(weights).sum()
    if n_jobs == -1:
        n_jobs = max(1, cpu_count() - 1)
    print(f"Models: {len(file_paths)} | Weights: {weights} | Method: {method}")
    print(f"Parallel jobs: {n_jobs} | Chunk size: {chunksize:,}")

    print("\nScanning files...")
    all_keys = set()
    for path in tqdm(file_paths, desc="Files"):
        for chunk in pd.read_csv(path, sep='\t', header=None,
                                 names=['protein', 'go_term', 'score'],
                                 dtype={'protein': str, 'go_term': str},
                                 usecols=[0, 1],
                                 chunksize=5_000_000):
            chunk = chunk.dropna()
            keys = chunk['protein'] + '_' + chunk['go_term']
            all_keys.update(keys.values)
            del chunk, keys
            gc.collect()

    all_keys = sorted(all_keys)
    print(f"Total predictions: {len(all_keys):,}")

    print("\nProcessing chunks in parallel...")
    chunk_args = []
    for start in range(0, len(all_keys), chunksize):
        end = min(start + chunksize, len(all_keys))
        chunk_args.append((start, end, all_keys, file_paths, weights, method, len(file_paths)))
    if n_jobs > 1:
        with Pool(n_jobs) as pool:
            results = list(tqdm(pool.imap(process_single_chunk, chunk_args),
                                total=len(chunk_args), desc="Chunks"))
    else:
        results = [process_single_chunk(args) for args in tqdm(chunk_args, desc="Chunks")]
    print("\nCombining results...")
    final_df = pd.concat(results, ignore_index=True)
    del results
    gc.collect()
    print(f"\nSaving to {output_path}...")
    final_df.to_csv(output_path, sep='\t', index=False, header=False)
    print(f"âœ“ Done! {len(final_df):,} predictions saved\n")
    return final_df

############# Per-GO-Term Threshold Optimization #############

def find_go_term_thresholds(df, label_df, go_terms, thresholds=np.arange(0.01, 1.0, 0.01)):
    merged = df.merge(label_df, on=['protein', 'go_term'], how='left').fillna({'label': 0})
    go_thresholds = {}
    for term in tqdm(go_terms, desc="GO terms"):
        y_true = merged.loc[merged['go_term'] == term, 'label'].astype(int).values
        y_score = merged.loc[merged['go_term'] == term, 'final_score'].values
        best_f1, best_t = 0, 0
        for t in thresholds:
            pred = (y_score >= t).astype(int)
            tp = ((pred == 1) & (y_true == 1)).sum()
            fp = ((pred == 1) & (y_true == 0)).sum()
            fn = ((pred == 0) & (y_true == 1)).sum()
            precision = tp/(tp+fp) if tp+fp>0 else 0
            recall = tp/(tp+fn) if tp+fn>0 else 0
            f1 = 2*precision*recall/(precision+recall) if precision+recall>0 else 0
            if f1 > best_f1:
                best_f1, best_t = f1, t
        go_thresholds[term] = best_t
    return go_thresholds

def binarize_with_go_thresholds(df, go_thresholds):
    df['binarized'] = df.apply(lambda row: int(row['final_score'] >= go_thresholds.get(row['go_term'], 0.5)), axis=1)
    return df

############# GO Hierarchy Enforcement #############

def enforce_go_hierarchy(df, go_parents):
    term_list = df['go_term'].unique()
    grouped = df.groupby('protein')
    records = []
    for prot, group in tqdm(grouped, desc="GO hierarchy"):
        go_positive = set(group.loc[group['binarized'] == 1, 'go_term'])
        updated = set(go_positive)
        queue = list(go_positive)
        while queue:
            term = queue.pop()
            for parent in go_parents.get(term, []):
                if parent not in updated:
                    updated.add(parent)
                    queue.append(parent)
        for go_term in group['go_term']:
            out_val = 1 if go_term in updated else 0
            records.append({'protein': prot, 'go_term': go_term, 'binarized': out_val})
    return pd.DataFrame(records, columns=['protein','go_term','binarized'])


In [None]:

############# USAGE EXAMPLE #############

# Ensemble step (adjust as needed)
file_paths = [
    '/kaggle/input/cafa-6-t5-embeddings-with-ensemble/submission.tsv',
    '/kaggle/input/cafa-6-predictions/submission.tsv'
]
weights = [0.35, 0.30]
result = stacking_ensemble_fast(
    file_paths=file_paths,
    weights=weights,
    method='all',
    output_path='submission.tsv',
    chunksize=5_000_000,
    n_jobs=-1
)

# Per-GO-term thresholding (update oof_labels, go_terms accordingly)
# oof_labels: DataFrame ['protein','go_term','label'] from validation
# go_terms: List/array of observed go terms
go_terms = result['go_term'].unique()
oof_labels = pd.read_csv('path_to_oof_labels.tsv', sep='\t') # should have columns 'protein','go_term','label'
go_thresholds = find_go_term_thresholds(result, oof_labels, go_terms)
result = binarize_with_go_thresholds(result, go_thresholds)

# Enforce GO hierarchy (update go_parents accordingly)
# go_parents: dictionary {child: [parent1, parent2, ...]}
import json
with open('path_to_go_parents.json', 'r') as f:
    go_parents = json.load(f)
result = enforce_go_hierarchy(result, go_parents)

# Save final submission
result[['protein','go_term','binarized']].to_csv('submission.tsv', sep='\t', index=False, header=False)
