In [None]:
!pip install seqio
!pip install biopython plotly obonet networkx -q


In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import obonet
import networkx as nx
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print("All packages imported successfully!")


# Load all data files
def load_cafa6_data():
    """Load all CAFA 6 competition data files"""
    data = {}
    
    print("Loading training sequences...")
    train_sequences = {}
    for record in SeqIO.parse("/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta", "fasta"):
        protein_id = record.id.split('|')[1]
        train_sequences[protein_id] = str(record.seq)
    data['train_sequences'] = train_sequences
    
    print("Loading test sequences...")
    test_sequences = {}
    for record in SeqIO.parse("/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta", "fasta"):
        protein_id = record.id
        test_sequences[protein_id] = str(record.seq)
    data['test_sequences'] = test_sequences
    
    print("Loading training terms...")
    train_terms = pd.read_csv("/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv", sep='\t')
    data['train_terms'] = train_terms
    
    print("Loading taxonomy...")
    train_taxonomy = pd.read_csv("/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv", sep='\t')
    data['train_taxonomy'] = train_taxonomy
    
    print("Loading IA weights...")
    ia_weights = pd.read_csv("/kaggle/input/cafa-6-protein-function-prediction/IA.tsv", sep='\t')
    data['ia_weights'] = ia_weights
    
    print("Loading GO ontology...")
    go_graph = obonet.read_obo("/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo")
    data['go_graph'] = go_graph
    
    print("Loading selected taxa...")
    selected_taxa = pd.read_csv("/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv", sep='\t')
    data['selected_taxa'] = selected_taxa
    
    return data

# Load the data
data = load_cafa6_data()

print("\n" + "="*50)
print("DATA LOADING COMPLETE")
print("="*50)

In [None]:
data.keys()

In [None]:
data['train_taxonomy']

In [None]:

def print_basic_stats(data):
    """Print basic statistics about the dataset"""
    print("üìä BASIC DATASET STATISTICS")
    print("-" * 40)
    
    print(f"üèãÔ∏è Training proteins: {len(data['train_sequences']):,}")
    print(f"üéØ Test proteins: {len(data['test_sequences']):,}")
    print(f"üè∑Ô∏è Training GO term annotations: {len(data['train_terms']):,}")
    print(f"üî§ Unique GO terms: {data['train_terms']['term'].nunique():,}")
    
    # Count terms by ontology
    term_counts = data['train_terms']['aspect'].value_counts()
    print(f"\nüìà Terms by ontology:")
    for aspect, count in term_counts.items():
        print(f"   {aspect}: {count:,} terms")
    
    # Sequence statistics
    train_seqs = list(data['train_sequences'].values())
    test_seqs = list(data['test_sequences'].values())
    
    print(f"\nüìè Sequence length statistics:")
    print(f"   Training - Mean: {np.mean([len(seq) for seq in train_seqs]):.1f}, "
          f"Max: {max([len(seq) for seq in train_seqs]):,}, "
          f"Min: {min([len(seq) for seq in train_seqs])}")
    print(f"   Test - Mean: {np.mean([len(seq) for seq in test_seqs]):.1f}, "
          f"Max: {max([len(seq) for seq in test_seqs]):,}, "
          f"Min: {min([len(seq) for seq in test_seqs])}")
    
    # Taxonomy statistics
    print(f"\nüß¨ Unique species in training: {data['train_taxonomy']['9606'].nunique():,}")

print_basic_stats(data)

In [None]:

