# CAFA-6 — Kaggle notebook setup



Everything lives in this notebook:

- Config (paths + artefacts)

- Sanity checks (inputs present, quick stats)

- Minimal diagnostics plots (so we catch path/data issues early)



Assumption (Kaggle): your dataset folder under `/kaggle/input/.../` contains:

- `Train/` and `Test/`

- `IA.tsv` and `sample_submission.tsv`


In [14]:
# 1. SETUP, CONFIG & DIAGNOSTICS
# ==========================================
# HARDWARE: CPU (Standard)
# ==========================================

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# ------------------------------------------
# Environment Detection & Paths
# ------------------------------------------
IS_KAGGLE = os.getenv('KAGGLE_KERNEL_RUN_TYPE') is not None or Path('/kaggle').exists()

if IS_KAGGLE:
    print("Environment: Kaggle Detected")
    INPUT_ROOT = Path('/kaggle/input')
    WORKING_ROOT = Path('/kaggle/working')
    
    # Standard Kaggle Input Listing
    for dirname, _, filenames in os.walk('/kaggle/input'):
        for filename in filenames:
            print(os.path.join(dirname, filename))
else:
    print("Environment: Local Detected")
    # Robust Project Root Detection
    CURRENT_DIR = Path.cwd()
    if CURRENT_DIR.name == 'notebooks':
        PROJECT_ROOT = CURRENT_DIR.parent
    else:
        PROJECT_ROOT = CURRENT_DIR
        
    INPUT_ROOT = PROJECT_ROOT
    WORKING_ROOT = PROJECT_ROOT / 'artefacts_local'
    WORKING_ROOT.mkdir(exist_ok=True)

# Artefacts Directory
ARTEFACTS_DIR = WORKING_ROOT / 'artefacts'
ARTEFACTS_DIR.mkdir(parents=True, exist_ok=True)
(ARTEFACTS_DIR / 'parsed').mkdir(parents=True, exist_ok=True)
(ARTEFACTS_DIR / 'features').mkdir(parents=True, exist_ok=True)

# ------------------------------------------
# Dataset Discovery
# ------------------------------------------
DATASET_SLUG = 'cafa-6-protein-function-prediction'

def find_dataset_root(input_root: Path, dataset_slug: str) -> Path:
    # 1. Check for Kaggle slug
    candidate = input_root / dataset_slug
    if candidate.exists():
        return candidate
    
    # 2. Check if we are already in the root (Local)
    if (input_root / 'Train').exists():
        return input_root

    # 3. Fallback search
    candidates = [p for p in input_root.iterdir() if p.is_dir()]
    def score(p: Path) -> int:
        return int((p / 'Train').exists()) + int((p / 'Test').exists())
    candidates = sorted(candidates, key=score, reverse=True)
    if candidates and score(candidates[0]) > 0:
        return candidates[0]
        
    raise FileNotFoundError(f"Dataset not found in {input_root}")

DATASET_ROOT = find_dataset_root(INPUT_ROOT, DATASET_SLUG)
print(f"DATASET_ROOT: {DATASET_ROOT}")

# Define Paths
PATH_IA = DATASET_ROOT / 'IA.tsv'
PATH_SAMPLE_SUB = DATASET_ROOT / 'sample_submission.tsv'
PATH_TRAIN_FASTA = DATASET_ROOT / 'Train' / 'train_sequences.fasta'
PATH_TRAIN_TERMS = DATASET_ROOT / 'Train' / 'train_terms.tsv'
PATH_TRAIN_TAXON = DATASET_ROOT / 'Train' / 'train_taxonomy.tsv'
PATH_GO_OBO = DATASET_ROOT / 'Train' / 'go-basic.obo'
PATH_TEST_FASTA = DATASET_ROOT / 'Test' / 'testsuperset.fasta'
PATH_TEST_TAXON = DATASET_ROOT / 'Test' / 'testsuperset-taxon-list.tsv'

# ------------------------------------------
# Sanity Checks
# ------------------------------------------
required = {
    'IA.tsv': PATH_IA,
    'Train/train_sequences.fasta': PATH_TRAIN_FASTA,
    'Train/train_terms.tsv': PATH_TRAIN_TERMS,
    'Train/go-basic.obo': PATH_GO_OBO,
}
missing = {k: v for k, v in required.items() if not v.exists()}
if missing:
    raise FileNotFoundError(f"Missing files: {missing}")
print("All required inputs found.")

# ------------------------------------------
# Initial Diagnostics (Sequence Lengths)
# ------------------------------------------
%matplotlib inline
plt.rcParams.update({'font.size': 10})

def read_fasta_lengths(path: Path, max_records=20000):
    lengths = []
    current = 0
    n = 0
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if n > 0: lengths.append(current)
                n += 1
                current = 0
                if max_records and n > max_records: break
            else:
                current += len(line)
        if n > 0: lengths.append(current)
    return np.array(lengths)

plt.figure(figsize=(10, 3))
plt.hist(read_fasta_lengths(PATH_TRAIN_FASTA), bins=50, alpha=0.5, label='Train')
if PATH_TEST_FASTA.exists():
    plt.hist(read_fasta_lengths(PATH_TEST_FASTA), bins=50, alpha=0.5, label='Test')
plt.title('Sequence Length Distribution (First 20k)')
plt.legend()
plt.show()

In [None]:
# 2. PHASE 1: DATA STRUCTURING & HIERARCHY
# ==========================================
# HARDWARE: CPU (Standard)
# ==========================================

# ------------------------------------------
# A. Parse FASTA to Feather
# ------------------------------------------
def parse_fasta(path: Path) -> pd.DataFrame:
    ids, seqs = [], []
    cur_id, cur_seq = None, []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if cur_id:
                    ids.append(cur_id)
                    seqs.append(''.join(cur_seq))
                cur_id = line[1:].split()[0]
                cur_seq = []
            else:
                cur_seq.append(line)
        if cur_id:
            ids.append(cur_id)
            seqs.append(''.join(cur_seq))
    return pd.DataFrame({'id': ids, 'sequence': seqs})

print("Parsing FASTA...")
parse_fasta(PATH_TRAIN_FASTA).to_feather(ARTEFACTS_DIR / 'parsed' / 'train_seq.feather')
if PATH_TEST_FASTA.exists():
    parse_fasta(PATH_TEST_FASTA).to_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')
print("FASTA parsed and saved to artefacts.")

# ------------------------------------------
# B. Parse OBO & Terms
# ------------------------------------------
def parse_obo(path: Path):
    parents = {}
    namespaces = {}
    cur_id, cur_ns = None, None
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line == '[Term]':
                if cur_id and cur_ns: namespaces[cur_id] = cur_ns
                cur_id, cur_ns = None, None
            elif line.startswith('id: GO:'):
                cur_id = line.split('id: ', 1)[1]
            elif line.startswith('namespace:'):
                cur_ns = line.split('namespace: ', 1)[1]
            elif line.startswith('is_a:') and cur_id:
                parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                if cur_id not in parents: parents[cur_id] = set()
                parents[cur_id].add(parent)
        if cur_id and cur_ns: namespaces[cur_id] = cur_ns
    return parents, namespaces

print("Parsing OBO...")
go_parents, go_namespaces = parse_obo(PATH_GO_OBO)
print(f"GO Graph: {len(go_parents)} nodes with parents, {len(go_namespaces)} terms with namespace.")

# ------------------------------------------
# C. Process Terms & Priors
# ------------------------------------------
terms = pd.read_csv(PATH_TRAIN_TERMS, sep='\t')
col_term = terms.columns[1]
terms['aspect'] = terms[col_term].map(lambda x: go_namespaces.get(x, 'UNK'))

# Plot Aspects
plt.figure(figsize=(6, 3))
terms['aspect'].value_counts().plot(kind='bar', title='Annotations by Namespace')
plt.show()

# Save Priors
priors = (terms[col_term].value_counts() / terms.iloc[:,0].nunique()).reset_index()
priors.columns = ['term', 'prior']
if PATH_IA.exists():
    ia = pd.read_csv(PATH_IA, sep='\t', names=['term', 'ia'])
    priors = priors.merge(ia, on='term', how='left').fillna(0)
priors.to_parquet(ARTEFACTS_DIR / 'parsed' / 'term_priors.parquet')
print("Terms processed and priors saved.")

# ------------------------------------------
# D. Process Taxonomy
# ------------------------------------------
print("Processing Taxonomy...")
# Train Taxonomy
tax_train = pd.read_csv(PATH_TRAIN_TAXON, sep='\t', header=None, names=['id', 'taxon_id'])
tax_train['taxon_id'] = tax_train['taxon_id'].astype(int)
tax_train.to_feather(ARTEFACTS_DIR / 'parsed' / 'train_taxa.feather')

