In [None]:
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import pickle
import pandas as pd
import os
import seaborn as sns

sns.set_theme(
    context='notebook', style='ticks', palette='bright',
    color_codes=True)  #other contexts: “paper”, “talk”, and “poster”,

# Plotting settings
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 30

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "sans-serif",
    "font.serif": ["Arial"],
    "font.size": MEDIUM_SIZE,
    "axes.titlesize": MEDIUM_SIZE,
    "axes.labelsize": MEDIUM_SIZE,
    "figure.labelsize": MEDIUM_SIZE,
    "figure.titlesize": MEDIUM_SIZE,
    "xtick.labelsize": SMALL_SIZE,
    "ytick.labelsize": SMALL_SIZE,
    "legend.fontsize": MEDIUM_SIZE,
})

color_ = [(64, 83, 211), (0, 178, 93), (181, 29, 20), (221, 179, 16), (0, 190, 255), (251, 73, 176), (202, 202, 202)]
color =[]
for t in color_:
    color.append(tuple(ti/255 for ti in t))

In [None]:
# import dataframe from pickle file
name_data = "check-131k_gen_seqs_full" # check-131k
fim_generation = True if input("Are sequences generated using FIM? (y/n)") == "y" else False

df = pd.read_pickle(f"figures/generated_sequences/dataframe_{name_data}.pkl")
families = df["family_id"].unique()

with open(f"figures/generated_sequences/all_structures_representatives.pkl", "rb") as f:
    structures_representatives = pickle.load(f)
    
if not fim_generation:
    # import baselines
    with open(f"figures/generated_sequences/all_hamming_ctx_{name_data}.pkl", "rb") as f:
        all_hamming_ctx = pickle.load(f)
    with open(f"figures/generated_sequences/all_hmmer_ctx_{name_data}.pkl", "rb") as f:
        all_hmmer_ctx = pickle.load(f)
    with open(f"figures/generated_sequences/all_structures_ctx_{name_data}.pkl", "rb") as f:
        all_structures_ctx = pickle.load(f)

assert all_structures_ctx.keys() == all_hamming_ctx.keys() == all_hmmer_ctx.keys()
assert list(all_structures_ctx.keys()) == list(df["family_id"].unique())
 
df.head()

In [None]:
from Bio import SeqIO
def read_msa_unaligned(filename: str):
    """ Reads the sequences from an MSA file, removes only . - and * characters."""
    return [(record.description, str(record.seq).replace(".","").replace("-","").replace("*","").upper()) for record in SeqIO.parse(filename, "fasta")]
# compute lengths of natural sequences
lengths_natural = {}
for family_id in families:
    msa_filepath = f"figures/pdb_structures/msas/{family_id}.a3m"
    msa = read_msa_unaligned(msa_filepath)
    lengths_natural[family_id] = [len(seq) for _, seq in msa]

## Plotting functions

