In [None]:
%matplotlib inline

from pathlib import Path
import pandas as pd
import numpy as np
import math
from collections import defaultdict

import seaborn as sns

from downstream.signals.signal_r2_permutation_test import process, collect_paths
from downstream.signals.signals_util import extract_normalization, extract_datatype
import downstream.bed_metrics as bm

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

In [None]:
# simulations = 6
simulations = 100001
threads = 8
clear_cache = False

In [None]:
def walk_folders(path):
    yield path
    for f in path.iterdir():
        if f.is_dir():
            yield from walk_folders(f)

# Histone modifications

In [None]:
datatype_loci = defaultdict(list)
datatype_loci["meth"] = [
    "ucsc_cpgIslandExt.hg19", 
    "cpg_minavcov10_complex_4outliers.narrow.adjusted.regions.filtered",
    "washu_german_rrbs_filtered_dmrs_all_10.hg19"
]
datatype_loci["H3K36me3"] = ["coding_transcript", "noncoding_transcript"]

hist_to_chromhmm_states = {
    "H3K4me1": ["8_EnhG2", "9_EnhA1"],
    "H3K4me3": ["1_TssA", "2_TssFlnk"],
    "H3K27ac": ["1_TssA", "3_TssFlnkU"],
    "H3K36me3": ["5_Tx", "7_EnhG1", "8_EnhG2"],
    "H3K27me3": ["14_TssBiv", "15_EnhBiv","16_ReprPC"],
}
for hist, states in hist_to_chromhmm_states.items():
    for s in states:
        for vers in ["encsr511wof", "encsr907lcd"]:
            datatype_loci[hist].append("cd14_{}_chromhmm18.hg19.{}".format(vers, s))
        
for dt in ["H3K4me1", "H3K4me3", "H3K27ac", "H3K36me3", "H3K27me3", "meth"]:
    loci = datatype_loci[dt]
    loci.append("random1000x10000")
    loci.append("hg19_100000")
    if dt.startswith("H"):
        loci.append("coding_tss[-2000..2000]")
        loci.append("{}_zinbra_median_consensus".format(dt))
        loci.append("{}_zinbra_weak_consensus".format(dt))
        
signals_root = Path("/mnt/stripe/bio/experiments/signal_experiment")

In [None]:
folders = []
for dt, loci in datatype_loci.items():
    dt_root = signals_root / dt
    if dt_root.exists():
        folders.extend(f for f in walk_folders(dt_root) if f.name in loci)
        
paths = [p for folder in folders for p in collect_paths(folder)]
print("Paths: ", len(paths))
print("Loci folders: {}".format(len(folders)), *[str(p) for p in folders], sep="\n  ")

In [None]:
output_path = signals_root / "validate.norms.permutation_r2.{}.csv".format(simulations)
output_path

print(str(output_path), "[exists]" if output_path.exists() else "[not exists]")

In [None]:
# TODO: clear paths? if clear cache

In [None]:
if clear_cache or not output_path.exists():
    process(paths, str(output_path), seed=100, simulations=simulations, threads=threads)

print("Results file: ", str(output_path))

In [None]:
df = pd.DataFrame.from_csv(output_path, index_col=None)
df["loci"] = [Path(f).name for f in df["file"]]
print("Shape:", df.shape)
df.head(10)

## Summary Avg and Sd bars for Hist

In [None]:
def plot_summary_avg_sd_bars(df, metric):
    plt.figure(figsize=(14, 6))
    g = sns.barplot(data=df, y=metric, x="normalization",
                    hue="modification", #order=sorted(datatype_loci.keys()),
                    ci="sd", capsize=.2, errwidth=1, # error bars
                    #color="lightgray", 
                    alpha=0.7,
                    edgecolor="black")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    g.set_yscale('log')
    plt.xticks(rotation=-90)
    plt.show()

In [None]:
plot_summary_avg_sd_bars(df, "p2")

In [None]:
plot_summary_avg_sd_bars(df, "mean")

In [None]:
plot_summary_avg_sd_bars(df, "median")

In [None]:
plot_summary_avg_sd_bars(df, "wdist")

## Summary Boxpots with Hist info

In [None]:
def plot_summary_boxplot_with_dtype(df, metric):
    plt.figure(figsize=(16, 10))
    ax = sns.boxplot(data=df, y=metric, hue="modification", x="normalization")
    plt.xticks(rotation=-90)
    plt.title("Signal R2 at selected loci")
    plt.show()

