# Test normalizations
## basic settings and imports first

In [1]:
from bokeh.plotting import figure
from bokeh.palettes import viridis
from bokeh.io import show, output_notebook
from bokeh.models import ColumnDataSource, HoverTool, FactorRange
from bokeh.layouts import column, row, gridplot

import json
import sys
import shelve
import random
from bokeh.models import Legend
from math import log10


output_notebook()

KBP = 1000
MBP = 1000 * KBP

WINDOW_SIZES = [1 * MBP, 5 * MBP, 10 * MBP]
BIN_SIZES = [50 * KBP, 100 * KBP, 500 * KBP]
# WINDOW_SIZES = [5 * MBP]
# BIN_SIZES = [500 * KBP]


curr_bin_size = BIN_SIZES[1]
curr_win_size = WINDOW_SIZES[1]
# curr_bin_size = BIN_SIZES[0]
# curr_win_size = WINDOW_SIZES[0]

OUTPUT_BACKEND = "svg"
NUM_SAMPLES = [1, 2, 10, 25, 100, 250, 1000, 2500]
# NUM_SAMPLES = [1, 2, 10, 25, 100, 250] #, 1000]
NUM_SAMPLES_ICE = NUM_SAMPLES
NUM_SAMPLES_GRID_SEQ = NUM_SAMPLES
NUM_SAMPLES_RADICL_SEQ = NUM_SAMPLES
NUM_SAMPLES_DDD = NUM_SAMPLES

CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME = False

## Load data

In [2]:
def get_mean_dev(ground_truth, points):
    return sum(abs(a-b) for a, b in zip(ground_truth, points)) / len(points)

class ShelveTupleDict:
    def __init__(self, filename, flag='c'):
        self.data = shelve.open(filename, flag=flag)
        
    def __to_to_key(self, tup):
        return ".".join(str(x) for x in tup)
    
    def __getitem__(self, tup):
        return self.data[self.__to_to_key(tup)]
    
    def __contains__(self, tup):
        return self.__to_to_key(tup) in self.data
    
    def __setitem__(self, tup, val):
        self.data[self.__to_to_key(tup)] = val

    def close(self):
        self.data.close()

    def keys(self):
        return self.data.keys()

data = ShelveTupleDict("norm_corelation.shelf", 'r')
#data = ShelveTupleDict("norm_corelation_bin_test.shelf", 'r')

## Investigate the scatter plot for one bin and window size

Expect a bad correleation for a low number of samples, it should then gradually improve with the number of samples.

In [12]:
ALMOST_ZERO = 10**-5
def plot_scatter_points(ground_truth, data, title, x_range=(ALMOST_ZERO, 10**0), y_range=(ALMOST_ZERO, 10**5)):
    palette = viridis(sum(1 if c is None else 0 for _1, _2, c in data))
    f = figure(
            title=title, 
            x_axis_type="log", 
            y_axis_type="log", 
            x_range=x_range, 
            y_range=y_range,
            frame_height=500,
            frame_width=500
        )
    f.line(x=[ALMOST_ZERO,100], y=[ALMOST_ZERO,100], color="black")
    items = []
    idx = 0
    for name, points, color in data:
        xs = []
        ys = []
        zipped_data = [(x,y) for x,y in zip(ground_truth, points) if x != 0 and y != 0]
        for (x, y) in random.sample(zipped_data, min(len(zipped_data), 500)):
            xs.append(x)
            ys.append(y)
        mean_dev = round(get_mean_dev(ground_truth, points), 5)
        d = f.dot(x=xs, y=ys, color=palette[idx] if color is None else color, size=25, alpha=0.5)
        idx += 1 if color is None else 0
        items.append((name + " dev: " + str(mean_dev), [d]))

    f.xaxis.axis_label = "ground truth"
    f.yaxis.axis_label = "sample"
    f.add_layout(Legend(items=items, location="center", click_policy="hide"), "right")
    #f.output_backend = OUTPUT_BACKEND
    show(f)

# plot_scatter_points(
#     data[("ice", "GT", curr_bin_size, "bin_vals")], 
#     [("unnormalized", data[("raw_hic", "GT", curr_bin_size, "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("ice",  curr_win_size, num_samples, curr_bin_size, "bin_vals")], None) for num_samples in NUM_SAMPLES_ICE  if ("ice",  curr_win_size, num_samples, curr_bin_size, "bin_vals") in data], 
#     "icing - stability num samples; bin-size= " + str(curr_bin_size // KBP) + "k window-size= " + str(curr_win_size//KBP) + "k" )