In [None]:
def plot_hamming_hmmer(fig,
                       axs,
                       family_id,
                       df_family,
                       lengths_nat,
                       hamming_ctx,
                       hmmer_ctx,
                       structures_ctx,
                       last_row=False):
    lst_all_ham = [el for ham in df_family["hamming"] for el in ham]
    lst_min_ham = [min(ham) for ham in df_family["hamming"]]
    lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]
    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]
    lst_seq_len = [len(seq) for seq in df_family["generated_sequence"]]
    plddts_nat = np.array([structures_ctx[k]["mean_plddt"] for k in structures_ctx.keys()])
    
    # Sequence lengths
    # axs[0].hist(df_family["sequence_length"].to_list(), bins=30, color=color[0], alpha=0.7, density=True, label="Generated")
    # axs[0].hist(lengths_nat, bins=30, color=color[1], alpha=0.5, density=True, label="Natural")
    # # axs[0].hist([], bins=30, color=color[2], alpha=0.5, density=True, label="Generated (top 10%)")
    # axs[0].set_xlabel("Sequence length")
    axs[0].set_ylabel(family_id+"\nDensity")#"Family ID: "+
    # axs[0].set_title("Sequence lengths")
    # axs[0].legend(frameon=False)
    
    # Hamming distances
    axs[0].hist(lst_all_ham, bins = 40, alpha=0.7, density=True, color=color[0], label="Generated")
    axs[0].hist(lst_all_ham_ctx, bins = 40, alpha=0.5, density=True, color=color[1], label="Natural")
    if last_row:
        axs[0].set_xlabel("Hamming distance")
    # axs[1].set_ylabel("Density")
    # axs[1].set_title("Hamming distances")
    # axs[1].legend()
    
    # HMMER scores
    axs[1].hist(df_family["score_gen"], bins=40, alpha=0.7, density=True, color=color[0], label="Generated")
    axs[1].hist(hmmer_ctx["score"], bins=40, alpha=0.5, density=True, color=color[1], label="Natural")
    if last_row:
        axs[1].set_xlabel("HMMER score")
    # axs[2].set_ylabel("Density")
    # axs[2].set_title("HMMER scores")
    # axs[2].legend()
    
    # pLDDT
    perplexities, plddts = df_family["perplexity"].to_numpy(), df_family["mean_plddt_gen"].to_numpy() 
    plddts, perplexities, lens = plddts[plddts>0], perplexities[plddts>0], np.array(lst_seq_len)[plddts>0]
    mins, maxs = min(plddts.min(), min(plddts_nat)), max(plddts.max(), max(plddts_nat))
    bins = np.linspace(mins, maxs, 40)
    corr = np.corrcoef(perplexities,plddts)[0,1]
    ind = np.argsort(perplexities)[:len(plddts)//10]
    plddts_new, perplexities_new = plddts[ind], perplexities[ind]
    axs[2].hist(df_family["mean_plddt_gen"], bins=bins, alpha=0.7, density=True, color=color[0], label="Generated")
    axs[2].hist(plddts_new, bins=bins, alpha=0.5, density=True, color=color[2], label="Generated (top 10%)")
    axs[2].hist(plddts_nat, bins=bins, alpha=0.5, density=True, color=color[1], label="Natural")
    if last_row:
        axs[2].set_xlabel("Mean pLDDT")
    # axs[3].set_ylabel("Density")
    # axs[3].set_title("Mean pLDDT")
    # axs[3].legend()
    
    # Perplexity vs min Hamming
    im = axs[3].scatter(lst_min_ham, df_family["perplexity"].to_list(), c=lst_seq_len, cmap="viridis", alpha=0.8)
    axs[3].axvline(np.median(lst_min_ham_ctx), color=color[1], linestyle="--", linewidth=1.5)
    axs[3].axvline(np.median(lst_min_ham), color=color[0], linestyle="--", linewidth=1.5)
    if last_row:
        axs[3].set_xlabel("Min Hamming")#distance to\nclosest natural sequence
    axs[3].set_ylabel("Perplexity")
    # axs[4].set_title("Perplexity vs Hamming")
    
    # Perplexity vs HMMER score
    im = axs[4].scatter(df_family["score_gen"].to_list(), df_family["perplexity"].to_list(), c=lst_seq_len, cmap="viridis", alpha=0.8)
    axs[4].axvline(np.median(hmmer_ctx["score"]), color=color[1], linestyle="--", linewidth=1.5)
    axs[4].axvline(np.median(df_family["score_gen"]), color=color[0], linestyle="--", linewidth=1.5)
    if last_row:
        axs[4].set_xlabel("HMMER score")
    # axs[5].set_ylabel("Perplexity")
    # axs[5].set_title("Perplexity vs HMMER score")
    
    # Perplexity vs pLDDT   
    im = axs[5].scatter(plddts,perplexities, c=lens, cmap="viridis", alpha=0.8)
    axs[5].axvline(np.median(plddts_nat), color=color[1], linestyle="--", linewidth=1.5)
    axs[5].axvline(np.median(df_family["mean_plddt_gen"]), color=color[0], linestyle="--", linewidth=1.5)
    axs[5].axvline(np.median(plddts_new), color=color[2], linestyle="--", linewidth=1.5)
    # axs[5].text(0.05, 0.08, f"Correlation: {corr:.2f}", transform=axs[5].transAxes, fontsize=MEDIUM_SIZE, verticalalignment='top')
    if last_row:
        axs[5].set_xlabel("Mean pLDDT")
    # axs[6].set_ylabel("Perplexity")
    # axs[6].set_title("Perplexity vs pLDDT")

    handles, labels = axs[2].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.04), ncol=3, fancybox=False)
    fig.colorbar(im, ax=axs[5], label="Sequence length")
    return fig, axs