def analyze_sequences(data):
    """Analyze protein sequences and amino acid composition"""
    print("\nüî¨ SEQUENCE ANALYSIS")
    print("-" * 40)
    
    train_seqs = list(data['train_sequences'].values())
    test_seqs = list(data['test_sequences'].values())
    
    # Amino acid composition
    aa_list = list('ACDEFGHIKLMNPQRSTVWY')  # 20 standard amino acids
    
    def get_aa_composition(sequences):
        all_aas = ''.join(sequences)
        aa_counts = Counter(all_aas)
        total = sum(aa_counts.values())
        return {aa: aa_counts.get(aa, 0)/total for aa in aa_list}
    
    train_aa_comp = get_aa_composition(train_seqs)
    test_aa_comp = get_aa_composition(test_seqs)
    
    # Create comparison plot
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Training set AA composition
    axes[0,0].bar(train_aa_comp.keys(), train_aa_comp.values(), color='skyblue', alpha=0.7)
    axes[0,0].set_title('Training Set - Amino Acid Composition', fontsize=14, fontweight='bold')
    axes[0,0].set_ylabel('Frequency', fontsize=12)
    axes[0,0].grid(True, alpha=0.3)
    
    # Test set AA composition
    axes[0,1].bar(test_aa_comp.keys(), test_aa_comp.values(), color='lightcoral', alpha=0.7)
    axes[0,1].set_title('Test Set - Amino Acid Composition', fontsize=14, fontweight='bold')
    axes[0,1].set_ylabel('Frequency', fontsize=12)
    axes[0,1].grid(True, alpha=0.3)
    
    # Sequence length distribution
    train_lengths = [len(seq) for seq in train_seqs]
    test_lengths = [len(seq) for seq in test_seqs]
    
    axes[1,0].hist(train_lengths, bins=100, alpha=0.7, color='skyblue', label='Training')
    axes[1,0].set_title('Training Set - Sequence Length Distribution', fontsize=14, fontweight='bold')
    axes[1,0].set_xlabel('Sequence Length', fontsize=12)
    axes[1,0].set_ylabel('Frequency', fontsize=12)
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    axes[1,1].hist(test_lengths, bins=100, alpha=0.7, color='lightcoral', label='Test')
    axes[1,1].set_title('Test Set - Sequence Length Distribution', fontsize=14, fontweight='bold')
    axes[1,1].set_xlabel('Sequence Length', fontsize=12)
    axes[1,1].set_ylabel('Frequency', fontsize=12)
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"üìä Sequence Length Statistics:")
    print(f"   Training - Mean: {np.mean(train_lengths):.1f}, Std: {np.std(train_lengths):.1f}")
    print(f"   Test - Mean: {np.mean(test_lengths):.1f}, Std: {np.std(test_lengths):.1f}")
    
    # Most common amino acids
    print(f"\nüèÜ Top 5 Amino Acids:")
    sorted_train = sorted(train_aa_comp.items(), key=lambda x: x[1], reverse=True)[:5]
    for aa, freq in sorted_train:
        print(f"   {aa}: {freq:.3f}")

analyze_sequences(data)


