In [None]:
"""
K50 Protein Stability Dataset Processing Pipeline

This module processes the K50 dG dataset to extract and validate single-point 
mutations for protein stability prediction tasks. The pipeline performs:
1. Data loading and feature selection
2. Filtering for single-point mutations
3. Mutation annotation parsing
4. Wild-type sequence reconstruction
5. Mutation validation
6. Stability classification
7. Dataset export

Author: ML Research Team
Date: 2025
"""

import re
import warnings
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import pandas as pd

# Suppress pandas warnings for cleaner output
warnings.filterwarnings("ignore")


# =============================================================================
# CONFIGURATION
# =============================================================================

# File paths
INPUT_PATH = Path(r'D:\ML_Project\K50 data\Processed_K50_dG_datasets\K50_dG_Dataset1_Dataset2.csv')
OUTPUT_PATH = Path('k50_cleaned.csv')

# Regex pattern to parse mutation strings (e.g., "A123G" -> wild-type A, position 123, mutant G)
MUTATION_PATTERN = re.compile(r'([A-Za-z]+)(\d+)([A-Za-z]+)')

# Column names to retain from the raw dataset
FEATURES_TO_KEEP = [
    'name',           # Mutation/protein identifier
    'aa_seq',         # Amino acid sequence (mutant)
    'ddG_ML',         # Predicted change in Gibbs free energy
    'mut_type',       # Mutation annotation (e.g., "A123G")
    'WT_name',        # Wild-type protein identifier
    'WT_cluster'      # Protein family/cluster identifier
]

# Stability classification threshold
# Mutations with ddG < 0 are stabilizing (class 0)
# Mutations with ddG >= 0 are destabilizing (class 1)
STABILITY_THRESHOLD = 0.0


# =============================================================================
# DATA LOADING
# =============================================================================

def load_k50_dataset(filepath: Path) -> pd.DataFrame:
    """
    Load the K50 dataset from a CSV file with validation.
    
    This function reads the complete K50 dataset and performs basic validation
    to ensure the file exists and contains data.
    
    Args:
        filepath: Path object pointing to the K50 dataset CSV file
        
    Returns:
        pd.DataFrame: Complete dataset with all original columns
        
    Raises:
        FileNotFoundError: If the specified file does not exist
        ValueError: If the CSV file is empty after loading
        
    Example:
        >>> df = load_k50_dataset(Path('data/k50_dataset.csv'))
        Loaded dataset shape: (851552, 38)
    """
    # Check if file exists before attempting to read
    if not filepath.exists():
        raise FileNotFoundError(f"Dataset not found at: {filepath}")
    
    # Load CSV into DataFrame
    df = pd.read_csv(filepath)
    
    # Validate that we actually loaded data
    if df.empty:
        raise ValueError("Loaded dataset is empty")
    
    # Print diagnostic information
    print(f"Loaded dataset shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")
    
    return df


# =============================================================================
# FEATURE SELECTION AND FILTERING
# =============================================================================

def select_features(df: pd.DataFrame, features: list) -> pd.DataFrame:
    """
    Extract a subset of columns from the dataset.
    
    This function selects only the columns needed for downstream processing,
    reducing memory usage and simplifying the dataset.
    
    Args:
        df: Input DataFrame with all columns
        features: List of column names to retain
        
    Returns:
        pd.DataFrame: Copy of the input with only selected columns
        
    Raises:
        KeyError: If any requested feature is not present in the DataFrame
        
    Example:
        >>> features = ['name', 'aa_seq', 'ddG_ML']
        >>> df_subset = select_features(df_raw, features)
    """
    # Check that all requested features exist in the DataFrame
    missing_features = set(features) - set(df.columns)
    if missing_features:
        raise KeyError(f"Missing features in dataset: {missing_features}")
    
    # Return a copy to avoid modifying the original DataFrame
    return df[features].copy()