# Test Taxonomy (Extract from FASTA headers)
if PATH_TEST_FASTA.exists():
    ids, taxons = [], []
    with PATH_TEST_FASTA.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                parts = line[1:].split()
                ids.append(parts[0])
                # Assume second part is taxon if present
                if len(parts) > 1:
                    try:
                        taxons.append(int(parts[1]))
                    except ValueError:
                        taxons.append(0)
                else:
                    taxons.append(0)
    tax_test = pd.DataFrame({'id': ids, 'taxon_id': taxons})
    tax_test.to_feather(ARTEFACTS_DIR / 'parsed' / 'test_taxa.feather')
    print(f"Taxonomy processed. Train: {len(tax_train)}, Test: {len(tax_test)}")
else:
    print(f"Taxonomy processed. Train: {len(tax_train)}")

# ------------------------------------------
# E. Save Targets & Term List
# ------------------------------------------
print("Saving Targets & Term List...")
# Save full terms list (long format)
terms.to_parquet(ARTEFACTS_DIR / 'parsed' / 'train_terms.parquet')

# Save unique term list with counts
term_counts = terms['term'].value_counts().reset_index()
term_counts.columns = ['term', 'count']
term_counts.to_parquet(ARTEFACTS_DIR / 'parsed' / 'term_counts.parquet')
print("Targets saved.")

## How the local GOA notebook fits (recommended)
This notebook’s **External Data & Evidence Codes** step is disk-heavy and is best done locally.

Use the local precompute notebook: `notebooks/CAFA6_Local_GOA_Precompute.ipynb`
- It creates compact, compressed artefacts filtered to CAFA protein IDs:
  - `goa_filtered_iea.tsv.gz` (IEA-only; safest default)
  - `goa_filtered_all.tsv.gz` (all evidence codes; richer signal)
- Upload one or both files to Kaggle as a Dataset (or keep locally if running locally).

Then, here (on Kaggle):
- Prefer **using the precomputed `.tsv.gz`** rather than downloading/parsing the full GOA dump.
- The notebook cell below is updated to look for these artefacts first.

In [None]:
# 2.1 PHASE 1 (Step 3): EXTERNAL DATA & EVIDENCE CODES (ARTEFACT-FIRST)
# ================================================================
# Goal: use precomputed, CAFA-filtered GOA artefacts and avoid Kaggle disk blow-ups.

from pathlib import Path

# Toggle to enable external features
PROCESS_EXTERNAL = True

# Recommended (Kaggle): set this to your dataset folder, e.g. Path('/kaggle/input/goa-filtered-all-tsv-gz')
# If left empty on Kaggle, we scan `/kaggle/input` for `goa_filtered_*.tsv(.gz)` only.
GOA_ARTEFACT_DIRS = [
    # Path('/kaggle/input/<dataset-folder>'),
]

if PROCESS_EXTERNAL:
    EXT_DIR = ARTEFACTS_DIR / 'external'
    EXT_DIR.mkdir(exist_ok=True)

    def _existing(dirs: list[Path]) -> list[Path]:
        return [Path(d) for d in dirs if d is not None and str(d).strip() and Path(d).exists()]

    def _discover(dirs: list[Path]) -> list[Path]:
        # keep discovery tight so we don't accidentally pick dataset TSVs (train_terms, etc.)
        patterns = ('**/goa_filtered_*.tsv.gz', '**/goa_filtered_*.tsv')
        found: list[Path] = []
        for d in dirs:
            # expected names (fast path)
            found += [
                d / 'goa_filtered_all.tsv.gz', d / 'goa_filtered_iea.tsv.gz',
                d / 'goa_filtered_all.tsv', d / 'goa_filtered_iea.tsv',
            ]
            for pat in patterns:
                found += sorted(Path(d).glob(pat))
        # de-dupe + keep only existing files
        out: list[Path] = []
        seen = set()
        for p in found:
            p = Path(p)
            if p in seen:
                continue
            seen.add(p)
            if p.exists() and p.is_file():
                out.append(p)
        return out

    def _pick_best(paths: list[Path]) -> Path:
        def score(p: Path) -> tuple[int, int]:
            name = p.name.lower()
            ext_rank = 0 if name.endswith('.tsv.gz') else 1
            if 'goa_filtered_all' in name:
                return (0, ext_rank)
            if 'goa_filtered_iea' in name:
                return (1, ext_rank)
            if 'all' in name:
                return (2, ext_rank)
            if 'iea' in name:
                return (3, ext_rank)
            return (4, ext_rank)
        return sorted(paths, key=score)[0]

    roots = _existing(GOA_ARTEFACT_DIRS)
    if not roots and Path('/kaggle/input').exists():
        print('GOA_ARTEFACT_DIRS empty; scanning /kaggle/input for artefacts...')
        roots = [Path('/kaggle/input')]
    elif not roots:
        roots = [EXT_DIR, Path('artefacts_local/artefacts/external')]

    goa_paths = _discover(roots)
    if not goa_paths:
        print('No precomputed GOA artefacts found.')
        print('Fix: upload `goa_filtered_all.tsv(.gz)` and/or `goa_filtered_iea.tsv(.gz)` as a Kaggle Dataset,')
        print('then set GOA_ARTEFACT_DIRS = [Path("/kaggle/input/<dataset-folder>")]')
        PROCESS_EXTERNAL = False
    else:
        print('Found GOA artefacts (first 10):')
        for p in goa_paths[:10]:
            print(' -', p)
        if len(goa_paths) > 10:
            print(f' ... (+{len(goa_paths) - 10} more)')
        GOA_FEATURE_PATH = _pick_best(goa_paths)
        print('Using:', GOA_FEATURE_PATH)
        print('Format expected: EntryID<TAB>term<TAB>evidence (header included)')
else:
    print('Skipping External Data (PROCESS_EXTERNAL=False).')

In [None]:
# 2.2 PHASE 1 (Step 4): HIERARCHY PROPAGATION FOR EXTERNAL GOA (NO-KAGGLE / IEA)
# =======================================================================
# Produces the intended artefacts from `NB LM/Phases.md`:
# - `prop_train_no_kaggle.tsv.gz`
# - `prop_test_no_kaggle.tsv.gz`
# These are propagated external labels (IEA only) restricted to the top-K train terms.

import gzip
import pandas as pd
from pathlib import Path

EXTERNAL_TOP_K = 1500
EXTERNAL_PRIOR_SCORE = 1.0  # binary prior (later we down-weight when injecting into models)

