In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import glob
import os
import numpy as np

In [None]:
plt.rcParams['xtick.major.width'] = 0.2
plt.rcParams['ytick.major.width'] = 0.2
plt.rcParams['xtick.minor.width'] = 0.2
plt.rcParams['ytick.minor.width'] = 0.2
plt.rcParams['xtick.major.size'] = 3
plt.rcParams['ytick.major.size'] = 3
plt.rcParams['xtick.minor.size'] = 2
plt.rcParams['ytick.minor.size'] = 2

def plot_pairwise_data(file_path, output_path=None):
    # Column names for the TSV file
    columns = ["query", "target", "identity", "alnlen", "mismatch", "gapopen",
               "qstart", "qend", "tstart", "tend", "evalue", "bits"]
    
    # Load the TSV file into a pandas dataframe
    df = pd.read_csv(file_path, sep="\t", names=columns)
    
    # Filter out rows where query == target to avoid self-comparisons
    df_filtered = df[df['query'] != df['target']]

    # Calculate the statistics for the annotation
    num_pairwise_comparisons = len(df)
    num_clusters = df['query'].nunique()
    num_clusters_more_than_one = df_filtered['query'].nunique()

    # Set up the 2x2 grid for plots
    fig, axes = plt.subplots(2, 2, figsize=(7., 5), gridspec_kw={'height_ratios': [1, 1.5]}, constrained_layout=False, dpi=300)
    plt.subplots_adjust(wspace=0.35)
    
    # First row, left: KDE/Histogram of sequence identities
    ax1 = axes[0, 0]
    sns.histplot(df_filtered['identity'], bins=20, kde=True, ax=ax1, color='dodgerblue', lw=0.5)
    ax1.set_xlabel("Sequence Identity", size=6)
    ax1.set_ylabel("Frequency", size=6)
    ax1.tick_params(axis='both', labelsize=6)
    ax1.spines['left'].set_linewidth(0.2)
    ax1.spines['bottom'].set_linewidth(0.2)
    sns.despine(ax=ax1, top=True)  # Remove top and right spines
    
    # Add twin y-axis for ECDF plot
    ax2 = ax1.twinx()
    sns.ecdfplot(df_filtered['identity'], ax=ax2, color='red')  # ECDF plot
    ax2.set_ylabel("ECDF", size=6, color='red')
    ax2.tick_params(axis='y', labelsize=6, colors='red')
    ax2.spines['bottom'].set_linewidth(0.2)
    ax2.spines['left'].set_linewidth(0.2)
    sns.despine(ax=ax2, top=True)  # Remove top and right spines

    # Add the total number of unique sequences, pairwise comparisons, and clusters
    ax1.text(0.98, 0.75, 
             f"N proteins = {num_pairwise_comparisons}\nTotal Clusters = {num_clusters}\nTotal clusters n > 1 = {num_clusters_more_than_one}",
             ha='right', va='top', transform=ax1.transAxes, fontsize=5)

    # First row, right: Scatter plot of sequence identity vs. alignment length
    sns.scatterplot(x="alnlen", y="identity", data=df_filtered, ax=axes[0, 1], color='dodgerblue', 
                    edgecolor='black', s=10, alpha=0.6, linewidth=0.2)  # Added alpha and set linewidth to 0.2
    axes[0, 1].set_xlabel("Alignment Length", size=6)
    axes[0, 1].set_ylabel("Sequence Identity", size=6)
    axes[0, 1].tick_params(axis='both', labelsize=6)
    axes[0, 1].set_ylim(0.2, 1)  # Adjust y-axis limits
    # Remove top and right spines
    sns.despine(ax=axes[0, 1], top=True, right=True)
    axes[0, 1].spines['left'].set_linewidth(0.2)
    axes[0, 1].spines['bottom'].set_linewidth(0.2)

    for ax in axes.flatten()[-2:]:
        ax.axis('off')
    
    # Create a merged axis for the second row (spanning the entire width)
    ax_full = fig.add_subplot(212)  # This replaces axes[1, 0] and axes[1, 1]

    # Box plot of sequence identities per query (sorted by median)
    sorted_df = df_filtered.groupby('query')['identity'].median().sort_values(ascending=False).index

    # Add the countplot first (so it appears behind the boxplot)
    ax_count = ax_full.twinx()
    sns.countplot(x="query", data=df_filtered, order=sorted_df, ax=ax_count, color='lightgray', alpha=0.5)
    ax_count.set_ylabel("Count", size=6)
    ax_count.tick_params(axis='y', labelsize=6)
    sns.despine(ax=ax_count, top=True, right=True)  # Remove top and right spines
    ax_count.spines['left'].set_linewidth(0.2)
    ax_count.spines['bottom'].set_linewidth(0.2)
    
    # Annotate the counts above the bars
    for p in ax_count.patches:
        height = p.get_height()
        ax_count.annotate(f'{height:.0f}', (p.get_x() + p.get_width() / 2., height),
                          ha='center', va='bottom', fontsize=5, color='black')

    # Now overlay the boxplot on top of the countplot
    sns.boxplot(x="query", y="identity", data=df_filtered, ax=ax_full, order=sorted_df, color='dodgerblue', fliersize=3, linewidth=0.75, 
                boxprops=dict(edgecolor='black', linewidth=0.2), whiskerprops=dict(color='black', linewidth=0.2), capprops=dict(color='black', linewidth=0.2), 
                medianprops=dict(color='black', linewidth=0.2), flierprops=dict(marker='o',
                                                                                markerfacecolor='dodgerblue', 
                                                                                markeredgecolor='black', 
                                                                                markersize=3, markeredgewidth=0.2), zorder=10)

    ax_full.set_xlabel("")
    ax_full.set_ylabel("Sequence Identity", size=6)
    ax_full.tick_params(axis='x', rotation=90, labelsize=6)
    ax_full.tick_params(axis='y', labelsize=6)
    ax_full.set_ylim(0.2, 1)  # Adjust y-axis limits
    # Remove top and right spines
    sns.despine(ax=ax_full, top=True, right=True)
    ax_full.spines['left'].set_linewidth(0.2)
    ax_full.spines['bottom'].set_linewidth(0.2)

    plt.tight_layout()
    
    # Save the figure if output_path is provided, otherwise display it
    if output_path:
        plt.savefig(output_path, format='pdf')
        print(f"Plot saved to {output_path}")
    else:
        plt.show()
        
        
        
        