def filter_single_point_mutations(df: pd.DataFrame) -> pd.DataFrame:
    """
    Filter dataset to retain only valid single-point mutations.
    
    This function removes:
    1. Wild-type sequences (mut_type == 'wt' or 'WT')
    2. Multiple mutations (indicated by ':' in mutation string)
    3. Insertions and deletions (starting with 'ins' or 'del')
    
    Why filter these?
    - Wild-type sequences have no mutation (ddG should be 0)
    - Multiple mutations are complex and harder to model
    - Indels require different processing than substitutions
    
    Args:
        df: Input DataFrame with 'mut_type' column
        
    Returns:
        pd.DataFrame: Filtered DataFrame containing only single-point substitutions
        
    Example:
        >>> df_filtered = filter_single_point_mutations(df)
        Original records: 851552
        After filtering: 375560
        Removed: 475992 records
    """
    # Create a copy to avoid modifying the input
    df = df.copy()
    
    # Ensure mut_type is a clean string (remove leading/trailing whitespace)
    df['mut_type'] = df['mut_type'].astype(str).str.strip()
    
    # Create boolean masks for filtering
    # Mask 1: Exclude wild-type sequences (case-insensitive)
    not_wt = ~df['mut_type'].str.lower().eq('wt')
    
    # Mask 2: Exclude multiple mutations (contain ':' separator)
    # Example of multiple mutation: "A123G:T456A"
    not_multiple = ~df['mut_type'].str.contains(':', na=False)
    
    # Mask 3: Exclude insertions and deletions
    # These start with 'ins' or 'del' in the mutation string
    not_indel = ~df['mut_type'].str.startswith(('del', 'ins'), na=False)
    
    # Combine all masks with AND logic
    mask = not_wt & not_multiple & not_indel
    
    # Apply the combined filter
    filtered_df = df[mask].copy()
    
    # Report filtering statistics
    print(f"Original records: {len(df)}")
    print(f"After filtering: {len(filtered_df)}")
    print(f"Removed: {len(df) - len(filtered_df)} records")
    
    return filtered_df


# =============================================================================
# MUTATION PARSING
# =============================================================================

def parse_mutation(mutation_str: str) -> Optional[Tuple[str, int, str]]:
    """
    Parse a mutation string into its constituent components.
    
    Mutation strings follow the format: <WT_residue><position><Mutant_residue>
    - WT_residue: Original amino acid (1 or 3 letter code)
    - position: 1-indexed position in the sequence
    - Mutant_residue: Mutated amino acid (1 or 3 letter code)
    
    Args:
        mutation_str: Mutation annotation string (e.g., "A123G" or "Ala123Gly")
        
    Returns:
        Tuple of (wild_type_residue, position, mutant_residue) if parsing succeeds
        None if the string doesn't match the expected format
        
    Example:
        >>> parse_mutation("A123G")
        ('A', 123, 'G')
        >>> parse_mutation("Ala123Gly")
        ('Ala', 123, 'Gly')
        >>> parse_mutation("invalid")
        None
    """
    # Attempt to match the mutation pattern
    match = MUTATION_PATTERN.match(mutation_str)
    
    if match:
        # Extract the three components from the regex groups
        wt_residue, position, mut_residue = match.groups()
        
        # Convert position to integer for numerical operations
        return wt_residue, int(position), mut_residue
    
    # Return None if parsing fails
    return None


def annotate_mutations(df: pd.DataFrame) -> pd.DataFrame:
    """
    Parse mutation strings and add structured annotation columns.
    
    This function decomposes the 'mut_type' string into three separate columns
    for easier downstream processing and validation.
    
    Added columns:
    - wt_residue: Wild-type amino acid
    - mutation_position: 1-indexed position in the sequence
    - mutated_residue: Mutant amino acid
    
    Records that cannot be parsed are removed from the output.
    
    Args:
        df: DataFrame with 'mut_type' column
        
    Returns:
        pd.DataFrame: DataFrame with added mutation annotation columns
        
    Example:
        >>> df_annotated = annotate_mutations(df_filtered)
        Successfully parsed 375560 mutations
    """
    df = df.copy()
    
    # Apply parsing function to each mutation string
    parsed = df['mut_type'].apply(parse_mutation)
    
    # Remove records where parsing failed (returned None)
    df = df[parsed.notna()].copy()
    parsed = parsed[parsed.notna()]
    
    # Extract individual components from the tuple
    # Lambda functions are used to access tuple elements by index
    df['wt_residue'] = parsed.apply(lambda x: x[0])           # Wild-type amino acid
    df['mutation_position'] = parsed.apply(lambda x: x[1])    # Position (int)
    df['mutated_residue'] = parsed.apply(lambda x: x[2])      # Mutant amino acid
    
    # Report how many mutations were successfully parsed
    print(f"Successfully parsed {len(df)} mutations")
    
    return df


