# CAFA 6 - Phase 2: Comprehensive Feature Engineering
- Author: Dr. Hany Ghazal (PhD)
- Date: 19 October 2025

This module implements all feature engineering methods for protein function prediction:
1. Basic Sequence Features
2. K-mer Features
3. Physicochemical Properties
4. Positional Features
5. Composition Features
6. Structure-based Features
7. GO Hierarchy Features

In [None]:
!pip install biopython

In [None]:
from Bio import SeqIO
import pandas as pd
import numpy as np
from collections import Counter
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


In [None]:
# ============================================================================
# 1. AMINO ACID PROPERTIES
# ============================================================================

class AminoAcidProperties:
    """Database of amino acid physicochemical properties."""
    
    # Molecular weights (Da)
    MOLECULAR_WEIGHT = {
        'A': 89.1, 'C': 121.2, 'D': 133.1, 'E': 147.1, 'F': 165.2,
        'G': 75.1, 'H': 155.2, 'I': 131.2, 'K': 146.2, 'L': 131.2,
        'M': 149.2, 'N': 132.1, 'P': 115.1, 'Q': 146.2, 'R': 174.2,
        'S': 105.1, 'T': 119.1, 'V': 117.1, 'W': 204.2, 'Y': 181.2
    }
    
    # Hydrophobicity (Kyte-Doolittle scale)
    HYDROPHOBICITY = {
        'A': 1.8, 'C': 2.5, 'D': -3.5, 'E': -3.5, 'F': 2.8,
        'G': -0.4, 'H': -3.2, 'I': 4.5, 'K': -3.9, 'L': 3.8,
        'M': 1.9, 'N': -3.5, 'P': -1.6, 'Q': -3.5, 'R': -4.5,
        'S': -0.8, 'T': -0.7, 'V': 4.2, 'W': -0.9, 'Y': -1.3
    }
    
    # Charge at pH 7
    CHARGE = {
        'A': 0, 'C': 0, 'D': -1, 'E': -1, 'F': 0,
        'G': 0, 'H': 0.1, 'I': 0, 'K': 1, 'L': 0,
        'M': 0, 'N': 0, 'P': 0, 'Q': 0, 'R': 1,
        'S': 0, 'T': 0, 'V': 0, 'W': 0, 'Y': 0
    }
    
    # Polarity
    POLARITY = {
        'A': 0, 'C': 1, 'D': 1, 'E': 1, 'F': 0,
        'G': 0, 'H': 1, 'I': 0, 'K': 1, 'L': 0,
        'M': 0, 'N': 1, 'P': 0, 'Q': 1, 'R': 1,
        'S': 1, 'T': 1, 'V': 0, 'W': 0, 'Y': 1
    }
    
    # Van der Waals volume
    VOLUME = {
        'A': 67, 'C': 86, 'D': 91, 'E': 109, 'F': 135,
        'G': 48, 'H': 118, 'I': 124, 'K': 135, 'L': 124,
        'M': 124, 'N': 96, 'P': 90, 'Q': 114, 'R': 148,
        'S': 73, 'T': 93, 'V': 105, 'W': 163, 'Y': 141
    }
    
    # Aromaticity (0 or 1)
    AROMATIC = {
        'A': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 1,
        'G': 0, 'H': 1, 'I': 0, 'K': 0, 'L': 0,
        'M': 0, 'N': 0, 'P': 0, 'Q': 0, 'R': 0,
        'S': 0, 'T': 0, 'V': 0, 'W': 1, 'Y': 1
    }
    
    # Amino acid groups
    GROUPS = {
        'aliphatic': ['A', 'V', 'L', 'I', 'M'],
        'aromatic': ['F', 'Y', 'W'],
        'polar': ['S', 'T', 'N', 'Q'],
        'charged': ['K', 'R', 'H', 'D', 'E'],
        'positive': ['K', 'R', 'H'],
        'negative': ['D', 'E'],
        'small': ['A', 'G', 'S'],
        'tiny': ['G', 'A', 'S'],
        'sulfur': ['C', 'M']
    }

In [None]:
# ============================================================================
# 2. BASIC SEQUENCE FEATURES
# ============================================================================

class BasicSequenceFeatures:
    """Extract basic features from protein sequences."""
    
    @staticmethod
    def extract_length_features(sequence: str) -> Dict[str, float]:
        """Extract sequence length-based features."""
        length = len(sequence)
        return {
            'seq_length': length,
            'log_length': np.log1p(length),
            'sqrt_length': np.sqrt(length)
        }
    
    @staticmethod
    def extract_molecular_weight(sequence: str) -> Dict[str, float]:
        """Calculate molecular weight of protein."""
        weight = sum(AminoAcidProperties.MOLECULAR_WEIGHT.get(aa, 0) for aa in sequence)
        # Subtract water molecules formed in peptide bonds
        weight -= (len(sequence) - 1) * 18.015
        
        return {
            'molecular_weight': weight,
            'molecular_weight_per_aa': weight / len(sequence) if len(sequence) > 0 else 0
        }
    
    @staticmethod
    def extract_composition(sequence: str) -> Dict[str, float]:
        """Calculate amino acid composition (frequency of each AA)."""
        length = len(sequence)
        if length == 0:
            return {f'comp_{aa}': 0.0 for aa in AminoAcidProperties.MOLECULAR_WEIGHT.keys()}
        
        counter = Counter(sequence)
        return {
            f'comp_{aa}': counter.get(aa, 0) / length
            for aa in AminoAcidProperties.MOLECULAR_WEIGHT.keys()
        }
    
    @staticmethod
    def extract_group_composition(sequence: str) -> Dict[str, float]:
        """Calculate composition of amino acid groups."""
        length = len(sequence)
        if length == 0:
            return {f'group_{name}': 0.0 for name in AminoAcidProperties.GROUPS.keys()}
        
        features = {}
        for group_name, group_aas in AminoAcidProperties.GROUPS.items():
            count = sum(1 for aa in sequence if aa in group_aas)
            features[f'group_{group_name}'] = count / length
        
        return features