# plot_scatter_points(
#     data[("grid_seq", "GT", curr_bin_size, "bin_vals")], 
#     [("unnormalized", data[("raw_radicl", "GT", curr_bin_size, "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("grid_seq", curr_win_size, num_samples, curr_bin_size, "bin_vals")], None) for num_samples in NUM_SAMPLES_GRID_SEQ
#       ], 
#     "grid-seq - stability num samples; bin-size= " + str(curr_bin_size // KBP) + "k window-size= " + str(curr_win_size//KBP) + "k",
#     x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

# plot_scatter_points(
#     data[("radicl_seq", "GT", curr_bin_size, "bin_vals")], 
#     [("unnormalized", data[("raw_radicl", "GT", curr_bin_size, "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("radicl_seq", curr_win_size, num_samples, curr_bin_size, "bin_vals")], None) for num_samples in NUM_SAMPLES_RADICL_SEQ], 
#     "radicl-seq - stability num samples; bin-size= " + str(curr_bin_size // KBP) + "k window-size= " + str(curr_win_size//KBP) + "k",
#     x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

# plot_scatter_points(
#     data[("ddd", "GT", curr_bin_size, "bin_vals")], 
#     [("unnormalized", data[("raw_hic", "GT", curr_bin_size, "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("ddd", curr_win_size, num_samples, curr_bin_size, "bin_vals")], None) for num_samples in NUM_SAMPLES_DDD], 
#     "distance dependent decay - stability num samples; bin-size= " + str(curr_bin_size // KBP) + "k window-size= " + str(curr_win_size//KBP) + "k" )

plot_scatter_points(
    data[("cooler", "GT", BIN_SIZES[2], "bin_vals")], 
    [("", data[("ice", "GT", BIN_SIZES[2], "bin_vals")], "blue")], 
    "icing - cooler vs my implementation; bin-size= " + str(BIN_SIZES[0] // KBP) + "k " )

## Plot mean deviation & runtime as a function of the number of samples for all bin and window sizes

In [None]:
COLOR_PALETTE = ["#0072B2", "#D55E00", "#009E73", "#E69F00", "#CC79A7", "#56B4E9", "#F0E442"]
SCATTER_PALETTE = ["x", "circle", "triangle", "+", "square"]

def get_median(xs):
    return list(sorted(xs))[len(xs)//2]

def corr_as_func_of_samples(title, conditions, w_patch=False):
    width = 200
    #width = 800
    def log_violin_plot(get_y_values, zero_val, violin_width=1.5):

        f1 = figure(
                y_axis_type="log", 
                x_range=FactorRange(*[(str(num_samples)) for num_samples, _ in conditions[0][2]]) if w_patch else None,
                x_axis_type=None if w_patch else "log", 
                #y_range=(.1, 10000),
                frame_width=width,
                frame_height=width,
                title=title
            )

        for ground_truth, get_sample, samples, name, idxa, idxb in conditions:
            xs1 = []
            ys1 = []
            xs0 = []
            ys0 = []
            for idx in samples:
                num_samples, points, runtimes = get_sample(*idx)
                a = get_y_values(ground_truth, points, runtimes)
                y = max(sum(a) / len(a), zero_val)
                if w_patch:
                    xs1.append((str(num_samples), ))
                else:
                    xs1.append(num_samples)
                ys1.append(y)
                if w_patch:
                    granuality = 0.025
                    min_val = max(min(a), zero_val)
                    min_y = int(log10(min_val) / granuality)
                    max_val = max(max(a), zero_val)
                    vs = [0] * (int(log10(max_val) / granuality) - min_y + 1)
                    for y in a:
                        idy = int(log10(max(y, zero_val))/granuality) - min_y
                        vs[idy] += 1
                    max_val = len(a) * violin_width


                    xs3 = [(str(num_samples), )] + [(str(num_samples), -v/max_val) for v in vs] + \
                        [(str(num_samples), )]
                    xs4 = [(str(num_samples), )] + [(str(num_samples), v/max_val) for v in vs] + \
                        [(str(num_samples), )]
                    ys_2 = [0] + [10**((k + min_y)*granuality) for k in range(len(vs) + 1)]

                    c = COLOR_PALETTE[idxa % len(COLOR_PALETTE)]
                    f1.patch(xs3 + xs4[::-1], 
                            ys_2 + ys_2[::-1], 
                            fill_color=c, 
                            line_color=None, 
                            fill_alpha=0.2)
            f1.line(x=xs1, y=ys1, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], legend_label=name, 
                    line_width=2)
            # f1.scatter(
            #     x=xs1, 
            #     y=ys1, 
            #     marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], 
            #     line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], 
            #     fill_color=None,
            #     legend_label=name,
            #     size=10,
            #     line_width=2
            # )
            f1.dot(
                x=xs1, 
                y=ys1, 
                color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)],
                legend_label=name,
                size=20,
            )
        
        f1.output_backend = OUTPUT_BACKEND
        f1.xaxis.axis_label = "number of samples"
        f1.legend.visible=False
        f1.toolbar_location = "below"

        return f1

    f = log_violin_plot(lambda ground_truth, points, runtimes: [abs(a-b) for a, b in zip(ground_truth, points)], 
                            0.00001, 0.35)
    f_2 = log_violin_plot(lambda ground_truth, points, runtimes: [sum(tmp)/1000 for tmp in zip(*runtimes.values())], 
                          0.1, 1.25)
    f.yaxis.axis_label = "deviation"
    f_2.yaxis.axis_label = "runtime [ms]"
    show(row([f, f_2]), notebook_handle=True)