# =============================================================================
# WILD-TYPE SEQUENCE RECONSTRUCTION
# =============================================================================

def reconstruct_wt_sequence(mutant_seq: str, mutation_str: str) -> Optional[str]:
    """
    Reconstruct the wild-type sequence from mutant sequence and mutation annotation.
    
    This is essentially "reversing" the mutation:
    - We know the mutant sequence
    - We know what position was mutated and what it was changed to
    - We can substitute back the wild-type residue to get the original sequence
    
    Validation checks:
    1. Position must be within sequence bounds (not negative, not beyond sequence length)
    2. The mutant sequence must have the expected mutant residue at the position
    
    Args:
        mutant_seq: Amino acid sequence of the mutant protein
        mutation_str: Mutation annotation (e.g., "A123G")
        
    Returns:
        Reconstructed wild-type sequence string, or None if reconstruction fails
        
    Example:
        >>> mutant = "MKGLWSKSSVIG"
        >>> mutation = "K2G"  # Position 2 changed from K to G
        >>> reconstruct_wt_sequence(mutant, mutation)
        'MKKLWSKSSVIG'  # G at position 2 changed back to K
    """
    # Parse the mutation string
    parsed = MUTATION_PATTERN.match(mutation_str)
    if not parsed:
        return None
    
    # Extract components
    wt_residue, position, mut_residue = parsed.groups()
    
    # Convert from 1-indexed (biological convention) to 0-indexed (Python)
    position_idx = int(position) - 1
    
    # Validate that position is within sequence bounds
    if position_idx < 0 or position_idx >= len(mutant_seq):
        return None
    
    # Validate that the mutant sequence has the expected mutant residue
    # This is a consistency check - if it fails, the annotation is wrong
    if mutant_seq[position_idx] != mut_residue:
        return None
    
    # Reconstruct WT sequence by substituting the wild-type residue
    # String slicing: [before mutation] + [WT residue] + [after mutation]
    wt_seq = mutant_seq[:position_idx] + wt_residue + mutant_seq[position_idx + 1:]
    
    return wt_seq


