In [None]:
import os

import mokapot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from wispy import theme
pal = theme.paper()
os.makedirs("figures", exist_ok=True)

TWO_COL = 180 / 25.4
HEIGHT = 3.5
ONE_COL = 88 / 25.4

talk=False
if talk:
    pal = theme.talk()
    TWO_COL = 13
    HEIGHT = 5.8
    ONE_COL = 6.5

In [None]:
def calc_gain(dfs):
    """Calculate the number of discoveries gained using a joint model"""
    gain = (pd.concat(dfs, axis=0)
            .groupby(["pin_file", "model"])["mokapot q-value"]
            .count()
            .reset_index()
            .rename(columns={"mokapot q-value": "passed"})
            .pivot(index="pin_file", columns="model")
            .reset_index()["passed"])
    
    gain["joint_gained"] = (gain["joint"] - gain["independent"]) / gain["independent"]
    gain["joint_gained"] = gain["joint_gained"] * 100
    
    return gain

In [None]:
psm_gain = []
for f in os.listdir("mokapot-out"):
    mod = f.split(".")[0]
    if "psms" in f:
        psms = pd.read_csv(os.path.join("mokapot-out", f), sep="\t")
        psms = psms.loc[psms["mokapot q-value"] <= 0.01, :]
        psms["model"] = mod
        psm_gain.append(psms)
      
df = pd.concat(psm_gain)
psm_gain = calc_gain(psm_gain)
psm_gain["level"] = "PSMs"

In [None]:
df.loc[df["model"] == "independent", :].groupby("pin_file")["SpecId"].count().sort_values()

In [None]:
peps = {}
pep_gain = []
for f in os.listdir("mokapot-out"):
    mod = f.split(".")[0]
    if "peptides" in f:
        pep = pd.read_csv(os.path.join("mokapot-out", f), sep="\t")
        pep["Peptide"] = pep["Peptide"].str.replace("^..", "")
        pep["Peptide"] = pep["Peptide"].str.replace("..$", "")
        pep = pep.loc[pep["mokapot q-value"] <= 0.01, :]
        pep_df = pep.copy()
        
        peps[mod] = (pep.groupby("Peptide")["pin_file"]
                     .count()
                     .value_counts()
                     .sort_index(ascending=False)
                     .cumsum())
        peps[mod].name = mod
        
        pep_df["model"] = mod
        pep_gain.append(pep_df)
        
        
peps = pd.concat(peps.values(), axis=1)
pep_gain = calc_gain(pep_gain)
pep_gain["level"] = "Peptides"

In [None]:
prots = {}
prot_gain = []
for f in os.listdir("mokapot-out"):
    mod = f.split(".")[0]
    if "proteins" in f:
        prot = pd.read_csv(os.path.join("mokapot-out", f), sep="\t")
        prot = prot.loc[prot["mokapot q-value"] <= 0.01, :]
        prot_df = prot.copy()
        
        prots[mod] = (prot.groupby("mokapot protein group")["pin_file"]
                      .count()
                      .value_counts()
                      .sort_index(ascending=False)
                      .cumsum())
        prots[mod].name = mod
        
        prot_df["model"] = mod
        prot_gain.append(prot_df)
        
prots = pd.concat(prots.values(), axis=1)
prot_gain = calc_gain(prot_gain)
prot_gain["level"] = "Proteins"

In [None]:
gain = pd.concat([psm_gain, pep_gain, prot_gain])
prots

In [None]:
fig = plt.figure(figsize=(TWO_COL, HEIGHT))
gs = fig.add_gridspec(2, 2)

ax1 = fig.add_subplot(gs[:, 0])
sns.violinplot(x=gain["level"], y=gain["joint_gained"], inner=None,
               linewidth=0, ax=ax1, color=pal[0])
sns.swarmplot(x=gain["level"], y=gain["joint_gained"], ax=ax1, color="black", size=4)
ax1.axhline(0, color="black", linestyle="dashed")
ax1.set_xlabel("FDR level")
ax1.set_ylabel("Gain by joint model (%)")

ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(peps.index, peps.joint.values - peps.independent.values, 
         label="Joint Model")
ax2.plot(peps.index, peps.static.values - peps.independent.values, 
         label="Static Model", zorder=0)

ax2.set_xlabel("Number of experiments detected")
ax2.set_ylabel("Peptides gained")
ax2.axhline(0, color="black", linestyle="dashed")
ax2.legend(frameon=False)

ax3 = fig.add_subplot(gs[1, 1])
ax3.plot(prots.index, prots.joint.values - prots.independent.values, 
         label="Joint Model")
ax3.plot(prots.index, prots.static.values - prots.independent.values, 
         label="Static Model", zorder=0)

ax3.set_xlabel("Number of experiments detected")
ax3.set_ylabel("Proteins gained")
ax3.axhline(0, color="black", linestyle="dashed")
ax3.legend(frameon=False)

fig.align_labels()

for idx, (ax, label) in enumerate(zip(fig.axes, ["a", "b", "c"])):
    lab_y = 88
    if not idx:
        lab_y = 202
    
    ax.annotate(
        label, 
        (-10, lab_y), 
        xycoords="axes points", 
        fontweight='bold', 
        va='top', 
        ha='right'
    )

plt.tight_layout()
plt.savefig("figures/joint_models.png", dpi=300)