In [None]:

from tqdm import tqdm
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from glob import glob
import numpy as np
import seaborn as sns
KEEP_THRESHOLD = 0.1
father_dir = '/home/data-house-01/cdhofficial/MolGens/MolGenBench-dev/test_interaction'
save_fig_dir = '/home/data-house-01/cdhofficial/MolGens/MolGenBench-dev/test_interaction_fig'


In [None]:

def repairCSV(data_path):
    data_sample = pd.read_csv(data_path)
    new_columns = []
    for residue_name,interaction in zip(data_sample.iloc[0],data_sample.iloc[1]):
        new_columns.append(f'{residue_name}_{interaction}')
    data_sample.columns = new_columns
    data_sample = data_sample.drop([0,1,2])
    data_sample = data_sample.drop('protein_interaction',axis = 1).reset_index(drop = True)
    for column in data_sample.columns:
        if 'Hydrophobic' in column or 'VdWContact' in column:
            data_sample.drop(column,axis = 1,inplace = True)
            continue
        data_sample[column] = data_sample[column].apply(lambda x: 0 if x == 'False' else  1)
    return data_sample
def getInteractionMap(data_sample,keep_threshold = 0.1):
    # KEEP_THRESHOLD = 0.1
    interaction_ref = pd.DataFrame(data_sample.sum(axis = 0).sort_values(ascending = False)).reset_index()
    interaction_ref.columns = ['interaction','count']

    interaction_ref['FilterByThreshold'] = interaction_ref['count'].apply(lambda x: True if x > keep_threshold*len(data_sample) else False)
    interaction_ref = interaction_ref[interaction_ref['FilterByThreshold']]
    interaction_ref['ratio'] = interaction_ref['count']/interaction_ref['count'].sum()
    interaction_ref_map = {}
    for interaction,ratio in zip(interaction_ref['interaction'],interaction_ref['ratio']):
        interaction_ref_map[interaction] = ratio
    return interaction_ref_map
def getInteractionScore(input_serise,interaction_ref_map):
    interaction_score = 0
    for col_name,interaction in zip(input_serise.index,input_serise):
        if interaction == 1 and col_name in interaction_ref_map:
            interaction_score += interaction_ref_map[col_name]
    
    return interaction_score

def getInteractionNum(input_serise,interaction_ref_map):
    interaction_num = 0
    for col_name,interaction in zip(input_serise.index,input_serise):
        if interaction == 1 and col_name in interaction_ref_map:
            interaction_num += 1
    
    return interaction_num


In [None]:


pbar = tqdm(os.listdir(father_dir),total = len(os.listdir(father_dir)))
error_list =[]
for uniprot_id in pbar:
    for type_name in ['vina_docked']:
        pbar.set_description(desc=f"Processing Uniprot ID: {uniprot_id}, Type: {type_name}")
        
        uniprot_dir = os.path.join(father_dir,uniprot_id)
        ref_csv_path = os.path.join(f'{father_dir}/{uniprot_id}/reference_active_molecules/{uniprot_id}_reference_active_molecules_{type_name}_interactions.csv')
        if not os.path.exists(ref_csv_path):
            print(f'Uniprot : {uniprot_id} dont have ref csv file')
            continue
        
        data_sample = repairCSV(ref_csv_path)
        interaction_ref_map = getInteractionMap(data_sample,keep_threshold = KEEP_THRESHOLD)
        if not os.path.exists(ref_csv_path.replace('.csv','_score.csv')):
                        # continue
            data_sample['InteractionScore'] = data_sample.apply(lambda x: getInteractionScore(x, interaction_ref_map), axis=1)
            data_sample['InteractionNum'] = data_sample.apply(lambda x: getInteractionNum(x, interaction_ref_map), axis=1)
            data_sample = data_sample.sort_values('InteractionScore',ascending = False)
            data_sample.to_csv(ref_csv_path.replace('.csv','_score.csv'),index = False)
        
        sub_dir_names = ['Round1','Round2','Round3']

        for sub_dir in sub_dir_names:

            
            all_files  = glob(os.path.join(father_dir,uniprot_id,sub_dir,'*','*','*',f'*{type_name}_interactions.csv')) + \
                            glob(os.path.join(father_dir,uniprot_id,sub_dir,'*','*',f'*{type_name}_interactions.csv'))
            for test_path in all_files:
                try:
                    if os.path.exists(test_path.replace('.csv','_score.csv')):
                        continue
                    test_samples = repairCSV(test_path)
                    
                    test_samples['InteractionScore'] = test_samples.apply(lambda x: getInteractionScore(x, interaction_ref_map), axis=1)
                    test_samples['InteractionNum'] = test_samples.apply(lambda x: getInteractionNum(x, interaction_ref_map), axis=1)
                    test_samples = test_samples.sort_values('InteractionScore',ascending = False)
                    test_samples.to_csv(test_path.replace('.csv','_score.csv'),index = False)
                except:
                    error_list.append(test_path)
                    continue


            all_files  = glob(os.path.join(father_dir,uniprot_id,sub_dir,'*','*','*',f'*_interactions.csv')) + \
                            glob(os.path.join(father_dir,uniprot_id,sub_dir,'*','*',f'*_interactions.csv'))
            for test_path in all_files:
                try:
                    if 'ligprep_glide_sp_pv' in os.path.basename(test_path) or 'vina_docked' in os.path.basename(test_path):
                        continue
                    
                    if os.path.exists(test_path.replace('.csv',f'_score_raw_compare_{type_name}.csv')):
                        continue
                    test_samples = repairCSV(test_path)
                    test_samples['InteractionScore'] = test_samples.apply(lambda x: getInteractionScore(x, interaction_ref_map), axis=1)
                    test_samples['InteractionNum'] = test_samples.apply(lambda x: getInteractionNum(x, interaction_ref_map), axis=1)
                    test_samples = test_samples.sort_values('InteractionScore',ascending = False)
                    test_samples.to_csv(test_path.replace('.csv',f'_score_raw_compare_{type_name}.csv'),index = False)
                except:
                    error_list.append(test_path)
                    continue

## 统计不同的模型的interactionscore 的累积分布

In [None]:
## Define a function to read the csv files in all folders, and then add a filename column to each file
def read_csv_files_in_directory(directory):

    all_dataframes = []
    sub_dir_names = ['Round1','Round2','Round3','reference_active_molecules']
    for uniprot_id in os.listdir(directory):
        for sub_dir in sub_dir_names:

            if sub_dir == 'reference_active_molecules':
                csv_files  = glob(os.path.join(directory,uniprot_id,sub_dir,f'*score*.csv'))
            else:
                    
                csv_files  = glob(os.path.join(directory,uniprot_id,sub_dir,'*','*','*',f'*score*.csv')) + \
                                glob(os.path.join(directory,uniprot_id,sub_dir,'*','*',f'*score*.csv'))

            for file in csv_files:

                df = pd.read_csv(file)

                df['filename'] = [file]*len(df)
                df['uniprot_id'] = [uniprot_id]*len(df)
                df['round'] = [sub_dir]*len(df)

                if sub_dir == 'reference_active_molecules':
                    df['model_name'] = 'Reference'
                    df['Task'] = ['Reference']*len(df)
                else:
                    if 'Sries' not in file:
                        df['model_name'] = df['filename'].apply(lambda x: '_'.join(os.path.basename(x).split('_generated')[0].split('_')[1:]))
                        df['Task'] = ['Denovo']*len(df)
                    else:
                        df['model_name'] = df['filename'].apply(lambda x: '_'.join(os.path.basename(x).split('_Hit')[0].split('_')[2:]))
                        df['Task'] = ['Hit2Lead']*len(df)
                df = df[['Task','model_name','round','uniprot_id','filename','InteractionScore','InteractionNum']]
                
                all_dataframes.append(df)
                
        

    combined_df = pd.concat(all_dataframes, ignore_index=True)
    
    return combined_df

In [None]:


combined_df = read_csv_files_in_directory(father_dir)


In [None]:
def temp_func_interaction_type(filename):
    type_name = 'Glide' if 'glide' in filename else 'Vina'
    type_name = type_name if 'raw_compare' not in filename else 'Raw_Compare_' + type_name
    return type_name