def build_cluster_dataframe(fs):
    """
    Builds a dataframe containing 'target', 'cluster', 'identity', and 'PF' columns
    from a list of alignment dataframes.

    Parameters:
    - fs: List of file paths to the alignment files.

    Returns:
    - final_df: A pandas DataFrame with the combined data.
    """
    # Initialize an empty list to store dataframes
    df_list = []

    # Iterate over each file in 'fs'
    for f in fs:
        # Extract PF from the file path
        pf = f.split('/')[-2]

        # Read the alignment dataframe
        # Assuming the file is a TSV (tab-separated values)
        columns = ['query', 'target', 'identity', 'alnlen', 'mismatch', 'gapopen',
                   'qstart', 'qend', 'tstart', 'tend', 'evalue', 'bits']
        df = pd.read_csv(f, sep='\t', names=columns)

        # Filter out rows where 'query' == 'target' to avoid self-comparisons (optional)
        # df = df[df['query'] != df['target']]

        # Assign 'PF' to the dataframe
        df['PF'] = pf

        # Assign cluster numbers within this PF
        # Count the occurrences of each query within this PF
        query_counts = df['query'].value_counts()

        # Sort the queries by count descending
        sorted_queries = query_counts.index.tolist()

        # Assign cluster numbers starting from 0 within this PF
        cluster_mapping = {query: idx for idx, query in enumerate(sorted_queries)}

        # Map the 'query' strings to their cluster numbers
        df['cluster'] = df['query'].map(cluster_mapping)

        # Extract the necessary columns
        df_extracted = df[['target', 'cluster', 'identity', 'PF']].copy()

        # Append the dataframe to the list
        df_list.append(df_extracted)

    # Concatenate all dataframes into one
    final_df = pd.concat(df_list, ignore_index=True)

    return final_df



In [None]:
fs = glob.glob("../scripts/mmseqs_clustering/mmseqs_outputs_dynamic/*/*_pairwise.tsv")