In [None]:
\def analyze_go_terms(data):
    """Analyze GO terms and their distribution"""
    print("\nüè∑Ô∏è GO TERM ANALYSIS")
    print("-" * 40)
    
    train_terms = data['train_terms']
    
    # Terms per protein
    terms_per_protein = train_terms.groupby('EntryID')['term'].count()
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Terms per protein distribution
    axes[0,0].hist(terms_per_protein, bins=50, alpha=0.7, color='teal')
    axes[0,0].set_title('GO Terms per Protein', fontsize=14, fontweight='bold')
    axes[0,0].set_xlabel('Number of Terms')
    axes[0,0].set_ylabel('Number of Proteins')
    axes[0,0].grid(True, alpha=0.3)
    
    # Terms by ontology (pie chart)
    ontology_counts = train_terms['aspect'].value_counts()
    colors = ['#ff9999', '#66b3ff', '#99ff99']
    axes[0,1].pie(ontology_counts.values, labels=ontology_counts.index, autopct='%1.1f%%', 
                  colors=colors, startangle=90)
    axes[0,1].set_title('GO Terms by Ontology', fontsize=14, fontweight='bold')
    
    # Terms by ontology (bar chart)
    axes[0,2].bar(ontology_counts.index, ontology_counts.values, color=colors, alpha=0.7)
    axes[0,2].set_title('GO Terms Count by Ontology', fontsize=14, fontweight='bold')
    axes[0,2].set_ylabel('Number of Terms')
    axes[0,2].grid(True, alpha=0.3)
    
    # Top terms in each ontology
    ontologies = ['BPO', 'MFO', 'CCO']
    colors_ont = ['lightblue', 'lightgreen', 'lightcoral']
    
    for i, ontology in enumerate(ontologies):
        ontology_terms = train_terms[train_terms['aspect'] == ontology]
        top_terms = ontology_terms['term'].value_counts().head(10)
        
        axes[1,i].barh(range(len(top_terms)), top_terms.values, color=colors_ont[i], alpha=0.7)
        axes[1,i].set_title(f'Top 10 {ontology} Terms', fontsize=12, fontweight='bold')
        axes[1,i].set_yticks(range(len(top_terms)))
        axes[1,i].set_yticklabels([f'GO:{term}' for term in top_terms.index], fontsize=8)
        axes[1,i].set_xlabel('Frequency')
        axes[1,i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print(f"üìä GO Term Statistics:")
    print(f"   Average terms per protein: {terms_per_protein.mean():.2f}")
    print(f"   Max terms per protein: {terms_per_protein.max()}")
    print(f"   Min terms per protein: {terms_per_protein.min()}")
    print(f"   Std of terms per protein: {terms_per_protein.std():.2f}")
    
    # Label sparsity calculation
    total_possible_annotations = len(data['train_sequences']) * train_terms['term'].nunique()
    actual_annotations = len(train_terms)
    sparsity = (1 - actual_annotations / total_possible_annotations) * 100
    print(f"   Label matrix sparsity: {sparsity:.4f}%")
    
    # Most annotated proteins
    print(f"\nüèÜ Top 5 Most Annotated Proteins:")
    top_proteins = terms_per_protein.sort_values(ascending=False).head()
    for protein, count in top_proteins.items():
        print(f"   {protein}: {count} terms")

analyze_go_terms(data)

In [None]:

def analyze_taxonomy(data):
    """Analyze species distribution in the dataset"""
    print("\nüß¨ TAXONOMY ANALYSIS")
    print("-" * 40)
    
    taxonomy = data['train_taxonomy']
    selected_taxa = data['selected_taxa']
    
    # Species distribution
    species_counts = taxonomy['9606'].value_counts()
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Top 20 species
    top_species = species_counts.head(20)
    ax1.bar(range(len(top_species)), top_species.values, color='purple', alpha=0.7)
    ax1.set_title('Top 20 Species by Protein Count', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Species Rank')
    ax1.set_ylabel('Number of Proteins')
    ax1.set_xticks(range(len(top_species)))
    ax1.set_xticklabels([f'Taxon {taxon}' for taxon in top_species.index], rotation=45, ha='right')
    ax1.grid(True, alpha=0.3)
    
    # Species distribution (log scale)
    ax2.hist(species_counts.values, bins=50, alpha=0.7, color='orange', log=True)
    ax2.set_title('Species Distribution (Log Scale)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Proteins per Species')
    ax2.set_ylabel('Number of Species (log)')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"üìä Taxonomy Statistics:")
    print(f"   Total number of species: {taxonomy['9606'].nunique():,}")
    print(f"   Total number of selected taxa: {len(selected_taxa):,}")
    
    print(f"\nüèÜ Top 10 Species:")
    for i, (taxon, count) in enumerate(species_counts.head(10).items(), 1):
        print(f"   {i:2d}. Taxon {taxon}: {count:,} proteins")
    
    # Coverage statistics
    coverage_stats = species_counts.describe()
    print(f"\nüìà Species Coverage:")
    print(f"   Mean proteins per species: {coverage_stats['mean']:.1f}")
    print(f"   Std proteins per species: {coverage_stats['std']:.1f}")
    print(f"   Max proteins per species: {coverage_stats['max']:,}")
    print(f"   Min proteins per species: {coverage_stats['min']}")

analyze_taxonomy(data)


In [None]:
data['ia_weights']

In [None]:

# def analyze_ia_weights(data):
#     """Analyze information accretion weights"""
#     print("\n‚öñÔ∏è INFORMATION ACCRETION ANALYSIS")
#     print("-" * 40)
    
#     ia_weights = data['ia_weights']
#     train_terms = data['train_terms']
    
#     # Merge with train terms to get ontology information
#     ia_with_ontology = ia_weights.merge(
#         train_terms[['term', 'aspect']].drop_duplicates(), 
#         on='term', 
#         how='left'
#     )
    
#     # Create visualization
#     fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
#     # IA distribution by ontology
#     ontologies = ['BPO', 'MFO', 'CCO']
#     colors = ['lightblue', 'lightgreen', 'lightcoral']
    
#     for i, ontology in enumerate(ontologies):
#         ontology_ia = ia_with_ontology[ia_with_ontology['aspect'] == ontology]['ia']
#         ax1.hist(ontology_ia, bins=50, alpha=0.6, color=colors[i], label=ontology)
    
#     ax1.set_title('IA Distribution by Ontology', fontsize=14, fontweight='bold')
#     ax1.set_xlabel('Information Accretion')
#     ax1.set_ylabel('Frequency')
#     ax1.legend()
#     ax1.grid(True, alpha=0.3)
    
#     # Box plot of IA by ontology
#     ia_data = [ia_with_ontology[ia_with_ontology['aspect'] == ont]['ia'] for ont in ontologies]
#     ax2.boxplot(ia_data, labels=ontologies, patch_artist=True,
#                 boxprops=dict(facecolor='lightgray', color='black'),
#                 medianprops=dict(color='red'))
#     ax2.set_title('IA Distribution Box Plot', fontsize=14, fontweight='bold')
#     ax2.set_ylabel('Information Accretion')
#     ax2.grid(True, alpha=0.3)
    
#     # Cumulative distribution of IA
#     for i, ontology in enumerate(ontologies):
#         ontology_ia = ia_with_ontology[ia_with_ontology['aspect'] == ontology]['ia']
#         sorted_ia = np.sort(ontology_ia)
#         y = np.arange(1, len(sorted_ia) + 1) / len(sorted_ia)
#         ax3.plot(sorted_ia, y, label=ontology, color=colors[i], linewidth=2)
    
#     ax3.set_title('Cumulative Distribution of IA', fontsize=14, fontweight='bold')
#     ax3.set_xlabel('Information Accretion')
#     ax3.set_ylabel('Cumulative Probability')
#     ax3.legend()
#     ax3.grid(True, alpha=0.3)
    
#     # IA vs Term Frequency
#     term_freq = train_terms['term'].value_counts()
#     ia_freq = ia_weights.merge(term_freq.rename('frequency'), left_on='term', right_index=True)
    
#     ax4.scatter(ia_freq['frequency'], ia_freq['ia'], alpha=0.5, color='purple')
#     ax4.set_xscale('log')
#     ax4.set_title('IA vs Term Frequency', fontsize=14, fontweight='bold')
#     ax4.set_xlabel('Term Frequency (log scale)')
#     ax4.set_ylabel('Information Accretion')
#     ax4.grid(True, alpha=0.3)
    
#     plt.tight_layout()
#     plt.show()
    
#     # Print IA statistics
#     print(f"üìä Information Accretion Statistics:")
#     for ontology in ontologies:
#         ontology_ia = ia_with_ontology[ia_with_ontology['aspect'] == ontology]['ia']
#         print(f"   {ontology}:")
#         print(f"      Mean IA: {ontology_ia.mean():.4f}")
#         print(f"      Std IA: {ontology_ia.std():.4f}")
#         print(f"      Max IA: {ontology_ia.max():.4f}")
#         print(f"      Min IA: {ontology_ia.min():.4f}")
#         print(f"      Median IA: {ontology_ia.median():.4f}")

# analyze_ia_weights(data)

In [None]:

def analyze_go_structure(data):
    """Analyze the GO ontology graph structure"""
    print("\nüå≥ GO ONTOLOGY STRUCTURE ANALYSIS")
    print("-" * 40)
    
    go_graph = data['go_graph']
    
    # Basic graph statistics
    print(f"üìä GO Graph Statistics:")
    print(f"   Number of nodes (terms): {len(go_graph.nodes):,}")
    print(f"   Number of edges (relationships): {len(go_graph.edges):,}")
    
    # Root nodes
    root_nodes = {
        'BPO': 'GO:0008150',  # biological_process
        'MFO': 'GO:0003674',  # molecular_function
        'CCO': 'GO:0005575'   # cellular_component
    }
    
    # Calculate depth and number of children for each term
    def get_subtree_size(graph, root):
        """Get the size of subtree under root"""
        descendants = nx.descendants(graph, root)
        return len(descendants) + 1  # +1 for root itself
    
    print(f"\nüå≤ Ontology Tree Sizes:")
    for ont_name, root_id in root_nodes.items():
        size = get_subtree_size(go_graph, root_id)
        print(f"   {ont_name} ({root_id}): {size:,} terms")
    
    # Degree distribution
    in_degrees = [d for n, d in go_graph.in_degree()]
    out_degrees = [d for n, d in go_graph.out_degree()]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    ax1.hist(in_degrees, bins=50, alpha=0.7, color='green', log=True)
    ax1.set_title('In-Degree Distribution (Log Scale)', fontsize=14, fontweight='bold')
    ax1.set_xlabel('In-Degree')
    ax1.set_ylabel('Frequency (log)')
    ax1.grid(True, alpha=0.3)
    
    ax2.hist(out_degrees, bins=50, alpha=0.7, color='blue', log=True)
    ax2.set_title('Out-Degree Distribution (Log Scale)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Out-Degree')
    ax2.set_ylabel('Frequency (log)')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print degree statistics
    print(f"\nüìà Degree Statistics:")
    print(f"   In-degree - Mean: {np.mean(in_degrees):.2f}, Max: {max(in_degrees)}")
    print(f"   Out-degree - Mean: {np.mean(out_degrees):.2f}, Max: {max(out_degrees)}")
    
    # Find terms with highest degree
    high_in_degree = sorted(go_graph.in_degree(), key=lambda x: x[1], reverse=True)[:5]
    high_out_degree = sorted(go_graph.out_degree(), key=lambda x: x[1], reverse=True)[:5]
    
    print(f"\nüèÜ Terms with Highest In-Degree:")
    for term, degree in high_in_degree:
        print(f"   {term}: {degree}")
    
    print(f"\nüèÜ Terms with Highest Out-Degree:")
    for term, degree in high_out_degree:
        print(f"   {term}: {degree}")

analyze_go_structure(data)

In [None]:
# Ultra memory-efficient multi-label analysis
def analyze_multilabel_ultra_efficient(data):
    """Ultra memory-efficient multi-label analysis"""
    print("\nüéØ MULTI-LABEL CHARACTERISTICS (ULTRA EFFICIENT)")
    print("-" * 55)
    
    train_terms = data['train_terms']
    
    # Process in chunks to avoid memory issues
    chunk_size = 10000
    total_chunks = (len(train_terms) + chunk_size - 1) // chunk_size
    
    print("Processing data in chunks...")
    
    # Initialize counters
    protein_term_counts = {}
    ontology_counts = {'BPO': set(), 'MFO': set(), 'CCO': set()}
    term_counts = {}
    
    # Process data in chunks
    for i in range(total_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, len(train_terms))
        chunk = train_terms.iloc[start_idx:end_idx]
        
        # Update protein-term counts
        for _, row in chunk.iterrows():
            protein = row['EntryID']
            term = row['term']
            aspect = row['aspect']
            
            # Count terms per protein
            if protein not in protein_term_counts:
                protein_term_counts[protein] = 0
            protein_term_counts[protein] += 1
            
            # Track proteins per ontology
            ontology_counts[aspect].add(protein)
            
            # Count term frequencies
            if term not in term_counts:
                term_counts[term] = 0
            term_counts[term] += 1
        
        if (i + 1) % 10 == 0:
            print(f"Processed {end_idx:,} rows...")
    
    # Convert to series for analysis
    terms_per_protein = pd.Series(protein_term_counts)
    
    # Calculate statistics
    label_cardinality = terms_per_protein.mean()
    total_terms = len(term_counts)
    label_density = label_cardinality / total_terms
    
    # Create simplified visualization
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    # Histogram of terms per protein (sampled)
    if len(terms_per_protein) > 50000:
        sampled = terms_per_protein.sample(n=30000, random_state=42)
        ax1.hist(sampled, bins=30, alpha=0.7, color='red')
        ax1.set_title(f'Terms per Protein\n(30K sample of {len(terms_per_protein):,})', fontweight='bold')
    else:
        ax1.hist(terms_per_protein, bins=30, alpha=0.7, color='red')
        ax1.set_title('Terms per Protein', fontweight='bold')
    ax1.set_xlabel('Number of Terms')
    ax1.set_ylabel('Number of Proteins')
    ax1.grid(True, alpha=0.3)
    
    # Ontology distribution
    ontology_sizes = {ont: len(proteins) for ont, proteins in ontology_counts.items()}
    ax2.bar(ontology_sizes.keys(), ontology_sizes.values(), 
            color=['lightblue', 'lightgreen', 'lightcoral'], alpha=0.7)
    ax2.set_title('Proteins per Ontology', fontweight='bold')
    ax2.set_ylabel('Number of Proteins')
    ax2.grid(True, alpha=0.3)
    
    # Term frequency distribution (log scale)
    term_freq = pd.Series(term_counts)
    ax3.hist(term_freq.values, bins=50, alpha=0.7, color='purple', log=True)
    ax3.set_title('GO Term Frequency Distribution\n(Log Scale)', fontweight='bold')
    ax3.set_xlabel('Term Frequency')
    ax3.set_ylabel('Number of Terms (log)')
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"üìä Multi-label Statistics:")
    print(f"   Total proteins: {len(protein_term_counts):,}")
    print(f"   Total GO terms: {total_terms:,}")
    print(f"   Total annotations: {len(train_terms):,}")
    print(f"   Label cardinality: {label_cardinality:.2f} terms/protein")
    print(f"   Label density: {label_density:.6f}")
    
    print(f"\nüè∑Ô∏è Ontology Coverage:")
    for ont, proteins in ontology_counts.items():
        coverage = len(proteins) / len(protein_term_counts) * 100
        print(f"   {ont}: {len(proteins):,} proteins ({coverage:.1f}%)")
    
    # Calculate overlaps
    bpo_set = ontology_counts['BPO']
    mfo_set = ontology_counts['MFO']
    cco_set = ontology_counts['CCO']
    
    all_three = len(bpo_set & mfo_set & cco_set)
    print(f"   All three ontologies: {all_three:,} proteins ({all_three/len(protein_term_counts)*100:.1f}%)")
    
    print(f"\nüìà Term Frequency Statistics:")
    print(f"   Most frequent term: {term_freq.max()} occurrences")
    print(f"   Average term frequency: {term_freq.mean():.1f}")
    print(f"   Terms appearing only once: {(term_freq == 1).sum():,}")
    
    # Show top terms
    top_terms = term_freq.nlargest(10)
    print(f"\nüèÜ Top 10 Most Frequent GO Terms:")
    for term, count in top_terms.items():
        print(f"   {term}: {count:,} proteins")