def compare_generation_parameters(fig,
                                  axs,
                                  family_id,
                                  df_family,
                                  lengths_nat,
                                  hamming_ctx,
                                  hmmer_ctx):
    df_family["generation_parameters"] = list(zip(df_family["temperature"].to_list(),
                                                  df_family["top_k"].to_list(),
                                                  df_family["top_p"].to_list()))
    gen_par = df_family["generation_parameters"].unique()
    color_p = color if len(color)>=len(gen_par) else sns.color_palette("viridis", len(gen_par))
    ctx_len = df_family["n_seqs_ctx"].unique()
    lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]
    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]
    for j, ctx_l in enumerate(ctx_len):
        df_ctx = df_family[df_family["n_seqs_ctx"] == ctx_l]
        ax = axs[j]
        for i, gen_p in enumerate(gen_par):
            df_gen_p = df_ctx[df_ctx["generation_parameters"] == gen_p]
            lst_all_ham = [el for ham in df_gen_p["hamming"] for el in ham]
            lst_min_ham = [min(ham) for ham in df_gen_p["hamming"]]
            lst_seq_len = [len(seq) for seq in df_gen_p["generated_sequence"]]
            # Perplexity vs Sequence lengths
            im0 = ax[0].scatter(lst_seq_len, df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p)
            # Perplexity vs min Hamming
            im1 = ax[1].scatter(lst_min_ham, df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p)
            # Perplexity vs HMMER score
            im2 = ax[2].scatter(df_gen_p["score_gen"], df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p)

        ax[0].axvline(np.median(lengths_nat), color="k", linestyle="--", linewidth=1.5, label="Natural")
        ax[0].set_ylabel(f"Seqs in ctx: {ctx_l}")
        ax[1].axvline(np.median(lst_min_ham_ctx), color="k", linestyle="--", linewidth=1.5)
        ax[2].axvline(np.median(hmmer_ctx["score"]), color="k", linestyle="--", linewidth=1.5)
    axs[0, 0].set_title("Perplexity vs Sequence length")
    axs[0, 1].set_title("Perplexity vs Hamming")
    axs[0, 2].set_title("Perplexity vs HMMER score")
    axs[-1,0].set_xlabel("Sequence length")
    axs[-1,1].set_xlabel("Hamming distance to\nclosest natural sequence")
    axs[-1,2].set_xlabel("HMMER score")
    
    # Unique legend for the entire figure on the top of all subplots
    handles, labels = axs[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=5, fancybox=False)
    fig.supylabel("Perplexity")
    return fig, axs

