In [None]:
import os
import numpy as np
from collections import defaultdict, deque
from tqdm.auto import tqdm
import pickle
from pathlib import Path

# ==================== CONFIGURATION ====================
COMPETITION_DATA = '/kaggle/input/cafa-6-protein-function-prediction'
PREDICTION_DATA = '/kaggle/input/cafa6-goa-predictions'
GOA = '/kaggle/input/cafa-6-goa-prott5-ensemble-0-370-f2fcb6/submission.tsv'

CONFIG = {
    # Ensemble weights
    "WEIGHT_GOA": 0.68,
    "WEIGHT_PROTT5": 0.32,
    
    # ProtT5 penalties
    "PROTT5_ONLY_PENALTY_BASE": 0.78,
    "PROTT5_ONLY_PENALTY_SUPPORTED": 0.90,
    "PROTT5_ONLY_GAMMA": 1.15,
    "AGREE_BONUS": 0.03,
    
    # Propagation
    "UP_DECAY": 0.985,
    "UP_MIN_SCORE": 0.015,
    "CAP_CHILD_BY_PARENT": True,
    "POWER": 0.92,
    "DO_PER_PROTEIN_MAX_NORM": False,
    
    # Output
    "TOP_K": 250,
    "SCORE_THRESHOLD": 0.0005,
}

ROOTS = {"GO:0003674", "GO:0008150", "GO:0005575"}

# ==================== CACHED ONTOLOGY LOADER ====================
class OntologyCache:
    """Cache ontology to avoid re-parsing on multiple runs"""
    _cache_file = "ontology_cache.pkl"
    
    @classmethod
    def load_or_create(cls, obo_path):
        if Path(cls._cache_file).exists():
            print("Loading cached ontology...")
            with open(cls._cache_file, 'rb') as f:
                return pickle.load(f)
        
        print("Parsing Ontology (caching for future runs)...")
        term_parents = defaultdict(set)
        
        with open(obo_path, 'r') as f:
            lines = f.readlines()
        
        current_id = None
        for line in lines:
            line = line.strip()
            if line.startswith("id: "):
                current_id = line[4:].strip()
            elif line.startswith("is_a: ") and current_id:
                parent = line.split()[1].strip()
                term_parents[current_id].add(parent)
            elif line.startswith("relationship: part_of ") and current_id:
                parent = line.split()[2].strip()
                term_parents[current_id].add(parent)
        
        # Build ancestors with iterative BFS (faster than recursion)
        ancestors_map = {}
        for term in term_parents:
            if term in ancestors_map:
                continue
            stack = [term]
            ancestors = set()
            while stack:
                t = stack.pop()
                if t in ancestors_map:
                    ancestors.update(ancestors_map[t])
                    continue
                for p in term_parents.get(t, []):
                    ancestors.add(p)
                    stack.append(p)
            ancestors_map[term] = ancestors
        
        # Save cache
        with open(cls._cache_file, 'wb') as f:
            pickle.dump((dict(term_parents), ancestors_map), f)
        
        return dict(term_parents), ancestors_map

# ==================== OPTIMIZED DATA LOADER ====================
def load_predictions_optimized(filepath, use_numpy=True):
    """Optimized loader using numpy for numerical operations"""
    data = {}
    with open(filepath, 'r') as f:
        # Pre-allocate batches for faster processing
        batch_size = 100000
        batch = []
        
        for line in tqdm(f, desc=f"Loading {os.path.basename(filepath)}"):
            parts = line.rstrip("\n").split("\t")
            if len(parts) >= 3:
                batch.append((parts[0], parts[1], float(parts[2])))
            
            if len(batch) >= batch_size:
                for pid, go, score in batch:
                    if pid not in data:
                        data[pid] = {}
                    data[pid][go] = max(data[pid].get(go, 0.0), score)
                batch = []
        
        # Process remaining
        for pid, go, score in batch:
            if pid not in data:
                data[pid] = {}
            data[pid][go] = max(data[pid].get(go, 0.0), score)
    
    return data

