# Compare our reproduction with the original analysis

## Setup

In [2]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import ticker

# Import some utility functions:
sys.path.append("bin")
import utils

# plot styling
sns.set(context="paper", style="ticks")

# The search result files:
is_new = [Path(f).parent == Path("../results") for f in snakemake.input]
new_files = [Path(f) for f, n in zip(snakemake.input, is_new) if n]
old_files = [Path(f) for f, n in zip(snakemake.input, is_new) if not n]
mztab_files = {"new": new_files, "old": old_files}

# Parameters to define mass shifts:
tol_mass = 0.1
tol_mode = "Da"

This is a tick formatting function for our mirror plot. See: https://stackoverflow.com/questions/51086732/how-can-i-remove-the-negative-sign-from-y-tick-labels-in-matplotlib-pyplot-figur

In [38]:
class MirrorFormatter(ticker.ScalarFormatter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def __call__(self, x, pos=None):
        x = -x if x < 0 else x
        return super().__call__(x, pos)

## Read SSMs

In [3]:
ssms = {}
mass_shifts = {}
for group, group_files in mztab_files.items():
    ssms[group] = utils.read_matches(group_files)
    mass_shifts[group] = utils.get_mass_groups(ssms[group], tol_mass, tol_mode)
    mass_shifts[group]["version"] = group
    


## Create the figure

In [41]:
fig = plt.figure(figsize=(3.33, 4))
gs = fig.add_gridspec(2, 1)
pal = sns.color_palette()

# Mirror plot of mass shifts
new_mod = mass_shifts["new"].loc[mass_shifts["new"]["mass_diff_median"].abs() > tol_mass, :]
old_mod = mass_shifts["old"].loc[mass_shifts["old"]["mass_diff_median"].abs() > tol_mass, :]

ax1 = fig.add_subplot(gs[0:1])
ax1.axhline(0, color="black")
ax1.vlines(new_mod["mass_diff_median"], 0, new_mod["num_psms"], linewidth=1.2, color=pal[0])
ax1.vlines(old_mod["mass_diff_median"], 0, -old_mod["num_psms"], linewidth=1.2, color=pal[1])
ax1.set_xlim(-50, 350)
ax1.set_ylabel("SSMs")
ax1.text(70, 35000, "Reanalysis\nANN-SoLo v0.3.3, GPU", verticalalignment="center")
ax1.text(70, -35000, "Original\nANN-SoLo v0.1.2, CPU",  verticalalignment="center")
ax1.set_xlabel("Mass Shift (Da)")
ax1.yaxis.set_major_formatter(MirrorFormatter())
ax1.ticklabel_format(axis="y", style="sci", scilimits=(0,1), useMathText=True)
ax1.set_title("A", fontweight="bold", loc="left", transform=ax1.transAxes, x=-0.15)
sns.despine(ax=ax1)

# Most common mass shifts:
versions = {"new": "Reanalysis", "old": "Original"}

mods = pd.concat(mass_shifts.values())
mods["key"] = mods["mass_diff_median"].round(2)
mods["key"] = mods["key"].apply(lambda x: f"{x:.2f}")
mods["version"] = mods["version"].replace(versions)
mods = (
    mods.loc[:, ["key", "version", "num_psms"]]
    .pivot(index="key", columns="version")
    .fillna(0)
)
total = mods.sum(axis=0)
total.name = "total"

mods = mods.append(total)
mods["total"] = mods.sum(axis=1)
mods = mods.sort_values("total", ascending=False).head(5).reset_index()
mods = (
    mods.melt(id_vars=["key", "total"])
    .sort_values("version", ascending=False)
    .sort_values("total", ascending=False)
)

ax2 = fig.add_subplot(gs[1])
sns.barplot(data=mods, x="key", y="value", hue="version")
ax2.ticklabel_format(axis="y", style="sci", scilimits=(0,1), useMathText=True)
ax2.set_title("B", fontweight="bold", loc="left", transform=ax2.transAxes, x=-0.15)
ax2.set_xlabel("Top Mass Shifts (Da)")
ax2.set_ylabel("SSMs")
sns.despine(ax=ax2)
ax2.legend(title="", frameon=False)

plt.tight_layout()
mods.head(20)

plt.savefig(snakemake.output[0], dpi=300, bbox_inches="tight")
plt.show()