In [3]:
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from sklearn.preprocessing import StandardScaler
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
import warnings
warnings.filterwarnings('ignore')


plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def load_and_prepare_data():

    try:

        df = pd.read_excel('data/5-predict/predict.xlsx')
        print(f"success")
        return df
    except Exception as e:
        print(f"error: {str(e)}")
        return None

def get_all_descriptors(mol):

    if mol is None:
        return None


    descriptor_names = [x[0] for x in Descriptors._descList]
    calculator = MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names)

    try:

        descriptors = calculator.CalcDescriptors(mol)
        return dict(zip(descriptor_names, descriptors))
    except:
        return None

def process_smiles_pair(smiles1, smiles2):

    if pd.isna(smiles1) or pd.isna(smiles2):
        return None

    mol1 = Chem.MolFromSmiles(smiles1.strip())
    mol2 = Chem.MolFromSmiles(smiles2.strip())

    if mol1 is None or mol2 is None:
        return None

    desc1 = get_all_descriptors(mol1)
    desc2 = get_all_descriptors(mol2)

    if desc1 is None or desc2 is None:
        return None

    combined = {}
    for key in desc1.keys():
        combined[f'mol1_{key}'] = desc1[key]
        combined[f'mol2_{key}'] = desc2[key]
        combined[f'avg_{key}'] = (desc1[key] + desc2[key])/2
        combined[f'diff_{key}'] = abs(desc1[key] - desc2[key])

    return combined

def calculate_descriptors(df):
    df['SMILE1'] = df['SMILE1'].astype(str).str.strip()
    df['SMILE2'] = df['SMILE2'].astype(str).str.strip()

    descriptors_list = []
    invalid_pairs = []

    for idx, row in df.iterrows():
        desc = process_smiles_pair(row['SMILE1'], row['SMILE2'])
        if desc is not None:
            descriptors_list.append(desc)
        else:
            invalid_pairs.append((idx, row['API'], row['CCF']))

    if invalid_pairs:
        print(f"warning:  {len(invalid_pairs)} error")
        for idx, api, ccf in invalid_pairs[:5]:
            print(f"   {idx}: API={api}, CCF={ccf}")

    df_descriptors = pd.DataFrame(descriptors_list)
    return df_descriptors, df

def load_kbest_xgboost_model():

    model_path = "Predict/kbest_fs"

    try:

        model_file = os.path.join(model_path, 'XGBoost_model.joblib')
        if not os.path.exists(model_file):
            print(f"error: {model_file}")
            return None

        model = joblib.load(model_file)


        scaler_file = os.path.join(model_path, 'XGBoost_scaler.joblib')
        if not os.path.exists(scaler_file):
            print(f"error: {scaler_file}")
            return None

        scaler = joblib.load(scaler_file)


        params_file = os.path.join(model_path, 'XGBoost_best_params.txt')
        if os.path.exists(params_file):
            with open(params_file, 'r') as f:
                best_params = f.read()
            print("success")
        else:
            best_params = "error"
            print("error")
        return {
            'model': model,
            'scaler': scaler,
            'best_params': best_params
        }

    except Exception as e:
        print(f"error {str(e)}")
        return None

def prepare_kbest_features(df_descriptors):

    try:

        kbest_info_path = 'Predict/feature/Kbest/feature_info.json'
        if not os.path.exists(kbest_info_path):
            print(f"error: {kbest_info_path}")
            return None

        with open(kbest_info_path, 'r') as f:
            kbest_info = json.load(f)

        k_value = kbest_info['k_value']
        selected_features = kbest_info['feature_scores']['features'][:k_value]

        available_features = [f for f in selected_features if f in df_descriptors.columns]
        missing_features = [f for f in selected_features if f not in df_descriptors.columns]

        if missing_features:
            print(f"error{len(missing_features)} ")
            for f in missing_features[:5]:
                print(f"  - {f}")

        if not available_features:
            raise ValueError("error")

        return df_descriptors[available_features]

    except Exception as e:
        print(f"error {str(e)}")
        return None