if PROCESS_EXTERNAL and 'GOA_FEATURE_PATH' in locals():
    # Ensure we have GO parents (hierarchy)
    if 'go_parents' not in locals():
        def parse_obo(path: Path):
            parents = {}
            namespaces = {}
            cur_id, cur_ns = None, None
            with path.open('r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line == '[Term]':
                        if cur_id and cur_ns:
                            namespaces[cur_id] = cur_ns
                        cur_id, cur_ns = None, None
                    elif line.startswith('id: GO:'):
                        cur_id = line.split('id: ', 1)[1]
                    elif line.startswith('namespace:'):
                        cur_ns = line.split('namespace: ', 1)[1]
                    elif line.startswith('is_a:') and cur_id:
                        parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                        parents.setdefault(cur_id, set()).add(parent)
                if cur_id and cur_ns:
                    namespaces[cur_id] = cur_ns
            return parents, namespaces
        go_parents, go_namespaces = parse_obo(PATH_GO_OBO)

    # Ensure train/test IDs are available
    train_seq_path = ARTEFACTS_DIR / 'parsed' / 'train_seq.feather'
    test_seq_path = ARTEFACTS_DIR / 'parsed' / 'test_seq.feather'
    if not train_seq_path.exists():
        raise FileNotFoundError('Missing train_seq.feather. Run Phase 1 Step 2 (FASTA parse) first.')
    if not test_seq_path.exists():
        raise FileNotFoundError('Missing test_seq.feather. Run Phase 1 Step 2 (FASTA parse) first.')
    train_ids = set(pd.read_feather(train_seq_path)['id'].astype(str).tolist())
    test_ids = set(pd.read_feather(test_seq_path)['id'].astype(str).tolist())

    # Get top-K train terms (defines the external feature space)
    train_terms_path = ARTEFACTS_DIR / 'parsed' / 'train_terms.parquet'
    if not train_terms_path.exists():
        raise FileNotFoundError('Missing train_terms.parquet. Run Phase 1 Step 2 (targets parse) first.')
    train_terms = pd.read_parquet(train_terms_path)
    top_terms = train_terms['term'].value_counts().head(EXTERNAL_TOP_K).index.tolist()
    top_terms_set = set(top_terms)
    print(f'External propagation restricted to top {len(top_terms)} train terms.')

    EXT_DIR = ARTEFACTS_DIR / 'external'
    EXT_DIR.mkdir(exist_ok=True)
    out_train = EXT_DIR / 'prop_train_no_kaggle.tsv.gz'
    out_test = EXT_DIR / 'prop_test_no_kaggle.tsv.gz'

    # Ancestor closure with memoisation
    _anc_cache: dict[str, set[str]] = {}
    def ancestors(term: str) -> set[str]:
        if term in _anc_cache:
            return _anc_cache[term]
        seen = set([term])
        stack = [term]
        while stack:
            t = stack.pop()
            for p in go_parents.get(t, ()):
                if p not in seen:
                    seen.add(p)
                    stack.append(p)
        _anc_cache[term] = seen
        return seen

    # Stream GOA rows and write propagated edges
    cols = ['EntryID', 'term', 'evidence']
    print('Streaming GOA artefact:', GOA_FEATURE_PATH)
    print('Writing:', out_train)
    print('Writing:', out_test)

    n_train = 0
    n_test = 0
    # Write headers
    with gzip.open(out_train, 'wt', encoding='utf-8') as ftr, gzip.open(out_test, 'wt', encoding='utf-8') as fte:
        ftr.write('EntryID\tterm\tscore\n')
        fte.write('EntryID\tterm\tscore\n')

        for chunk in pd.read_csv(
            GOA_FEATURE_PATH,
            sep='\t',
            dtype=str,
            usecols=lambda c: c in cols,
            chunksize=500_000,
        ):
            # Normalise columns if needed
            missing = [c for c in cols if c not in chunk.columns]
            if missing:
                raise ValueError(f'GOA artefact missing columns: {missing}. Found: {list(chunk.columns)}')
            chunk = chunk[cols]
            chunk = chunk.dropna()

            # No-Kaggle = IEA only
            chunk = chunk[chunk['evidence'] == 'IEA']
            if chunk.empty:
                continue

            # De-dupe within chunk (cuts write volume)
            chunk = chunk.drop_duplicates(subset=['EntryID', 'term'])

            # Propagate -> write only terms in top_terms_set
            for entry_id, term in zip(chunk['EntryID'].tolist(), chunk['term'].tolist()):
                if entry_id in train_ids:
                    target = ftr
                elif entry_id in test_ids:
                    target = fte
                else:
                    continue
                keep = ancestors(term) & top_terms_set
                if not keep:
                    continue
                for t in keep:
                    target.write(f'{entry_id}\t{t}\t{EXTERNAL_PRIOR_SCORE}\n')
                if target is ftr:
                    n_train += len(keep)
                else:
                    n_test += len(keep)

    print(f'Wrote propagated IEA edges: train={n_train:,} test={n_test:,}')
    print('Outputs are intentionally sparse priors (score=1.0) and will be down-weighted when injected.')
else:
    print('Skipping external propagation (PROCESS_EXTERNAL=False or GOA_FEATURE_PATH missing).')

In [None]:
# 3a. PHASE 1: EMBEDDINGS GENERATION (T5 only)
# ============================================
# HARDWARE: GPU recommended
# ============================================

# Split from ESM2 so you can run each independently on Kaggle.

COMPUTE_T5 = True  # <--- enable/disable T5 run

if COMPUTE_T5:
    import os
    import gc
    import numpy as np
    import pandas as pd
    import torch
    from transformers import T5Tokenizer, T5EncoderModel
    from tqdm.auto import tqdm
    import contextlib

    # Optimise CUDA memory allocation
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    # Fix Protobuf 'GetPrototype' error
    os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    amp_ctx = torch.amp.autocast('cuda') if device.type == 'cuda' else contextlib.nullcontext()

    def get_t5_model():
        print("Loading T5 Model...")
        tokenizer = T5Tokenizer.from_pretrained(
            "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False, legacy=True
        )
        model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)
        model.eval()
        return tokenizer, model

    def generate_embeddings_t5(model, tokenizer, sequences, batch_size=4, max_len=1024):
        seq_lens = [len(s) for s in sequences]
        sort_idx = np.argsort(seq_lens)[::-1]
        sorted_seqs = [sequences[i] for i in sort_idx]

        embeddings_list = []
        for i in tqdm(range(0, len(sorted_seqs), batch_size), desc="Embedding T5 (Smart Batch)"):
            batch_seqs = sorted_seqs[i : i + batch_size]
            batch_seqs = [seq.replace('U','X').replace('Z','X').replace('O','X').replace('B','X') for seq in batch_seqs]
            batch_seqs = [" ".join(list(seq)) for seq in batch_seqs]

            ids = tokenizer.batch_encode_plus(
                batch_seqs, add_special_tokens=True, padding="longest",
                truncation=True, max_length=max_len, return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                with amp_ctx:
                    embedding_repr = model(input_ids=ids['input_ids'], attention_mask=ids['attention_mask'])

            emb = embedding_repr.last_hidden_state.float().detach().cpu().numpy()
            mask = ids['attention_mask'].detach().cpu().numpy()
            for j in range(len(batch_seqs)):
                seq_len = int(mask[j].sum())
                valid_emb = emb[j, :seq_len]
                embeddings_list.append(valid_emb.mean(axis=0))

            del ids, embedding_repr, emb, mask
            if device.type == 'cuda':
                torch.cuda.empty_cache()

        sorted_embeddings = np.vstack(embeddings_list)
        original_order_embeddings = np.zeros_like(sorted_embeddings)
        original_order_embeddings[sort_idx] = sorted_embeddings
        return original_order_embeddings

    print("Loading sequences for T5 embedding...")
    train_df = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'train_seq.feather')
    test_df = None
    if (ARTEFACTS_DIR / 'parsed' / 'test_seq.feather').exists():
        test_df = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')

    tokenizer, model = get_t5_model()
    print(f"Generating Train Embeddings T5 ({len(train_df)})...")
    train_emb = generate_embeddings_t5(model, tokenizer, train_df['sequence'].tolist())
    np.save(ARTEFACTS_DIR / 'features' / 'train_embeds_t5.npy', train_emb)
    del train_emb
    gc.collect()

    if test_df is not None:
        print(f"Generating Test Embeddings T5 ({len(test_df)})...")
        test_emb = generate_embeddings_t5(model, tokenizer, test_df['sequence'].tolist())
        np.save(ARTEFACTS_DIR / 'features' / 'test_embeds_t5.npy', test_emb)
        del test_emb

    del model, tokenizer, train_df, test_df
    gc.collect()
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    print("T5 embeddings generated.")
else:
    print("Skipping T5 embedding generation (COMPUTE_T5=False).")

In [None]:
# 3b. PHASE 1: EMBEDDINGS GENERATION (ESM2 only)
# ==============================================
# HARDWARE: GPU recommended
# ==============================================

# Split from T5 so you can run each independently on Kaggle.

COMPUTE_ESM2 = True  # <--- enable/disable ESM2 run