In [None]:
# ============================================================================
# 3. PHYSICOCHEMICAL PROPERTIES
# ============================================================================

class PhysicochemicalFeatures:
    """Extract physicochemical property-based features."""
    
    @staticmethod
    def extract_property_statistics(sequence: str, property_dict: Dict[str, float], 
                                   property_name: str) -> Dict[str, float]:
        """Calculate statistics for a given property."""
        if len(sequence) == 0:
            return {
                f'{property_name}_mean': 0.0,
                f'{property_name}_std': 0.0,
                f'{property_name}_min': 0.0,
                f'{property_name}_max': 0.0
            }
        
        values = [property_dict.get(aa, 0) for aa in sequence]
        
        return {
            f'{property_name}_mean': np.mean(values),
            f'{property_name}_std': np.std(values),
            f'{property_name}_min': np.min(values),
            f'{property_name}_max': np.max(values)
        }
    
    @staticmethod
    def extract_all_properties(sequence: str) -> Dict[str, float]:
        """ Extract all physicochemical property features."""
        features = {}
        
        # Hydrophobicity
        features.update(PhysicochemicalFeatures.extract_property_statistics(
            sequence, AminoAcidProperties.HYDROPHOBICITY, 'hydrophobicity'))
        
        # Charge
        features.update(PhysicochemicalFeatures.extract_property_statistics(
            sequence, AminoAcidProperties.CHARGE, 'charge'))
        
        # Volume
        features.update(PhysicochemicalFeatures.extract_property_statistics(
            sequence, AminoAcidProperties.VOLUME, 'volume'))
        
        # Polarity
        features.update(PhysicochemicalFeatures.extract_property_statistics(
            sequence, AminoAcidProperties.POLARITY, 'polarity'))
        
        # Aromaticity
        if len(sequence) > 0:
            aromatic_count = sum(AminoAcidProperties.AROMATIC.get(aa, 0) for aa in sequence)
            features['aromaticity'] = aromatic_count / len(sequence)
        else:
            features['aromaticity'] = 0.0
        
        # Net charge
        if len(sequence) > 0:
            net_charge = sum(AminoAcidProperties.CHARGE.get(aa, 0) for aa in sequence)
            features['net_charge'] = net_charge
            features['charge_density'] = net_charge / len(sequence)
        else:
            features['net_charge'] = 0.0
            features['charge_density'] = 0.0
        
        return features
    
    @staticmethod
    def extract_isoelectric_point(sequence: str) -> Dict[str, float]:
        """Estimate isoelectric point (pI) - simplified calculation."""
        if len(sequence) == 0:
            return {'isoelectric_point': 7.0}
        
        # Count charged residues
        positive = sum(1 for aa in sequence if aa in ['K', 'R', 'H'])
        negative = sum(1 for aa in sequence if aa in ['D', 'E'])
        
        # Simplified pI estimation
        if positive + negative == 0:
            pI = 7.0
        else:
            pI = 6.5 + 0.5 * (positive - negative) / len(sequence) * 10
            pI = max(3.0, min(11.0, pI))  # Clamp between 3 and 11
        
        return {'isoelectric_point': pI}


In [None]:
# ============================================================================
# 4. K-MER FEATURES
# ============================================================================

class KmerFeatures:
    """Extract k-mer frequency features."""
    
    @staticmethod
    def generate_kmers(k: int) -> List[str]:
        """Generate all possible k-mers."""
        amino_acids = list(AminoAcidProperties.MOLECULAR_WEIGHT.keys())
        
        if k == 1:
            return amino_acids
        
        kmers = amino_acids.copy()
        for _ in range(k - 1):
            new_kmers = []
            for kmer in kmers:
                for aa in amino_acids:
                    new_kmers.append(kmer + aa)
            kmers = new_kmers
        
        return kmers
    
    @staticmethod
    def extract_kmer_frequencies(sequence: str, k: int, 
                                 normalize: bool = True) -> Dict[str, float]:
        """Extract k-mer frequencies from sequence."""
        if len(sequence) < k:
            kmers = KmerFeatures.generate_kmers(k)
            return {f'kmer_{k}_{kmer}': 0.0 for kmer in kmers}
        
        # Count k-mers
        kmer_counts = Counter()
        for i in range(len(sequence) - k + 1):
            kmer = sequence[i:i+k]
            if all(aa in AminoAcidProperties.MOLECULAR_WEIGHT for aa in kmer):
                kmer_counts[kmer] += 1
        
        # Normalize
        total = sum(kmer_counts.values())
        if normalize and total > 0:
            kmer_counts = {k: v/total for k, v in kmer_counts.items()}
        
        # Get all possible k-mers
        all_kmers = KmerFeatures.generate_kmers(k)
        
        return {
            f'kmer_{k}_{kmer}': kmer_counts.get(kmer, 0.0)
            for kmer in all_kmers
        }
    
    @staticmethod
    def extract_dipeptide_composition(sequence: str) -> Dict[str, float]:
        """Extract dipeptide (2-mer) composition - 400 features."""
        return KmerFeatures.extract_kmer_frequencies(sequence, k=2, normalize=True)
    
    @staticmethod
    def extract_tripeptide_composition(sequence: str, top_n: int = 100) -> Dict[str, float]:
        """Extract top N most common tripeptides - reduced from 8000 features."""
        if len(sequence) < 3:
            return {f'kmer_3_top_{i}': 0.0 for i in range(top_n)}
        
        # Count all tripeptides
        tripeptide_counts = Counter()
        for i in range(len(sequence) - 2):
            tripeptide = sequence[i:i+3]
            if all(aa in AminoAcidProperties.MOLECULAR_WEIGHT for aa in tripeptide):
                tripeptide_counts[tripeptide] += 1
        
        # Get top N
        total = sum(tripeptide_counts.values())
        top_tripeptides = tripeptide_counts.most_common(top_n)
        
        features = {}
        for i, (tripeptide, count) in enumerate(top_tripeptides):
            features[f'kmer_3_{tripeptide}'] = count / total if total > 0 else 0.0
        
        # Fill remaining with zeros
        for i in range(len(top_tripeptides), top_n):
            features[f'kmer_3_top_{i}'] = 0.0
        
        return features


