In [None]:
import pandas as pd
from collections import defaultdict

# Load the Excel file
FILE = '20250507_XLMS_FAIMS_ingel_insolution.xlsx'
df = pd.read_excel(FILE)

# Strip column names to fix any whitespace issues
df.columns = df.columns.str.strip()

# Define the columns of interest
group_cols = [
    'Spectrum File',
    'Protein Accession A', 'Leading Protein Position A',
    'Protein Accession B', 'Leading Protein Position B'
]

# Group by file and crosslink positions
grouped = df.groupby(group_cols)

# Aggregate: count occurrences and get max XlinkX Score
summary = grouped.agg(
    Occurrences=('XlinkX Score', 'count'),
    Max_XlinkX_Score=('XlinkX Score', 'max')
).reset_index()

# Function to normalize key for merging inverted crosslinks
def normalize_key(a, x, b, y):
    return (a, x, b, y) if (a, x) <= (b, y) else (b, y, a, x)

# Dictionary to hold result_set per file
file_to_result_set = {}

# Iterate over each file group
for file, group_df in summary.groupby("Spectrum File"):
    result_set = {
        (
            row['Protein Accession A'], row['Leading Protein Position A'],
            row['Protein Accession B'], row['Leading Protein Position B'],
            row['Occurrences'], row['Max_XlinkX_Score']
        )
        for _, row in group_df.iterrows()
    }

    # Merge inverted pairs
    merged_dict = defaultdict(lambda: [0, 0.0])
    for a, x, b, y, occ, score in result_set:
        key = normalize_key(a, x, b, y)
        merged_dict[key][0] += occ
        merged_dict[key][1] = max(merged_dict[key][1], score)

    # Create final merged result_set for this file
    merged_result_set = {
        (a, x, b, y, occ, score)
        for (a, x, b, y), (occ, score) in merged_dict.items()
    }
    merged_result_sorted = sorted(merged_result_set, key=lambda x: (x[0], x[1]))


    file_to_result_set[file] = merged_result_sorted


# Example: print keys and how many crosslinks each file has
for file, rset in file_to_result_set.items():
    print(file, file_to_result_set[file])
    print(f"{file}: {len(rset)} crosslinks")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def xl_ms_plot(file,merged_result_sorted):
    # Create DataFrame
    df = pd.DataFrame(merged_result_sorted, columns=["accA", "posA", "accB", "posB", "CSM count", "top XlinkX score"])
    df["a_x"] = df["accA"] + "_" + df["posA"].astype(str)
    df["b_y"] = df["accB"] + "_" + df["posB"].astype(str)
    
    # Prepare long format for symmetric plotting (so both A vs B and B vs A are shown)
    df_long = pd.concat([
        df[['a_x', 'b_y', 'CSM count', 'top XlinkX score']],
        df.rename(columns={'a_x': 'b_y', 'b_y': 'a_x'})[['a_x', 'b_y', 'CSM count', 'top XlinkX score']]
    ])
    
    # Create ordered tick list from all a_x and b_y values
    tick_set = set(df["a_x"]).union(set(df["b_y"]))
    tick_list_sorted = sorted(tick_set, key=lambda s: (s.split("_")[0], int(s.split("_")[1])))
    
    # Convert a_x and b_y to ordered categorical types
    df_long["a_x"] = pd.Categorical(df_long["a_x"], categories=tick_list_sorted, ordered=True)
    df_long["b_y"] = pd.Categorical(df_long["b_y"], categories=tick_list_sorted[::-1], ordered=True)
    
    # Plot
    plt.figure(figsize=(18, 15))
    bubble = sns.scatterplot(
        data=df_long,
        x="a_x",
        y="b_y",
        size="CSM count",
        hue="top XlinkX score",
        palette="Spectral_r",
        sizes=(100, 800),
        edgecolor="black",
        legend="auto",
    )
    bubble.grid(True, which='major', axis='both', linestyle='--', alpha=0.5)
    
    # Plot y=x line
    tick_indices = list(range(len(tick_list_sorted)))
    for i in tick_indices:
        bubble.plot([i], [len(tick_list_sorted) - 1 - i], marker='o', markersize=0, color='black')
    bubble.plot(tick_indices, [len(tick_list_sorted) - 1 - i for i in tick_indices], color='black', linewidth=1, linestyle='--')
    
    # Get all tick labels from the categorical axis
    all_ticks = df_long["a_x"].cat.categories
    # Extract accessions in order of appearance
    accessions = []
    last_acc = None
    for tick in all_ticks:
        acc = tick.split("_")[0]
        if acc != last_acc:
            accessions.append(acc)
            last_acc = acc
    # Compute switch indices
    switch_indices = []
    current_acc = all_ticks[0].split("_")[0]
    for i, tick in enumerate(all_ticks):
        acc = tick.split("_")[0]
        if acc != current_acc:
            switch_indices.append(i)
            current_acc = acc
    # Plot vertical and horizontal lines at switch positions
    for idx in switch_indices:
        plt.axvline(x=idx - 0.5, color='gray', linestyle='-', linewidth=1)
        plt.axhline(y=len(all_ticks) - idx - 0.5, color='gray', linestyle='-', linewidth=1)

    
    
    # Axis labeling and ticks
    plt.title(file, fontsize=24,pad=15)
    plt.xlabel('Crosslink1 (Accession_Position)', fontsize=20)
    plt.ylabel('Crosslink2 (Accession_Position)', fontsize=20)
    plt.xticks(rotation=90,size=18)
    plt.yticks(size=18)
    
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left',labelspacing=1.5, prop={'size': 15})
    plt.tight_layout()
    plt.savefig(file+".png")
    plt.show()

for file, rset in file_to_result_set.items():
    xl_ms_plot(file, rset)