# Use this if the optimized version still causes memory issues
analyze_multilabel_ultra_efficient(data)

In [None]:
# Create interactive visualizations
def create_interactive_plots(data):
    """Create interactive plots using Plotly"""
    print("\nüìä INTERACTIVE VISUALIZATIONS")
    print("-" * 40)
    
    train_terms = data['train_terms']
    
    # Interactive terms per protein distribution
    terms_per_protein = train_terms.groupby('EntryID')['term'].count()
    
    fig1 = px.histogram(terms_per_protein, nbins=50, 
                       title='Interactive: GO Terms per Protein Distribution',
                       labels={'value': 'Number of Terms', 'count': 'Number of Proteins'},
                       opacity=0.7)
    fig1.show()
    
    # Interactive ontology comparison
    ontology_stats = train_terms.groupby('aspect').agg({
        'term': ['count', 'nunique'],
        'EntryID': 'nunique'
    }).round(2)
    ontology_stats.columns = ['Total Annotations', 'Unique Terms', 'Unique Proteins']
    ontology_stats = ontology_stats.reset_index()
    
    fig2 = px.bar(ontology_stats, x='aspect', y='Unique Terms',
                 title='Interactive: Unique GO Terms by Ontology',
                 color='aspect',
                 text='Unique Terms')
    fig2.show()
    
    # Sequence length interactive plot
    train_lengths = [len(seq) for seq in data['train_sequences'].values()]
    test_lengths = [len(seq) for seq in data['test_sequences'].values()]
    
    fig3 = go.Figure()
    fig3.add_trace(go.Box(y=train_lengths, name='Training Set', boxpoints='outliers'))
    fig3.add_trace(go.Box(y=test_lengths, name='Test Set', boxpoints='outliers'))
    fig3.update_layout(title='Interactive: Sequence Length Distribution Comparison',
                      yaxis_title='Sequence Length')
    fig3.show()
    
    print("‚úÖ Interactive plots created successfully!")

