In [None]:
# ============================================================================
# CELL 1: CONFIGURATION
# ============================================================================
# This cell contains all the settings a user needs to change.

# --- PATHS ---
# TODO: Update these paths to match your file locations.
# Use relative paths from the notebook's location for best practice.
INPUT_FILE_PATH = "data/All_TM2.xlsx"  # Example: "data/my_data.xlsx"
OUTPUT_DIR = "output/MissingDataAnalysis" # Example: "results/"

# --- PARAMETERS ---
# Define what represents missing values in your dataset
MISSING_VALUE_DEFINITION = 0.001

# FILTERING THRESHOLDS
DETECTION_THRESHOLDS = {
    'strict': 80,
    'moderate': 50,
    'lenient': 20
}

# --- METADATA ---
# Define your experimental conditions and labels
TREATMENT_SAMPLES = [
    "TM2A1_1", "TM2A1_2", "TM2A1_3", "TM2A1_4", "TM2A1_5", "TM2A2_1", "TM2A2_2", 
    "TM2A2_3", "TM2A2_4", "TM2A2_5", "TM2A3_1", "TM2A3_2", "TM2A3_3", "TM2A3_4", 
    "TM2A3_5", "TM2A4_1", "TM2A4_2", "TM2A4_3", "TM2A4_4", "TM2A4_5", "TM2A5_1", 
    "TM2A5_2", "TM2A5_3", "TM2A5_4", "TM2A5_5"
]
CONTROL_SAMPLES = [
    "TM2An1_1", "TM2An1_2", "TM2An1_3", "TM2An1_4", "TM2An1_5", "TM2An2_1", "TM2An2_2", 
    "TM2An2_3", "TM2An2_4", "TM2An2_5", "TM2An3_1", "TM2An3_2", "TM2An3_3", "TM2An3_4", 
    "TM2An3_5", "TM2An4_1", "TM2An4_2", "TM2An4_3", "TM2An4_4", "TM2An4_5", "TM2An5_1", 
    "TM2An5_2", "TM2An5_3", "TM2An5_4", "TM2An5_5"
]
TREATMENT_LABEL = "GFP+ (CFPS)"
CONTROL_LABEL = "Negative Control"
EXPERIMENT_NAME = "CFPS Metabolomics - GFP vs Control"

In [None]:
# ============================================================================
# CELL 2: SCRIPT LOGIC (Functions)
# ============================================================================
# This cell contains the core logic of the script.
# A user typically does not need to edit this cell.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os

warnings.filterwarnings('ignore')

# --- HELPER FUNCTIONS ---

def clean_data_matrix(data_matrix):
    # (Function code is the same as the previous version)
    numeric_columns = []
    excluded_columns = []
    for col in data_matrix.columns:
        col_numeric = pd.to_numeric(data_matrix[col], errors='coerce')
        non_numeric_count = col_numeric.isnull().sum() - data_matrix[col].isnull().sum()
        if non_numeric_count > len(data_matrix) * 0.5:
            excluded_columns.append(col)
        else:
            numeric_columns.append(col)
    if excluded_columns:
        print(f"Excluding non-numeric columns: {', '.join(excluded_columns)}")
        data_matrix = data_matrix[numeric_columns]
    return data_matrix

def convert_missing_values(data_matrix, missing_value):
    # (Function code is the same as the previous version)
    data_numeric = data_matrix.copy()
    for col in data_numeric.columns:
        data_numeric[col] = pd.to_numeric(data_numeric[col], errors='coerce')
    if isinstance(missing_value, (int, float)):
        missing_mask = np.isclose(data_numeric, missing_value)
    else:
        missing_mask = data_numeric == missing_value
    data_numeric = data_numeric.mask(missing_mask)
    return data_numeric

def analyze_sample_classification(sample_names, treatment_samples, control_samples):
    # (Function code is the same as the previous version)
    found_treatment = [s for s in treatment_samples if s in sample_names]
    found_control = [s for s in control_samples if s in sample_names]
    if len(found_treatment) == 0 or len(found_control) == 0:
        print("ERROR: Could not find both treatment and control samples in the data file!")
        return None
    all_samples = found_treatment + found_control
    conditions = ['Treatment'] * len(found_treatment) + ['Control'] * len(found_control)
    return found_treatment, found_control, all_samples, conditions

