# CAFA6 GOA + ProtT5 Ensemble

Simple ensemble of GOA and ProtT5 predictions with GO hierarchy propagation.

Required dataset: `ymuroya47/cafa6-goa-predictions`

In [None]:
import os
from collections import defaultdict
from tqdm.auto import tqdm
import heapq

# Paths
COMPETITION_DATA = '/kaggle/input/cafa-6-protein-function-prediction'
GOA_DATA = '/kaggle/input/cafa6-goa-predictions'
GOA_PATH = f'{GOA_DATA}/goa_submission.tsv'
PROTT5_PATH = f'{GOA_DATA}/prott5_interpro_predictions.tsv'
GO_OBO_PATH = f'{COMPETITION_DATA}/Train/go-basic.obo'
OUTPUT_PATH = 'submission.tsv'

# Weights
WEIGHT_GOA = 0.55
WEIGHT_PROTT5 = 0.45

# GO Roots
GO_ROOTS = {"GO:0003674", "GO:0008150", "GO:0005575"}

# Parameters (high TOP_K for better score)
TOP_K = 250
MIN_SCORE = 0.001
POWER_SCALE = 0.8
MAX_SCORE = 0.93

print('Configuration loaded')

In [None]:
def load_test_ids():
    test_fasta = f'{COMPETITION_DATA}/Test/testsuperset.fasta'
    ids = set()
    with open(test_fasta, 'r') as f:
        for line in f:
            if line.startswith('>'):
                header = line[1:].strip().split()[0]
                if '|' in header:
                    header = header.split('|')[1]
                ids.add(header)
    print(f'Loaded {len(ids):,} test proteins')
    return ids

test_ids = load_test_ids()

In [None]:
def parse_go_ontology():
    print('Parsing GO ontology...')
    term_parents = defaultdict(set)

    with open(GO_OBO_PATH, 'r') as f:
        current_id = None
        for line in f:
            line = line.strip()
            if line.startswith('id: '):
                current_id = line.split('id: ')[1].strip()
            elif line.startswith('is_a: ') and current_id:
                parent = line.split()[1].strip()
                term_parents[current_id].add(parent)
            elif line.startswith('relationship: part_of ') and current_id:
                parts = line.split()
                if len(parts) >= 3:
                    parent = parts[2].strip()
                    term_parents[current_id].add(parent)

    ancestors_map = {}

    def get_ancestors(term):
        if term in ancestors_map:
            return ancestors_map[term]
        parents = term_parents.get(term, set())
        all_anc = set(parents)
        for p in parents:
            all_anc |= get_ancestors(p)
        ancestors_map[term] = all_anc
        return all_anc

    for term in tqdm(list(term_parents.keys()), desc='Building ancestors'):
        get_ancestors(term)

    print(f'Cached {len(ancestors_map):,} GO terms')
    return ancestors_map

ancestors_map = parse_go_ontology()

In [None]:
def load_predictions(filepath, allowed_proteins, desc='Loading'):
    preds = defaultdict(dict)
    with open(filepath, 'r') as f:
        for line in tqdm(f, desc=desc):
            parts = line.strip().split('\t')
            if len(parts) < 3:
                continue
            protein = parts[0].strip()
            if '|' in protein:
                protein = protein.split('|')[1]
            if protein not in allowed_proteins:
                continue
            go_term = parts[1].strip()
            try:
                score = float(parts[2])
            except ValueError:
                continue
            if go_term in preds[protein]:
                preds[protein][go_term] = max(preds[protein][go_term], score)
            else:
                preds[protein][go_term] = score
    return dict(preds)

print('Loading GOA...')
goa_preds = load_predictions(GOA_PATH, test_ids, 'GOA')
print(f'GOA proteins: {len(goa_preds):,}')

print('Loading ProtT5...')
prott5_preds = load_predictions(PROTT5_PATH, test_ids, 'ProtT5')
print(f'ProtT5 proteins: {len(prott5_preds):,}')

In [None]:
def merge_predictions(goa_preds, prott5_preds):
    print('Merging predictions...')
    all_proteins = set(goa_preds.keys()) | set(prott5_preds.keys())
    merged = {}
    for pid in tqdm(all_proteins, desc='Merging'):
        a = goa_preds.get(pid, {})
        b = prott5_preds.get(pid, {})
        if not a and not b:
            continue
        terms = set(a.keys()) | set(b.keys())
        merged[pid] = {}
        for t in terms:
            s1 = a.get(t, 0.0)
            s2 = b.get(t, 0.0)
            if s1 > 0.0 and s2 > 0.0:
                merged[pid][t] = WEIGHT_GOA * s1 + WEIGHT_PROTT5 * s2
            else:
                merged[pid][t] = s1 if s1 > 0.0 else s2
    return merged

merged = merge_predictions(goa_preds, prott5_preds)
print(f'Merged proteins: {len(merged):,}')

# Free memory
del goa_preds, prott5_preds

In [None]:
def positive_propagation(base_scores, ancestors_map):
    upd = dict(base_scores)
    for term, score in base_scores.items():
        if term in GO_ROOTS:
            continue
        for anc in ancestors_map.get(term, ()):
            prev = upd.get(anc)
            if prev is None or score > prev:
                upd[anc] = score
    for root in GO_ROOTS:
        upd[root] = 1.0
    return upd

def power_scaling(scores, power=0.80, max_score=0.93):
    out = dict(scores)
    non_root = [s for t, s in out.items() if t not in GO_ROOTS]
    if not non_root:
        for r in GO_ROOTS:
            out[r] = 1.0
        return out
    mx = max(non_root)
    if mx <= 0.0 or mx >= max_score:
        for r in GO_ROOTS:
            out[r] = 1.0
        return out
    inv = 1.0 / mx
    for t in list(out.keys()):
        if t in GO_ROOTS:
            continue
        val = (out[t] * inv) ** power * max_score
        out[t] = min(1.0, val)
    for r in GO_ROOTS:
        out[r] = 1.0
    return out

def topk_filter(scores, k):
    return heapq.nlargest(k, scores.items(), key=lambda x: x[1])

In [None]:
print('Processing proteins...')
output_lines = []

for pid in tqdm(merged, desc='Processing'):
    base = merged[pid]
    if not base:
        continue
    # 1. Positive propagation
    pos = positive_propagation(base, ancestors_map)
    # 2. Power scaling
    scaled = power_scaling(pos, POWER_SCALE, MAX_SCORE)
    # 3. Top-K filtering
    top_terms = topk_filter(scaled, TOP_K)
    # 4. Write predictions
    for go_term, score in top_terms:
        if score >= MIN_SCORE:
            output_lines.append(f'{pid}\t{go_term}\t{score:.4f}')

print(f'Total predictions: {len(output_lines):,}')

In [None]:
print(f'Saving to {OUTPUT_PATH}...')
with open(OUTPUT_PATH, 'w') as f:
    f.write('\n'.join(output_lines))

import os
size_mb = os.path.getsize(OUTPUT_PATH) / 1024 / 1024
print(f'\nSubmission saved: {OUTPUT_PATH}')
print(f'  Total predictions: {len(output_lines):,}')
print(f'  File size: {size_mb:.1f} MB')