def add_wt_sequences(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add reconstructed wild-type sequences to the DataFrame.
    
    This function applies the reconstruction logic to every row in the dataset,
    creating a new column 'WT_sequence' that contains the original (pre-mutation)
    sequence for each protein.
    
    Args:
        df: DataFrame with 'aa_seq' (mutant) and 'mut_type' columns
        
    Returns:
        pd.DataFrame: DataFrame with added 'WT_sequence' column
        
    Example:
        >>> df_with_wt = add_wt_sequences(df_annotated)
        Successfully reconstructed 375560/375560 WT sequences
    """
    df = df.copy()
    
    # Apply reconstruction function to each row
    # axis=1 means apply function across columns (row-wise)
    df['WT_sequence'] = df.apply(
        lambda row: reconstruct_wt_sequence(row['aa_seq'], row['mut_type']),
        axis=1
    )
    
    # Count how many sequences were successfully reconstructed
    num_reconstructed = df['WT_sequence'].notna().sum()
    print(f"Successfully reconstructed {num_reconstructed}/{len(df)} WT sequences")
    
    return df


# =============================================================================
# MUTATION VALIDATION
# =============================================================================

def validate_mutation(mutant_seq: str, wt_seq: str, mutation_str: str) -> bool:
    """
    Validate that mutation annotation is consistent with both sequences.
    
    This is a critical quality control step. We verify three things:
    1. Position is within bounds (not out of range)
    2. Mutant sequence has the mutant residue at the specified position
    3. Wild-type sequence has the WT residue at the specified position
    
    Why is this important?
    - Catches data entry errors
    - Catches alignment issues
    - Ensures our reconstruction was correct
    - Prevents training ML models on corrupted data
    
    Args:
        mutant_seq: Mutant amino acid sequence
        wt_seq: Wild-type amino acid sequence
        mutation_str: Mutation annotation (e.g., "A123G")
        
    Returns:
        True if all validation checks pass, False otherwise
        
    Example:
        >>> wt = "MKKLWSKSSVIG"
        >>> mut = "MKGLWSKSSVIG"
        >>> validate_mutation(mut, wt, "K2G")
        True
        >>> validate_mutation(mut, wt, "K2A")  # Wrong mutant residue
        False
    """
    # Parse the mutation string
    parsed = MUTATION_PATTERN.match(mutation_str)
    if not parsed:
        return False
    
    wt_residue, position, mut_residue = parsed.groups()
    position_idx = int(position) - 1  # Convert to 0-indexed
    
    # Check 1: Position must be within sequence bounds
    if position_idx < 0 or position_idx >= len(mutant_seq):
        return False
    
    # Check 2: Mutant sequence must have the mutant residue at this position
    if mutant_seq[position_idx] != mut_residue:
        return False
    
    # Check 3: WT sequence must have the WT residue at this position
    if wt_seq[position_idx] != wt_residue:
        return False
    
    # All checks passed
    return True


def add_validation_flags(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add validation flag indicating whether each mutation is consistent.
    
    This function applies the validation logic to every row, creating a boolean
    column that indicates whether the mutation annotation is trustworthy.
    
    Rows with is_valid_mutation=False should be investigated or removed.
    
    Args:
        df: DataFrame with mutation data (must have aa_seq, WT_sequence, mut_type)
        
    Returns:
        pd.DataFrame: DataFrame with added 'is_valid_mutation' boolean column
        
    Example:
        >>> df_validated = add_validation_flags(df_with_wt)
        Validation results:
          Valid mutations: 375560
          Invalid mutations: 0
    """
    df = df.copy()
    
    # Apply validation function to each row
    df['is_valid_mutation'] = df.apply(
        lambda row: validate_mutation(row['aa_seq'], row['WT_sequence'], row['mut_type']),
        axis=1
    )
    
    # Count validation results
    num_valid = df['is_valid_mutation'].sum()
    num_invalid = (~df['is_valid_mutation']).sum()
    
    # Report statistics
    print(f"Validation results:")
    print(f"  Valid mutations: {num_valid}")
    print(f"  Invalid mutations: {num_invalid}")
    
    return df


# =============================================================================
# STABILITY CLASSIFICATION
# =============================================================================

def classify_stability(df: pd.DataFrame, threshold: float = 0.0) -> pd.DataFrame:
    """
    Classify mutations based on their effect on protein stability.
    
    Classification scheme (binary):
    - Class 0 (Stabilizing): ddG < threshold
      These mutations make the protein MORE stable
    - Class 1 (Destabilizing): ddG >= threshold
      These mutations make the protein LESS stable or unchanged
    
    Why use ddG?
    - ddG (delta-delta-G) represents the change in Gibbs free energy
    - Negative ddG = mutation lowers energy = more stable folded state
    - Positive ddG = mutation raises energy = less stable folded state
    
    Args:
        df: DataFrame with 'ddG_ML' column (predicted ddG values)
        threshold: ddG threshold for classification (default: 0.0)
        
    Returns:
        pd.DataFrame: DataFrame with added 'ddG' and 'stability_class' columns
        
    Example:
        >>> df_classified = classify_stability(df, threshold=0.0)
        Stability classification (threshold=0.0):
          Class 0 (Stabilizing, ddG < 0.0): 187234
          Class 1 (Destabilizing, ddG >= 0.0): 188326
    """
    df = df.copy()
    
    # Rename ddG_ML to ddG for clarity
    # (ddG_ML indicates this is a machine learning prediction)
    df['ddG'] = df['ddG_ML']
    
    # Binary classification: True (1) if destabilizing, False (0) if stabilizing
    # The .astype(int) converts boolean to integer (True->1, False->0)
    df['stability_class'] = (df['ddG'] >= threshold).astype(int)
    
    # Report class distribution for quality control
    class_counts = df['stability_class'].value_counts().sort_index()
    print(f"Stability classification (threshold={threshold}):")
    print(f"  Class 0 (Stabilizing, ddG < {threshold}): {class_counts.get(0, 0)}")
    print(f"  Class 1 (Destabilizing, ddG >= {threshold}): {class_counts.get(1, 0)}")
    
    return df


# =============================================================================
# FINAL DATA PREPARATION
# =============================================================================

def prepare_final_dataset(df: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare final cleaned dataset with standardized column names and ordering.
    
    This function:
    1. Renames columns to consistent, clear names
    2. Reorders columns in a logical sequence
    3. Validates that all required columns exist
    
    Column naming conventions:
    - wt_* : Wild-type related
    - mut_* : Mutant related
    - *_seq : Sequence data
    - *_res : Residue (amino acid)
    
    Args:
        df: Input DataFrame with all processed columns
        
    Returns:
        pd.DataFrame: DataFrame with standardized column names and ordering
        
    Raises:
        ValueError: If any required column is missing
        
    Example:
        >>> df_final = prepare_final_dataset(df_classified)
        Final dataset shape: (375560, 12)
    """
    df = df.copy()
    
    # Define column renaming mapping for clarity and consistency
    column_mapping = {
        'aa_seq': 'mut_seq',              # More explicit: this is the MUTANT sequence
        'WT_sequence': 'wt_seq',          # Consistent naming: wt_seq
        'WT_cluster': 'wt_cluster',       # Lowercase for consistency
        'WT_name': 'wt_name',             # Lowercase for consistency
        'mutation_position': 'pos',       # Shorter, clearer
        'mutated_residue': 'mut_res',     # Consistent with mut_seq naming
        'wt_residue': 'wt_res'            # Consistent with wt_seq naming
    }
    df = df.rename(columns=column_mapping)
    
    # Define final column order (logical grouping)
    # Order: identifiers -> sequences -> mutation details -> predictions -> flags
    final_columns = [
        'name',                 # Mutation identifier
        'wt_name',              # Wild-type protein identifier
        'wt_cluster',           # Protein family/cluster
        'wt_seq',               # Wild-type sequence
        'mut_seq',              # Mutant sequence
        'mut_type',             # Mutation annotation
        'wt_res',               # Wild-type residue
        'pos',                  # Mutation position
        'mut_res',              # Mutant residue
        'ddG',                  # Stability change prediction
        'stability_class',      # Binary stability classification
        'is_valid_mutation'     # Validation flag
    ]
    
    # Validate that all required columns exist
    missing_cols = set(final_columns) - set(df.columns)
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    # Select and reorder columns
    df = df[final_columns]
    
    # Report final dataset information
    print(f"Final dataset shape: {df.shape}")
    print(f"Final columns: {df.columns.tolist()}")
    
    return df


def export_dataset(df: pd.DataFrame, output_path: Path) -> None:
    """
    Export cleaned dataset to CSV file.
    
    This function:
    1. Creates output directory if it doesn't exist
    2. Writes DataFrame to CSV without row indices
    3. Reports export statistics
    
    Args:
        df: DataFrame to export
        output_path: Path object for output CSV file
        
    Example:
        >>> export_dataset(df_final, Path('k50_cleaned.csv'))
        Dataset exported to: k50_cleaned.csv
        Total records: 375560
    """
    # Create parent directories if they don't exist
    # parents=True: create intermediate directories
    # exist_ok=True: don't raise error if directory already exists
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Write to CSV
    # index=False: don't write row numbers to file
    df.to_csv(output_path, index=False)
    
    # Report export success
    print(f"Dataset exported to: {output_path}")
    print(f"Total records: {len(df)}")


# =============================================================================
# SUMMARY STATISTICS
# =============================================================================

def print_summary_statistics(df: pd.DataFrame) -> None:
    """
    Print comprehensive summary statistics for the final dataset.
    
    This function provides a quick overview of:
    - Dataset size and diversity
    - Stability class distribution
    - ddG statistics (mean, std, min, max, etc.)
    - Mutation position distribution
    - Sequence length statistics
    
    This is useful for:
    - Quality control
    - Understanding data distribution
    - Detecting anomalies
    - Documenting dataset characteristics
    
    Args:
        df: Final processed DataFrame
        
    Example:
        >>> print_summary_statistics(df_final)
        ============================================================
        DATASET SUMMARY STATISTICS
        ============================================================
        
        Total mutations: 375560
        Unique proteins (WT): 145
        ...
    """
    print("="*60)
    print("DATASET SUMMARY STATISTICS")
    print("="*60)
    
    # Basic counts
    print(f"\nTotal mutations: {len(df)}")
    print(f"Unique proteins (WT): {df['wt_name'].nunique()}")
    print(f"Unique clusters: {df['wt_cluster'].nunique()}")
    
    # Stability distribution
    print(f"\nStability distribution:")
    print(df['stability_class'].value_counts().sort_index())
    
    # ddG statistics (central tendency, spread, range)
    print(f"\nddG statistics:")
    print(df['ddG'].describe())
    
    # Mutation position statistics
    print(f"\nMutation position range:")
    print(f"  Min: {df['pos'].min()}")
    print(f"  Max: {df['pos'].max()}")
    print(f"  Mean: {df['pos'].mean():.2f}")
    
    # Sequence length statistics
    print(f"\nSequence length statistics:")
    seq_lengths = df['mut_seq'].str.len()
    print(f"  Min: {seq_lengths.min()}")
    print(f"  Max: {seq_lengths.max()}")
    print(f"  Mean: {seq_lengths.mean():.2f}")
    
    print("\n" + "="*60)


# =============================================================================
# MAIN EXECUTION
# =============================================================================

def main():
    """
    Execute the complete data processing pipeline.
    
    Pipeline steps:
    1. Load raw K50 dataset
    2. Select relevant features
    3. Filter for single-point mutations
    4. Parse mutation annotations
    5. Reconstruct wild-type sequences
    6. Validate mutations
    7. Classify stability effects
    8. Prepare final dataset
    9. Export to CSV
    10. Print summary statistics
    """
    print("Starting K50 data processing pipeline...\n")
    
    # Step 1: Load data
    print("Step 1: Loading raw dataset...")
    df_raw = load_k50_dataset(INPUT_PATH)
    print()
    
    # Step 2: Feature selection
    print("Step 2: Selecting features...")
    df_selected = select_features(df_raw, FEATURES_TO_KEEP)
    print()
    
    # Step 3: Filter mutations
    print("Step 3: Filtering for single-point mutations...")
    df_filtered = filter_single_point_mutations(df_selected)
    print()
    
    # Step 4: Parse mutations
    print("Step 4: Parsing mutation annotations...")
    df_annotated = annotate_mutations(df_filtered)
    print()
    
    # Step 5: Reconstruct WT sequences
    print("Step 5: Reconstructing wild-type sequences...")
    df_with_wt = add_wt_sequences(df_annotated)
    print()
    
    # Step 6: Validate mutations
    print("Step 6: Validating mutations...")
    df_validated = add_validation_flags(df_with_wt)
    print()
    
    # Step 7: Classify stability
    print("Step 7: Classifying stability effects...")
    df_classified = classify_stability(df_validated, threshold=STABILITY_THRESHOLD)
    print()
    
    # Step 8: Prepare final dataset
    print("Step 8: Preparing final dataset...")
    df_final = prepare_final_dataset(df_classified)
    print()
    
    # Step 9: Export
    print("Step 9: Exporting cleaned dataset...")
    export_dataset(df_final, OUTPUT_PATH)
    print()
    
    # Step 10: Summary
    print("Step 10: Generating summary statistics...")
    print_summary_statistics(df_final)
    print()
    
    print("Pipeline completed successfully!")


if __name__ == "__main__":
    main()