# Experiments on markup-extent

“The authors developed the SPAN methodology which identifies peaks for ULI-ChIP-seq data using semi-supervised approach which uses user-annotations to train a model and choose parameters optimized for minimization of the training error. However, I could not find in the text the extent of manual annotations that should be done in order to achieve a reliable model. In addition, the test-error (manual vs. automated peaks annotation) is not reported. I thus suggest adding a table/plot summarizing the test error resulting from each extent of manual annotation.”

Experiment design:

* take ABF ChIP-seq data (5 targets, 40 donors)
* investigate the effects of using k = 10, 20, 50, 100 (out of ~500) labels
* for each k, create 10 random batches of k training labels and 500-k test labels 
  (only 5 batches for k=100 for obvious reasons)
* for each batch, tune Span on training labels and report training and test error
  average the errors across batches
* in total, 7000 = 35 * 5 * 40 training-test error pairs were reported
* from those, 800 = 4 * 5 * 40 mean training-test error pairs were generated

Before launching this notebook, please launch `SpanTestErrorEvaluation` experiment from https://github.com/JetBrains-Research/epigenome

In [None]:
import pandas as pd;
import numpy as np;
import os;
import matplotlib.pyplot as plt;
import seaborn as sns;

In [None]:
TARGETS = ["H3K27ac", "H3K27me3", "H3K4me3", "H3K4me1", "H3K36me3"]
TSVS = [os.path.join("/mnt/stripe/bio/experiments/span-test-error", target + ".tsv") for target in TARGETS]

In [None]:
dataframes = [pd.read_csv(tsv, sep="\t", comment="#") for tsv in TSVS]