# ==================== VECTORIZED ENSEMBLING ====================
class ProteinEnsembler:
    """Optimized ensembler with vectorized operations"""
    
    def __init__(self, config, term_parents, ancestors_map):
        self.config = config
        self.term_parents = term_parents
        self.ancestors_map = ancestors_map
    
    def ensemble_protein(self, pid, goa, pt):
        """Ensemble scores for a single protein"""
        if not goa and not pt:
            return {}
        
        merged = {}
        goa_term_set = set(goa.keys())
        
        # Vectorize GOA scores (already present)
        for term, s in goa.items():
            merged[term] = s
        
        # Process ProtT5 scores with numpy operations
        for term, s_pt in pt.items():
            s_goa = goa.get(term, 0.0)
            
            if s_goa > 0:
                # Agreement case
                s = (self.config["WEIGHT_GOA"] * s_goa + 
                     self.config["WEIGHT_PROTT5"] * s_pt)
                s = min(1.0, s + self.config["AGREE_BONUS"] * min(s_goa, s_pt))
            else:
                # ProtT5-only case
                s_shaped = s_pt ** self.config["PROTT5_ONLY_GAMMA"]
                
                # Check support using precomputed ancestors
                if any(a in goa_term_set for a in self.ancestors_map.get(term, set())):
                    pen = self.config["PROTT5_ONLY_PENALTY_SUPPORTED"]
                else:
                    pen = self.config["PROTT5_ONLY_PENALTY_BASE"]
                
                s = s_shaped * pen
            
            # Keep max score
            if s > merged.get(term, 0.0):
                merged[term] = s
        
        return merged
    
    def propagate_scores(self, scores):
        """Propagate and constrain scores"""
        if not scores:
            return {r: 1.0 for r in ROOTS}
        
        updated = dict(scores)
        
        # Fast upward propagation
        if updated:
            # Use priority queue for propagation (highest scores first)
            high_score_terms = [(s, t) for t, s in updated.items() 
                               if s >= self.config["UP_MIN_SCORE"] and t not in ROOTS]
            high_score_terms.sort(reverse=True)
            
            for s_term, term in high_score_terms:
                for parent in self.term_parents.get(term, []):
                    s_parent = s_term * self.config["UP_DECAY"]
                    if s_parent > updated.get(parent, 0.0):
                        updated[parent] = s_parent
        
        # Downward constraint (True Path Rule)
        if self.config["CAP_CHILD_BY_PARENT"] and updated:
            # Single efficient pass
            for term, score in list(updated.items()):
                if term in ROOTS:
                    continue
                
                parents = self.term_parents.get(term)
                if not parents:
                    continue
                
                parent_scores = [updated.get(p, 0.0) for p in parents if p in updated]
                if parent_scores:
                    best_parent = max(parent_scores)
                    if score > best_parent:
                        updated[term] = 0.5 * score + 0.5 * best_parent
        
        # Apply power scaling
        if self.config["POWER"] != 1.0:
            for term in list(updated.keys()):
                if term not in ROOTS:
                    updated[term] = updated[term] ** self.config["POWER"]
        
        # Add roots
        for r in ROOTS:
            updated[r] = 1.0
        
        return updated

# ==================== QUICK PREDICTION FUNCTION ====================
def make_quick_predictions(proteins_to_predict=None, limit=None):
    """Optimized main function with quick prediction capability"""
    
    # Load ontology
    obo_path = f"{COMPETITION_DATA}/Train/go-basic.obo"
    term_parents, ancestors_map = OntologyCache.load_or_create(obo_path)
    
    # Load predictions (only for requested proteins if specified)
    print("Loading predictions...")
    goa_preds = load_predictions_optimized(GOA)
    prott5_preds = load_predictions_optimized(f"{PREDICTION_DATA}/prott5_interpro_predictions.tsv")
    
    # Initialize ensembler
    ensembler = ProteinEnsembler(CONFIG, term_parents, ancestors_map)
    
    # Determine which proteins to process
    all_proteins = set(goa_preds.keys()) | set(prott5_preds.keys())
    if proteins_to_predict:
        proteins_to_process = [p for p in proteins_to_predict if p in all_proteins]
    else:
        proteins_to_process = list(all_proteins)
    
    if limit:
        proteins_to_process = proteins_to_process[:limit]
    
    print(f"Processing {len(proteins_to_process)} proteins...")
    
    # Process proteins
    final_rows = []
    for pid in tqdm(proteins_to_process, desc="Ensembling & Propagating"):
        goa = goa_preds.get(pid, {})
        pt = prott5_preds.get(pid, {})
        
        # Ensemble
        ensemble_scores = ensembler.ensemble_protein(pid, goa, pt)
        
        # Propagate
        propagated = ensembler.propagate_scores(ensemble_scores)
        
        # Filter and select top K
        items = [(t, s) for t, s in propagated.items() 
                if s >= CONFIG["SCORE_THRESHOLD"]]
        items.sort(key=lambda x: x[1], reverse=True)
        
        if len(items) > CONFIG["TOP_K"]:
            items = items[:CONFIG["TOP_K"]]
        
        # Format output
        for term, score in items:
            final_rows.append(f"{pid}\t{term}\t{score:.5f}")
    
    # Write output
    OUTPUT_FILE = "submission.tsv"
    with open(OUTPUT_FILE, "w") as f:
        f.write("\n".join(final_rows))
    
    print(f"File saved: {OUTPUT_FILE}")
    print(f"Total Predictions: {len(final_rows):,}")
    
    return final_rows