In [None]:
plot_summary_boxplot_with_dtype("p2")

In [None]:
plot_summary_boxplot_with_dtype(df, "mean")

In [None]:
plot_summary_boxplot_with_dtype(df, "median")

In [None]:
plot_summary_boxplot_with_dtype(df, "wdist")

## Summary Boxpots

In [None]:
def plot_summary_boxplot(df, metric):
    plt.figure(figsize=(10, 7))
    ax = sns.boxplot(data=df[df.modification != "meth"], 
                     y=metric, x="normalization")
    plt.xticks(rotation=-90)
    plt.title("Signal R2 at selected loci")
    plt.show()

In [None]:
plot_summary_boxplot(df, "p2")

In [None]:
plot_summary_boxplot(df, "mean")

In [None]:
plot_summary_boxplot(df, "median")

In [None]:
plot_summary_boxplot(df, "wdist")

## Boxpots by Norm

In [None]:
def plot_boxplot_by_norm(df, metric, ylim):
    norms = sorted(set(df["normalization"]))
    n_subplots = math.ceil(np.sqrt(len(norms)))

    plt.figure(figsize=(n_subplots*5, n_subplots*5))
    
    for i, norm in enumerate(norms, 1):
        ax = plt.subplot(n_subplots, n_subplots, i)
        g = sns.boxplot(data=df[df["normalization"] == norm], y=metric, x="modification",
                        ax=ax)
        for item in g.get_xticklabels():
            item.set_rotation(-90)
        ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        ax.set_title(norm)
        ax.set_ylim(ylim)
        
    plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.35, wspace=0.35)
    plt.show()

In [None]:
plot_boxplot_by_norm(df, "p2", (0.7, 1.0))

In [None]:
plot_boxplot_by_norm(df, "mean", (0.9, 1.0))

In [None]:
plot_boxplot_by_norm(df, "wdist", (0, 0.2))

## Detailed by Locus

In [None]:
def plot_detailed_by_locus(df, metric):
    norms = sorted(set(df["normalization"]))

    for norm in norms:
        plt.figure(figsize=(10,6))
        g = sns.barplot(data=df[df["normalization"] == norm], y=metric, x="loci", hue="modification")
        for item in g.get_xticklabels():
            item.set_rotation(-90)
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        #plt.yscale("log")
        plt.title(norm)
        plt.show()

In [None]:
plot_detailed_by_locus(df, "p2")

In [None]:
plot_detailed_by_locus(df, "mean")

In [None]:
plot_detailed_by_locus(df, "wdist")

TODO: So: "rawq", "fripz" or "fripm"

# Normalizations Correlation

## Code

In [None]:
from scipy.stats import pearsonr
from collections import Counter
import re
from multiprocessing import Pool
import math
from itertools import chain

HIST_MODS = ["H3K4me1", "H3K4me3", "H3K27ac", "H3K27me3", "H3K36me3"]
def sorter_by_hist_and_donor(hist_ordering):
    def inner(label):
        # support both Y20O20 and ENCODE naming:
        # * "CD14_GSM1003562_Broad_r0_H3K36me3"
        # * "OD_OD11_H3K4me1"

        hist = hist_for(label)
        donor = label.replace("OD_", "").replace("YD_", "").replace("_" + hist, "")
        if donor.startswith("OD") or donor.startswith("YD"):
            age = donor[0:2]
            idx = int(donor[2:])
        else:
            age = "_ENCODE_"
            idx = donor
        return (hist_ordering[hist], age, idx)
    return inner

def hist_for(label):
    m = re.match("((.*_)|^)(H[0-9]+K[a-z0-9]+).*", label, re.IGNORECASE)
    return label if not m else m.groups()[2]

def color_annotator_hist(label):
    color = {'H3K4me1':'lightblue', 'H3K4me3':'red', 'H3K27ac':'black',
             'H3K27me3':'green', 'H3K36me3':'lightgray'}[hist_for(label)]

    return (("Histone", color),)

def encode_annotator(label):
    color = "mediumseagreen" if label.startswith("OD") or label.startswith("YD") else "black"
    return (("Dataset", color),)