combined_df['interaction_type'] = combined_df['filename'].apply(lambda x: temp_func_interaction_type(x))


In [None]:
combined_df['interaction_type'].value_counts()


## Count the values of different quantiles of the reference, and then calculate the proportion greater than this number.

In [None]:
# combined_df_denovo_round1['interaction_type'].unique()
def calculate_quantile_proportions(reference_df, generated_df, quantiles=[0.25, 0.5, 0.75]):
    """
    Calculates the proportion of generated molecules with InteractionScores exceeding 
    reference quantiles for each interaction type, uniprot_id, and model.

    Args:
        reference_df (pd.DataFrame): DataFrame with reference interaction scores.
        generated_df (pd.DataFrame): DataFrame with generated molecules' interaction scores.
        quantiles (list): A list of quantiles to calculate (e.g., [0.25, 0.5, 0.75]).

    Returns:
        pd.DataFrame: A DataFrame with the calculated proportions.
    """
    # Calculate quantile thresholds from the reference dataframe
    thresholds = reference_df.groupby(['interaction_type', 'uniprot_id'])['InteractionScore'].quantile(quantiles).unstack()
    # Duplicate Vina thresholds and rename as Raw_Compare_Vina
    if 'Vina' in thresholds.index.get_level_values('interaction_type'):
        vina_dup = thresholds.xs('Vina', level='interaction_type', drop_level=False).copy()
        vina_dup.index = pd.MultiIndex.from_tuples(
            [('Raw_Compare_Vina', uid) for _, uid in vina_dup.index],
            names=thresholds.index.names
        )
        thresholds = pd.concat([thresholds, vina_dup])
    if 'Glide' in thresholds.index.get_level_values('interaction_type'):
        vina_dup = thresholds.xs('Glide', level='interaction_type', drop_level=False).copy()
        vina_dup.index = pd.MultiIndex.from_tuples(
            [('Raw_Compare_Glide', uid) for _, uid in vina_dup.index],
            names=thresholds.index.names
        )
        thresholds = pd.concat([thresholds, vina_dup])
    
    results = []
    
    # Group the generated data to calculate proportions
    grouped_generated = generated_df.groupby(['interaction_type', 'uniprot_id', 'model_name'])
    
    for name, group in grouped_generated:
        interaction_type, uniprot_id, model_name = name
        
        # Check if the corresponding threshold exists
        if (interaction_type, uniprot_id) in thresholds.index:
            threshold_values = thresholds.loc[(interaction_type, uniprot_id)]
            total_count = len(group)
            
            if total_count > 0:
                for quantile_level, threshold in threshold_values.items():
                    # Count how many scores are above the threshold
                    count_above = (group['InteractionScore'] > threshold).sum()
                    proportion = count_above / total_count
                    
                    results.append({
                        'interaction_type': interaction_type,
                        'uniprot_id': uniprot_id,
                        'model_name': model_name,
                        'quantile': quantile_level,
                        'proportion_above': proportion
                    })
                    
    return pd.DataFrame(results)

def replace_model_name(model_name):
    if 'diffSBDD_cond_crossdocked' in model_name:
        return 'DiffSBDD-C'
    elif 'diffSBDD_cond_moad' in model_name:
        return 'DiffSBDD-M'
    elif 'shepherd' in model_name:
        return 'ShEPhERD'
    elif 'DeleteHit2Lead' in model_name:
        return 'Delete'
    else:
        return model_name