# ==================== UTILITY FOR BATCH PROCESSING ====================
def batch_predict(protein_batches, batch_size=1000):
    """Process proteins in batches for memory efficiency"""
    all_predictions = []
    
    for i in range(0, len(protein_batches), batch_size):
        batch = protein_batches[i:i + batch_size]
        print(f"Processing batch {i//batch_size + 1}/{(len(protein_batches)-1)//batch_size + 1}")
        predictions = make_quick_predictions(batch)
        all_predictions.extend(predictions)
    
    return all_predictions

# ==================== MAIN EXECUTION ====================
if __name__ == "__main__":
    # Quick single protein prediction
    # predictions = make_quick_predictions(["protein_id_1", "protein_id_2"])
    
    # Full prediction (same as original)
    make_quick_predictions()
    
    # For batch processing large datasets:
    # all_proteins = [...]  # Your list of proteins
    # batch_predict(all_proteins, batch_size=5000)

In [None]:
# import os
# import gc
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# from tqdm.auto import tqdm
# from multiprocessing import Pool, cpu_count
# import warnings
# warnings.filterwarnings('ignore')

# # ==========================================
# # 1. ULTRA-FAST OBO PARSER (Vectorized)
# # ==========================================
# def parse_obo_parents(go_obo_path):
#     """Optimized OBO parser using list operations"""
#     print(f"[1/5] Parsing OBO Ontology...")
#     term_parents = {}
#     roots = {'GO:0003674', 'GO:0008150', 'GO:0005575'}
    
#     with open(go_obo_path, "r") as f:
#         lines = f.readlines()
    
#     cur_id = None
#     parent_list = []
    
#     for line in lines:
#         line = line.strip()
#         if line == "[Term]":
#             if cur_id and parent_list:
#                 term_parents[cur_id] = set(parent_list)
#             cur_id = None
#             parent_list = []
#         elif line.startswith("id: "):
#             cur_id = line[4:].strip()
#         elif line.startswith("is_a: "):
#             parent_list.append(line.split()[1])
#         elif line.startswith("relationship: part_of "):
#             parts = line.split()
#             if len(parts) >= 3:
#                 parent_list.append(parts[2])
    
#     if cur_id and parent_list:
#         term_parents[cur_id] = set(parent_list)
    
#     return term_parents, roots

# def get_ancestors_map_fast(term_parents):
#     """Iterative BFS approach - much faster than recursion"""
#     print("[1/5] Building Ancestor Map (Fast BFS)...")
#     ancestors = {}
    
#     # Process in batches
#     all_terms = list(term_parents.keys())
    
#     for term in tqdm(all_terms, desc="Computing ancestors"):
#         if term in ancestors:
#             continue
            
#         visited = set()
#         queue = list(term_parents.get(term, set()))
        
#         while queue:
#             current = queue.pop(0)
#             if current in visited:
#                 continue
#             visited.add(current)
#             queue.extend(term_parents.get(current, set()))
        
#         ancestors[term] = visited
    
#     return ancestors

# # ==========================================
# # 2. VECTORIZED PROCESSING (NumPy Based)
# # ==========================================
# def process_predictions_vectorized(df, ancestors_map, roots):
#     """Ultra-fast vectorized processing using NumPy and dict operations"""
#     print("[3/5] Processing Predictions (Vectorized Approach)...")
    
#     # Convert to numpy for speed
#     proteins = df['protein_id'].values
#     terms = df['go_term'].values
#     scores = df['score'].values.astype(np.float32)
    
#     # Group by protein using dict
#     protein_data = defaultdict(lambda: defaultdict(float))
    
#     print("   → Grouping data...")
#     for i in tqdm(range(len(proteins)), desc="Grouping", disable=True):
#         p, t, s = proteins[i], terms[i], scores[i]
#         protein_data[p][t] = max(protein_data[p][t], s)
    
#     # Pre-allocate result lists
#     result_proteins = []
#     result_terms = []
#     result_scores = []
    
#     print("   → Propagating & normalizing...")
#     for pid, terms_dict in tqdm(protein_data.items(), desc="Processing"):
#         final_scores = {}
        
#         # Copy original scores
#         for t, s in terms_dict.items():
#             final_scores[t] = s
        
#         # PROPAGATION: Parent >= Child
#         for term, score in terms_dict.items():
#             if term in ancestors_map:
#                 for anc in ancestors_map[term]:
#                     final_scores[anc] = max(final_scores.get(anc, 0.0), score)
        
#         # FORCE ROOTS to 1.0
#         if final_scores:
#             for r in roots:
#                 final_scores[r] = 1.0
        