if COMPUTE_ESM2:
    import os
    import gc
    import numpy as np
    import pandas as pd
    import torch
    from transformers import EsmTokenizer, EsmModel
    from tqdm.auto import tqdm
    import contextlib

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    amp_ctx = torch.amp.autocast('cuda') if device.type == 'cuda' else contextlib.nullcontext()

    def get_esm2_model():
        print("Loading ESM2 650M Model...")
        model_name = "facebook/esm2_t33_650M_UR50D"
        tokenizer = EsmTokenizer.from_pretrained(model_name)
        model = EsmModel.from_pretrained(model_name).to(device)
        model.eval()
        return tokenizer, model

    def generate_embeddings_esm2(model, tokenizer, sequences, batch_size=16, max_len=1024):
        seq_lens = [len(s) for s in sequences]
        sort_idx = np.argsort(seq_lens)[::-1]
        sorted_seqs = [sequences[i] for i in sort_idx]

        embeddings_list = []
        for i in tqdm(range(0, len(sorted_seqs), batch_size), desc="Embedding ESM2 (Smart Batch)"):
            batch_seqs = sorted_seqs[i : i + batch_size]
            ids = tokenizer.batch_encode_plus(
                batch_seqs, add_special_tokens=True, padding="longest",
                truncation=True, max_length=max_len, return_tensors="pt"
            ).to(device)

            with torch.no_grad():
                with amp_ctx:
                    output = model(input_ids=ids['input_ids'], attention_mask=ids['attention_mask'])

            emb = output.last_hidden_state.float().detach().cpu().numpy()
            mask = ids['attention_mask'].detach().cpu().numpy()
            for j in range(len(batch_seqs)):
                seq_len = int(mask[j].sum())
                valid_emb = emb[j, :seq_len]
                embeddings_list.append(valid_emb.mean(axis=0))

            del ids, output, emb, mask
            if device.type == 'cuda':
                torch.cuda.empty_cache()

        sorted_embeddings = np.vstack(embeddings_list)
        original_order_embeddings = np.zeros_like(sorted_embeddings)
        original_order_embeddings[sort_idx] = sorted_embeddings
        return original_order_embeddings

    print("Loading sequences for ESM2 embedding...")
    train_df = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'train_seq.feather')
    test_df = None
    if (ARTEFACTS_DIR / 'parsed' / 'test_seq.feather').exists():
        test_df = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')

    tokenizer, model = get_esm2_model()
    print(f"Generating Train Embeddings ESM2 ({len(train_df)})...")
    train_emb = generate_embeddings_esm2(model, tokenizer, train_df['sequence'].tolist())
    np.save(ARTEFACTS_DIR / 'features' / 'train_embeds_esm2.npy', train_emb)
    del train_emb
    gc.collect()

    if test_df is not None:
        print(f"Generating Test Embeddings ESM2 ({len(test_df)})...")
        test_emb = generate_embeddings_esm2(model, tokenizer, test_df['sequence'].tolist())
        np.save(ARTEFACTS_DIR / 'features' / 'test_embeds_esm2.npy', test_emb)
        del test_emb

    del model, tokenizer, train_df, test_df
    gc.collect()
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    print("ESM2 embeddings generated.")
else:
    print("Skipping ESM2 embedding generation (COMPUTE_ESM2=False).")

## Multimodal inputs (auditor note)
This notebook is intentionally **Kaggle-session safe**: only **T5** and **ESM2-650M** embedding generation is included as runnable code.

However, **Phase 2 is already wired** to load and use additional modalities *if you provide them as `.npy` artefacts* in `artefacts/.../features/`.
This matches the Rank-1 “extreme multimodal” pattern: separate feature extraction (often done offline) + unified training/stacking.

### Optional embedding artefacts (drop-in)
Place these files under `ARTEFACTS_DIR / 'features'` (same row-order as `parsed/train_seq.feather` / `parsed/test_seq.feather`):
- `train_embeds_esm2_3b.npy` + `test_embeds_esm2_3b.npy` (expected ~2560D)
- `train_embeds_ankh.npy` + `test_embeds_ankh.npy` (expected ~1536D)
- `train_embeds_text.npy` + `test_embeds_text.npy` (expected ~10279D)

### Taxa feature
Taxonomy is parsed in Phase 1 to:
- `parsed/train_taxa.feather`
- `parsed/test_taxa.feather`
Phase 2 one-hot encodes this and concatenates it as a strong contextual prior.

### Where they are consumed
- **Cell 9 (Phase 2)**: auto-loads any of the above that exist and trains LR/GBDT on a flat concatenation; the DNN uses a multi-tower architecture over the separate modalities.
- **Cell 11 (Phase 3b)**: stacker consumes Level-1 OOF/test predictions and trains 3 specialised GCNs (BP/MF/CC).

In [None]:
# 4. PHASE 2: LEVEL-1 MODELS (DIVERSE ENSEMBLE)
# =============================================
# HARDWARE: GPU (32GB+ recommended for full run)
# =============================================

# We train a diverse set of models:
# 1. Logistic Regression (Baseline)
# 2. Py-Boost (GBDT) - Requires 'py-boost' package
# 3. DNN Ensemble (Deep Learning)

TRAIN_LEVEL1 = True