def plot_quantile_polar(stats_df, interaction_type, quantiles=(0.25, 0.50, 0.75),
                        quantile_palette=None, model_order=None, show_mean_values=True, show_legend=True,
                        title_suffix="R=mean±std(ALL Proteins)"):
    """
    Aggregate all uniprot, display each model's proportion_above_mean for different quantiles in polar coordinates
    (radius = mean, shading = ± std average).
    """
    from matplotlib.patches import Patch  # Import only for creating legend color patches
    sub_df = stats_df[stats_df['interaction_type'] == interaction_type].copy()  # Filter data for specified interaction
    agg = (sub_df
           .groupby(['model_name', 'quantile'], as_index=False)  # Aggregate by model and quantile
           .agg({'proportion_above_mean': 'mean',                # Calculate mean of mean (across uniprot / repeat)
                 'proportion_above_std': 'mean'}))               # Calculate mean of std (across uniprot / repeat)

    quantiles = list(quantiles)  # Ensure quantiles are indexable
    if model_order is None:  # If model order is not specified
        first_q = quantiles[0]  # Use first quantile for sorting
        model_order = (agg[agg['quantile'] == first_q]  # Get rows for this quantile
                       .sort_values('proportion_above_mean', ascending=False)  # Sort by mean descending
                       ['model_name'].tolist())  # Extract model order list

    n_models = len(model_order)  # Number of models
    sector_width = 2 * np.pi / max(n_models, 1)  # Radian width of each sector

    if quantile_palette is None:  # If quantile colors not specified
        quantile_palette = {q: c for q, c in zip(quantiles, sns.color_palette('hls', len(quantiles)))}  # Auto-assign colors

    base_palette = sns.color_palette('deep', n_models)  # Generate base colors for each model
    model_sorted = sorted(model_order)  # Sort model names (ensure color mapping stable)
    base_color = dict(zip(model_sorted, base_palette))   # Map model->color

    fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'polar': True})  # Create polar plot
    ax.set_xticks([])            # Remove polar angle ticks
    ax.set_rlim(0, 1.0)          # Fixed radius range 0~1
    # ax.set_rticks([0.25, 0.5, 0.75, 1.0])  # No need to display radius ticks
    # Display radius tick lines but not tick values
    ax.set_rticks([0.25, 0.5, 0.75, 1.0])
    ax.set_yticklabels([])
    ax.set_rlabel_position(90)   # Place radius labels at 90°
    ax.grid(alpha=1, linestyle=':', linewidth=1.5)  # Grid style

    # Draw background sectors (one base color扇形 for each model)
    for idx, model in enumerate(model_order):
        theta_mid = (idx + 0.5) * sector_width  # Sector center angle
        ax.bar(theta_mid, 1.0, width=sector_width, bottom=0,  # Use bar to generate扇形
               color=base_color[model], edgecolor='none',
               linewidth=0, alpha=0.25, zorder=0)

    theta_res = 80  # Number of interpolation points per sector curve
    for idx, model in enumerate(model_order):  # Iterate through models
        theta_start = idx * sector_width       # Sector start angle
        theta_end = (idx + 1) * sector_width   # Sector end angle
        theta = np.linspace(theta_start, theta_end, theta_res)  # Angle sampling

        for q in quantiles:  # Iterate through each quantile
            row = agg[(agg['model_name'] == model) & (agg['quantile'] == q)]  # Get aggregated row for this model and quantile
            if row.empty:
                continue  # Skip if missing
            mean_r = float(row['proportion_above_mean'])  # Radius mean
            std_r = float(row['proportion_above_std'])    # Radius std (average std)
            mean_r = np.clip(mean_r, 0, 1)                # Clip mean to [0,1]
            std_r = max(std_r, 0)                         # Ensure non-negative

            lower = max(mean_r - std_r, 0.0)  # Shading lower bound
            upper = min(mean_r + std_r, 1.0)  # Shading upper bound

            if std_r > 0:
                ax.fill_between(theta, lower, upper,       # Draw ±std shading
                                color=quantile_palette[q],
                                alpha=0.4, linewidth=0, zorder=2)

            ax.plot(theta,
                    np.full_like(theta, mean_r),           # Draw mean arc
                    color=quantile_palette[q],
                    linewidth=1.5,
                    label=f"Q{int(q*100)}" if (idx == 0) else None,  # Only add legend label for first sector
                    zorder=3)

        ax.plot([theta_start, theta_start], [0, 1.0],       # Sector left boundary
                color='gray', linewidth=0.6, alpha=0.6, zorder=5)
        # Add average value labels for each quantile of this model (separated from mean line, placed above or below)
        
        if show_mean_values:
            theta_mid = (theta_start + theta_end) / 2
            for q in quantiles:
                row_q = agg[(agg['model_name'] == model) & (agg['quantile'] == q)]
                if row_q.empty:
                    continue
                m_val = float(row_q['proportion_above_mean'])
                label_r = min(m_val + 0.05, 0.98)
                if label_r - m_val < 0.03:  # If truncated, place below
                    label_r = max(m_val - 0.05, 0.02)
                ax.text(theta_mid, label_r, f"{m_val:.3f}",
                    ha='center', va='center',
                    fontsize=7, color=quantile_palette[q])

    ax.plot([n_models * sector_width, n_models * sector_width], [0, 1.0],  # Rightmost boundary line (close)
            color='gray', linewidth=0.6, alpha=0.6, zorder=5)

    ax.set_title(f"{interaction_type} | {title_suffix}", pad=30, fontsize=12)  # Title
    
    model_patches = [Patch(facecolor=base_color[m], alpha=0.4, label=replace_model_name(m)) for m in model_order]  # Model legend color patches
    if show_legend:
        # First make space for right-side model legend to avoid cropping
        fig.subplots_adjust(right=0.80)

        # Place quantile legend inside the plot (no longer outside right), avoid half going out
        q_legend = ax.legend(loc='center left',
                     bbox_to_anchor=(1.05, 0.8),
                     frameon=False,
                     title='Quantiles',
                     ncol=1,
                     handlelength=1.2,
                     columnspacing=1.0)
        ax.add_artist(q_legend)  # Add quantile legend

        ax.legend(handles=model_patches, loc='center left',   # Model legend
                bbox_to_anchor=(1.05, 0.5),
                frameon=False, title='Models')

    plt.tight_layout()  # Compact layout
    # Solve legend cropping: shrink polar area, leave space for right legend
    os.makedirs(save_fig_dir, exist_ok=True)
    plt.savefig(f'{save_fig_dir}/polar_plot_{interaction_type}_{title_suffix}.svg', bbox_inches='tight', dpi=660, format='svg')  # Save image

    return 