#         # NORMALIZATION: Boost max to 0.95
#         non_root_scores = [s for t, s in final_scores.items() if t not in roots]
#         if non_root_scores:
#             max_val = max(non_root_scores)
#             if 0 < max_val < 0.95:
#                 scale = 0.95 / max_val
#                 for t in final_scores:
#                     if t not in roots:
#                         final_scores[t] = min(1.0, final_scores[t] * scale)
        
#         # Collect results (filter low scores)
#         for go_term, score in final_scores.items():
#             if score >= 0.001:
#                 result_proteins.append(pid)
#                 result_terms.append(go_term)
#                 result_scores.append(score)
    
#     # Create DataFrame efficiently
#     print("   → Creating output DataFrame...")
#     return pd.DataFrame({
#         'protein_id': result_proteins,
#         'go_term': result_terms,
#         'score': result_scores
#     })

# # ==========================================
# # 3. CHUNKED PROCESSING FOR LARGE FILES
# # ==========================================
# def process_in_chunks(submission_path, ancestors_map, roots, chunk_size=2_000_000):
#     """Process large files in chunks to avoid memory issues"""
#     print(f"[2/5] Loading submission in chunks...")
    
#     chunks_processed = []
#     chunk_num = 0
    
#     for chunk in pd.read_csv(submission_path, sep='\t', header=None, 
#                               names=['protein_id', 'go_term', 'score', 'key'],
#                               usecols=['protein_id', 'go_term', 'score'],
#                               chunksize=chunk_size,
#                               dtype={'protein_id': str, 'go_term': str, 'score': np.float32}):
        
#         chunk_num += 1
#         print(f"   Processing chunk {chunk_num} ({len(chunk):,} rows)...")
        
#         processed = process_predictions_vectorized(chunk, ancestors_map, roots)
#         chunks_processed.append(processed)
        
#         del chunk
#         gc.collect()
    
#     print("   → Combining all chunks...")
#     return pd.concat(chunks_processed, ignore_index=True)

# # ==========================================
# # 4. MAIN PIPELINE (OPTIMIZED)
# # ==========================================
# def main():
#     # Paths
#     OBO_PATH = "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo"
#     SUBMISSION_INPUT = '/kaggle/input/goa-negative-propagation/submission.tsv'
#     SUBMISSION_OUTPUT = 'submission.tsv'
    
#     print("="*60)
#     print("  CAFA-6 OPTIMIZED SUBMISSION GENERATOR")
#     print("="*60)
    
#     # Step 1: Parse Ontology (Fast)
#     term_parents, roots = parse_obo_parents(OBO_PATH)
#     ancestors_map = get_ancestors_map_fast(term_parents)
#     print(f"   ✓ Loaded {len(term_parents):,} GO terms")
#     print(f"   ✓ Computed {len(ancestors_map):,} ancestor relationships")
    
#     gc.collect()
    
#     # Step 2 & 3: Process Predictions
#     try:
#         # Try loading full file first
#         print(f"[2/5] Attempting full file load...")
#         submission = pd.read_csv(SUBMISSION_INPUT, sep='\t', header=None,
#                                 names=['protein_id', 'go_term', 'score', 'key'],
#                                 usecols=['protein_id', 'go_term', 'score'],
#                                 dtype={'protein_id': str, 'go_term': str, 'score': np.float32})
        
#         print(f"   ✓ Loaded {len(submission):,} predictions")
#         final_df = process_predictions_vectorized(submission, ancestors_map, roots)
        
#         del submission
#         gc.collect()
        
#     except MemoryError:
#         print("   ⚠ Memory limit - switching to chunked processing...")
#         final_df = process_in_chunks(SUBMISSION_INPUT, ancestors_map, roots)
    
#     # Step 4: Remove duplicates (keep highest score)
#     print(f"[4/5] Deduplicating {len(final_df):,} rows...")
#     final_df = final_df.groupby(['protein_id', 'go_term'], as_index=False)['score'].max()
    
#     # Step 5: Sort & Save
#     print(f"[5/5] Sorting and saving {len(final_df):,} rows...")
#     final_df.sort_values(['protein_id', 'score'], ascending=[True, False], inplace=True)
#     final_df.to_csv(SUBMISSION_OUTPUT, sep='\t', index=False, header=False)
    
#     print("="*60)
#     print(f"✅ COMPLETE! Saved to: {SUBMISSION_OUTPUT}")
#     print(f"   Total predictions: {len(final_df):,}")
#     print(f"   Unique proteins: {final_df['protein_id'].nunique():,}")
#     print(f"   Unique GO terms: {final_df['go_term'].nunique():,}")
#     print("="*60)
#     print("\nFirst 10 rows:")
#     print(final_df.head(10))
    
#     return final_df

# # ==========================================
# # RUN
# # ==========================================
# if __name__ == "__main__":
#     final_df = main()