In [None]:
for dataframe in dataframes:
    track_id = [i // 35 for i in range(0, len(dataframe))]
    dataframe["track_id"] = track_id

In [None]:
def aggregate_test_error(dataframe):
    res = pd.DataFrame(columns=["k", "track_id",
                                "mean_train_error", "mean_test_error",
                                "sd_train_error", "sd_test_error",
                                "median_train_error", "median_test_error",
                                "median_macs2_error", "median_sicer_error",
                                "mean_macs2_error", "mean_sicer_error",
                                "median_span_error", "mean_span_error"])
    for k in sorted(set(dataframe["k"])):
        for track_id in sorted(set(dataframe["track_id"])):
            subset = dataframe[np.logical_and(dataframe["k"] == k, dataframe["track_id"] == track_id)]
            mean_train_error = np.mean(subset["train_error"])
            sd_train_error = np.std(subset["train_error"])
            mean_test_error = np.mean(subset["test_error"])
            sd_test_error = np.std(subset["test_error"])
            median_train_error = np.median(subset["train_error"])
            median_test_error = np.median(subset["test_error"])
            median_macs2_error = np.median(subset["macs2_error"])
            mean_macs2_error = np.mean(subset["macs2_error"])
            median_sicer_error = np.median(subset["sicer_error"])
            mean_sicer_error = np.mean(subset["sicer_error"])
            median_span_error = np.median(subset["span_error"])
            mean_span_error = np.mean(subset["span_error"])
            res.loc[len(res)] = (k, track_id,
                                 mean_train_error, mean_test_error,
                                 sd_train_error, sd_test_error,
                                 median_train_error, median_test_error,
                                 median_macs2_error, median_sicer_error,
                                 mean_macs2_error, mean_sicer_error,
                                 median_span_error, mean_span_error)
    return res

In [None]:
test_error_dfs = [aggregate_test_error(dataframe) for dataframe in dataframes]

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for i, k in zip(range(len(ks)), ks):
        data = test_error_df[test_error_df["k"] == k]["mean_train_error"]
        plt.boxplot(data, positions=[i], widths=0.6)
    plt.xlim(-1, len(ks))    
    plt.xticks(range(0,4), ks)
    plt.title(target)
    plt.xlabel('Training labels')
    plt.ylabel('Mean training error')
    plt.tight_layout()
    plt.savefig('{}_train.png'.format(target), width=800, height=600)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for i, k in zip(range(len(ks)), ks):
        data = test_error_df[test_error_df["k"] == k]["mean_test_error"]        
        plt.boxplot(data, positions=[i], widths=0.6)
    plt.xlim(-1,4)    
    plt.xticks(range(0,4), ks)
    plt.title(target)
    plt.xlabel('Training labels')
    plt.ylabel('Mean test error')
    plt.tight_layout()
    plt.savefig('{}_test.png'.format(target), width=800, height=600)
    plt.show()

In [None]:
COLOR_MAP = {10: 'k', 20: 'b', 50: 'y', 100: 'r'}
for target, test_error_df in zip(TARGETS, test_error_dfs):
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for k in ks:
        data = test_error_df[test_error_df["k"] == k]
        plt.plot(data["mean_train_error"], data["mean_test_error"], 'o', color=COLOR_MAP[k])
    plt.xlabel("Mean training error")
    plt.ylabel("Mean test error")
    plt.title(target)    
    # plt.axis([0, 1, 0, 1])
    plt.plot([0,1], [0,1], '-')
    plt.legend(ks + ["y=x"])
    plt.savefig('{}_train_test.png'.format(target), width=800, height=600)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):    
    plot_dataframes = []
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for i, k in zip(range(len(ks)), ks):
        data_error = np.append(test_error_df[test_error_df["k"] == k]["median_train_error"],
                               test_error_df[test_error_df["k"] == k]["median_test_error"])
        data_k = [k] * len(data_error)
        data_type = ["train"] * (len(data_error) // 2) + ["test"] * (len(data_error) // 2)
        plot_dataframes.append(pd.DataFrame({"k": data_k, "error": data_error, "type": data_type}))    
    sns.swarmplot(x="k", y="error", hue="type", dodge=True, data=pd.concat(plot_dataframes))
    plt.ylabel("median error")
    plt.title(target)
    plt.legend(loc='upper right', title=None)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):    
    plot_dataframes = []
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for i, k in zip(range(len(ks)), ks):
        data_error = np.append(np.append(test_error_df[test_error_df["k"] == k]["median_test_error"],
                                         test_error_df[test_error_df["k"] == k]["median_macs2_error"]),
                               test_error_df[test_error_df["k"] == k]["median_sicer_error"])
        data_k = [k] * len(data_error)
        data_type = ["Span"] * (len(data_error) // 3) + ["MACS2"] * (len(data_error) // 3) + ["SICER"] * (len(data_error) // 3)
        plot_dataframes.append(pd.DataFrame({"k": data_k, "error": data_error, "type": data_type}))
    sns.swarmplot(x="k", y="error", hue="type", dodge=True, data=pd.concat(plot_dataframes))
    plt.ylabel("Median test set error")
    plt.xlabel("Training labels")
    plt.title(target)
    plt.legend(loc='upper right', title=None)
    plt.savefig(f"peak_callers_{target}.png", width=800, height=600)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):    
    plot_dataframes = []
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    for i, k in zip(range(len(ks)), ks):
        data_error = np.append(np.append(test_error_df[test_error_df["k"] == k]["mean_test_error"],
                                         test_error_df[test_error_df["k"] == k]["mean_macs2_error"]),
                               test_error_df[test_error_df["k"] == k]["mean_sicer_error"])
        data_k = [k] * len(data_error)
        data_type = ["Span"] * (len(data_error) // 3) + ["MACS2"] * (len(data_error) // 3) + ["SICER"] * (len(data_error) // 3)
        plot_dataframes.append(pd.DataFrame({"k": data_k, "error": data_error, "type": data_type}))
    sns.swarmplot(x="k", y="error", hue="type", dodge=True, data=pd.concat(plot_dataframes))
    plt.ylabel("Mean test set error")
    plt.xlabel("Training labels")
    plt.title(target)
    plt.legend(loc='upper right', title=None)
    plt.savefig(f"peak_callers_{target}_mean.png", width=800, height=600)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    kmin = min(ks)
    data_error = np.append(np.append(test_error_df[test_error_df["k"] == kmin]["median_macs2_error"],
                                     test_error_df[test_error_df["k"] == kmin]["median_sicer_error"]),
                           test_error_df[test_error_df["k"] == kmin]["median_span_error"])
    ndots = len(data_error) // 3
    data_type = ["MACS2"] * ndots + ["SICER"] * ndots + ["Span 0"] * ndots
    data_hue = ["default"] * len(data_error)
    peak_callers = ["MACS2", "SICER", "Span 0"]
    for i, k in zip(range(len(ks)), ks):
        data_error = np.append(data_error, test_error_df[test_error_df["k"] == k]["median_test_error"])
        data_type = data_type + [f"Span {k}"] * ndots
        data_hue = data_hue + ["trained"] * ndots
        peak_callers.append(f"Span {k}")
    plot_dataframe = pd.DataFrame({"peak_caller": data_type, "error": data_error, "hue": data_hue})
    sns.swarmplot(x="peak_caller", y="error", hue="hue", data=plot_dataframe)
    medians = [np.median(plot_dataframe[plot_dataframe["peak_caller"] == pc]["error"])
               for pc in peak_callers]
    for i, m in zip(range(0, len(medians)), medians):
        plt.plot([i-0.4, i+0.4],[m, m], '-k', zorder=10)    
    plt.ylabel("Median test set error")
    plt.xlabel("Peak caller")
    plt.title(target)
    plt.legend(loc='upper right', title=None)    
    plt.savefig(f"peak_callers_{target}_median.png", width=800, height=600)
    plt.show()

In [None]:
for target, test_error_df in zip(TARGETS, test_error_dfs):
    ks = sorted(set(int(k) for k in test_error_df["k"]))
    kmin = min(ks)
    data_error = np.append(np.append(test_error_df[test_error_df["k"] == kmin]["mean_macs2_error"],
                                     test_error_df[test_error_df["k"] == kmin]["mean_sicer_error"]),
                           test_error_df[test_error_df["k"] == kmin]["mean_span_error"])
    ndots = len(data_error) // 3
    data_type = ["MACS2"] * ndots + ["SICER"] * ndots + ["Span 0"] * ndots
    data_hue = ["default"] * len(data_error)
    peak_callers = ["MACS2", "SICER", "Span 0"]
    for i, k in zip(range(len(ks)), ks):
        data_error = np.append(data_error, test_error_df[test_error_df["k"] == k]["mean_test_error"])
        data_type = data_type + [f"Span {k}"] * ndots
        data_hue = data_hue + ["trained"] * ndots
        peak_callers.append(f"Span {k}")
    plot_dataframe = pd.DataFrame({"peak_caller": data_type, "error": data_error, "hue": data_hue})
    sns.swarmplot(x="peak_caller", y="error", hue="hue", data=plot_dataframe)
    means = [np.mean(plot_dataframe[plot_dataframe["peak_caller"] == pc]["error"])
               for pc in peak_callers]
    for i, m in zip(range(0, len(means)), means):
        plt.plot([i-0.4, i+0.4],[m, m], '-k', zorder=10)    
    plt.ylabel("Mean test set error")
    plt.xlabel("Peak caller")
    plt.title(target)
    plt.legend(loc='upper right', title=None)
    plt.plot([-0.5,6.5], [0,0], '-k')
    plt.savefig(f"peak_callers_{target}_mean.png", width=800, height=600)
    plt.show()

# Overlap analysis

In [None]:
from tqdm.auto import tqdm
import glob

# Load files
PATH = '/mnt/stripe/bio/experiments/span-test-error/tuned_span_peaks'
KS = [10, 20, 50, 100]

dfs = []
for k in tqdm(KS):
    files = glob.glob(os.path.join(PATH, 'k{}'.format(k)) + '/**/*.*')
    fnames = [os.path.basename(f) for f in files]
    ids = [fn.split('_')[1] for fn in fnames]
    ks = [k] * len(files)
    batches = [f.replace(os.path.join(PATH, 'k{}'.format(k)) + '/b', '').split('/')[0] for f in files]
    modifications = [fn.split('_')[2] for fn in fnames]
    dfs.append(pd.DataFrame({'k': ks, 'batch': batches, 'id': ids, 'file': files, 'modification': modifications}))
df = pd.concat(dfs)

display(df.head())

In [None]:
import downstream.bed_metrics as bm
from pathlib import Path

# Compute overlaps
df_overlap = pd.DataFrame(columns=['id', 'modification', 'k', 'batch', 'overlap'])

for k in KS:
    for m in TARGETS:
        for b in list(set(df['batch'])):
            files = df.loc[
                np.logical_and(df['k'] == k, np.logical_and(df['modification'] == m, df['batch'] == b))]['file']
            paths = [Path(f) for f in files]
            df_path = '/mnt/stripe/figures/overlap_{}_k{}_b{}.tsv'.format(m, k, b)
            mt = bm.load_or_build_metrics_table(paths, paths, Path(df_path),
                                                jaccard=False,
                                                threads=30)
            for r in mt.index:
                for c in mt.columns:
                    overlap = mt.loc[r][c]
                    df_overlap.loc[len(df_overlap)] = (r + "@" + c, m, k, b, overlap)

display(df_overlap.head())

In [None]:
df_overlap_copy = df_overlap.copy()
df_overlap.to_csv('/mnt/stripe/figures/overlap_batches.csv', index=False)

In [None]:
# Filter out failed tracks
df_failed = pd.read_csv('/mnt/stripe/bio/experiments/configs/Y20O20/benchmark/Y20O20_peaks_summary_uli.tsv', 
                        sep='\t', comment='#')
df_failed = df_failed.loc[df_failed['status'] == 'failed'][['donor', 'modification']].drop_duplicates()
# display(df_failed)
failed = {}
for m in TARGETS:
    failed[m] = list(df_failed.loc[df_failed['modification'] == m]['donor'])
display(failed)

In [None]:
df_overlap = pd.DataFrame(columns=df_overlap_copy.columns)
filtered = 0
for _, row in tqdm(df_overlap_copy.iterrows()):
    rid, rm, _, _, _ = row
    if all(d + '_' not in rid for d in failed[rm]):
        df_overlap.loc[len(df_overlap)] = row
    else:
        filtered += 1
        
print('Filtered', filtered)
print('Len', len(df_overlap))

In [None]:
df_overlap.head()

In [None]:
df_peaks = pd.read_csv('/mnt/stripe/bio/experiments/configs/Y20O20/benchmark/Y20O20_peaks_summary_uli.tsv', 
                        sep='\t', comment='#')
df_peaks = df_peaks.loc[df_peaks['status'] != 'failed']
df_peaks.head()

In [None]:
# Compute overlaps for SPAN not tuned
df_span0_overlap = pd.DataFrame(columns=['id', 'modification', 'k', 'batch', 'overlap'])

for m in TARGETS:
    files = df_peaks.loc[
        np.logical_and(df_peaks['modification'] == m, 
                       np.logical_and(df_peaks['tool'] == 'span', df_peaks['procedure'] != 'tuned'))]['file']
    paths = [Path(f) for f in files]
    df_path = '/mnt/stripe/figures/overlap_span_not_tuned_{}.tsv'.format(m)
    mt = bm.load_or_build_metrics_table(paths, paths, Path(df_path),
                                        jaccard=False,
                                        threads=30)
    for r in mt.index:
        for c in mt.columns:
            overlap = mt.loc[r][c]
            df_span0_overlap.loc[len(df_span0_overlap)] = (r + "@" + c, m, 0, np.nan, overlap)

display(df_span0_overlap.head())

In [None]:
# Concat k=0 with experimental data
df_overlap = pd.concat([df_span0_overlap, df_overlap])

In [None]:
df_mean = df_overlap.groupby(['modification', 'k']).mean().reset_index()
print('Mean overlap')
display(df_mean)

df_std = df_overlap.groupby(['modification', 'k']).std().reset_index()
print('STD overlap')
display(df_std)


df_mean['mk'] = df_mean['modification'] + " " + list(map(str, df_mean['k']))
mpl = len(set(df_mean['mk']))
fig = plt.figure(figsize=(int(mpl * .75), 4))
offset = 0
for m in TARGETS:
    datam = df_mean.loc[df_mean['modification'] == m]
    datas = df_std.loc[df_std['modification'] == m]
    xlabels = []
    for t in datam['k']:
        if t not in xlabels:
            xlabels.append(t)
    w = len(set(datam['mk']))
    ax = plt.subplot2grid((1, mpl), (0, offset), colspan=w)
    sns.lineplot(ax=ax, x='k', y='overlap', markers=True, data=datas)
    ax.legend().set_visible(False)
    ax.set_ylim(0, 1)
    if offset > 0:
        ax.get_yaxis().set_ticklabels([])
        ax.set_ylabel('')
    else:
        ax.set_ylabel('std')
    
    ax2 = ax.twinx()    
    sns.lineplot(ax=ax, x='k', y='overlap', markers=True, data=datam)
    if m != TARGETS[-1]:
        ax2.get_yaxis().set_ticklabels([])
        ax2.set_ylabel('')
    else:
        ax2.set_ylabel('mean')
   
    offset += w
    ax.set_xlabel('k')
    ax.set_title(m)
    plt.xticks(range(0, len(xlabels)), xlabels, rotation=90)

plt.tight_layout()

plt.show()