def plot_corr_as_func_of_samples(key="ice", samples=NUM_SAMPLES_ICE):
    def get_samples(num_samples, bin_size):
        return (num_samples, 
                data[(key, curr_win_size, num_samples, bin_size, "bin_vals")],
                data[(key, curr_win_size, num_samples, bin_size, "runtimes")])
    conditions_bin = [
        (
            data[(key, "GT", bin_size, "bin_vals")], 
            get_samples,
            [(num_samples, bin_size) for num_samples in samples if (key, curr_win_size, num_samples, bin_size, "bin_vals") in data], 
            key + " bin=" + str(bin_size//KBP) + "k win=" + str(curr_win_size//KBP) + "k",
            idxa,
            idxa
        ) 
        for idxa, bin_size in enumerate(BIN_SIZES)
    ]
    
    def get_samples(num_samples, window_size):
        return (num_samples, 
                data[(key, window_size, num_samples, curr_bin_size, "bin_vals")],
                data[(key, window_size, num_samples, curr_bin_size, "runtimes")])
    conditions_window = [
        (
            data[(key, "GT", curr_bin_size, "bin_vals")], 
            get_samples,
            [(num_samples, window_size) for num_samples in samples if (key, window_size, num_samples, curr_bin_size, "bin_vals") in data], 
            key + " bin=" + str(curr_bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
            idxb + len(BIN_SIZES),
            idxb + len(BIN_SIZES)
        ) 
        for idxb, window_size in enumerate([w for w in WINDOW_SIZES if w != curr_win_size])
    ]

    corr_as_func_of_samples(key, conditions_bin + conditions_window)

# plot_corr_as_func_of_samples("ice", NUM_SAMPLES_ICE)
# plot_corr_as_func_of_samples("ddd", NUM_SAMPLES_DDD)
# plot_corr_as_func_of_samples("grid_seq", NUM_SAMPLES_GRID_SEQ)
plot_corr_as_func_of_samples("radicl_seq", NUM_SAMPLES_RADICL_SEQ)


def get_samples(num_samples, key):
    return (num_samples, 
            data[(key, curr_win_size, num_samples, curr_bin_size, "bin_vals")],
            data[(key, curr_win_size, num_samples, curr_bin_size, "runtimes")])
# conditions_key = [
#     (
#         data[(key, "GT", curr_bin_size, "bin_vals")], 
#         get_samples,
#         [(num_samples, key) for num_samples in NUM_SAMPLES], 
#         key + " bin=" + str(curr_bin_size//KBP) + "k win=" + str(curr_win_size//KBP) + "k",
#         idxb,
#         idxb
#     ) 
#     for idxb, key in enumerate(["ice", "ddd", "grid_seq", "radicl_seq"])
#     # for idxb, key in enumerate(["radicl_seq"])
# ]
# corr_as_func_of_samples("", conditions_key)


## Plot some of the heatmaps for visual verification

In [10]:
div_hic = 10000
div_radicl = 50000

DEFAULT_RANGE=(0, 3*curr_win_size // div_hic)
def plot_heatmap(datas, bg_color, x_range=DEFAULT_RANGE, y_range=DEFAULT_RANGE):
    fl = []
    for data, title, w_size in datas:
        if len(fl) == 0:
            f = figure(title=title, width=300, height=300)
        else:
            f = figure(title=title, x_range=fl[0].x_range, y_range=fl[0].y_range, width=300, height=300)
        d_filtered = {}

        # background
        # d_filtered["color"] = [bg_color]
        # d_filtered["screen_left"] = [x_range[0]]
        # d_filtered["screen_bottom"] = [y_range[0]]
        # d_filtered["screen_top"] = [x_range[1]]
        # d_filtered["screen_right"] = [y_range[1]]
        # d_filtered["0"] = [0]
        # d_filtered["bin_id_x"] = [0]
        # d_filtered["bin_id_y"] = [0]
        # d_filtered["ranged_score"] = [0]
        # d_filtered["score_a"] = [0]
        # d_filtered["score_b"] = [0]
        # d_filtered["score_total"] = [0]

        for key, vals in data.items():
            d_filtered[key] = []
            for idx, v in enumerate(vals):
                if data["screen_left"][idx] >= x_range[0] and data["screen_right"][idx] <= x_range[1] and \
                data["screen_bottom"][idx] >= y_range[0] and data["screen_top"][idx] <= y_range[1]:
                    d_filtered[key].append(v)
        f.quad(
            left="screen_left",
            bottom="screen_bottom",
            right="screen_right",
            top="screen_top",
            fill_color="color",
            line_color=None,
            source=ColumnDataSource(data=d_filtered),
        )
        if not w_size is None:
            x_windows = range(0, x_range[1] + 1, w_size)
            y_windows = range(0, y_range[1] + 1, w_size)
            f.multi_line(
                xs=[[x, x] for x in x_windows],
                ys=[[y_range[0], y_range[1]] for x in x_windows],
                color="white",
            )
            f.multi_line(
                xs=[[x_range[0], x_range[1]] for y in y_windows],
                ys=[[y, y] for y in y_windows],
                color="white",
            )
        f.background_fill_color = bg_color
        f.grid.grid_line_color = None
        f.axis.ticker = []

        f.add_tools(
            HoverTool(
                tooltips=[
                    ("score", "@score_total"),
                    ("color", "@color"),
                    ("reads by group", "A: @score_a, B: @score_b"),
                    ("x", "@screen_left - @screen_right"),
                    ("y", "@screen_bottom - @screen_top"),
                ]
            )
        )
        #f.output_backend = OUTPUT_BACKEND
        fl.append(f)
    show(gridplot([fl]), notebook_handle=True)

RADICL_RANGE=(0, 3*curr_win_size // div_radicl)
plot_heatmap([
    # (data[("raw_hic", "GT", curr_bin_size, "heatmap")], "raw data", None),
    # *[(data[("ice", curr_win_size, num_sample, curr_bin_size, "heatmap")], "ice - num samples = " + str(num_sample), 
    #    curr_win_size // div_hic) for num_sample in [10, 100, 1000]],
    (data[("ice", "GT", curr_bin_size, "heatmap")], "global", None), 
    (data[("cooler", "GT", curr_bin_size, "heatmap")], "global", None), 
    ], "#440154")

# plot_heatmap([
#     (data[("raw_radicl", "GT", curr_bin_size, "heatmap")], "raw data", None),
#     *[(data[("grid_seq", curr_win_size, num_sample, curr_bin_size, "heatmap")], "grid-seq - num samples = " + str(num_sample), curr_win_size // div_radicl) for num_sample in [10, 100, 1000]],
#     (data[("grid_seq", "GT", curr_bin_size, "heatmap")], "global", None), 
#     ], "#440154",
#     x_range=RADICL_RANGE,
#     y_range=RADICL_RANGE)

# plot_heatmap([
#     (data[("raw_radicl", "GT", curr_bin_size, "heatmap")], "raw data", None),
#     *[(data[("radicl_seq", curr_win_size, num_sample, curr_bin_size, "heatmap")], 
#        "radicl-seq - num samples = " + str(num_sample), 
#        curr_win_size // div_radicl) for num_sample in [10, 100, 1000]],
#     #(data[("radicl_seq", curr_win_size, "inf", curr_bin_size, "heatmap")], "all samples", curr_win_size // div_radicl), 
#     #(data[("radicl_seq", "GT-100", curr_bin_size, "heatmap")], "global 100 samples", None), 
#     (data[("radicl_seq", "GT", curr_bin_size, "heatmap")], "global", None), 
#     ], "#440154",
#     x_range=RADICL_RANGE,
#     y_range=RADICL_RANGE)

# plot_heatmap([
#     (data[("raw_hic", "GT", curr_bin_size, "heatmap")], "raw data", None),
#     *[(data[("ddd", curr_win_size, num_sample, curr_bin_size, "heatmap")], "ddd - num samples = " + str(num_sample), 
#        curr_win_size // div_hic)  for num_sample in [10, 100, 1000]],
#     (data[("ddd", "GT", curr_bin_size, "heatmap")], "global", None), 
#     ], "#440154")