def calculate_missing_statistics(data_matrix, treatment_indices, control_indices):
    # (Function code is the same as the previous version)
    missing_per_metabolite = (data_matrix.isnull().sum(axis=1) / data_matrix.shape[1]) * 100
    treatment_data = data_matrix.iloc[:, treatment_indices]
    control_data = data_matrix.iloc[:, control_indices]
    missing_treatment = (treatment_data.isnull().sum(axis=1) / treatment_data.shape[1]) * 100
    missing_control = (control_data.isnull().sum(axis=1) / control_data.shape[1]) * 100
    overall_missing = (data_matrix.isnull().sum().sum() / data_matrix.size) * 100
    return missing_per_metabolite, missing_treatment, missing_control, overall_missing

def apply_multiple_threshold_filtering(missing_treatment, missing_control, data_matrix):
    # (Function code is the same as the previous version)
    filtered_datasets = {}
    for threshold_name, detection_threshold in DETECTION_THRESHOLDS.items():
        missing_threshold = 100 - detection_threshold
        treatment_pass = missing_treatment <= missing_threshold
        control_pass = missing_control <= missing_threshold
        condition_aware_keep = treatment_pass | control_pass
        filtered_datasets[threshold_name] = {
            'data_matrix': data_matrix,
            'keep_mask': condition_aware_keep,
            'stats': {'total_metabolites': condition_aware_keep.sum(), 'detection_threshold': detection_threshold}
        }
    return filtered_datasets