In [None]:
# ============================================================================
# 5. POSITIONAL FEATURES
# ============================================================================

class PositionalFeatures:
    """Extract position-specific features."""
    
    @staticmethod
    def extract_terminal_composition(sequence: str, n: int = 25) -> Dict[str, float]:
        """Extract amino acid composition of N-terminal and C-terminal regions."""
        features = {}
        
        # N-terminal
        n_term = sequence[:n] if len(sequence) >= n else sequence
        n_term_counter = Counter(n_term)
        n_term_length = len(n_term)
        
        for aa in AminoAcidProperties.MOLECULAR_WEIGHT.keys():
            features[f'n_term_{aa}'] = n_term_counter.get(aa, 0) / n_term_length if n_term_length > 0 else 0.0
        
        # C-terminal
        c_term = sequence[-n:] if len(sequence) >= n else sequence
        c_term_counter = Counter(c_term)
        c_term_length = len(c_term)
        
        for aa in AminoAcidProperties.MOLECULAR_WEIGHT.keys():
            features[f'c_term_{aa}'] = c_term_counter.get(aa, 0) / c_term_length if c_term_length > 0 else 0.0
        
        return features
    
    @staticmethod
    def extract_region_properties(sequence: str, n_regions: int = 5) -> Dict[str, float]:
        """Divide sequence into regions and extract properties."""
        features = {}
        
        if len(sequence) == 0:
            return features
        
        region_size = len(sequence) // n_regions
        
        for i in range(n_regions):
            start = i * region_size
            end = start + region_size if i < n_regions - 1 else len(sequence)
            region = sequence[start:end]
            
            if len(region) > 0:
                # Hydrophobicity
                hydro = np.mean([AminoAcidProperties.HYDROPHOBICITY.get(aa, 0) for aa in region])
                features[f'region_{i}_hydrophobicity'] = hydro
                
                # Charge
                charge = np.mean([AminoAcidProperties.CHARGE.get(aa, 0) for aa in region])
                features[f'region_{i}_charge'] = charge
        
        return features
    
    @staticmethod
    def extract_position_specific_scoring(sequence: str, window: int = 5) -> Dict[str, float]:
        """Calculate position-specific scoring features using sliding window."""
        features = {}
        
        if len(sequence) < window:
            return {
                'pssm_hydro_mean': 0.0,
                'pssm_hydro_std': 0.0,
                'pssm_charge_mean': 0.0,
                'pssm_charge_std': 0.0
            }
        
        hydro_scores = []
        charge_scores = []
        
        for i in range(len(sequence) - window + 1):
            window_seq = sequence[i:i+window]
            
            hydro = np.mean([AminoAcidProperties.HYDROPHOBICITY.get(aa, 0) for aa in window_seq])
            charge = np.sum([AminoAcidProperties.CHARGE.get(aa, 0) for aa in window_seq])
            
            hydro_scores.append(hydro)
            charge_scores.append(charge)
        
        features['pssm_hydro_mean'] = np.mean(hydro_scores)
        features['pssm_hydro_std'] = np.std(hydro_scores)
        features['pssm_charge_mean'] = np.mean(charge_scores)
        features['pssm_charge_std'] = np.std(charge_scores)
        
        return features

In [None]:
# ============================================================================
# 6. SEQUENCE PATTERNS
# ============================================================================