if TRAIN_LEVEL1:
    import joblib
    from sklearn.model_selection import KFold
    from sklearn.metrics import f1_score

    # -----------------------------
    # Load targets + ids
    # -----------------------------
    print("Loading targets...")
    train_terms = pd.read_parquet(ARTEFACTS_DIR / 'parsed' / 'train_terms.parquet')
    train_ids = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'train_seq.feather')['id'].astype(str)
    test_ids = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')['id'].astype(str)

    # Target Matrix Construction (Top K Terms)
    TOP_K = 1500
    top_terms = train_terms['term'].value_counts().head(TOP_K).index.tolist()

    train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
    Y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
    Y_df = Y_df.reindex(train_ids, fill_value=0)
    Y = Y_df.values.astype(np.float32)

    print(f"Targets: Y={Y.shape}")

    # -----------------------------
    # Feature loading (multimodal)
    # -----------------------------
    print("Loading multimodal features...")
    FEAT_DIR = ARTEFACTS_DIR / 'features'

    def _load_pair(stem):
        tr = FEAT_DIR / f'train_embeds_{stem}.npy'
        te = FEAT_DIR / f'test_embeds_{stem}.npy'
        if tr.exists() and te.exists():
            return np.load(tr).astype(np.float32), np.load(te).astype(np.float32)
        return None, None

    features_train = {}
    features_test = {}

    # Sequence embeddings (present today: T5 + ESM2-650M; optional: ESM2-3B / Ankh / Text)
    for stem, key in [
        ('t5', 't5'),
        ('esm2', 'esm2_650m'),
        ('esm2_3b', 'esm2_3b'),
        ('ankh', 'ankh'),
        ('text', 'text'),
    ]:
        a_tr, a_te = _load_pair(stem)
        if a_tr is not None:
            features_train[key] = a_tr
            features_test[key] = a_te

    # Taxonomy (encode as one-hot / bag-of-taxa)
    taxa_train_path = ARTEFACTS_DIR / 'parsed' / 'train_taxa.feather'
    taxa_test_path = ARTEFACTS_DIR / 'parsed' / 'test_taxa.feather'
    if taxa_train_path.exists() and taxa_test_path.exists():
        from sklearn.preprocessing import OneHotEncoder
        tax_tr = pd.read_feather(taxa_train_path).astype({'id': str})
        tax_te = pd.read_feather(taxa_test_path).astype({'id': str})

        tax_tr = tax_tr.set_index('id').reindex(train_ids, fill_value=0).reset_index()
        tax_te = tax_te.set_index('id').reindex(test_ids, fill_value=0).reset_index()

        enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False, dtype=np.float32)
        enc.fit(pd.concat([tax_tr[['taxon_id']], tax_te[['taxon_id']]], axis=0))
        X_tax_tr = enc.transform(tax_tr[['taxon_id']]).astype(np.float32)
        X_tax_te = enc.transform(tax_te[['taxon_id']]).astype(np.float32)

        features_train['taxa'] = X_tax_tr
        features_test['taxa'] = X_tax_te
        print(f"Taxa features: train={X_tax_tr.shape} test={X_tax_te.shape}")
    else:
        print("Taxa features not found; skipping (expected parsed/train_taxa.feather + parsed/test_taxa.feather).")

    # Sanity checks
    n_train = len(train_ids)
    n_test = len(test_ids)
    for k, v in features_train.items():
        if v.shape[0] != n_train:
            raise ValueError(f"Feature {k} train rows mismatch: {v.shape[0]} vs {n_train}")
    for k, v in features_test.items():
        if v.shape[0] != n_test:
            raise ValueError(f"Feature {k} test rows mismatch: {v.shape[0]} vs {n_test}")

    if 't5' not in features_train:
        raise FileNotFoundError("Missing required T5 embeddings: features/train_embeds_t5.npy and features/test_embeds_t5.npy")

    # Flat concatenation for classical models (LR/GBDT)
    FLAT_KEYS = [k for k in ['t5', 'esm2_650m', 'esm2_3b', 'ankh', 'taxa', 'text'] if k in features_train]
    X = np.hstack([features_train[k] for k in FLAT_KEYS]).astype(np.float32)
    X_test = np.hstack([features_test[k] for k in FLAT_KEYS]).astype(np.float32)
    print(f"Flat X keys={FLAT_KEYS}")
    print(f"Flat shapes: X={X.shape}, X_test={X_test.shape}")

    # -----------------------------
    # CAFA-like IA-weighted diagnostic F1 (vectorised)
    # -----------------------------
    if 'ia' in locals():
        ia_df = ia[['term', 'ia']].copy()
    elif PATH_IA.exists():
        ia_df = pd.read_csv(PATH_IA, sep='\t', names=['term', 'ia'])
    else:
        ia_df = pd.DataFrame({'term': [], 'ia': []})

    ia_map = dict(zip(ia_df['term'], ia_df['ia']))

    def _ia_weight(term):
        v = ia_map.get(term, 0.0)
        if pd.isna(v):
            return 0.0
        return float(v)

    weights = np.array([_ia_weight(t) for t in top_terms], dtype=np.float32)

    ns_to_aspect = {
        'molecular_function': 'MF',
        'biological_process': 'BP',
        'cellular_component': 'CC',
    }
    if 'go_namespaces' in locals():
        term_aspects = np.array([ns_to_aspect.get(go_namespaces.get(t, ''), 'UNK') for t in top_terms])
    else:
        term_aspects = np.array(['UNK'] * len(top_terms))

    def ia_weighted_f1(y_true, y_score, thr=0.3):
        y_true = (y_true > 0).astype(np.int8)
        y_pred = (y_score >= thr).astype(np.int8)

        tp = (y_pred & y_true).sum(axis=0).astype(np.float64)
        pred = y_pred.sum(axis=0).astype(np.float64)
        true = y_true.sum(axis=0).astype(np.float64)

        def _score(mask=None):
            w = weights if mask is None else (weights * mask)
            w_tp = float((w * tp).sum())
            w_pred = float((w * pred).sum())
            w_true = float((w * true).sum())
            p = (w_tp / w_pred) if w_pred > 0 else 0.0
            r = (w_tp / w_true) if w_true > 0 else 0.0
            return (2 * p * r / (p + r)) if (p + r) > 0 else 0.0

        out = {'ALL': _score(None)}
        for asp in ['MF', 'BP', 'CC']:
            mask = (term_aspects == asp).astype(np.float32)
            out[asp] = _score(mask)
        return out

    # ------------------------------------------
    # A. Logistic Regression (Baseline)
    # ------------------------------------------
    print("\n--- Training Logistic Regression ---")
    from sklearn.linear_model import LogisticRegression
    from sklearn.multiclass import OneVsRestClassifier

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    oof_preds_logreg = np.zeros(Y.shape, dtype=np.float32)
    test_preds_logreg = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)

    for fold, (idx_tr, idx_val) in enumerate(kf.split(X)):
        print(f"LogReg Fold {fold+1}/5")
        X_tr, X_val = X[idx_tr], X[idx_val]
        Y_tr, Y_val = Y[idx_tr], Y[idx_val]

        clf_logreg = OneVsRestClassifier(
            LogisticRegression(max_iter=500, solver='sag', n_jobs=1, C=1.0)
        )
        clf_logreg.n_jobs = -1
        clf_logreg.fit(X_tr, Y_tr)

        val_probs = clf_logreg.predict_proba(X_val)
        oof_preds_logreg[idx_val] = val_probs

        test_preds_logreg += clf_logreg.predict_proba(X_test) / kf.get_n_splits()

        val_preds = (val_probs > 0.3).astype(int)
        f1 = f1_score(Y_val, val_preds, average='micro')
        ia_f1 = ia_weighted_f1(Y_val, val_probs, thr=0.3)
        print(f"  >> Fold {fold+1} micro-F1@0.30: {f1:.4f}")
        print(f"  >> Fold {fold+1} IA-F1@0.30: ALL={ia_f1['ALL']:.4f} MF={ia_f1['MF']:.4f} BP={ia_f1['BP']:.4f} CC={ia_f1['CC']:.4f}")

        joblib.dump(clf_logreg, ARTEFACTS_DIR / 'features' / f'level1_logreg_fold{fold}.pkl')

    np.save(ARTEFACTS_DIR / 'features' / 'oof_logreg.npy', oof_preds_logreg)
    np.save(ARTEFACTS_DIR / 'features' / 'test_pred_logreg.npy', test_preds_logreg)
    print("LogReg OOF + test preds saved.")

    # ------------------------------------------
    # B. Py-Boost (GBDT)
    # ------------------------------------------
    try:
        from py_boost import GradientBoosting
        HAS_PYBOOST = True
    except ImportError:
        print("\n[WARNING] Py-Boost not installed. Skipping GBDT.")
        HAS_PYBOOST = False

    if HAS_PYBOOST:
        print("\n--- Training Py-Boost GBDT ---")
        oof_preds_gbdt = np.zeros(Y.shape, dtype=np.float32)
        test_preds_gbdt = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)

        for fold, (idx_tr, idx_val) in enumerate(kf.split(X)):
            print(f"GBDT Fold {fold+1}/5")
            X_tr, X_val = X[idx_tr], X[idx_val]
            Y_tr, Y_val = Y[idx_tr], Y[idx_val]

            model = GradientBoosting(
                loss='bce',
                ntrees=1000,
                lr=0.05,
                max_depth=6,
                verbose=100,
                es=50,
            )

            model.fit(X_tr, Y_tr, eval_sets=[{'X': X_val, 'y': Y_val}])

            val_probs = model.predict(X_val)
            oof_preds_gbdt[idx_val] = val_probs

            test_preds_gbdt += model.predict(X_test) / kf.get_n_splits()

            val_preds = (val_probs > 0.3).astype(int)
            f1 = f1_score(Y_val, val_preds, average='micro')
            ia_f1 = ia_weighted_f1(Y_val, val_probs, thr=0.3)
            print(f"  >> Fold {fold+1} micro-F1@0.30: {f1:.4f}")
            print(f"  >> Fold {fold+1} IA-F1@0.30: ALL={ia_f1['ALL']:.4f} MF={ia_f1['MF']:.4f} BP={ia_f1['BP']:.4f} CC={ia_f1['CC']:.4f}")

            model.save(str(ARTEFACTS_DIR / 'features' / f'level1_gbdt_fold{fold}.json'))

        np.save(ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy', oof_preds_gbdt)
        np.save(ARTEFACTS_DIR / 'features' / 'test_pred_gbdt.npy', test_preds_gbdt)
        print("GBDT OOF + test preds saved.")

    # ------------------------------------------
    # C. DNN Ensemble (PyTorch, IA-weighted, multi-input + multi-state)
    # ------------------------------------------
    print("\n--- Training DNN Ensemble (IA-weighted, multimodal, multi-state) ---")
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Build a stable per-label IA weight vector for the current TOP_K targets
    ia_w = weights.copy()
    ia_w = np.where(np.isfinite(ia_w) & (ia_w > 0), ia_w, 1.0).astype(np.float32)
    ia_w = ia_w / float(np.mean(ia_w))
    ia_w = np.clip(ia_w, 0.5, 5.0)
    ia_w_t = torch.tensor(ia_w, dtype=torch.float32, device=device).view(1, -1)

    # Optional: include other model predictions as an input stream (PB OOFs analogue)
    USE_BASE_OOFS_IN_DNN = True
    if USE_BASE_OOFS_IN_DNN and (ARTEFACTS_DIR / 'features' / 'oof_logreg.npy').exists():
        oof_stream = [np.load(ARTEFACTS_DIR / 'features' / 'oof_logreg.npy').astype(np.float32)]
        test_stream = [np.load(ARTEFACTS_DIR / 'features' / 'test_pred_logreg.npy').astype(np.float32)]
        if (ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy').exists():
            oof_stream.append(np.load(ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy').astype(np.float32))
            test_stream.append(np.load(ARTEFACTS_DIR / 'features' / 'test_pred_gbdt.npy').astype(np.float32))
        base_oof = np.hstack(oof_stream)
        base_test = np.hstack(test_stream)
        features_train['base_oof'] = base_oof
        features_test['base_oof'] = base_test
        print(f"Base OOF stream: train={base_oof.shape} test={base_test.shape}")

    # Select modality keys for the DNN (towers)
    DNN_KEYS = [k for k in ['t5', 'esm2_650m', 'esm2_3b', 'ankh', 'taxa', 'text', 'base_oof'] if k in features_train]
    print(f"DNN modality keys={DNN_KEYS}")

    class Tower(nn.Module):
        def __init__(self, in_dim, out_dim=512, dropout=0.1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, 1024),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(1024, out_dim),
                nn.ReLU(),
            )

        def forward(self, x):
            return self.net(x)

    class ColossalMultiModalDNN(nn.Module):
        def __init__(self, dims: dict, output_dim: int):
            super().__init__()
            self.keys = list(dims.keys())
            self.towers = nn.ModuleDict({k: Tower(dims[k]) for k in self.keys})
            fused_dim = 512 * len(self.keys)
            self.head = nn.Sequential(
                nn.Linear(fused_dim, 2048),
                nn.BatchNorm1d(2048),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(2048, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(1024, output_dim),
            )

        def forward(self, batch: dict):
            hs = [self.towers[k](batch[k]) for k in self.keys]
            h = torch.cat(hs, dim=1)
            return self.head(h)

    # Prepare torch tensors per modality
    train_t = {k: torch.tensor(features_train[k], dtype=torch.float32, device=device) for k in DNN_KEYS}
    test_t = {k: torch.tensor(features_test[k], dtype=torch.float32, device=device) for k in DNN_KEYS}

    def _batch_dict(tensors: dict, idx):
        return {k: v[idx] for k, v in tensors.items()}

    # Multi-state ensembling
    DNN_SEEDS = [42, 43, 44, 45, 46]
    DNN_EPOCHS = 10
    BATCH_SIZE = 256

    oof_sum = np.zeros(Y.shape, dtype=np.float32)
    test_sum = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)
    n_states = len(DNN_SEEDS)

    for state_i, seed in enumerate(DNN_SEEDS, 1):
        print(f"\n[DNN] Random state {state_i}/{n_states}: seed={seed}")
        torch.manual_seed(seed)
        np.random.seed(seed)

        kf_state = KFold(n_splits=5, shuffle=True, random_state=seed)
        oof_state = np.zeros(Y.shape, dtype=np.float32)
        test_state = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)

        dims = {k: int(features_train[k].shape[1]) for k in DNN_KEYS}

        for fold, (idx_tr, idx_val) in enumerate(kf_state.split(train_ids)):
            print(f"DNN Fold {fold+1}/5")
            model = ColossalMultiModalDNN(dims=dims, output_dim=Y.shape[1]).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

            Y_full_t = torch.tensor(Y, dtype=torch.float32, device=device)
            n_samples = len(idx_tr)

            model.train()
            idx_tr_t = torch.tensor(idx_tr, dtype=torch.long, device=device)
            for _epoch in range(DNN_EPOCHS):
                perm = idx_tr_t[torch.randperm(n_samples, device=device)]
                for i in range(0, n_samples, BATCH_SIZE):
                    b = perm[i:i + BATCH_SIZE]
                    optimizer.zero_grad()
                    logits = model(_batch_dict(train_t, b))
                    yb = Y_full_t[b]
                    loss_el = F.binary_cross_entropy_with_logits(logits, yb, reduction='none')
                    loss = (loss_el * ia_w_t).mean()
                    loss.backward()
                    optimizer.step()

            model.eval()
            with torch.no_grad():
                idx_val_t = torch.tensor(idx_val, dtype=torch.long, device=device)
                val_probs = torch.sigmoid(model(_batch_dict(train_t, idx_val_t))).cpu().numpy()
                oof_state[idx_val] = val_probs

                # test prediction (average over folds)
                test_probs = torch.sigmoid(model(test_t)).cpu().numpy()
                test_state += test_probs / kf_state.get_n_splits()

            val_preds = (val_probs > 0.3).astype(int)
            f1 = f1_score(Y[idx_val], val_preds, average='micro')
            ia_f1 = ia_weighted_f1(Y[idx_val], val_probs, thr=0.3)
            print(f"  >> Fold {fold+1} micro-F1@0.30: {f1:.4f}")
            print(f"  >> Fold {fold+1} IA-F1@0.30: ALL={ia_f1['ALL']:.4f} MF={ia_f1['MF']:.4f} BP={ia_f1['BP']:.4f} CC={ia_f1['CC']:.4f}")

            torch.save(model.state_dict(), ARTEFACTS_DIR / 'features' / f'level1_dnn_seed{seed}_fold{fold}.pth')

        oof_sum += oof_state
        test_sum += test_state

    oof_preds_dnn = (oof_sum / n_states).astype(np.float32)
    test_preds_dnn = (test_sum / n_states).astype(np.float32)

    np.save(ARTEFACTS_DIR / 'features' / 'oof_dnn.npy', oof_preds_dnn)
    np.save(ARTEFACTS_DIR / 'features' / 'test_pred_dnn.npy', test_preds_dnn)
    print("DNN OOF + test preds saved (multi-state averaged).")

    # Diagnostic: IA-F1 vs threshold curve (OOF)
    thrs = np.linspace(0.05, 0.60, 23)
    curves = {k: [] for k in ['ALL', 'MF', 'BP', 'CC']}
    for thr in thrs:
        s = ia_weighted_f1(Y, oof_preds_dnn, thr=float(thr))
        for k in curves.keys():
            curves[k].append(s[k])

    plt.figure(figsize=(10, 3))
    for k in ['ALL', 'MF', 'BP', 'CC']:
        plt.plot(thrs, curves[k], label=k)
    plt.title('DNN OOF: IA-weighted F1 vs threshold (multi-state)')
    plt.xlabel('threshold')
    plt.ylabel('IA-F1')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    print("Phase 2 Complete. OOF + test predictions generated.")
else:
    print("Skipping Phase 2.")

In [None]:
# 5. PHASE 3: HIERARCHY-AWARE STACKING (GCN)
# ==========================================
# NOTE: This cell is kept as a lightweight alternative stacker.
# The main stacker used by Phase 4 is implemented in the next cell.

TRAIN_STACKER = False

if TRAIN_STACKER:
    import torch
    import torch.nn as nn
    from sklearn.metrics import f1_score

    print("Loading OOF Predictions for stacking...")
    oof_logreg = np.load(ARTEFACTS_DIR / 'features' / 'oof_logreg.npy')

    try:
        oof_gbdt = np.load(ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy')
    except Exception:
        oof_gbdt = np.zeros_like(oof_logreg)

    try:
        oof_dnn = np.load(ARTEFACTS_DIR / 'features' / 'oof_dnn.npy')
    except Exception:
        oof_dnn = np.zeros_like(oof_logreg)

    X_stack = np.hstack([oof_logreg, oof_gbdt, oof_dnn])
    Y_stack = Y

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X_stack_t = torch.tensor(X_stack, dtype=torch.float32).to(device)
    Y_stack_t = torch.tensor(Y_stack, dtype=torch.float32).to(device)

    class Stacker(nn.Module):
        def __init__(self, input_dim, output_dim):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, 2048),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(2048, output_dim),
                nn.Sigmoid(),
            )

        def forward(self, x):
            return self.net(x)

    stacker = Stacker(X_stack.shape[1], Y_stack.shape[1]).to(device)
    optimizer = torch.optim.Adam(stacker.parameters(), lr=1e-3)
    criterion = nn.BCELoss()

    stacker.train()
    for epoch in range(20):
        optimizer.zero_grad()
        out = stacker(X_stack_t)
        loss = criterion(out, Y_stack_t)
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0:
            with torch.no_grad():
                preds = (out > 0.3).float()
                f1 = f1_score(Y_stack_t.cpu().numpy()[:1000], preds.cpu().numpy()[:1000], average='micro')
            print(f"Epoch {epoch}: Loss {loss.item():.4f}, Approx micro-F1@0.30 {f1:.4f}")

    torch.save(stacker.state_dict(), ARTEFACTS_DIR / 'features' / 'final_stacker.pth')
    print("Saved: final_stacker.pth")

    stacker.eval()
    with torch.no_grad():
        final_preds = stacker(X_stack_t).cpu().numpy()
        final_f1 = f1_score(Y_stack, (final_preds > 0.3).astype(int), average='micro')
    print(f"Final Stacker micro-F1@0.30: {final_f1:.4f}")
else:
    print("Skipping Phase 3 (lightweight stacker).")

In [None]:
# 5b. PHASE 3: HIERARCHY-AWARE STACKING (GRAPH SMOOTHING GCN)
    # ===========================================================
    # This is the stacker used by Phase 4.
    # - Train on Level-1 OOF predictions (features)
    # - Infer on Level-1 test predictions
    # - Trains 3 specialised models: BP/MF/CC (Rank-1 style)
    # - Save `gcn_stacker_{bp|mf|cc}.pth`, `top_terms_1500.json`, `test_pred_gcn.npy`

TRAIN_GCN = True

if TRAIN_GCN:
    import json
    import torch
    import torch.nn as nn
    from sklearn.metrics import f1_score

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1. Load OOF Predictions (Features for GCN)
    print("Loading OOF predictions...")
    features_list = []

    if (ARTEFACTS_DIR / 'features' / 'oof_logreg.npy').exists():
        features_list.append(np.load(ARTEFACTS_DIR / 'features' / 'oof_logreg.npy'))
    if (ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy').exists():
        features_list.append(np.load(ARTEFACTS_DIR / 'features' / 'oof_gbdt.npy'))
    if (ARTEFACTS_DIR / 'features' / 'oof_dnn.npy').exists():
        features_list.append(np.load(ARTEFACTS_DIR / 'features' / 'oof_dnn.npy'))

    if not features_list:
        raise FileNotFoundError("No OOF predictions found. Run Phase 2.")

    X_stack = np.mean(features_list, axis=0).astype(np.float32)  # (N_train, TOP_K)
    print(f"Train stack shape: {X_stack.shape}")

    # 2. Rebuild term list + targets
    train_terms = pd.read_parquet(ARTEFACTS_DIR / 'parsed' / 'train_terms.parquet')
    train_ids = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'train_seq.feather')['id'].astype(str)

    TOP_K = X_stack.shape[1]
    top_terms = train_terms['term'].value_counts().head(TOP_K).index.tolist()

    train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
    Y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
    Y_df = Y_df.reindex(train_ids, fill_value=0)
    Y = Y_df.values.astype(np.float32)

    # 2b. External priors (Phase 1 Step 4 outputs) -> inject as *conservative* extra signal
    EXTERNAL_PRIOR_WEIGHT = 0.25
    ext_dir = ARTEFACTS_DIR / 'external'
    prior_train_path = ext_dir / 'prop_train_no_kaggle.tsv.gz'
    if PROCESS_EXTERNAL and prior_train_path.exists():
        prior_train = pd.read_csv(prior_train_path, sep='\t')
        prior_train = prior_train[prior_train['term'].isin(top_terms)]
        prior_mat = prior_train.pivot_table(index='EntryID', columns='term', values='score', aggfunc='max', fill_value=0.0)
        prior_mat = prior_mat.reindex(train_ids.tolist(), fill_value=0.0)
        prior_mat = prior_mat.reindex(columns=top_terms, fill_value=0.0)
        prior_np = prior_mat.values.astype(np.float32)
        X_stack = np.maximum(X_stack, EXTERNAL_PRIOR_WEIGHT * prior_np)
        print(f"Injected external IEA prior into train stack (weight={EXTERNAL_PRIOR_WEIGHT}).")
    else:
        print("No external train prior found (or PROCESS_EXTERNAL=False); training without GOA priors.")

    (ARTEFACTS_DIR / 'features').mkdir(parents=True, exist_ok=True)
    with open(ARTEFACTS_DIR / 'features' / 'top_terms_1500.json', 'w') as f:
        json.dump(top_terms, f)
    print("Saved: top_terms_1500.json")

    # 3. Graph adjacency from go-basic.obo parse (from Phase 1)
    if 'go_parents' not in locals():
        raise RuntimeError("go_parents not found. Run Phase 1 OBO parsing cell first.")

    def build_adjacency(terms_list, parents_dict):
        term_to_idx = {t: i for i, t in enumerate(terms_list)}
        n_terms = len(terms_list)
        src, dst = [], []

        for child in terms_list:
            parents = parents_dict.get(child, set())
            if not parents:
                continue
            child_idx = term_to_idx[child]
            for parent in parents:
                if parent in term_to_idx:
                    parent_idx = term_to_idx[parent]
                    src.append(child_idx); dst.append(parent_idx)
                    src.append(parent_idx); dst.append(child_idx)

        # self-loops
        src.extend(range(n_terms))
        dst.extend(range(n_terms))

        indices = torch.tensor([src, dst], dtype=torch.long)
        values = torch.ones(len(src), dtype=torch.float32)
        return torch.sparse_coo_tensor(indices, values, (n_terms, n_terms)).coalesce().to(device)

    class SimpleGCN(nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim, adj_matrix):
            super().__init__()
            self.adj = adj_matrix
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, output_dim)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(0.3)

        def forward(self, x):
            x = self.relu(self.fc1(x))
            x = self.dropout(x)
            x = self.fc2(x)
            x = torch.sparse.mm(self.adj, x.t()).t()
            return torch.sigmoid(x)

    # 4. Build test stack once (and inject external priors once), then split by ontology
    print("\nPreparing test stack...")
    test_feats = []
    for fname in ['test_pred_logreg.npy', 'test_pred_gbdt.npy', 'test_pred_dnn.npy']:
        p = ARTEFACTS_DIR / 'features' / fname
        if p.exists():
            test_feats.append(np.load(p))

    if not test_feats:
        raise FileNotFoundError("No Level-1 test predictions found. Run Phase 2 first.")

    X_test_stack = np.mean(test_feats, axis=0).astype(np.float32)

    prior_test_path = ext_dir / 'prop_test_no_kaggle.tsv.gz'
    if PROCESS_EXTERNAL and prior_test_path.exists():
        test_ids = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')['id'].astype(str)
        prior_test = pd.read_csv(prior_test_path, sep='\t')
        prior_test = prior_test[prior_test['term'].isin(top_terms)]
        prior_t = prior_test.pivot_table(index='EntryID', columns='term', values='score', aggfunc='max', fill_value=0.0)
        prior_t = prior_t.reindex(test_ids.tolist(), fill_value=0.0)
        prior_t = prior_t.reindex(columns=top_terms, fill_value=0.0)
        prior_test_np = prior_t.values.astype(np.float32)
        X_test_stack = np.maximum(X_test_stack, EXTERNAL_PRIOR_WEIGHT * prior_test_np)
        print(f"Injected external IEA prior into test stack (weight={EXTERNAL_PRIOR_WEIGHT}).")
    else:
        print("No external test prior found (or PROCESS_EXTERNAL=False); inferring without GOA priors.")

    # 5. Ontology split (BP/MF/CC)
    if 'go_namespaces' not in locals():
        raise RuntimeError("go_namespaces not found. Run Phase 1 OBO parsing cell first.")

    ns_to_aspect = {
        'molecular_function': 'MF',
        'biological_process': 'BP',
        'cellular_component': 'CC',
    }
    aspects = []
    for t in top_terms:
        asp = ns_to_aspect.get(go_namespaces.get(t, ''), 'BP')  # default BP to keep full coverage
        aspects.append(asp)
    aspects = np.array(aspects)

    aspect_to_idx = {
        'BP': np.where(aspects == 'BP')[0].tolist(),
        'MF': np.where(aspects == 'MF')[0].tolist(),
        'CC': np.where(aspects == 'CC')[0].tolist(),
    }
    for k in ['BP', 'MF', 'CC']:
        print(f"Terms[{k}]={len(aspect_to_idx[k])}")

    # 6. Train 3 specialised GCNs and stitch outputs back
    test_pred_gcn = np.zeros_like(X_test_stack, dtype=np.float32)
    X_tensor_full = torch.tensor(X_stack, dtype=torch.float32, device=device)
    Y_tensor_full = torch.tensor(Y, dtype=torch.float32, device=device)
    X_test_full = torch.tensor(X_test_stack, dtype=torch.float32, device=device)

    def train_one(aspect_name: str, idx_cols: list[int]):
        if not idx_cols:
            print(f"[{aspect_name}] No terms; skipping.")
            return None
        terms_sub = [top_terms[i] for i in idx_cols]
        adj = build_adjacency(terms_sub, go_parents)
        model = SimpleGCN(input_dim=len(idx_cols), hidden_dim=1024, output_dim=len(idx_cols), adj_matrix=adj).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.BCELoss()

        X_t = X_tensor_full[:, idx_cols]
        Y_t = Y_tensor_full[:, idx_cols]
        n_samples = X_t.shape[0]
        BS = 256
        EPOCHS = 5

        model.train()
        print(f"\n=== Training GCN[{aspect_name}] terms={len(idx_cols)} ===")
        for epoch in range(EPOCHS):
            total_loss = 0.0
            perm = torch.randperm(n_samples, device=device)
            for i in range(0, n_samples, BS):
                b = perm[i:i + BS]
                optimizer.zero_grad()
                out = model(X_t[b])
                loss = criterion(out, Y_t[b])
                loss.backward()
                optimizer.step()
                total_loss += float(loss.item())

            with torch.no_grad():
                pred = (model(X_t[:2000]) > 0.3).float().cpu().numpy()
                f1 = f1_score(Y[:2000, idx_cols], pred, average='micro')
            denom = max(1.0, (n_samples / BS))
            print(f"Epoch {epoch+1}: loss={total_loss / denom:.4f}, micro-F1@0.30≈{f1:.4f}")

        torch.save(model.state_dict(), ARTEFACTS_DIR / 'features' / f'gcn_stacker_{aspect_name.lower()}.pth')
        print(f"Saved: gcn_stacker_{aspect_name.lower()}.pth")
        return model

    models = {}
    for asp in ['BP', 'MF', 'CC']:
        models[asp] = train_one(asp, aspect_to_idx[asp])

    print("\nInferring test predictions with specialised GCNs...")
    for asp in ['BP', 'MF', 'CC']:
        idx_cols = aspect_to_idx[asp]
        if not idx_cols:
            continue
        model = models[asp]
        model.eval()
        X_te = X_test_full[:, idx_cols]
        preds = []
        BS = 2048
        with torch.no_grad():
            for i in range(0, X_te.shape[0], BS):
                preds.append(model(X_te[i:i + BS]).cpu().numpy())
        pred_sub = np.vstack(preds).astype(np.float32)
        test_pred_gcn[:, idx_cols] = pred_sub

    np.save(ARTEFACTS_DIR / 'features' / 'test_pred_gcn.npy', test_pred_gcn)
    print("Saved: test_pred_gcn.npy")
