In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
from collections import defaultdict

# ---------- GO Ontology ----------
class GOOntology:
    ROOTS = {'GO:0003674', 'GO:0008150', 'GO:0005575'}

    def __init__(self, obo):
        self.parents = defaultdict(set)
        self.anc = {}
        with open(obo) as f:
            tid = None
            for l in f:
                l = l.strip()
                if l.startswith('id: '):
                    tid = l[4:]
                elif tid and (l.startswith('is_a: ') or l.startswith('relationship: part_of ')):
                    self.parents[tid].add(l.split()[1])

    def ancestors(self, t):
        if t in self.anc:
            return self.anc[t]
        s = set(self.parents.get(t, []))
        for p in list(s):
            s |= self.ancestors(p)
        self.anc[t] = s
        return s

# ---------- Load Predictions ----------
def load_preds(fp):
    d = defaultdict(dict)
    with open(fp) as f:
        for l in f:
            p, t, s = l.split('\t')
            s = float(s)
            if t not in d[p] or s > d[p][t]:
                d[p][t] = s
    return d

# ---------- Ensemble ----------
def ensemble(goa, prott5, w1=0.55):
    w2 = 1 - w1
    out = {}
    for p in set(goa) | set(prott5):
        a, b = goa.get(p, {}), prott5.get(p, {})
        r = {}
        for t in set(a) | set(b):
            s = w1 * a.get(t, 0) + w2 * b.get(t, 0)
            if s > 0:
                r[t] = s
        out[p] = r
    return out

# ---------- GOA+ Propagation ----------
def propagate(preds, ont, alpha=0.7, power=0.8, max_s=0.93, topk=270):
    out = {}
    for p, sc in preds.items():
        u = sc.copy()

        for t, s in sc.items():
            for a in ont.ancestors(t):
                u[a] = max(u.get(a, 0), s)

        for t in list(u):
            anc = ont.ancestors(t)
            if anc:
                m = min(u[a] for a in anc if a in u)
                if m < u[t]:
                    u[t] = alpha * m + (1 - alpha) * u[t]

        vals = [v for k, v in u.items() if k not in ont.ROOTS]
        if vals:
            mx = max(vals)
            if 0 < mx < max_s:
                for k in u:
                    if k not in ont.ROOTS:
                        u[k] = min(1.0, (u[k] / mx) ** power * max_s)

        for r in ont.ROOTS:
            u[r] = 1.0

        out[p] = dict(sorted(u.items(), key=lambda x: -x[1])[:topk])
    return out

# ---------- Save ----------
def save(preds, fp):
    with open(fp, 'w') as f:
        for p, ts in preds.items():
            for t, s in ts.items():
                if s >= 0.001:
                    f.write(f"{p}\t{t}\t{s:.6f}\n")

# ---------- RUN ----------
COMP = '/kaggle/input/cafa-6-protein-function-prediction'
PRED = '/kaggle/input/cafa6-goa-predictions'

ont = GOOntology(f'{COMP}/Train/go-basic.obo')
goa = load_preds(f'{PRED}/goa_submission.tsv')
pt5 = load_preds(f'{PRED}/prott5_interpro_predictions.tsv')

ens = ensemble(goa, pt5)
final = propagate(ens, ont)

save(final, 'submission.tsv')