# Uncomment to run interactive plots (may be heavy for large datasets)
create_interactive_plots(data)

In [None]:
# Generate key insights summary
def generate_insights_summary(data):
    """Generate a summary of key insights from the EDA"""
    print("\n" + "="*60)
    print("üîç KEY INSIGHTS SUMMARY")
    print("="*60)
    
    train_terms = data['train_terms']
    train_seqs = list(data['train_sequences'].values())
    test_seqs = list(data['test_sequences'].values())
    
    # Calculate key metrics
    terms_per_protein = train_terms.groupby('EntryID')['term'].count()
    avg_terms_per_protein = terms_per_protein.mean()
    
    train_avg_len = np.mean([len(seq) for seq in train_seqs])
    test_avg_len = np.mean([len(seq) for seq in test_seqs])
    
    # Unique terms by ontology
    bpo_terms = train_terms[train_terms['aspect'] == 'BPO']['term'].nunique()
    mfo_terms = train_terms[train_terms['aspect'] == 'MFO']['term'].nunique()
    cco_terms = train_terms[train_terms['aspect'] == 'CCO']['term'].nunique()
    
    print("\nüìà COMPETITION SCALE:")
    print(f"   ‚Ä¢ {len(data['train_sequences']):,} training proteins")
    print(f"   ‚Ä¢ {len(data['test_sequences']):,} test proteins")
    print(f"   ‚Ä¢ {len(train_terms):,} total GO term annotations")
    print(f"   ‚Ä¢ {train_terms['term'].nunique():,} unique GO terms")
    
    print("\nüéØ MULTI-LABEL COMPLEXITY:")
    print(f"   ‚Ä¢ Average {avg_terms_per_protein:.1f} terms per protein")
    print(f"   ‚Ä¢ Extremely sparse label matrix (>99.9% sparse)")
    print(f"   ‚Ä¢ Hierarchical relationships between terms")
    
    print("\nüß¨ ONTOLOGY DISTRIBUTION:")
    print(f"   ‚Ä¢ Biological Process (BPO): {bpo_terms:,} terms")
    print(f"   ‚Ä¢ Molecular Function (MFO): {mfo_terms:,} terms")
    print(f"   ‚Ä¢ Cellular Component (CCO): {cco_terms:,} terms")
    
    print("\nüìè SEQUENCE CHARACTERISTICS:")
    print(f"   ‚Ä¢ Training avg length: {train_avg_len:.1f} amino acids")
    print(f"   ‚Ä¢ Test avg length: {test_avg_len:.1f} amino acids")
    print(f"   ‚Ä¢ Wide range of sequence lengths (tens to thousands)")
    
    print("\n‚öñÔ∏è EVALUATION COMPLEXITY:")
    print(f"   ‚Ä¢ Weighted by Information Accretion (IA)")
    print(f"   ‚Ä¢ Hierarchical precision/recall")
    print(f"   ‚Ä¢ Three separate subontology evaluations")
    
    print("\nüöÄ RECOMMENDATIONS FOR MODELING:")
    print(f"   ‚Ä¢ Use protein language models (ESM, ProtBERT)")
    print(f"   ‚Ä¢ Implement hierarchical multi-label classification")
    print(f"   ‚Ä¢ Handle extreme class imbalance")
    print(f"   ‚Ä¢ Consider taxonomic information")
    print(f"   ‚Ä¢ Use IA weights in loss function")
    
    print("\n‚ö†Ô∏è  CHALLENGES:")
    print(f"   ‚Ä¢ Extreme multi-label classification")
    print(f"   ‚Ä¢ Hierarchical label relationships")
    print(f"   ‚Ä¢ Sparse and imbalanced annotations")
    print(f"   ‚Ä¢ Computational complexity")
    print(f"   ‚Ä¢ Prospective evaluation (future test set)")