In [None]:

import json
with open('../../sup_info/UniprotIDs_duplicated_with_crossdock2020.json','r') as f:
    UniprotId_in_crossdock = json.load(f)

In [None]:
combined_df_hit2lead = combined_df[combined_df['Task'] == 'Hit2Lead']
combined_df_denovo = combined_df[combined_df['Task'] == 'Denovo']
combined_df_reference = combined_df[combined_df['Task'] == 'Reference']
# Denovo 
combined_df_denovo_round1 = combined_df_denovo[combined_df_denovo['round'] == 'Round1']
combined_df_denovo_round2 = combined_df_denovo[combined_df_denovo['round'] == 'Round3']
combined_df_denovo_round3 = combined_df_denovo[combined_df_denovo['round'] == 'Round3']
# hit2lead 
combined_df_hit2lead_round1 = combined_df_hit2lead[combined_df_hit2lead['round'] == 'Round1']
combined_df_hit2lead_round2 = combined_df_hit2lead[combined_df_hit2lead['round'] == 'Round3']
combined_df_hit2lead_round3 = combined_df_hit2lead[combined_df_hit2lead['round'] == 'Round3'] 

## unseen proteins
combined_df_unseen = combined_df[~combined_df['uniprot_id'].isin(UniprotId_in_crossdock.keys())]
combined_df_seen = combined_df[combined_df['uniprot_id'].isin(UniprotId_in_crossdock.keys())]
print(f"unseen proteins:{len(combined_df_unseen['uniprot_id'].unique())}",f"seen proteins:{len(combined_df_seen['uniprot_id'].unique())}")