def predict_with_kbest_xgboost(df_features, model_data):

    try:

        X_scaled = model_data['scaler'].transform(df_features)


        probabilities = model_data['model'].predict_proba(X_scaled)[:, 1]
        predictions = model_data['model'].predict(X_scaled)

        print("success")
        return probabilities, predictions

    except Exception as e:
        print(f"error: {str(e)}")
        return None, None

def create_prediction_results(df, probabilities, predictions):

    results_df = df.copy()


    if len(results_df) > len(probabilities):
        print(f"error")
        results_df = results_df.iloc[:len(probabilities)].copy()

    results_df['cocrystal_probability'] = probabilities
    results_df['cocrystal_prediction'] = predictions


    results_df['confidence'] = results_df['cocrystal_probability'].apply(
        lambda x: 'high' if x > 0.7 else 'medium' if x > 0.5 else 'low'
    )

    results_df = results_df.sort_values('cocrystal_probability', ascending=False)


    column_order = ['API', 'SMILE1', 'CCF', 'SMILE2', 'cocrystal_prediction',
                   'cocrystal_probability', 'confidence']


    other_cols = [col for col in results_df.columns if col not in column_order]
    results_df = results_df[column_order + other_cols]

    return results_df

def visualize_cocrystal_predictions(results_df, save_path="cocrystal_predictions"):

    os.makedirs(save_path, exist_ok=True)


    plt.figure(figsize=(16, 12))


    display_count = min(100, len(results_df))
    top_results = results_df.head(display_count).copy()


    colors = []
    for _, row in top_results.iterrows():

        intensity = row['cocrystal_probability']
        colors.append((0.2, 0.6 + 0.4 * intensity, 0.2))


    bars = plt.barh(range(display_count), top_results['cocrystal_probability'],
                    color=colors, alpha=0.7, edgecolor='black')


    y_labels = [f"{row['API']} + {row['CCF']}" for _, row in top_results.iterrows()]
    plt.yticks(range(display_count), y_labels, fontsize=10)
    plt.xlabel('probability', fontsize=10)
    plt.title(f' {display_count} rank of API-CCF')
    plt.gca().invert_yaxis()

    for i, (bar, prob) in enumerate(zip(bars, top_results['cocrystal_probability'])):
        plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                f'{prob:.3f}', va='center', fontsize=9)


    plt.axvline(x=0.5, color='red', linestyle='--', alpha=0.7, label='validation (0.5)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'cocrystal_probability_ranking.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


    plt.figure(figsize=(12, 8))


    plt.subplot(2, 2, 1)
    plt.hist(results_df['cocrystal_probability'], bins=20, alpha=0.7,
             color='skyblue', edgecolor='black')
    plt.axvline(0.5, color='red', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
    plt.xlabel('probability')
    plt.ylabel('number of API-CCF')
    plt.title('Eutectic Formation Probability Distribution')
    plt.legend()

    # 预测类别分布
    plt.subplot(2, 2, 2)
    class_counts = results_df['cocrystal_prediction'].value_counts().sort_index()
    colors = ['lightcoral', 'lightgreen']
    labels = ['false (0)', 'success (1)']

    bars = plt.bar(labels, class_counts.values, color=colors, alpha=0.7, edgecolor='black')
    plt.ylabel('number of API-CCF')
    plt.title('Eutectic Formation Prediction Results')


    for bar, count in zip(bars, class_counts.values):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                f'{count}', ha='center', va='bottom')


    plt.subplot(2, 2, 3)
    confidence_counts = results_df['confidence'].value_counts()
    plt.pie(confidence_counts.values, labels=confidence_counts.index, autopct='%1.1f%%',
            colors=['lightgreen', 'gold', 'lightcoral'])
    plt.title('Prediction Confidence Distribution')

    plt.subplot(2, 2, 4)
    sns.kdeplot(data=results_df, x='cocrystal_probability', fill=True, alpha=0.6)
    plt.axvline(0.5, color='red', linestyle='--', linewidth=2)
    plt.xlabel('cocrystal formation probability')
    plt.ylabel('density')
    plt.title('Probability density distribution of cocrystal formation')

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'cocrystal_prediction_analysis.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    plt.figure(figsize=(12, 6))

    rankings = np.arange(1, len(results_df)+1)
    probabilities = results_df['cocrystal_probability'].values

    plt.scatter(rankings, probabilities, c=probabilities, cmap='RdYlGn',
                alpha=0.6, s=30, edgecolors='black', linewidth=0.5)

    z = np.polyfit(rankings, probabilities, 3)
    p = np.poly1d(z)
    plt.plot(rankings, p(rankings), "r--", alpha=0.8, linewidth=2, label='Trend line')

    plt.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Decision threshold')
    plt.xlabel('rank')
    plt.ylabel('probability')
    plt.title('rank and probability of cocrystal formation')
    plt.colorbar(label='probability of cocrystal formation')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'cocrystal_ranking_vs_probability.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


    try:

        top_apis = results_df['API'].value_counts().head(20).index
        top_ccfs = results_df['CCF'].value_counts().head(20).index

        subset_df = results_df[results_df['API'].isin(top_apis) & results_df['CCF'].isin(top_ccfs)]

        if len(subset_df) > 0:
            pivot_table = subset_df.pivot_table(values='cocrystal_probability',
                                              index='API', columns='CCF',
                                              aggfunc='mean')

            plt.figure(figsize=(14, 10))
            sns.heatmap(pivot_table, annot=True, fmt='.2f', cmap='RdYlGn',
                       cbar_kws={'label': 'probability of cocrystal formation'})
            plt.title('API-CCF Cocrystal Formation Probability Heatmap')
            plt.xlabel('CCF')
            plt.ylabel('API')
            plt.tight_layout()
            plt.savefig(os.path.join(save_path, 'cocrystal_heatmap.png'),
                        dpi=300, bbox_inches='tight')
            plt.close()
    except Exception as e:
        print(f"error: {str(e)}")

def generate_cocrystal_prediction_report(results_df, save_path="cocrystal_predictions"):

    report_content = []

    report_content.append("=" * 70)
    report_content.append("           API-CCF Cocrystal Formation Prediction Report")
    report_content.append("=" * 70)
    report_content.append(f"Time: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    report_content.append(f"number of API-CCF: {len(results_df)}")
    report_content.append("")


    positive_count = results_df['cocrystal_prediction'].sum()
    negative_count = len(results_df) - positive_count

    report_content.append("Cocrystal Formation Prediction Results Statistics:")
    report_content.append(f"  - success (1): {positive_count} 对 ({positive_count/len(results_df)*100:.1f}%)")
    report_content.append(f"  - unsuccess (0): {negative_count} 对 ({negative_count/len(results_df)*100:.1f}%)")
    report_content.append("")

    confidence_counts = results_df['confidence'].value_counts()
    report_content.append("Prediction Confidence Statistics:")
    for conf_level, count in confidence_counts.items():
        report_content.append(f"  - {conf_level}: {count} 对 ({count/len(results_df)*100:.1f}%)")
    report_content.append("")


    bins = [0, 0.2, 0.3, 0.4, 0.5, 0.7, 1.0]
    labels = ['0-0.2', '0.2-0.3', '0.3-0.4', '0.4-0.5', '0.5-0.7', '0.7-1.0']

    results_df['probability_bin'] = pd.cut(results_df['cocrystal_probability'],
                                         bins=bins, labels=labels, include_lowest=True)

    bin_counts = results_df['probability_bin'].value_counts().sort_index()

    report_content.append("Cocrystal Formation Probability Segmentation Statistics:")
    for bin_label, count in bin_counts.items():
        percentage = count/len(results_df)*100
        report_content.append(f"  - {bin_label}: {count} 对 ({percentage:.1f}%)")
    report_content.append("")

    report_content.append("Most probability of cocrystal:")
    for i, (_, row) in enumerate(results_df.head(10).iterrows(), 1):
        report_content.append(f"  {i}. {row['API']} + {row['CCF']}: {row['cocrystal_probability']:.3f} ({row['confidence']})")

    report_content.append("")
    report_content.append("Least probability of cocrystal:")
    for i, (_, row) in enumerate(results_df.tail(10).iterrows(), 1):
        report_content.append(f"  {i}. {row['API']} + {row['CCF']}: {row['cocrystal_probability']:.3f} ({row['confidence']})")


    prob_stats = results_df['cocrystal_probability'].describe()
    report_content.append("")
    report_content.append("probability of cocrystal formation:")
    report_content.append(f"  - Average: {prob_stats['mean']:.3f}")
    report_content.append(f"  - Standard: {prob_stats['std']:.3f}")
    report_content.append(f"  - Median: {prob_stats['50%']:.3f}")
    report_content.append(f"  - Maximum: {prob_stats['max']:.3f}")
    report_content.append(f"  - Minimum: {prob_stats['min']:.3f}")


    report_file = os.path.join(save_path, "cocrystal_prediction_report.txt")
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(report_content))

    print('\n'.join(report_content))
    print(f"\n {report_file}")

def cocrystal_prediction_pipeline():


    print("=" * 50)


    df = load_and_prepare_data()
    if df is None:
        print("error")
        return None

    df_descriptors, df_valid = calculate_descriptors(df)
    if df_descriptors is None or len(df_descriptors) == 0:
        print("error")
        return None

    model_data = load_kbest_xgboost_model()
    if model_data is None:
        print("error")
        return None

    df_kbest_features = prepare_kbest_features(df_descriptors)
    if df_kbest_features is None:
        print("error")
        return None

    probabilities, predictions = predict_with_kbest_xgboost(df_kbest_features, model_data)
    if probabilities is None:
        print("error")
        return None

    results_df = create_prediction_results(df_valid, probabilities, predictions)

    visualize_cocrystal_predictions(results_df)

    generate_cocrystal_prediction_report(results_df)

    output_file = 'api_ccf_cocrystal_prediction_results.xlsx'
    results_df.to_excel(output_file, index=False)

    display_cols = ['API', 'CCF', 'cocrystal_prediction', 'cocrystal_probability', 'confidence']
    print(results_df[display_cols].head(10))

    print(results_df[display_cols].tail(10))

    return results_df


if __name__ == "__main__":

    final_results = cocrystal_prediction_pipeline()

    if final_results is not None:
        print("\n" + "=" * 50)
        print("success!")
        print("file:")
        print("  - api_ccf_cocrystal_prediction_results.xlsx ")
        print("  - cocrystal_predictions/ ")
    else:
        print("\nFalse!")

success
success
success
           API-CCF Cocrystal Formation Prediction Report
Time: 2025-12-13 23:22:53
number of API-CCF: 46

Cocrystal Formation Prediction Results Statistics:
  - success (1): 38 对 (82.6%)
  - unsuccess (0): 8 对 (17.4%)

Prediction Confidence Statistics:
  - high: 19 对 (41.3%)
  - medium: 19 对 (41.3%)
  - low: 8 对 (17.4%)

Cocrystal Formation Probability Segmentation Statistics:
  - 0-0.2: 0 对 (0.0%)
  - 0.2-0.3: 3 对 (6.5%)
  - 0.3-0.4: 0 对 (0.0%)
  - 0.4-0.5: 5 对 (10.9%)
  - 0.5-0.7: 19 对 (41.3%)
  - 0.7-1.0: 19 对 (41.3%)

Most probability of cocrystal:
  1. AML + Benzene sulfonamide: 0.965 (high)
  2. AML + 4,.4'-bipyridine: 0.938 (high)
  3. AML + im: 0.919 (high)
  4. AML + Imidazole: 0.919 (high)
  5. AML + Methanesulfonic acid: 0.878 (high)
  6. AML + 4-acetophenetidide: 0.865 (high)
  7. AML + N-phenylacetamide: 0.864 (high)
  8. AML + Succinimide: 0.859 (high)
  9. AML + 3,5-difluorobenzoic acid: 0.851 (high)
  10. AML + 3,5-dihydroxybenzoic acid: 0.798 (h