def create_visualizations(missing_per_metabolite, sample_info, filtered_datasets, data_matrix, 
                          treatment_indices, control_indices, output_dir):
    # (Function code is the same as the previous version, with emojis removed)
    print("Creating visualizations...")
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), constrained_layout=True)
    fig.suptitle(EXPERIMENT_NAME, fontsize=16, weight='bold')
    axes[0, 0].hist(missing_per_metabolite, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].set_title('Distribution of Missing Values per Metabolite')
    axes[0, 0].set_xlabel('Missing Percentage (%)')
    axes[0, 0].set_ylabel('Number of Metabolites')
    sns.boxplot(x='Condition', y='Missing_Percent', data=sample_info, ax=axes[0, 1], palette=['red', 'blue'])
    axes[0, 1].set_title('Missing Values by Sample Condition')
    axes[0, 1].set_xlabel('Condition')
    axes[0, 1].set_ylabel('Missing Percentage (%)')
    threshold_names = list(DETECTION_THRESHOLDS.keys())
    metabolite_counts = [filtered_datasets[name]['stats']['total_metabolites'] for name in threshold_names]
    axes[0, 2].bar(threshold_names, metabolite_counts, alpha=0.7, color='steelblue')
    axes[0, 2].set_xlabel('Filtering Strategy')
    axes[0, 2].set_ylabel('Number of Metabolites Retained')
    axes[0, 2].set_title('Metabolites Retained by Strategy')
    for i, count in enumerate(metabolite_counts):
        axes[0, 2].text(i, count, str(count), ha='center', va='bottom', fontsize=10, weight='bold')
    axes[1, 0].hist(sample_info['Missing_Percent'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[1, 0].set_title('Distribution of Missing Values per Sample')
    axes[1, 0].set_xlabel('Missing Percentage (%)')
    axes[1, 0].set_ylabel('Number of Samples')
    axes[1, 0].grid(True, alpha=0.3)
    thresholds = sorted(list(DETECTION_THRESHOLDS.values()))
    counts = []
    treatment_data = data_matrix.iloc[:, treatment_indices]
    control_data = data_matrix.iloc[:, control_indices]
    missing_treatment = (treatment_data.isnull().sum(axis=1) / treatment_data.shape[1]) * 100
    missing_control = (control_data.isnull().sum(axis=1) / control_data.shape[1]) * 100
    for thresh in thresholds:
        missing_thresh = 100 - thresh
        condition_aware_keep = (missing_treatment <= missing_thresh) | (missing_control <= missing_thresh)
        counts.append(condition_aware_keep.sum())
    axes[1, 1].plot(thresholds, counts, 'o-', linewidth=2, markersize=6, color='steelblue')
    axes[1, 1].set_xlabel('Detection Threshold (%)')
    axes[1, 1].set_ylabel('Metabolites Retained')
    axes[1, 1].set_title('Metabolites Retained vs Detection Threshold')
    axes[1, 1].grid(True, alpha=0.3)
    colors = ['red' if cond == 'Treatment' else 'blue' for cond in sample_info['Condition']]
    axes[1, 2].scatter(range(len(sample_info)), sample_info['Missing_Percent'], c=colors, alpha=0.7)
    axes[1, 2].set_xlabel('Sample Index')
    axes[1, 2].set_ylabel('Missing Percentage (%)')
    axes[1, 2].set_title('Missing Values per Sample')
    output_path = os.path.join(output_dir, 'missing_data_analysis.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: missing_data_analysis.png")


def save_results(filtered_datasets, output_dir):
    # (Function code is the same as the previous version, with emojis removed)
    print("\nSAVING RESULTS:")
    print("=" * 50)
    for threshold_name, dataset_info in filtered_datasets.items():
        filename = f'metabolites_for_imputation_{threshold_name}_threshold.xlsx'
        filepath = os.path.join(output_dir, filename)
        filtered_metabolites_mask = dataset_info['keep_mask']
        output_data = dataset_info['data_matrix'].loc[filtered_metabolites_mask].reset_index()
        output_data.to_excel(filepath, index=False)
        print(f"Saved: {filename} ({filtered_metabolites_mask.sum()} metabolites)")
    threshold_comparison = pd.DataFrame({
        'Threshold_Name': list(DETECTION_THRESHOLDS.keys()),
        'Detection_Percent_Required': list(DETECTION_THRESHOLDS.values()),
        'Total_Metabolites_Kept': [filtered_datasets[name]['stats']['total_metabolites'] for name in DETECTION_THRESHOLDS.keys()]
    })
    comparison_filepath = os.path.join(output_dir, 'threshold_comparison.xlsx')
    threshold_comparison.to_excel(comparison_filepath, index=False)
    print(f"Saved: threshold_comparison.xlsx")

# --- MAIN ANALYSIS FUNCTION ---

def main(input_file, output_dir):
    """Main function to orchestrate the analysis."""
    print("=" * 60)
    print(f"  {EXPERIMENT_NAME}")
    print("=" * 60)
    print(f"Input file: {input_file}")
    print(f"Output directory: {output_dir}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        metabolite_data = pd.read_excel(input_file)
    except Exception as e:
        print(f"ERROR: Could not read the input file. Details: {e}")
        return
    
    # (Rest of the main function logic is the same)
    if metabolite_data.shape[1] < 2:
        print("ERROR: Input data must have at least two columns (Metabolite Name and one sample).")
        return
    metabolite_names = metabolite_data.iloc[:, 0]
    data_matrix = metabolite_data.iloc[:, 1:]
    data_matrix.index = metabolite_names
    data_matrix = clean_data_matrix(data_matrix)
    data_matrix = convert_missing_values(data_matrix, MISSING_VALUE_DEFINITION)
    sample_names = data_matrix.columns.tolist()
    result = analyze_sample_classification(sample_names, TREATMENT_SAMPLES, CONTROL_SAMPLES)
    if result is None:
        return
    treatment_samples, control_samples, all_samples, conditions = result
    treatment_indices = [data_matrix.columns.get_loc(s) for s in treatment_samples]
    control_indices = [data_matrix.columns.get_loc(s) for s in control_samples]
    missing_per_metabolite, missing_treatment, missing_control, overall_missing = calculate_missing_statistics(
        data_matrix, treatment_indices, control_indices
    )
    filtered_datasets = apply_multiple_threshold_filtering(
        missing_treatment, missing_control, data_matrix
    )
    sample_missing_percentages = (data_matrix[all_samples].isnull().sum() / len(data_matrix)) * 100
    sample_info = pd.DataFrame({
        'Sample': all_samples,
        'Condition': conditions,
        'Missing_Percent': sample_missing_percentages
    })
    print(f"\nEXPERIMENT SUMMARY:")
    print(f"  Total initial metabolites: {len(data_matrix)}")
    print(f"  Treatment samples found: {len(treatment_samples)}")
    print(f"  Control samples found: {len(control_samples)}")
    print(f"  Overall missing data: {overall_missing:.2f}%")
    print(f"\nTHRESHOLD-BASED FILTERING RESULTS:")
    for threshold_name, dataset_info in filtered_datasets.items():
        stats = dataset_info['stats']
        print(f"  {threshold_name.upper()} ({DETECTION_THRESHOLDS[threshold_name]}% detection):")
        print(f"     - Metabolites retained: {stats['total_metabolites']}")
    
    save_results(filtered_datasets, output_dir)
    create_visualizations(missing_per_metabolite, sample_info, filtered_datasets, 
                          data_matrix, treatment_indices, control_indices, output_dir)
    print(f"\nANALYSIS COMPLETE!")
    print(f"All output files have been saved to '{output_dir}'.")