def norms_for_loci(signals_root, loci):
    folders = []

    for dt in HIST_MODS:
        dt_root = signals_root / dt
        if dt_root.exists():
            folders.extend(f for f in walk_folders(dt_root) if f.name == loci)

    paths = [p for folder in folders for p in collect_paths(folder)]
    print("Paths: ", len(paths))
    print("Loci folders: {}".format(len(folders)), *[str(p) for p in folders], sep="\n  ")

    norm_2_paths = defaultdict(dict)
    for p in paths:
        dt = extract_datatype(p)
        norm = extract_normalization(p)
        norm_2_paths[norm][dt] = [p]
        
    return norm_2_paths

def pearson_corr_df(norm, norm_2_paths, threads):
    timeout_mins = 10
    hist_mods = HIST_MODS

    norm_paths = norm_2_paths[norm]
    print("[{}]".format(norm))
    print("    Paths:", *[str(p) for _h, p in norm_paths.items()], sep="\n      ")
    assert(set(norm_paths.keys()) == set(hist_mods))

    print("    Load data...".format(norm))
    all_cols_names = []
    all_cols = []
    for dt, paths in norm_paths.items():
        for path in paths:
            df = pd.DataFrame.from_csv(path, sep="\t", index_col=None).drop(["chr", "start", "end"], axis=1)
            fcols = [c for c in df.columns if "input" not in c.lower()]
            df = df.loc[:, fcols]
            all_cols_names.extend(df.columns)
            for i in range(0, len(df.columns)):
                all_cols.append(df.iloc[:, i])

    n = len(all_cols)
    print("    Calculating pearson correlation".format(norm))
    
    # calc pearson correlation in parallel, split caclulations in equally sized chunks
    x = np.zeros((n, n), np.float32)
    
    if threads == 1:    
        for i in range(0, n):
            for j in range(i, n):
                corr = pearsonr(all_cols[i], all_cols[j])[0]
                x[i, j] = corr
                x[j, i] = corr
    else:
        cols_pairs = [(i, j) for i in range(0, n) for j in range(i, n)]
        chunk_size = max(1, math.ceil(len(cols_pairs) / threads))
        
        # Pool cannot pass big arrays to threads =(
        with Pool(processes=threads) as pool:
            multiple_results = [
                pool.apply_async(pearsonr_list, (all_cols, cols_pairs[s:s+chunk_size])) 
                for s in range(0, len(cols_pairs), chunk_size)
            ]

            values = list(chain(*[res.get(timeout=60 * timeout_mins) for res in multiple_results]))
            for value, (i, j) in values:
                x[i, j] = value
                x[j, i] = value

    return pd.DataFrame(data=x, columns=all_cols_names, index=all_cols_names)
    
def pearsonr_list(cols, pairs):
    return [(pearsonr(cols[i], cols[j])[0], (i, j)) for i,j in pairs]

def pearson_corr(norm, loci, norm_2_paths, pdf, threads, data_path=None):
    if data_path and data_path.exists():
        df = pd.DataFrame.from_csv(data_path, index_col=0)
    else:
        df = pearson_corr_df(norm, norm_2_paths, threads)
        if data_path:
            df.to_csv(data_path)
            print("Corr table saved to: {}".format(data_path))
    
    # plot:
    hist_mods = HIST_MODS
    # sort cols: by (hist, age, donor id) where hist in order according to hist_mods.
    sorted_cols = sorted(df.columns, key=sorter_by_hist_and_donor({h:i for i, h in enumerate(hist_mods)}))
    df = df.loc[sorted_cols, sorted_cols]

    if any(not c.startswith("OD") and not c.startswith("YD") for c in df.columns):
        ann=bm.color_annotator_chain(color_annotator_hist, encode_annotator)
    else:
        ann=color_annotator_hist
    g = bm.plot_metric_heatmap("{}: Pearson Correlation for {}".format(norm, loci), df, show_or_save_plot=False, cbar=False,
                               col_color_annotator=ann, 
                               row_color_annotator=ann)

    hist_donors_cnt = Counter()
    for label in df.columns:
        hist = hist_for(label)
        hist_donors_cnt[hist] += 1
        
    hist_donors_counts = [hist_donors_cnt[h] for h in hist_mods]

    ticks = [sum(hist_donors_counts[0:k]) + hist_donors_counts[k]/2 for k in range(len(hist_donors_counts))]
    g.ax_heatmap.set_xticks(ticks)
    g.ax_heatmap.set_xticklabels(hist_mods, rotation="horizontal", horizontalalignment = 'center')
    g.ax_heatmap.set_yticks(ticks)
    g.ax_heatmap.set_yticklabels(hist_mods, rotation="vertical", verticalalignment = 'center')

    plt.setp(g.ax_heatmap.get_yticklabels(), rotation=90)

    # Turn off annotations
    #g.ax_col_colors.set_yticks([])
    g.ax_row_colors.set_xticks([])

    #bm.save_plot(plot_path)
    if pdf:
        pdf.savefig()
    plt.show()
    plt.close()