In [None]:
combined_df_hit2lead_unseen = combined_df_unseen [combined_df_unseen ['Task'] == 'Hit2Lead']
combined_df_denovo_unseen = combined_df_unseen [combined_df_unseen ['Task'] == 'Denovo']
combined_df_reference_unseen = combined_df_unseen [combined_df_unseen ['Task'] == 'Reference']
# unseen denovo proteins 
combined_df_unseen_denovo_round1 = combined_df_denovo_unseen[combined_df_denovo_unseen['round'] == 'Round1']
combined_df_unseen_denovo_round2 =combined_df_denovo_unseen[combined_df_denovo_unseen['round'] == 'Round2']
combined_df_unseen_denovo_round3 = combined_df_denovo_unseen[combined_df_denovo_unseen['round'] == 'Round3']
#unseen hit2lead proteins 
combined_df_unseen_hit2lead_round1 = combined_df_hit2lead_unseen[combined_df_hit2lead_unseen['round'] == 'Round1']
combined_df_unseen_hit2lead_round2 = combined_df_hit2lead_unseen[combined_df_hit2lead_unseen['round'] == 'Round2']        
combined_df_unseen_hit2lead_round3 = combined_df_hit2lead_unseen[combined_df_hit2lead_unseen['round'] == 'Round3']


## seen proteins

combined_df_hit2lead_seen = combined_df_seen [combined_df_seen ['Task'] == 'Hit2Lead']
combined_df_denovo_seen = combined_df_seen [combined_df_seen ['Task'] == 'Denovo']
combined_df_reference_seen = combined_df_seen [combined_df_seen ['Task'] == 'Reference']
# seen denovo proteins 结
combined_df_seen_denovo_round1 = combined_df_denovo_seen[combined_df_denovo_seen['round'] == 'Round1']
combined_df_seen_denovo_round2 = combined_df_denovo_seen[combined_df_denovo_seen['round'] == 'Round2']
combined_df_seen_denovo_round3 = combined_df_denovo_seen[combined_df_denovo_seen['round'] == 'Round3']
# seen hit2lead proteins 
combined_df_seen_hit2lead_round1 = combined_df_hit2lead_seen[combined_df_hit2lead_seen['round'] == 'Round1']
combined_df_seen_hit2lead_round2 = combined_df_hit2lead_seen[combined_df_hit2lead_seen['round'] == 'Round2']        
combined_df_seen_hit2lead_round3 = combined_df_hit2lead_seen[combined_df_hit2lead_seen['round'] == 'Round3']



## all denovo

In [None]:

## denovo all proteins
denovo_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_denovo_round1)
denovo_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_denovo_round2)
denovo_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_denovo_round3)

denovo_proportions_round1['repeat'] = 'Round1'
denovo_proportions_round2['repeat'] = 'Round2'
denovo_proportions_round3['repeat'] = 'Round3'
# merge all rounds
denovo_proportions_all = pd.concat(
    [denovo_proportions_round1, denovo_proportions_round2, denovo_proportions_round3],
    ignore_index=True
)
# compute mean and std
denovo_proportions_stats = (
    denovo_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in denovo_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(All Proteins Denovo)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(All Proteins Denovo)',show_mean_values = False,show_legend=False)


## seen denovo 

In [None]:

## seen_denovo all proteins

seen_denovo_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_denovo_round1)
seen_denovo_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_denovo_round2)
seen_denovo_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_denovo_round3)

seen_denovo_proportions_round1['repeat'] = 'Round1'
seen_denovo_proportions_round2['repeat'] = 'Round2'
seen_denovo_proportions_round3['repeat'] = 'Round3'
# merge all rounds
seen_denovo_proportions_all = pd.concat(
    [seen_denovo_proportions_round1, seen_denovo_proportions_round2, seen_denovo_proportions_round3],
    ignore_index=True
)
# compute mean and std
seen_denovo_proportions_stats = (
    seen_denovo_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in seen_denovo_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(seen_denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(seen Proteins denovo)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(seen_denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(seen Proteins denovo)',show_mean_values = False,show_legend=False)

## unseen denovo

In [None]:

## unseen_denovo all proteins

unseen_denovo_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_denovo_round1)
unseen_denovo_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_denovo_round2)
unseen_denovo_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_denovo_round3)