for f in fs:

    pf = f.split('/')[-2]

    plot_pairwise_data(f, output_path=f'../scripts/mmseqs_clustering/mmseqs_outputs_dynamic/{pf}_mmseqs.pdf')

In [None]:
l = []
for file_path in fs:

    # Column names for the TSV file
    columns = ["query", "target", "identity", "alnlen", "mismatch", "gapopen",
               "qstart", "qend", "tstart", "tend", "evalue", "bits"]
    
    # Load the TSV file into a pandas dataframe
    df = pd.read_csv(file_path, sep="\t", names=columns)
    df['PF'] = file_path.split('/')[-2]
    l.append(df)

df = pd.concat(l).reset_index(drop=True)

In [None]:
# Step 1: Assign cluster numbers per PF (family) based on cluster size
df['cluster'] = -1  # Initialize cluster column

# Group by PF (family) and assign cluster numbers
for pf, group in df.groupby('PF'):
    # Rank clusters based on their size (number of members in each query cluster)
    cluster_counts = group['query'].value_counts()
    cluster_ranking = {query: i for i, (query, _) in enumerate(cluster_counts.items())}
    
    # Assign cluster numbers
    df.loc[df['PF'] == pf, 'cluster'] = df.loc[df['PF'] == pf, 'query'].map(cluster_ranking)

In [None]:
# Step 2: Create balanced folds for each family
num_folds = 5  # We want to divide into 5 groups (0-4)
fold_columns = ['folds_r0', 'folds_r1', 'folds_r2']

# Initialize fold columns
for fold_col in fold_columns:
    df[fold_col] = -1  # Initialize fold columns

In [None]:
# Function to assign clusters to folds and balance member counts
def assign_folds(df, fold_col):
    for pf, group in df.groupby('PF'):
        # Shuffle clusters to introduce randomness
        clusters = group['cluster'].unique()
        np.random.shuffle(clusters)
        
        # Assign folds to clusters while balancing member counts
        member_counts = {fold: 0 for fold in range(num_folds)}
        cluster_fold_map = {}
        
        for cluster in clusters:
            # Get members in the current cluster
            cluster_size = len(group[group['cluster'] == cluster])
            
            # Find the fold with the least members so far
            best_fold = min(member_counts, key=member_counts.get)
            
            # Assign this cluster to the fold with the least members
            cluster_fold_map[cluster] = best_fold
            member_counts[best_fold] += cluster_size
        
        # Assign fold numbers to each row based on the cluster fold map
        df.loc[(df['PF'] == pf), fold_col] = df.loc[(df['PF'] == pf), 'cluster'].map(cluster_fold_map)


In [None]:
# Assign folds for each of the three fold columns
for fold_col in fold_columns:
    assign_folds(df, fold_col)

In [None]:
df.drop_duplicates().reset_index(drop=True).to_json('../datasets/ML/MMSEQS/v13/mmseqs_outputs_dynamic/241120_mmseqs_folds_split.json')

In [None]:


df = pd.read_json('../datasets/ML/MMSEQS/v13/mmseqs_outputs_dynamic/241120_mmseqs_folds_split.json')

In [None]:
df_hx = pd.read_json("df_cooperativity.json")

In [None]:
output_path = '../results/mmseqs/241120_mmseqs_folds_split_FULL.json'

if not os.path.isdir(os.path.dirname(output_path)):
    os.makedirs(os.path.dirname(output_path))

pd.merge(df[['target', 'PF', 'cluster', 'folds_r0', 'folds_r1', 'folds_r2']],
         df_hx[['name', 'sequence', 'dg_mean', 'cooperativity_model_global', 'cooperativity_model_pf', 'normalized_cooperativity_model_global', 'normalized_cooperativity_model_pf']], 
         left_on='target', 
         right_on='name', 
         how='left')[[
    'name', 'sequence', 'PF', 'cluster', 'folds_r0', 'folds_r1', 'folds_r2', 'dg_mean', 'cooperativity_model_global', 'cooperativity_model_pf', 'normalized_cooperativity_model_global', 'normalized_cooperativity_model_pf'
         ]].to_json(output_path)