def plot_comparison(fig,
                       axs,
                       family_id,
                       df_family,
                       lengths_nat,
                       hamming_ctx,
                       hmmer_ctx,
                       structures_ctx):
    def make_formula(val,std):
        return "$"+f"{round(val,2)} \pm {round(std,2)}"+"$"
    
    # lst_all_ham = [el for ham in df_family["hamming"] for el in ham]
    # lst_min_ham = [min(ham) for ham in df_family["hamming"]]
    # lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]
    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]
    lst_seq_len = [len(seq) for seq in df_family["generated_sequence"]]

    perplexities = df_family["perplexity"].to_numpy()
    plddt_ref = structures_representatives[family_id]["mean_plddt"]
    ptm_ref = structures_representatives[family_id]["ptm"]
    plddts_gen = df_family["mean_plddt_gen"].to_numpy()
    ptm_gen = df_family["ptm_gen"].to_numpy()
    plddts_nat = np.array([structures_ctx[k]["mean_plddt"] for k in structures_ctx.keys()])
    ptm_nat = np.array([structures_ctx[k]["ptm"] for k in structures_ctx.keys()])

    corr2 = np.corrcoef(perplexities,df_family["score_gen"].to_numpy())[0,1]
    corr3 = np.corrcoef(perplexities,df_family["min_hamming"].to_numpy())[0,1]
    corr4 = np.corrcoef(perplexities,lst_seq_len)[0,1]    
    plddts_gen, ptm_gen, perplexities = plddts_gen[plddts_gen>0], ptm_gen[plddts_gen>0], perplexities[plddts_gen>0]
    corr0 = np.corrcoef(perplexities,plddts_gen)[0,1]
    corr1 = np.corrcoef(perplexities,ptm_gen)[0,1]
    ind = np.argsort(perplexities)[:100]#len(plddts_gen)//10
    ind1 = np.argsort(perplexities)[:len(hmmer_ctx["score"])]
    plddts_gen, ptm_gen, perplexities = plddts_gen[ind], ptm_gen[ind], perplexities[ind]
    hmmer_gen, ham_gen, lens_gen = df_family["score_gen"].to_numpy()[ind1], df_family["min_hamming"].to_numpy()[ind], np.array(lst_seq_len)
    
    # print(family_id, f" & ${round(ptm_ref, 2)}$ & ",  make_formula(np.median(ptm_gen),np.std(ptm_gen)), " & ", \
    #                                                     make_formula(np.median(ptm_nat),np.std(ptm_nat)), " & ", \
    #                  f" & ${round(plddt_ref, 2)}$ & ", \
    #                                                     make_formula(np.median(plddts_gen),np.std(plddts_gen)), " & ", \
    #                                                     make_formula(np.median(plddts_nat),np.std(plddts_nat)), " & ", \
    #                                                     f"${round(corr0,2)}$ \\\\")
    print(family_id, f" & ${round(corr3,2)}$ & ", f"${round(corr2,2)}$ & ", f"${round(corr0,2)}$ & ", f"${round(corr1,2)}$ \\\\")
    corrs = [corr3, corr2, corr0, corr1]
    # Sequence lengths
    axs[0].errorbar(np.median(lens_gen), np.median(lengths_nat), xerr=np.std(lens_gen), yerr=np.std(lengths_nat), fmt='o', color=color[0], label=family_id, capsize=2, elinewidth=0.8, markersize=4)
    # axs[0].plot([0, 1], [0, 1], transform=axs[0].transAxes, linestyle="--", color="k")    
    axs[0].set_title("Sequence length")
    
    # Min Hamming distances
    axs[1].errorbar(np.median(ham_gen), np.median(lst_min_ham_ctx), xerr=np.std(ham_gen), yerr=np.std(lst_min_ham_ctx), fmt='o', color=color[1], label=family_id, capsize=2, elinewidth=0.8, markersize=4)
    # axs[1].plot([0, 1], [0, 1], transform=axs[1].transAxes, linestyle="--", color="k")
    axs[1].set_title("Min Hamming")
    
    # HMMER scores
    mmin, mmax = min(min(hmmer_gen), min(hmmer_ctx["score"])), max(max(hmmer_gen), max(hmmer_ctx["score"]))
    rescaled_hmmer_gen = (hmmer_gen - mmin) / (mmax - mmin)
    rescaled_hmmer_ctx = (np.array(hmmer_ctx["score"]) - mmin) / (mmax - mmin)
    axs[2].errorbar(np.median(rescaled_hmmer_gen), np.median(rescaled_hmmer_ctx), xerr=np.std(rescaled_hmmer_gen), yerr=np.std(rescaled_hmmer_ctx), fmt='o', color=color[2], label=family_id, capsize=2, elinewidth=0.8, markersize=4)
    # axs[2].plot([0, 1], [0, 1], transform=axs[2].transAxes, linestyle="--", color="k")
    axs[2].set_title("HMMER score")
    
    # pLDDTs
    axs[3].errorbar(np.median(plddts_gen), np.median(plddts_nat), xerr=np.std(plddts_gen), yerr=np.std(plddts_nat), fmt='o', color=color[3], label=family_id, capsize=2, elinewidth=0.8, markersize=4)
    # axs[1].text(np.median(plddts_gen), np.median(plddts_nat), family_id, fontsize=SMALL_SIZE)
    # axs[3].plot([0, 1], [0, 1], transform=axs[3].transAxes, linestyle="--", color="k")
    axs[3].set_title("pLDDT")
    
    # pTMs
    axs[4].errorbar(np.median(ptm_gen), np.median(ptm_nat), xerr=np.std(ptm_gen), yerr=np.std(ptm_nat), fmt='o', color=color[4], label=family_id, capsize=2, elinewidth=0.8, markersize=4)
    # axs[4].plot([0, 1], [0, 1], transform=axs[4].transAxes, linestyle="--", color="k")
    axs[4].set_title("pTM")
    
    # axs[0].set_xlabel("generated")
    # axs[0].set_ylabel("natural")   
    fig.supxlabel("Generated")
    fig.supylabel("Natural")
    return fig, axs, corrs