unseen_denovo_proportions_round1['repeat'] = 'Round1'
unseen_denovo_proportions_round2['repeat'] = 'Round2'
unseen_denovo_proportions_round3['repeat'] = 'Round3'
# merge all rounds
unseen_denovo_proportions_all = pd.concat(
    [unseen_denovo_proportions_round1, unseen_denovo_proportions_round2, unseen_denovo_proportions_round3],
    ignore_index=True
)
# compute mean and std
unseen_denovo_proportions_stats = (
    unseen_denovo_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in unseen_denovo_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(unseen_denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(unseen Proteins denovo)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(unseen_denovo_proportions_stats , target_interaction,title_suffix='R=Mean±Std(unseen Proteins denovo)',show_mean_values = False,show_legend=False)

## unseen hit2lead

In [None]:

## unseen_hit2lead all proteins
# combined_df_unseen_unseen_hit2lead_round2 
unseen_hit2lead_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_hit2lead_round1)
unseen_hit2lead_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_hit2lead_round2)
unseen_hit2lead_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_unseen_hit2lead_round3)

unseen_hit2lead_proportions_round1['repeat'] = 'Round1'
unseen_hit2lead_proportions_round2['repeat'] = 'Round2'
unseen_hit2lead_proportions_round3['repeat'] = 'Round3'
# merge all rounds
unseen_hit2lead_proportions_all = pd.concat(
    [unseen_hit2lead_proportions_round1, unseen_hit2lead_proportions_round2, unseen_hit2lead_proportions_round3],
    ignore_index=True
)
# compute mean and std
unseen_hit2lead_proportions_stats = (
    unseen_hit2lead_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in unseen_hit2lead_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(unseen_hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(unseen Proteins hit2lead)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(unseen_hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(unseen Proteins hit2lead)',show_mean_values = False,show_legend=False)

## seen hit2lead

In [None]:

## seen_hit2lead all proteins
# combined_df_seen_seen_hit2lead_round2 
seen_hit2lead_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_hit2lead_round1)
seen_hit2lead_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_hit2lead_round2)
seen_hit2lead_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_seen_hit2lead_round3)

seen_hit2lead_proportions_round1['repeat'] = 'Round1'
seen_hit2lead_proportions_round2['repeat'] = 'Round2'
seen_hit2lead_proportions_round3['repeat'] = 'Round3'
# merge all rounds
seen_hit2lead_proportions_all = pd.concat(
    [seen_hit2lead_proportions_round1, seen_hit2lead_proportions_round2, seen_hit2lead_proportions_round3],
    ignore_index=True
)
# compute mean and std
seen_hit2lead_proportions_stats = (
    seen_hit2lead_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in seen_hit2lead_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(seen_hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(seen Proteins hit2lead)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(seen_hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(seen Proteins hit2lead)',show_mean_values = False,show_legend=False)


## all hit2lead

In [None]:

## hit2lead all proteins
hit2lead_proportions_round1 = calculate_quantile_proportions(combined_df_reference, combined_df_hit2lead_round1)
hit2lead_proportions_round2 = calculate_quantile_proportions(combined_df_reference, combined_df_hit2lead_round2)
hit2lead_proportions_round3 = calculate_quantile_proportions(combined_df_reference, combined_df_hit2lead_round3)

hit2lead_proportions_round1['repeat'] = 'Round1'
hit2lead_proportions_round2['repeat'] = 'Round2'
hit2lead_proportions_round3['repeat'] = 'Round3'
# merge all rounds
hit2lead_proportions_all = pd.concat(
    [hit2lead_proportions_round1, hit2lead_proportions_round2, hit2lead_proportions_round3],
    ignore_index=True
)
# compute mean and std
hit2lead_proportions_stats = (
    hit2lead_proportions_all
    .groupby(['interaction_type', 'uniprot_id', 'model_name', 'quantile'])['proportion_above']
    .agg(['mean', 'std', 'count'])
    .reset_index()
    .rename(columns={'mean': 'proportion_above_mean',
                     'std': 'proportion_above_std',
                     'count': 'n_repeats'})
)
for target_interaction in hit2lead_proportions_stats['interaction_type'].unique():
    plot_quantile_polar(hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(All Proteins hit2lead)_with_legend_mean_value',show_mean_values = True,show_legend=True)
    plot_quantile_polar(hit2lead_proportions_stats , target_interaction,title_suffix='R=Mean±Std(All Proteins hit2lead)',show_mean_values = False,show_legend=False)