class SequencePatterns:
    """Extract sequence pattern features."""
    
    @staticmethod
    def extract_repeats(sequence: str, min_length: int = 3, max_length: int = 6) -> Dict[str, float]:
        """Detect repeating patterns."""
        features = {}
        
        for length in range(min_length, max_length + 1):
            repeat_count = 0
            seen_patterns = set()
            
            for i in range(len(sequence) - length):
                pattern = sequence[i:i+length]
                # Check if pattern repeats immediately
                if i + 2*length <= len(sequence):
                    next_pattern = sequence[i+length:i+2*length]
                    if pattern == next_pattern and pattern not in seen_patterns:
                        repeat_count += 1
                        seen_patterns.add(pattern)
            
            features[f'repeats_len_{length}'] = repeat_count
        
        features['total_repeats'] = sum(v for k, v in features.items() if k.startswith('repeats_len'))
        
        return features
    
    @staticmethod
    def extract_motifs(sequence: str) -> Dict[str, int]:
        """Detect common protein motifs (simplified)."""
        motifs = {
            'signal_peptide': 0,  # Simplified: high hydrophobicity in N-term
            'transmembrane': 0,   # Long hydrophobic stretch
            'nuclear_localization': 0,  # K/R rich region
            'glycosylation_site': 0  # N-X-S/T
        }
        
        # Signal peptide (first 30 AA, high hydrophobicity)
        if len(sequence) >= 30:
            n_term = sequence[:30]
            hydro_score = np.mean([AminoAcidProperties.HYDROPHOBICITY.get(aa, 0) for aa in n_term])
            motifs['signal_peptide'] = 1 if hydro_score > 1.5 else 0
        
        # Transmembrane (20+ consecutive hydrophobic AA)
        hydrophobic = ['A', 'V', 'L', 'I', 'M', 'F', 'W', 'P']
        max_hydrophobic_stretch = 0
        current_stretch = 0
        
        for aa in sequence:
            if aa in hydrophobic:
                current_stretch += 1
                max_hydrophobic_stretch = max(max_hydrophobic_stretch, current_stretch)
            else:
                current_stretch = 0
        
        motifs['transmembrane'] = 1 if max_hydrophobic_stretch >= 20 else 0
        
        # Nuclear localization signal (K/R rich)
        for i in range(len(sequence) - 4):
            window = sequence[i:i+5]
            kr_count = sum(1 for aa in window if aa in ['K', 'R'])
            if kr_count >= 4:
                motifs['nuclear_localization'] = 1
                break
        
        # N-glycosylation site (N-X-S/T where X is not P)
        for i in range(len(sequence) - 2):
            if sequence[i] == 'N' and sequence[i+1] != 'P' and sequence[i+2] in ['S', 'T']:
                motifs['glycosylation_site'] += 1
        
        return motifs



In [None]:
# ============================================================================
# 7. MAIN FEATURE EXTRACTOR
# ============================================================================

class ProteinFeatureExtractor:
    """Main class for extracting all protein features."""
    
    def __init__(self, include_dipeptides: bool = True, 
                 include_tripeptides: bool = False,
                 tripeptide_top_n: int = 100):
        """
        Initialize feature extractor.
        
        Args:
            include_dipeptides: Whether to include dipeptide features (400 features)
            include_tripeptides: Whether to include tripeptide features (can be many)
            tripeptide_top_n: Number of top tripeptides to include
        """
        self.include_dipeptides = include_dipeptides
        self.include_tripeptides = include_tripeptides
        self.tripeptide_top_n = tripeptide_top_n
    
    def extract_sequence_features(self, sequence: str) -> Dict[str, float]:
        """
        Extract all features from a single protein sequence.
        
        Args:
            sequence: Protein sequence string
            
        Returns:
            Dictionary of features
        """
        features = {}
        
        # 1. Basic features
        features.update(BasicSequenceFeatures.extract_length_features(sequence))
        features.update(BasicSequenceFeatures.extract_molecular_weight(sequence))
        features.update(BasicSequenceFeatures.extract_composition(sequence))
        features.update(BasicSequenceFeatures.extract_group_composition(sequence))
        
        # 2. Physicochemical properties
        features.update(PhysicochemicalFeatures.extract_all_properties(sequence))
        features.update(PhysicochemicalFeatures.extract_isoelectric_point(sequence))
        
        # 3. K-mers (optional, can be large)
        if self.include_dipeptides:
            features.update(KmerFeatures.extract_dipeptide_composition(sequence))
        
        if self.include_tripeptides:
            features.update(KmerFeatures.extract_tripeptide_composition(
                sequence, top_n=self.tripeptide_top_n))
        
        # 4. Positional features
        features.update(PositionalFeatures.extract_terminal_composition(sequence, n=25))
        features.update(PositionalFeatures.extract_region_properties(sequence, n_regions=5))
        features.update(PositionalFeatures.extract_position_specific_scoring(sequence))
        
        # 5. Sequence patterns
        features.update(SequencePatterns.extract_repeats(sequence))
        features.update(SequencePatterns.extract_motifs(sequence))
        
        return features
    
    def extract_batch_features(self, sequences: Dict[str, str],  
                              verbose: bool = True) -> pd.DataFrame:
        """
        Extract features for multiple sequences.
        
        Args:
            sequences: Dictionary mapping protein_id to sequence
            verbose: Whether to print progress
            
        Returns:
            DataFrame with protein_id as index and features as columns
        """
        if verbose:
            print(f"Extracting features for {len(sequences):,} sequences...")
        
        features_list = []
        protein_ids = []
        
        for i, (protein_id, sequence) in enumerate(sequences.items()):
            if verbose and (i + 1) % 10000 == 0:
                print(f"  Processed {i+1:,} / {len(sequences):,} sequences...")
            
            try:
                features = self.extract_sequence_features(sequence)
                features_list.append(features)
                protein_ids.append(protein_id)
            except Exception as e:
                if verbose:
                    print(f"  Error processing {protein_id}: {str(e)}")
                continue
        
        df = pd.DataFrame(features_list, index=protein_ids)
        
        if verbose:
            print(f"✓ Feature extraction complete!")
            print(f"  Total proteins: {len(df):,}")
            print(f"  Total features: {len(df.columns):,}")
            print(f"  Feature types:")
            self._print_feature_summary(df)
        
        return df
    
    def _print_feature_summary(self, df: pd.DataFrame):
        """Print summary of extracted features."""
        feature_types = {
            'length': sum(1 for c in df.columns if 'length' in c),
            'composition': sum(1 for c in df.columns if 'comp_' in c),
            'group': sum(1 for c in df.columns if 'group_' in c),
            'physicochemical': sum(1 for c in df.columns if any(x in c for x in ['hydro', 'charge', 'volume', 'polar', 'aromatic'])),
            'kmer': sum(1 for c in df.columns if 'kmer_' in c),
            'terminal': sum(1 for c in df.columns if '_term_' in c),
            'regional': sum(1 for c in df.columns if 'region_' in c),
            'pattern': sum(1 for c in df.columns if any(x in c for x in ['repeat', 'signal', 'transmembrane', 'nuclear', 'glycosylation'])),
            'other': 0
        }
        
        feature_types['other'] = len(df.columns) - sum(feature_types.values())
        
        for feat_type, count in feature_types.items():
            if count > 0:
                print(f"    {feat_type.capitalize()}: {count}")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 6)


