In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple

def plot_tm_vs_identity(tsv_path, mmseq_path, protein_name="SQR", output_path="tm_vs_identity.png"):
    # Column names for FoldSeek
    colnames = [
        "query", "target", "fident", "alnlen", "mismatch", "gapopen",
        "qstart", "qend", "tstart", "tend", "evalue", "bits",
        "alntmscore", "qtmscore", "ttmscore"
    ]
    df = pd.read_csv(tsv_path, sep="\t", names=colnames)

    # Clean & convert numeric
    df["fident"] = df["fident"].astype(float)
    df["alntmscore"] = df["alntmscore"].astype(float)

    # Label SQR+ vs SQR–
    df["group"] = df["query"].apply(lambda x: "SQR+" if x.startswith("WP_") else "SQR–")

    # Load MMseqs results and extract unique queries
    mmseq_df = pd.read_csv(mmseq_path, sep="\t", header=None, names=[
        "query", "target", "pident", "alnlen", "mismatch", "gapopen", "qstart", "qend",
        "tstart", "tend", "evalue", "bits"
    ])
    mmseq_hits = set(mmseq_df["query"].unique())

    # Determine if each query was found by MMseqs2
    df["overlap"] = df["query"].apply(lambda x: "Detected in MMseqs2 & FoldSeek" if x in mmseq_hits else "Detected in FoldSeek only")

    # Summary table
    summary = df.groupby(["group", "overlap"]).size().unstack(fill_value=0)
    print("\nSummary Table:\n", summary)

    # Set up plot
    plt.figure(figsize=(7, 7))

    # Define colors for each group+overlap combo
    color_map = {
        ("SQR+", "Detected in MMseqs2 & FoldSeek"): ("green", "Detected using expanded database (SQR+)"),
        ("SQR+", "Detected in FoldSeek only"): ("blue", "Detected using SwissProt SQR database (SQR+)"),
        ("SQR–", "Detected in FoldSeek only"): ("purple", "SQR–"),
    }

    added_labels = set()

    for (group, overlap), (color, label) in color_map.items():
        subset = df[(df["group"] == group) & (df["overlap"] == overlap)]
        label_to_use = label if label not in added_labels else None  
        plt.scatter(subset["fident"], subset["alntmscore"], alpha=0.7,
                    edgecolors="k", color=color, label=label_to_use)
        added_labels.add(label)

    # Threshold line
    plt.axhline(y=0.9, color="red", linestyle="--", linewidth=1)

    handle_expanded = (
        Line2D([0], [0], marker='o', linestyle='', color='w', markerfacecolor='green', markersize=8),
        Line2D([0], [0], marker='o', linestyle='', color='w', markerfacecolor='blue', markersize=8)
    )
    handle_swissprot = Line2D([0], [0], marker='o', linestyle='', color='w', markerfacecolor='blue', markersize=8)
    handle_negative = Line2D([0], [0], marker='o', linestyle='', color='w', markerfacecolor='purple', markersize=8)

    plt.legend(
        handles=[handle_expanded, handle_swissprot, handle_negative],
        labels=[
            'Detected using expanded database (SQR+)',
            'Detected using SwissProt SQR database (SQR+)',
            'SQR–'
        ],
        handler_map={tuple: HandlerTuple(ndivide=None)},
        loc='lower left'
    )

    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("Global Sequence Identity", fontsize=12)
    plt.ylabel("TM-align Score", fontsize=12)
    plt.title("Structural vs. Sequence Similarity of SQR Proteins", fontsize=14)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print(f"Saved plot to {output_path}")
    plt.show()

plot_tm_vs_identity("foldseek_mmseq_combined_result.tsv", "mmseq_only_result.m8")