else:
    print("Skipping GCN training.")

In [None]:
# 6. PHASE 4: POST-PROCESSING & SUBMISSION
# ========================================
# HARDWARE: CPU / GPU
# ========================================

# This phase applies the "Strict Post-Processing" rules (Max/Min Propagation)
# and generates the final submission file.

import json
from pathlib import Path

import numpy as np
import pandas as pd

# Check if submission already exists
if (ARTEFACTS_DIR / 'submission.tsv').exists():
    print("submission.tsv already exists. Skipping Phase 4.")
else:
    print("Starting Phase 4: Post-processing & submission...")

    # Ensure go_parents is available (from Phase 1)
    if 'go_parents' not in locals() or 'go_namespaces' not in locals():
        print("Reloading GO graph (parse_obo)...")

        def parse_obo(path: Path):
            parents = {}
            namespaces = {}
            cur_id, cur_ns = None, None
            with path.open('r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line == '[Term]':
                        if cur_id and cur_ns:
                            namespaces[cur_id] = cur_ns
                        cur_id, cur_ns = None, None
                    elif line.startswith('id: GO:'):
                        cur_id = line.split('id: ', 1)[1]
                    elif line.startswith('namespace:'):
                        cur_ns = line.split('namespace: ', 1)[1]
                    elif line.startswith('is_a:') and cur_id:
                        parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                        parents.setdefault(cur_id, set()).add(parent)
                if cur_id and cur_ns:
                    namespaces[cur_id] = cur_ns
            return parents, namespaces

        go_parents, go_namespaces = parse_obo(PATH_GO_OBO)

    # Load test IDs
    test_ids = pd.read_feather(ARTEFACTS_DIR / 'parsed' / 'test_seq.feather')['id']

    # Load stacker predictions
    pred_path = ARTEFACTS_DIR / 'features' / 'test_pred_gcn.npy'
    if not pred_path.exists():
        raise FileNotFoundError("Missing `test_pred_gcn.npy`. Run Phase 3 (GCN stacker) first.")
    preds = np.load(pred_path)

    # Load term list (must match Phase 3)
    terms_path = ARTEFACTS_DIR / 'features' / 'top_terms_1500.json'
    if terms_path.exists():
        with open(terms_path, 'r') as f:
            top_terms = json.load(f)
    else:
        print("Warning: top_terms_1500.json missing; rebuilding from train_terms counts (may mismatch Phase 3).")
        train_terms = pd.read_parquet(ARTEFACTS_DIR / 'parsed' / 'train_terms.parquet')
        top_terms = train_terms['term'].value_counts().head(preds.shape[1]).index.tolist()

    if preds.shape[1] != len(top_terms):
        raise ValueError(f"Shape mismatch: preds has {preds.shape[1]} terms, top_terms has {len(top_terms)}.")

    # ------------------------------------------
    # Strict post-processing (Max/Min Propagation)
    # ------------------------------------------
    print(f"Applying hierarchy rules on {len(top_terms)} terms...")
    df_pred = pd.DataFrame(preds, columns=top_terms)

    term_set = set(top_terms)
    term_to_parents = {}
    term_to_children = {}

    for term in top_terms:
        parents = go_parents.get(term, set())
        if not parents:
            continue
        parents = parents.intersection(term_set)
        if not parents:
            continue
        term_to_parents[term] = list(parents)
        for p in parents:
            term_to_children.setdefault(p, []).append(term)

    # Max Propagation (Child -> Parent)
    for _ in range(2):
        for child, parents in term_to_parents.items():
            child_scores = df_pred[child].values
            for parent in parents:
                df_pred[parent] = np.maximum(df_pred[parent].values, child_scores)

    # Min Propagation (Parent -> Child)
    for _ in range(2):
        for parent, children in term_to_children.items():
            parent_scores = df_pred[parent].values
            for child in children:
                df_pred[child] = np.minimum(df_pred[child].values, parent_scores)

    # ------------------------------------------
    # Submission formatting (CAFA rules)
    # - tab-separated, no header
    # - score in (0, 1.000]
    # - up to 3 significant figures
    # - <= 1500 terms per target (MF/BP/CC combined)
    # ------------------------------------------
    df_pred['EntryID'] = test_ids.values
    submission = df_pred.melt(id_vars='EntryID', var_name='term', value_name='score')

    # Enforce score range + remove zeros
    submission['score'] = submission['score'].clip(lower=0.0, upper=1.0)
    submission = submission[submission['score'] > 0.0]

    # Light pruning (keeps file size sane; still rule-compliant)
    submission = submission[submission['score'] >= 0.001]

    # Keep top 1500 per protein (rule)
    submission = submission.sort_values(['EntryID', 'score'], ascending=[True, False])
    submission = submission.groupby('EntryID', sort=False).head(1500)

    # Write with <= 3 significant figures
    submission.to_csv(
        ARTEFACTS_DIR / 'submission.tsv',
        sep='\t',
        index=False,
        header=False,
        float_format='%.3g',
    )

    print(f"Done! Submission saved to {ARTEFACTS_DIR / 'submission.tsv'}")

In [None]:
# 7. PHASE 5: FREE TEXT PREDICTION (OPTIONAL)
# ==========================================
# HARDWARE: CPU
# ==========================================

# Official CAFA constraints (summary):
# - Combined file (GO + Text) allowed
# - Text: up to 5 lines per protein; ASCII printable; no tabs; <=3000 chars per protein total
# - Scores should be in (0, 1.000] and up to 3 significant figures

if (ARTEFACTS_DIR / 'submission_with_text.tsv').exists():
    print("submission_with_text.tsv already exists. Skipping Phase 5.")
elif not (ARTEFACTS_DIR / 'submission.tsv').exists():
    print("submission.tsv not found. Please run Phase 4 first.")
else:
    print("Starting Phase 5: Text Generation...")

    # 1. Load Submission & GO Graph
    print("Loading submission and GO data...")
    submission = pd.read_csv(
        ARTEFACTS_DIR / 'submission.tsv',
        sep='\t',
        header=None,
        names=['EntryID', 'term', 'score'],
    )

    if 'graph' not in locals():
        import obonet
        graph = obonet.read_obo(PATH_GO_OBO)

    # 2. Generate Text Descriptions
    print("Generating descriptions...")

    # Pre-fetch term names to avoid graph lookups in loop
    term_names = {node: data.get('name', 'unknown function') for node, data in graph.nodes(data=True)}

    text_rows = []
    unique_ids = submission['EntryID'].unique()

    for protein_id in tqdm(unique_ids, desc="Generating Text"):
        prot_preds = submission[submission['EntryID'] == protein_id]
        top_go = prot_preds.sort_values('score', ascending=False).head(3)

        if top_go.empty:
            # If we have no GO lines for this protein, skip text (keeps score>0 rule clean)
            continue

        term_descs = []
        for _, row in top_go.iterrows():
            term_id = row['term']
            term_descs.append(term_names.get(term_id, term_id))

        joined_terms = ", ".join(term_descs)
        description = f"{protein_id} is predicted to be involved in: {joined_terms}."

        # Ensure no tabs in description
        description = description.replace('\t', ' ')

        # Score: strictly > 0 and <= 1
        score = float(top_go.iloc[0]['score'])
        score = min(max(score, 0.001), 1.0)

        # One line per protein (<=5 allowed)
        text_rows.append({
            'EntryID': protein_id,
            'term': 'Text',
            'score': score,
            'description': description,
        })

    df_text = pd.DataFrame(text_rows)

    print("Saving combined submission...")

    with open(ARTEFACTS_DIR / 'submission_with_text.tsv', 'w', encoding='utf-8') as f:
        # 1) GO preds (3 cols) already CAFA-formatted in Phase 4
        submission.to_csv(f, sep='\t', index=False, header=False, float_format='%.3g')

        # 2) Text preds (4 cols)
        for _, row in df_text.iterrows():
            # Up to 3 significant figures
            score_str = format(float(row['score']), '.3g')
            f.write(f"{row['EntryID']}\tText\t{score_str}\t{row['description']}\n")

    print(f"Done! Combined submission saved to {ARTEFACTS_DIR / 'submission_with_text.tsv'}")