# ============================================================================
# FEATURE VISUALIZATION CLASS
# ============================================================================

class FeatureVisualizer:
    """Visualize feature engineering results."""
    
    def __init__(self):
        """Initialize feature visualizer."""
        self.feature_types = {
            'length': 'Length Features',
            'comp_': 'Composition Features',
            'group_': 'Group Composition',
            'hydro': 'Hydrophobicity',
            'charge': 'Charge',
            'volume': 'Volume',
            'polar': 'Polarity',
            'aromatic': 'Aromaticity',
            'kmer_2_': 'Dipeptides',
            'kmer_3_': 'Tripeptides',
            'n_term_': 'N-terminal',
            'c_term_': 'C-terminal',
            'region_': 'Regional',
            'pssm_': 'PSSM',
            'repeat': 'Repeats',
            'go_': 'GO Features',
            'esm': 'ESM Embeddings',
            'prot': 'ProtT5 Embeddings'
        }
    
    def categorize_features(self, feature_names: List[str]) -> Dict[str, List[str]]:
        """Categorize features by type."""
        categories = {name: [] for name in self.feature_types.values()}
        categories['Other'] = []
        
        for feature in feature_names:
            categorized = False
            for pattern, category in self.feature_types.items():
                if pattern in feature:
                    categories[category].append(feature)
                    categorized = True
                    break
            
            if not categorized:
                categories['Other'].append(feature)
        
        # Remove empty categories
        return {k: v for k, v in categories.items() if len(v) > 0}
    
    def plot_feature_category_breakdown(self, features_df: pd.DataFrame,
                                       save_path: str = None):
        """
        Plot breakdown of features by category.
        
        Args:
            features_df: DataFrame with extracted features
            save_path: Optional path to save figure
        """
        categories = self.categorize_features(features_df.columns.tolist())
        
        # Count features per category
        category_counts = {k: len(v) for k, v in categories.items()}
        sorted_categories = sorted(category_counts.items(), 
                                  key=lambda x: x[1], reverse=True)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Bar chart
        names, counts = zip(*sorted_categories)
        colors = plt.cm.Set3(np.linspace(0, 1, len(names)))
        
        bars = ax1.bar(range(len(names)), counts, color=colors, 
                      alpha=0.8, edgecolor='black')
        ax1.set_xticks(range(len(names)))
        ax1.set_xticklabels(names, rotation=45, ha='right')
        ax1.set_ylabel('Number of Features', fontsize=12, fontweight='bold')
        ax1.set_title('Feature Count by Category', fontsize=14, fontweight='bold')
        ax1.grid(axis='y', alpha=0.3)
        
        # Add value labels
        for bar, count in zip(bars, counts):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{count}', ha='center', va='bottom', fontweight='bold')
        
        # Pie chart
        ax2.pie(counts, labels=names, autopct='%1.1f%%', 
               colors=colors, startangle=90)
        ax2.set_title('Feature Distribution', fontsize=14, fontweight='bold')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        # Print summary
        print(f"\n{'='*60}")
        print(f"FEATURE CATEGORY BREAKDOWN")
        print(f"{'='*60}")
        print(f"Total features: {len(features_df.columns):,}")
        for name, count in sorted_categories:
            print(f"  {name}: {count:,} ({count/len(features_df.columns)*100:.1f}%)")
        print(f"{'='*60}\n")
    
    def plot_feature_distributions(self, features_df: pd.DataFrame,
                                  sample_features: int = 16,
                                  save_path: str = None):
        """
        Plot distribution of sample features.
        
        Args:
            features_df: DataFrame with features
            sample_features: Number of features to visualize
            save_path: Optional path to save figure
        """
        # Sample features from different categories
        categories = self.categorize_features(features_df.columns.tolist())
        sampled_features = []
        
        features_per_category = max(1, sample_features // len(categories))
        
        for category, feature_list in categories.items():
            n_sample = min(features_per_category, len(feature_list))
            sampled_features.extend(np.random.choice(feature_list, n_sample, replace=False))
        
        sampled_features = sampled_features[:sample_features]
        
        # Create subplots
        n_cols = 4
        n_rows = (len(sampled_features) + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, n_rows * 3))
        axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes
        
        for idx, feature in enumerate(sampled_features):
            ax = axes[idx]
            
            data = features_df[feature].dropna()
            
            # Plot histogram
            ax.hist(data, bins=30, color='steelblue', alpha=0.7, edgecolor='black')
            
            # Add mean line
            mean_val = data.mean()
            ax.axvline(mean_val, color='red', linestyle='--', linewidth=2,
                      label=f'Mean: {mean_val:.3f}')
            
            # Formatting
            ax.set_title(feature[:30] + '...' if len(feature) > 30 else feature,
                        fontsize=9, fontweight='bold')
            ax.set_xlabel('Value', fontsize=8)
            ax.set_ylabel('Frequency', fontsize=8)
            ax.legend(fontsize=7)
            ax.grid(alpha=0.3)
        
        # Hide unused subplots
        for idx in range(len(sampled_features), len(axes)):
            axes[idx].axis('off')
        
        plt.suptitle('Sample Feature Distributions', fontsize=16, fontweight='bold', y=1.00)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def plot_feature_correlation(self, features_df: pd.DataFrame,
                                sample_size: int = 50,
                                method: str = 'pearson',
                                save_path: str = None):
        """
        Plot correlation heatmap for sample of features.
        
        Args:
            features_df: DataFrame with features
            sample_size: Number of features to include in correlation
            method: Correlation method ('pearson', 'spearman')
            save_path: Optional path to save figure
        """
        # Ensure we only have numeric columns
        numeric_features = features_df.select_dtypes(include=[np.number])
        
        if len(numeric_features.columns) == 0:
            print("⚠️  No numeric features found for correlation analysis")
            return
        
        # Sample features
        if len(numeric_features.columns) > sample_size:
            # Sample from different categories
            categories = self.categorize_features(numeric_features.columns.tolist())
            sampled_features = []
            
            features_per_category = max(1, sample_size // len(categories))
            
            for category, feature_list in categories.items():
                n_sample = min(features_per_category, len(feature_list))
                sampled_features.extend(np.random.choice(feature_list, n_sample, replace=False))
            
            sampled_features = sampled_features[:sample_size]
            features_subset = numeric_features[sampled_features]
        else:
            features_subset = numeric_features
        
        # Remove any remaining non-numeric or handle NaN
        features_subset = features_subset.fillna(0)
        
        # Calculate correlation
        corr_matrix = features_subset.corr(method=method)
        
        # Plot
        fig, ax = plt.subplots(figsize=(14, 12))
        
        # Create mask for upper triangle
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)
        
        # Plot heatmap
        sns.heatmap(corr_matrix, mask=mask, cmap='RdBu_r', center=0,
                   square=True, linewidths=0.5, cbar_kws={"shrink": 0.8},
                   vmin=-1, vmax=1, ax=ax)
        
        ax.set_title(f'Feature Correlation Heatmap ({method.capitalize()})\n'
                    f'Sample of {len(features_subset.columns)} features',
                    fontsize=14, fontweight='bold', pad=20)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        # Print high correlations
        print(f"\n{'='*60}")
        print(f"HIGH CORRELATIONS (|r| > 0.8)")
        print(f"{'='*60}")
        
        high_corr = []
        for i in range(len(corr_matrix.columns)):
            for j in range(i+1, len(corr_matrix.columns)):
                if abs(corr_matrix.iloc[i, j]) > 0.8:
                    high_corr.append((
                        corr_matrix.columns[i],
                        corr_matrix.columns[j],
                        corr_matrix.iloc[i, j]
                    ))
        
        if high_corr:
            high_corr.sort(key=lambda x: abs(x[2]), reverse=True)
            for feat1, feat2, corr_val in high_corr[:10]:
                print(f"  {feat1[:30]:<30} ↔ {feat2[:30]:<30}: {corr_val:>6.3f}")
        else:
            print("  No high correlations found in sample")
        print(f"{'='*60}\n")
    
    def plot_feature_statistics(self, features_df: pd.DataFrame,
                               save_path: str = None):
        """
        Plot summary statistics for features.
        
        Args:
            features_df: DataFrame with features
            save_path: Optional path to save figure
        """
        # Select only numeric columns
        numeric_features = features_df.select_dtypes(include=[np.number])
        
        if len(numeric_features.columns) == 0:
            print("⚠️  No numeric features found for statistics")
            return
        
        # Calculate statistics
        stats = pd.DataFrame({
            'mean': numeric_features.mean(),
            'std': numeric_features.std(),
            'min': numeric_features.min(),
            'max': numeric_features.max(),
            'variance': numeric_features.var(),
            'zeros': (numeric_features == 0).sum() / len(numeric_features)
        })
        
        # Create subplots (2x3 = 6 plots)
        fig, axes = plt.subplots(2, 3, figsize=(16, 10))
        axes = axes.flatten()
        
        # Plot each statistic (only first 6)
        for idx, (stat_name, stat_data) in enumerate(stats.items()):
            if idx >= 6:  # Only plot first 6 statistics
                break
                
            ax = axes[idx]
            
            # Remove outliers for better visualization
            q1, q3 = stat_data.quantile([0.25, 0.75])
            iqr = q3 - q1
            lower = q1 - 1.5 * iqr
            upper = q3 + 1.5 * iqr
            filtered_data = stat_data[(stat_data >= lower) & (stat_data <= upper)]
            
            if len(filtered_data) > 0:
                ax.hist(filtered_data, bins=30, color='steelblue', 
                       alpha=0.7, edgecolor='black')
                ax.axvline(stat_data.median(), color='red', linestyle='--',
                          linewidth=2, label=f'Median: {stat_data.median():.3f}')
            
            ax.set_title(f'Feature {stat_name.capitalize()}', 
                        fontsize=12, fontweight='bold')
            ax.set_xlabel('Value', fontsize=10)
            ax.set_ylabel('Number of Features', fontsize=10)
            ax.legend(fontsize=9)
            ax.grid(alpha=0.3)
        
        plt.suptitle('Feature Statistics Distribution', 
                    fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        # Calculate missing values separately for summary
        missing_stats = numeric_features.isnull().sum() / len(numeric_features)
        
        # Print summary
        print(f"\n{'='*60}")
        print(f"FEATURE STATISTICS SUMMARY")
        print(f"{'='*60}")
        print(f"Total features: {len(numeric_features.columns):,}")
        print(f"Total proteins: {len(numeric_features):,}")
        print(f"\nFeature Statistics:")
        print(f"  Mean of means: {stats['mean'].mean():.4f}")
        print(f"  Mean of std devs: {stats['std'].mean():.4f}")
        print(f"  Features with >50% zeros: {(stats['zeros'] > 0.5).sum()}")
        print(f"  Features with missing values: {(missing_stats > 0).sum()}")
        
        # Low variance features
        low_var = (stats['variance'] < 0.01).sum()
        print(f"  Low variance features (<0.01): {low_var}")
        print(f"{'='*60}\n")
    
    def plot_dimensionality_reduction(self, features_df: pd.DataFrame,
                                     labels: Optional[pd.Series] = None,
                                     method: str = 'pca',
                                     n_samples: int = 5000,
                                     save_path: str = None):
        """
        Visualize features in 2D using dimensionality reduction.
        
        Args:
            features_df: DataFrame with features
            labels: Optional labels for coloring (e.g., number of GO terms)
            method: 'pca' or 'tsne'
            n_samples: Number of samples to plot (for performance)
            save_path: Optional path to save figure
        """
        # Select only numeric columns
        numeric_features = features_df.select_dtypes(include=[np.number])
        
        if len(numeric_features.columns) == 0:
            print("⚠️  No numeric features found for dimensionality reduction")
            return
        
        # Sample if too many proteins
        if len(numeric_features) > n_samples:
            sample_idx = np.random.choice(len(numeric_features), n_samples, replace=False)
            features_sample = numeric_features.iloc[sample_idx]
            labels_sample = labels.iloc[sample_idx] if labels is not None else None
        else:
            features_sample = numeric_features
            labels_sample = labels
        
        # Handle missing values
        features_clean = features_sample.fillna(0)
        
        # Check for infinite values
        features_clean = features_clean.replace([np.inf, -np.inf], 0)
        
        # Apply dimensionality reduction
        if method.lower() == 'pca':
            reducer = PCA(n_components=2, random_state=42)
            reduced = reducer.fit_transform(features_clean)
            title = f'PCA Visualization of Features\n' \
                   f'Explained Variance: {reducer.explained_variance_ratio_.sum():.2%}'
        else:  # tsne
            reducer = TSNE(n_components=2, random_state=42, perplexity=30)
            reduced = reducer.fit_transform(features_clean)
            title = 't-SNE Visualization of Features'
        
        # Plot
        fig, ax = plt.subplots(figsize=(12, 10))
        
        if labels_sample is not None:
            scatter = ax.scatter(reduced[:, 0], reduced[:, 1],
                               c=labels_sample, cmap='viridis',
                               alpha=0.6, s=30, edgecolors='black', linewidth=0.5)
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('Label Value', rotation=270, labelpad=20, fontsize=11)
        else:
            ax.scatter(reduced[:, 0], reduced[:, 1],
                      alpha=0.6, s=30, color='steelblue',
                      edgecolors='black', linewidth=0.5)
        
        ax.set_xlabel(f'Component 1', fontsize=12, fontweight='bold')
        ax.set_ylabel(f'Component 2', fontsize=12, fontweight='bold')
        ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
        ax.grid(alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        if method.lower() == 'pca':
            print(f"\n{'='*60}")
            print(f"PCA ANALYSIS")
            print(f"{'='*60}")
            print(f"Explained variance ratio:")
            print(f"  PC1: {reducer.explained_variance_ratio_[0]:.4f}")
            print(f"  PC2: {reducer.explained_variance_ratio_[1]:.4f}")
            print(f"  Total: {reducer.explained_variance_ratio_.sum():.4f}")
            print(f"{'='*60}\n")
    
    def plot_feature_importance(self, features_df: pd.DataFrame,
                               importance_values: np.ndarray,
                               top_n: int = 20,
                               save_path: str = None):
        """
        Plot feature importance from a trained model.
        
        Args:
            features_df: DataFrame with features
            importance_values: Array of importance values from model
            top_n: Number of top features to show
            save_path: Optional path to save figure
        """
        # Create importance DataFrame
        importance_df = pd.DataFrame({
            'feature': features_df.columns,
            'importance': importance_values
        }).sort_values('importance', ascending=False)
        
        # Get top features
        top_features = importance_df.head(top_n)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        # Horizontal bar chart
        colors = plt.cm.viridis(np.linspace(0, 1, len(top_features)))
        
        y_pos = np.arange(len(top_features))
        ax1.barh(y_pos, top_features['importance'], color=colors,
                alpha=0.8, edgecolor='black')
        ax1.set_yticks(y_pos)
        ax1.set_yticklabels(top_features['feature'], fontsize=9)
        ax1.set_xlabel('Importance', fontsize=12, fontweight='bold')
        ax1.set_title(f'Top {top_n} Most Important Features',
                     fontsize=14, fontweight='bold')
        ax1.invert_yaxis()
        ax1.grid(axis='x', alpha=0.3)
        
        # Cumulative importance
        ax2.plot(range(1, len(importance_df) + 1),
                np.cumsum(importance_df['importance']) / importance_df['importance'].sum(),
                linewidth=2, color='steelblue')
        ax2.axhline(0.8, color='red', linestyle='--', linewidth=2,
                   label='80% cumulative importance')
        ax2.axhline(0.95, color='orange', linestyle='--', linewidth=2,
                   label='95% cumulative importance')
        
        ax2.set_xlabel('Number of Features', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Cumulative Importance', fontsize=12, fontweight='bold')
        ax2.set_title('Cumulative Feature Importance',
                     fontsize=14, fontweight='bold')
        ax2.legend(fontsize=10)
        ax2.grid(alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        # Print summary
        n_80 = np.where(np.cumsum(importance_df['importance']) / 
                       importance_df['importance'].sum() >= 0.8)[0][0] + 1
        n_95 = np.where(np.cumsum(importance_df['importance']) / 
                       importance_df['importance'].sum() >= 0.95)[0][0] + 1
        
        print(f"\n{'='*60}")
        print(f"FEATURE IMPORTANCE ANALYSIS")
        print(f"{'='*60}")
        print(f"Total features: {len(features_df.columns):,}")
        print(f"Features for 80% importance: {n_80:,}")
        print(f"Features for 95% importance: {n_95:,}")
        print(f"\nTop 10 Features:")
        for i, row in importance_df.head(10).iterrows():
            print(f"  {row['feature'][:50]:<50}: {row['importance']:.6f}")
        print(f"{'='*60}\n")
    
    def plot_all_feature_analysis(self, features_df: pd.DataFrame,
                                  labels: Optional[pd.Series] = None,
                                  importance_values: Optional[np.ndarray] = None,
                                  save_dir: str = None):
        """
        Generate all feature engineering visualizations.
        
        Args:
            features_df: DataFrame with all features
            labels: Optional labels for dimensionality reduction
            importance_values: Optional feature importance from model
            save_dir: Optional directory to save all figures
        """
        import os
        
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
        
        print("\n" + "="*70)
        print(" "*15 + "FEATURE ENGINEERING VISUALIZATIONS")
        print("="*70 + "\n")
        
        # 1. Feature category breakdown
        print("1️⃣  Generating feature category breakdown...")
        self.plot_feature_category_breakdown(
            features_df,
            save_path=f"{save_dir}/01_feature_categories.png" if save_dir else None
        )
        
        # 2. Feature distributions
        print("\n2️⃣  Generating feature distributions...")
        self.plot_feature_distributions(
            features_df,
            sample_features=16,
            save_path=f"{save_dir}/02_feature_distributions.png" if save_dir else None
        )
        
        # 3. Feature correlation
        print("\n3️⃣  Generating correlation heatmap...")
        self.plot_feature_correlation(
            features_df,
            sample_size=50,
            save_path=f"{save_dir}/03_feature_correlation.png" if save_dir else None
        )
        
        # 4. Feature statistics
        print("\n4️⃣  Generating feature statistics...")
        self.plot_feature_statistics(
            features_df,
            save_path=f"{save_dir}/04_feature_statistics.png" if save_dir else None
        )
        
        # 5. Dimensionality reduction
        print("\n5️⃣  Generating dimensionality reduction visualization...")
        self.plot_dimensionality_reduction(
            features_df,
            labels=labels,
            method='pca',
            save_path=f"{save_dir}/05_pca_visualization.png" if save_dir else None
        )
        
        # 6. Feature importance (if provided)
        if importance_values is not None:
            print("\n6️⃣  Generating feature importance analysis...")
            self.plot_feature_importance(
                features_df,
                importance_values,
                save_path=f"{save_dir}/06_feature_importance.png" if save_dir else None
            )
        
        print("\n" + "="*70)
        print(" "*20 + "✓ ALL VISUALIZATIONS COMPLETE")
        if save_dir:
            print(f" "*20 + f"Saved to: {save_dir}/")
        print("="*70 + "\n")

In [None]:
# ============================================================================
# 8. USAGE
# ============================================================================

if __name__ == "__main__":    
    # Load sequences
    sequences = {}
    for record in SeqIO.parse('/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta', 'fasta'):
        protein_id = record.id.split('|')[1]
        sequences[protein_id] = str(record.seq)
    
    # Extract features
    extractor = ProteinFeatureExtractor(
        include_dipeptides=True,
        include_tripeptides=False
    )
    
    features_df = extractor.extract_batch_features(sequences, verbose=True)
    
    # Save features
    features_df.to_csv('/kaggle/working/protein_features.csv')
    
   
    # # Extract features for single sequence
    # single_features = extractor.extract_sequence_features('MTKPTQVLVRLEQVM')
    # print(f"Extracted {len(single_features)} features")
    
    

In [None]:
# Load features
features_df = pd.read_csv('/kaggle/working/protein_features.csv', index_col=0)

# Initialize visualizer
viz = FeatureVisualizer()

# Generate all visualizations
viz.plot_all_feature_analysis(
    features_df,
    labels=None,  # or provide labels for coloring
    importance_values=None,  # or provide from trained model
    save_dir='/kaggle/working/feature_visualizations'
)