In [None]:
def plot_scores_fim(fig,
                    axs,
                    family_id,
                    df_family):
    # pLDDT masked parts
    im = axs[0].scatter(df_family["masked_plddt_orig"].to_list(), df_family["masked_plddt_gen"].to_list(), c=df_family["perplexity"].to_list(), cmap="viridis")
    xvals, yvals = df_family["masked_plddt_orig"].to_list(), df_family["masked_plddt_gen"].to_list()
    mins, maxs = min(min(xvals),min(yvals)), max(max(xvals),max(yvals))
    axs[0].plot([mins,maxs],[mins,maxs], "k--")
    axs[0].set_xlabel("Masked pLDDT original")
    axs[0].set_ylabel("Family ID: "+family_id+"\nMasked pLDDT generated")
    
    fig.colorbar(im, ax=axs[0], label="Perplexity\n")
    # col_par = (df_family["fim_distance"]).to_list()
    col_par = df_family["perplexity"].to_list()
    # HMMER scores
    im = axs[1].scatter(df_family["fim_size"].to_list(),(df_family["score_gen"]-df_family["score_orig"]).to_list(), c=col_par, cmap="viridis")
    axs[1].set_ylabel("HMMER score difference")
    axs[1].set_xlabel("FIM size")
    
    # pTM scores
    im = axs[2].scatter(df_family["fim_size"].to_list(),(df_family["ptm_gen"]-df_family["ptm_orig"]).to_list(), c=col_par, cmap="viridis")
    axs[2].set_ylabel("pTM difference")
    axs[2].set_xlabel("FIM size")
    
    # Perplexity vs FIM distance relative to original
    im = axs[3].scatter(df_family["fim_size"].to_list(), df_family["perplexity"].to_list(), c=col_par, cmap="viridis")
    axs[3].set_xlabel("FIM size")
    axs[3].set_ylabel("Perplexity")
    
    # Perplexity vs TMscore
    im = axs[4].scatter(df_family["fim_size"].to_list(), df_family["tmscore_orig_gen"].to_list(), c=col_par, cmap="viridis")
    axs[4].set_ylabel("TMscore between orig. and gen.")
    axs[4].set_xlabel("FIM size")
       
    fig.colorbar(im, ax=axs[4], label="Perplexity")
    return fig, axs