In [None]:
import traceback
def process_pearson_corr(signals_root, loci, threads, prefix="pearson_"):
    norm_2_paths = norms_for_loci(signals_root, loci)
    for norm in sorted(list(norm_2_paths.keys())):
        try:
            plot_path = signals_root / "{}{}_{}.pdf".format(prefix, loci, norm)
            data_path = plot_path.with_suffix(".csv")
            if clear_cache:
                ! rm $data_path
            
            with PdfPages(str(plot_path)) as pdf:
                pearson_corr(norm, loci, norm_2_paths, pdf, threads, data_path)
            print("Saved: ", plot_path)

        except Exception as e:
            print("Failed:", norm, e)
            traceback.print_exc()
            #raise

## Y20O20

In [None]:
process_pearson_corr(Path("/mnt/stripe/bio/experiments/signal_experiment"), "hg19_100000", 
                     threads, "pearson_y20o20_")

In [None]:
process_pearson_corr(Path("/mnt/stripe/bio/experiments/signal_experiment"), "hg19_1000", 
                     1, "pearson_y20o20_")

## Encode

In [None]:
process_pearson_corr(Path("/mnt/stripe/bio/experiments/signal_experiment_encode"), "hg19_100000", 
                     threads, "pearson_encode_")

In [None]:
process_pearson_corr(Path("/mnt/stripe/bio/experiments/signal_experiment_encode"), "hg19_1000", 
                     1, "pearson_encode_")

## Y20O20 + Encode:

In [None]:
import traceback
def process_enc_aging_pearson_corr(loci, threads):
    root = Path("/mnt/stripe/bio/experiments")
    encode_norm_2_paths = norms_for_loci(root / "signal_experiment_encode" , loci)
    y20od20_norm_2_paths = norms_for_loci(root / "signal_experiment", loci)

    for norm in sorted(list(y20od20_norm_2_paths.keys())):
        if norm not in encode_norm_2_paths:
            print("No {} for ENCODE".format(norm))

        if set(HIST_MODS) !=  set(y20od20_norm_2_paths[norm]):
            print("[{}]: Y20O20 missed hist modes, available only {}".format(
                norm, sorted(list(y20od20_norm_2_paths[norm].keys()))
            ))
            continue

        if set(HIST_MODS) !=  set(encode_norm_2_paths[norm]):
            print("[{}]: ENCODE missed hist modes, available only {}".format(
                norm, sorted(list(encode_norm_2_paths[norm].keys()))
            ))
            continue

        try:
            plot_path = root / "signal_experiment/pearson_enc_vs_y20o20_{}_{}.pdf".format(loci, norm)
            data_path = plot_path.with_suffix(".csv")
            if clear_cache:
                ! rm $data_path
            
            with PdfPages(str(plot_path)) as pdf:
                pearson_corr(norm, loci, 
                             {norm: {h: y20od20_norm_2_paths[norm][h] + encode_norm_2_paths[norm][h] for h in HIST_MODS}},
                             None, threads, data_path)
            print("Saved: {}".format(plot_path))

        except Exception as e:
            print("Failed:", norm, e)
            traceback.print_exc()
            #raise

In [None]:
process_enc_aging_pearson_corr("hg19_100000", threads)

In [None]:
process_enc_aging_pearson_corr("hg19_1000", 1)

# H3K27me3 + DiffBind scores options

In [None]:
signals_root = Path("/mnt/stripe/bio/experiments/k27me3@dmrs")
paths = [p for p in signals_root.glob("k*_counts.csv")]
[p.name for p in paths]

In [None]:
output_path = signals_root / "validate.norms.permutation_r2.{}.csv".format(simulations)
output_path

print(str(output_path), "[exists]" if output_path.exists() else "[not exists]")

In [None]:
if not output_path.exists():
    process(paths, str(output_path), seed=100, simulations=simulations, threads=threads)
print("Results file: ", str(output_path))

In [None]:
df = pd.DataFrame.from_csv(output_path, index_col=None)
df["loci"] = [Path(f).name for f in df["file"]]
print("Shape:", df.shape)
df