generate_insights_summary(data)

# %% [markdown]
# ## 10. Data Quality Checks

# %% [code]
# Data quality validation
def perform_quality_checks(data):
    """Perform data quality checks and validation"""
    print("\nüîç DATA QUALITY CHECKS")
    print("-" * 40)
    
    # Check for missing values
    print("1. Missing Values Check:")
    train_terms_missing = data['train_terms'].isnull().sum()
    taxonomy_missing = data['train_taxonomy'].isnull().sum()
    
    print(f"   Train terms - Missing values: {train_terms_missing.sum()}")
    print(f"   Taxonomy - Missing values: {taxonomy_missing.sum()}")
    
    # Check sequence validity
    print("\n2. Sequence Validity Check:")
    valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
    
    def check_sequence_validity(sequences):
        invalid_chars = []
        for seq in sequences:
            invalid_in_seq = set(seq) - valid_aa
            invalid_chars.extend(invalid_in_seq)
        return set(invalid_chars)
    
    train_invalid = check_sequence_validity(data['train_sequences'].values())
    test_invalid = check_sequence_validity(data['test_sequences'].values())
    
    print(f"   Training - Invalid characters: {train_invalid if train_invalid else 'None'}")
    print(f"   Test - Invalid characters: {test_invalid if test_invalid else 'None'}")
    
    # Check ID consistency
    print("\n3. ID Consistency Check:")
    train_proteins = set(data['train_sequences'].keys())
    term_proteins = set(data['train_terms']['EntryID'].unique())
    taxonomy_proteins = set(data['train_taxonomy']['EntryID'])
    
    print(f"   Proteins in sequences: {len(train_proteins):,}")
    print(f"   Proteins in terms: {len(term_proteins):,}")
    print(f"   Proteins in taxonomy: {len(taxonomy_proteins):,}")
    
    missing_in_terms = train_proteins - term_proteins
    missing_in_taxonomy = train_proteins - taxonomy_proteins
    
    print(f"   Proteins missing term annotations: {len(missing_in_terms)}")
    print(f"   Proteins missing taxonomy: {len(missing_in_taxonomy)}")
    
    # Check GO term validity
    print("\n4. GO Term Validity Check:")
    valid_terms = set(data['go_graph'].nodes())
    used_terms = set(data['train_terms']['term'].unique())
    
    invalid_terms = used_terms - valid_terms
    print(f"   Used GO terms: {len(used_terms):,}")
    print(f"   Valid GO terms in ontology: {len(valid_terms):,}")
    print(f"   Invalid GO terms: {len(invalid_terms)}")
    
    if invalid_terms:
        print(f"   Example invalid terms: {list(invalid_terms)[:5]}")
    
    print("\n‚úÖ Data quality checks completed!")

perform_quality_checks(data)

# This comprehensive EDA reveals that CAFA 6 is an extremely challenging multi-label classification problem with:
# 
# - **Large scale**: Hundreds of thousands of proteins and GO terms
# - **High complexity**: Hierarchical, multi-ontology predictions
# - **Extreme sparsity**: Very few positive labels per protein
# - **Complex evaluation**: Information accretion weighted metrics
# - **Prospective test set**: Future annotations as ground truth

# Final summary cell
print("\n" + "="*70)
print("üéâ CAFA 6 EDA COMPLETED SUCCESSFULLY!")
print("="*70)
print("\nNext steps:")
print("1. Implement baseline models (sequence similarity, ESM embeddings)")
print("2. Develop hierarchical multi-label classification approaches")
print("3. Incorporate taxonomic and evolutionary information")
print("4. Handle class imbalance with appropriate loss functions")
print("5. Validate using the proper evaluation metrics")
print("\nGood luck with the competition! üöÄ")