def plot_scores_hist_fim(fig,
                        axs,
                        family_id,
                        df_family):
    perplexity = df_family["perplexity"].to_numpy()
    inds = np.argsort(perplexity)[:len(perplexity)//5]
    inds1 = np.argwhere(df_family["fim_distance"].to_numpy()>0)
    # take all indices that are both in inds and inds1
    inds = np.intersect1d(inds, inds1)
    bins = np.concatenate([-np.linspace(0,1,20)[::-1],np.linspace(0,1,20)[1:]])
    # pLDDT masked parts
    full_max = max(-(df_family["masked_plddt_gen"]-df_family["masked_plddt_orig"]).min(), (df_family["masked_plddt_gen"]-df_family["masked_plddt_orig"]).max())
    im = axs[0].hist((df_family["masked_plddt_gen"]-df_family["masked_plddt_orig"]).to_numpy(), bins=bins*full_max)
    im = axs[0].hist((df_family["masked_plddt_gen"]-df_family["masked_plddt_orig"]).to_numpy()[inds], bins=im[1])
    f0 = round(np.sum((df_family['masked_plddt_gen']-df_family['masked_plddt_orig']).to_numpy()>0)/len(df_family),2)
    f1 = round(np.sum((df_family['masked_plddt_gen']-df_family['masked_plddt_orig']).to_numpy()[inds]>0)/len(inds),2)
    axs[0].text(0.6, 0.8, f"F0: {f0}\nF1: {f1}", transform=axs[0].transAxes)
    axs[0].set_xlabel("Masked pLDDT difference")
    axs[0].set_ylabel("Family ID: "+family_id)
    
    # HMMER scores
    full_max = max(-(df_family["score_gen"]-df_family["score_orig"]).min(), (df_family["score_gen"]-df_family["score_orig"]).max())
    im = axs[1].hist((df_family["score_gen"]-df_family["score_orig"]).to_numpy(), bins=bins*full_max)
    im = axs[1].hist((df_family["score_gen"]-df_family["score_orig"]).to_numpy()[inds], bins=im[1])
    f0 = round(np.sum((df_family['score_gen']-df_family['score_orig']).to_numpy()>0)/len(df_family),2)
    f1 = round(np.sum((df_family['score_gen']-df_family['score_orig']).to_numpy()[inds]>0)/len(inds),2)
    axs[1].text(0.6, 0.8, f"F0: {f0}\nF1: {f1}", transform=axs[1].transAxes)
    axs[1].set_xlabel("HMMER score difference")
    
    # pTM scores
    full_max = max(-(df_family["ptm_gen"]-df_family["ptm_orig"]).min(), (df_family["ptm_gen"]-df_family["ptm_orig"]).max())
    im = axs[2].hist((df_family["ptm_gen"]-df_family["ptm_orig"]).to_numpy(), bins=bins*full_max)
    im = axs[2].hist((df_family["ptm_gen"]-df_family["ptm_orig"]).to_numpy()[inds], bins=im[1])
    f0 = round(np.sum((df_family['ptm_gen']-df_family['ptm_orig']).to_numpy()>0)/len(df_family),2)
    f1 = round(np.sum((df_family['ptm_gen']-df_family['ptm_orig']).to_numpy()[inds]>0)/len(inds),2)
    axs[2].text(0.6, 0.8, f"F0: {f0}\nF1: {f1}", transform=axs[2].transAxes)
    axs[2].set_xlabel("pTM difference")
    
    # Perplexity vs FIM distance relative to original
    im = axs[3].hist(df_family["fim_size"].to_numpy(), bins=40)
    im = axs[3].hist(df_family["fim_size"].to_numpy()[inds], bins=im[1])
    axs[3].set_xlabel("FIM size")
    
    # Perplexity vs TMscore
    im = axs[4].hist(df_family["tmscore_orig_gen"].to_numpy(), bins=40)
    im = axs[4].hist(df_family["tmscore_orig_gen"].to_numpy()[inds], bins=im[1])
    axs[4].set_xlabel("TMscore between orig. and gen.")

    return fig, axs

def compare_generation_parameters_fim(fig,
                                      axs,
                                      family_id,
                                      df_family):
    df_family["generation_parameters"] = list(zip(df_family["temperature"].to_list(),
                                                  df_family["top_k"].to_list(),
                                                  df_family["top_p"].to_list()))
    gen_par = df_family["generation_parameters"].unique()
    color_p = sns.color_palette("viridis", len(gen_par)) #color if len(color)>=len(gen_par) else 
    ctx_len = df_family["n_seqs_ctx"].unique()
    for j, ctx_l in enumerate(ctx_len):
        df_ctx = df_family[df_family["n_seqs_ctx"] == ctx_l]
        ax = axs[j]
        for i, gen_p in enumerate(gen_par):
            df_gen_p = df_ctx[df_ctx["generation_parameters"] == gen_p]
            # pLDDT masked parts
            im0 = ax[0].scatter(df_gen_p["masked_plddt_orig"].to_list(), df_gen_p["masked_plddt_gen"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)
            xvals, yvals = df_gen_p["masked_plddt_orig"].to_list(), df_gen_p["masked_plddt_gen"].to_list()
            mins, maxs = min(min(xvals),min(yvals)), max(max(xvals),max(yvals))
            ax[0].plot([mins,maxs],[mins,maxs], "k--")
            # HMMER scores
            im1 = ax[1].scatter((df_gen_p["score_gen"]-df_gen_p["score_orig"]).to_list(), df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)
            # pTM scores
            im2 = ax[2].scatter((df_gen_p["ptm_gen"]-df_gen_p["ptm_orig"]).to_list(), df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)
            # Perplexity vs FIM distance relative to original
            im3 = ax[3].scatter(df_gen_p["fim_distance"].to_list(), df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)
            # Perplexity vs TMscore
            im4 = ax[4].scatter(df_gen_p["tmscore_orig_gen"].to_list(), df_gen_p["perplexity"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)
        ax[0].set_ylabel(f"Seqs in ctx: {ctx_l}"+"\nMasked pLDDT generated")
        ax[1].set_ylabel("Perplexity")
    axs[-1,0].set_xlabel("Masked pLDDT original")
    axs[-1,1].set_xlabel("HMMER score difference")
    axs[-1,2].set_xlabel("pTM difference")
    axs[-1,3].set_xlabel("FIM distance relative to original")
    axs[-1,4].set_xlabel("TMscore between orig. and gen.")
    
    # Unique legend for the entire figure on the top of all subplots
    handles, labels = axs[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=6, fancybox=False)
    return fig, axs

## AR generated

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(13, 3), constrained_layout=True)
for ax in axs:
    ax.set_box_aspect(1)
all_corr = []
for i, family_id in enumerate(families):
    df_family = df[df["family_id"] == family_id]
    lengths_nat = lengths_natural[family_id]
    hamming_ctx = all_hamming_ctx[family_id]
    hmmer_ctx = all_hmmer_ctx[family_id]
    structures_ctx = all_structures_ctx[family_id]
    fig, tmp_axs, corrs = plot_comparison(fig, axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx, structures_ctx)
    all_corr += [np.array(corrs)]
all_corr = np.array(all_corr)

for i in range(len(axs)):
    ymin,ymax = tmp_axs[i].get_ylim()
    xmin, xmax = tmp_axs[i].get_xlim()
    mmin, mmax = min(xmin,ymin), max(xmax,ymax)
    tmp_axs[i].set_xlim(mmin, mmax)
    tmp_axs[i].set_ylim(mmin, mmax)
    tmp_axs[i].set_aspect('equal', adjustable='box')

tmp_axs[0].plot([0,500],[0,500], "k--")
tmp_axs[1].plot([0.2,0.8],[0.2,0.8], "k--")
tmp_axs[2].plot([0,0.8],[0,0.8], "k--")
tmp_axs[3].plot([0.5,1],[0.5,1], "k--")
tmp_axs[4].plot([0.,1],[0.,1], "k--")
# print("Mean", f" & ${round(all_corr[0],2)}$ & ", f"${round(all_corr[1],2)}$ & ", f"${round(all_corr[2],2)}$ & ", f"${round(all_corr[3],2)}$ \\\\")
plt.show()
fig.savefig(f"figures/generated_sequences/comparison_{name_data}.pdf", bbox_inches='tight')

In [None]:
inds = np.argsort(abs(all_corr).mean(1))

fig,axs = plt.subplots(1,1,figsize=(5,8))
plt.plot(all_corr[inds,0], np.arange(len(families)), "s", color=color[1], label="Hamming", markersize=8)
plt.plot(-all_corr[inds,1], np.arange(len(families)), "^", color=color[2], label="HMMER", markersize=8)
plt.plot(-all_corr[inds,2], np.arange(len(families)), "v", color=color[3], label="pLDDT", markersize=8)
plt.plot(-all_corr[inds,3], np.arange(len(families)), "o", color=color[4], label="pTM", markersize=8)
plt.plot(abs(all_corr[inds,:]).mean(1), np.arange(len(families)), "*", color="k", label="Mean", markersize=15)
plt.yticks(np.arange(len(families)), [families[i] for i in inds])
plt.xlabel("Pearson Correlation")
plt.legend(frameon=False)
fig.savefig(f"figures/generated_sequences/comparison_{name_data}_correlation.pdf", bbox_inches='tight')
plt.show()

# invert x and y axes
fig, axs = plt.subplots(1, 1, figsize=(12, 4), constrained_layout=True)
plt.plot(np.arange(len(families)), all_corr[inds,0], "s", color=color[1], label="Hamming", markersize=8)
plt.plot(np.arange(len(families)), -all_corr[inds,1], "^", color=color[2], label="HMMER", markersize=8)
plt.plot(np.arange(len(families)), -all_corr[inds,2], "v", color=color[3], label="pLDDT", markersize=8)
plt.plot(np.arange(len(families)), -all_corr[inds,3], "o", color=color[4], label="pTM", markersize=8)
plt.plot(np.arange(len(families)), abs(all_corr[inds,:]).mean(1), "*", color="k", label="Mean", markersize=15)
plt.xticks(np.arange(len(families)), [families[i] for i in inds], rotation=90)
plt.ylabel("Pearson Correlation")
plt.legend(frameon=False, bbox_to_anchor=(1.02, -0.05), loc='lower right', ncol=3)
# fig.savefig(f"figures/generated_sequences/comparison_{name_data}_correlation.pdf", bbox_inches='tight')
plt.show()

In [None]:

fig, axs = plt.subplots(len(families), 6, figsize=(18, 2.5*len(families)), constrained_layout=True)

for i, family_id in enumerate(families):
    df_family = df[df["family_id"] == family_id]
    lengths_nat = lengths_natural[family_id]
    hamming_ctx = all_hamming_ctx[family_id]
    hmmer_ctx = all_hmmer_ctx[family_id]
    structures_ctx = all_structures_ctx[family_id]

    tmp_axs = axs[i]
    fig, tmp_axs = plot_hamming_hmmer(fig, tmp_axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx, structures_ctx,last_row=(True if i==len(families)-1 else False))
plt.show()
fig.savefig(f"figures/generated_sequences/scatter_{name_data}.pdf", bbox_inches='tight')

In [None]:
ctx_len = df["n_seqs_ctx"].unique()
for i, family_id in enumerate(families):
    fig, axs = plt.subplots(len(ctx_len), 3, figsize=(15, 5*len(ctx_len)), sharex="col", sharey=True, constrained_layout=True)
    df_family = df[df["family_id"] == family_id]
    lengths_nat = lengths_natural[family_id]
    hamming_ctx = all_hamming_ctx[family_id]
    hmmer_ctx = all_hmmer_ctx[family_id]
    fig, axs = compare_generation_parameters(fig, axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx)
    plt.show()

## FIM generated

In [None]:
fig, axs = plt.subplots(len(families), 5, figsize=(25, 5*len(families)), constrained_layout=True)

for i, family_id in enumerate(families):
    df_family = df[df["family_id"] == family_id]

    tmp_axs = axs[i]
    fig, tmp_axs = plot_scores_fim(fig, tmp_axs, family_id, df_family)
plt.show()

In [None]:
fig, axs = plt.subplots(len(families), 5, figsize=(25, 5*len(families)), constrained_layout=True)

for i, family_id in enumerate(families):
    df_family = df[df["family_id"] == family_id]

    tmp_axs = axs[i]
    fig, tmp_axs = plot_scores_hist_fim(fig, tmp_axs, family_id, df_family)
plt.show()

In [None]:
ctx_l = df["n_seqs_ctx"].unique()
for i, family_id in enumerate(families):
    fig, axs = plt.subplots(len(ctx_l), 5, figsize=(25, 5*len(ctx_l)), sharex="col", sharey="col", constrained_layout=True)
    df_family = df[df["family_id"] == family_id]
    fig, axs = compare_generation_parameters_fim(fig, axs, family_id, df_family)
    plt.show()