# Test normalizations
## basic settings and imports first

In [1]:
from bokeh.plotting import figure
from bokeh.palettes import viridis
from bokeh.io import show, output_notebook
from bokeh.models import ColumnDataSource, HoverTool, FactorRange
from bokeh.layouts import column, row, gridplot

import json
import sys
import shelve
import random
from bokeh.models import Legend
from math import log10


output_notebook()

KBP = 1000
MBP = 1000 * KBP

WINDOW_SIZES = [1 * MBP, 5 * MBP, 10 * MBP]
BIN_SIZES = [50 * KBP, 100 * KBP, 500 * KBP]


curr_bin_size = BIN_SIZES[1]
curr_win_size = WINDOW_SIZES[1]

OUTPUT_BACKEND = "svg"
NUM_SAMPLES = [1, 10, 100, 1000]
NUM_SAMPLES_ICE = NUM_SAMPLES
NUM_SAMPLES_GRID_SEQ = NUM_SAMPLES
NUM_SAMPLES_RADICL_SEQ = NUM_SAMPLES
NUM_SAMPLES_DDD = NUM_SAMPLES

CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME = False

## Load data

In [2]:
def get_mean_dev(ground_truth, points):
    return sum(abs(a-b) for a, b in zip(ground_truth, points)) / len(points)

class ShelveTupleDict:
    def __init__(self, filename, flag='c'):
        self.data = shelve.open(filename, flag=flag)
        
    def __to_to_key(self, tup):
        return ".".join(str(x) for x in tup)
    
    def __getitem__(self, tup):
        return self.data[self.__to_to_key(tup)]
    
    def __contains__(self, tup):
        return self.__to_to_key(tup) in self.data
    
    def __setitem__(self, tup, val):
        self.data[self.__to_to_key(tup)] = val

    def close(self):
        self.data.close()

data = ShelveTupleDict("norm_corelation.shelf", 'r')

for k in data.data.keys():
    if "ice" in k:
        print(k)

ice.GT.50000.bin_vals
ice.GT.50000.heatmap
ice.1000000.1.50000.bin_vals
ice.1000000.1.50000.heatmap
ice.1000000.1.50000.runtimes
ice.1000000.10.50000.bin_vals
ice.1000000.10.50000.heatmap
ice.1000000.10.50000.runtimes
ice.1000000.100.50000.bin_vals
ice.1000000.100.50000.heatmap
ice.1000000.100.50000.runtimes
ice.1000000.1000.50000.bin_vals
ice.1000000.1000.50000.heatmap
ice.1000000.1000.50000.runtimes
ice.5000000.1.50000.bin_vals
ice.5000000.1.50000.heatmap
ice.5000000.1.50000.runtimes
ice.5000000.10.50000.bin_vals
ice.5000000.10.50000.heatmap
ice.5000000.10.50000.runtimes
ice.5000000.100.50000.bin_vals
ice.5000000.100.50000.heatmap
ice.5000000.100.50000.runtimes
ice.5000000.1000.50000.bin_vals
ice.5000000.1000.50000.heatmap
ice.5000000.1000.50000.runtimes
ice.10000000.1.50000.bin_vals
ice.10000000.1.50000.heatmap
ice.10000000.1.50000.runtimes
ice.10000000.10.50000.bin_vals
ice.10000000.10.50000.heatmap
ice.10000000.10.50000.runtimes
ice.10000000.100.50000.bin_vals
ice.10000000.100.500

## Investigate the scatter plot for one bin and window size

Expect a bad correleation for a low number of samples, it should then gradually improve with the number of samples.

In [3]:
ALMOST_ZERO = 10**-5
def plot_scatter_points(ground_truth, data, title, x_range=(ALMOST_ZERO, 10**0), y_range=(ALMOST_ZERO, 10**5)):
    palette = viridis(sum(1 if c is None else 0 for _1, _2, c in data))
    f = figure(
            title=title, 
            x_axis_type="log", 
            y_axis_type="log", 
            x_range=x_range, 
            y_range=y_range,
            height=300,
            width=500
        )
    f.line(x=[ALMOST_ZERO,100], y=[ALMOST_ZERO,100], color="black")
    items = []
    idx = 0
    for name, points, color in data:
        xs = []
        ys = []
        for (x, y) in random.sample(list(zip(ground_truth, points)), min(len(points), 250)):
            xs.append(x)
            ys.append(y)
        mean_dev = round(get_mean_dev(ground_truth, points), 5)
        d = f.dot(x=xs, y=ys, color=palette[idx] if color is None else color, size=25, alpha=0.5)
        idx += 1 if color is None else 0
        items.append((name + " dev: " + str(mean_dev), [d]))

    f.xaxis.axis_label = "ground truth"
    f.yaxis.axis_label = "sample"
    f.add_layout(Legend(items=items, location="center", click_policy="hide"), "right")
    f.output_backend = OUTPUT_BACKEND
    show(f)

plot_scatter_points(
    data[("ice", "GT", curr_bin_size, "bin_vals")], 
    [("unnormalized", data[("raw_hic", "GT", curr_bin_size, "bin_vals")], "red")] +
    [("num samples = " + str(num_samples), 
      data[("ice",  curr_win_size, num_samples, curr_bin_size, "bin_vals")], None) for num_samples in NUM_SAMPLES_ICE  if ("ice",  curr_win_size, num_samples, curr_bin_size, "bin_vals") in data], 
    "icing - stability num samples; bin-size= " + str(curr_bin_size // KBP) + "k window-size= " + str(curr_win_size//KBP) + "k" )

# plot_scatter_points(
#     data[("grid-seq", "GT", BIN_SIZES[0], "bin_vals")], 
#     [("unnormalized", data[("raw-radicl", BIN_SIZES[0], "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("grid-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_GRID_SEQ#[1, 4, 16, 32, 64, 128]
#       ], 
#     "grid-seq - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k",
#     x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

# plot_scatter_points(
#     data[("radicl-seq", "GT", BIN_SIZES[0], "bin_vals")], 
#     [("unnormalized", data[("raw-radicl", BIN_SIZES[0], "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("radicl-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_RADICL_SEQ], 
#     "radicl-seq - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k",
#     x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

# plot_scatter_points(
#     data[("ddd", "GT", BIN_SIZES[0], "bin_vals")], 
#     [("unnormalized", data[("raw-hic", BIN_SIZES[0], "bin_vals")], "red")] +
#     [("num samples = " + str(num_samples), 
#       data[("ddd", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_DDD], 
#     "distance dependent decay - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k" )


# plot_scatter_points(
#     data[("ICE", "cooler", BIN_SIZES[0], "bin_vals")], 
#     [("", data[("ICE", "local", BIN_SIZES[0], "bin_vals")], "blue")], 
#     "icing - cooler vs my implementation; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k" )

## Plot mean deviation as a function of the number of samples for all bin and window sizes

In [11]:
COLOR_PALETTE = ["#0072B2", "#D55E00", "#009E73", "#E69F00", "#CC79A7", "#56B4E9", "#F0E442"]
SCATTER_PALETTE = ["x", "cross", "circle", "dash"]

def get_median(xs):
    return list(sorted(xs))[len(xs)//2]

def corr_as_func_of_samples(conditions):
    width = 200
    f = figure(
            y_axis_type="log", 
            x_axis_type="log",
            width=width,
            height=180,
        )
    f1 = figure(
            y_axis_type="log", 
            x_range=FactorRange(*[(name, str(num_samples)) for _, sample, name, _, _ in conditions for num_samples, _, _ in sample]),
            width=width*3,
            height=180,
        )
    f0 = figure(
            x_range=f.x_range,
            y_range=["0"],
            height=75,
            x_axis_type="log",
            width=width,
        )
    fl = figure(width=400, height=400)
    items = []
    for ground_truth, sample, name, idxa, idxb in conditions:
        xs = []
        ys = []
        xs2 = []
        ys2 = []
        for num_samples, points, runtimes in sample:
            y = get_mean_dev(ground_truth, points)
            xs.append(num_samples)
            ys.append(y)
            a = [sum(tmp)/1000 for tmp in zip(*runtimes.values())]
            if y == 0:
                xs2.append(num_samples)
                ys2.append("0")

            granuality = 0.1
            ys_max = max(a)
            vs = [0] * (int(log10(ys_max) / granuality) + 1)
            for y in a:
                idy = int(log10(y)/granuality)
                vs[idy] += 1
            max_val = max(vs) * 2


            xs3 = [(name, str(num_samples))] + [(name, str(num_samples), -v/max_val) for v in vs] + \
                  [(name, str(num_samples))]
            xs4 = [(name, str(num_samples))] + [(name, str(num_samples), v/max_val) for v in vs] + \
                  [(name, str(num_samples))]
            ys_2 = [0] + [10**(k*granuality) for k in range(len(vs) + 1)]

            f1.patch(xs3 + xs4[::-1], 
                    ys_2 + ys_2[::-1], 
                    fill_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], line_color=None)
            f1.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)],
                       x=[(name, str(num_samples))], y=[get_median(a)], 
                       color="black", size=2)
    
        a = fl.line(x=[0], y=[0], color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        b = fl.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=[0], y=[0], 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)
        f.line(x=xs, y=ys, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        f.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=xs, y=ys, 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)
        items.append((name, [a, b]))
        f0.line(x=xs2, y=ys2, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        f0.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=xs2, y=ys2, 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)



    f.yaxis.axis_label = "mean deviation"
    f1.yaxis.axis_label = "runtime [ms]"
    #f.xaxis.major_label_text_color = None
    f0.xaxis.axis_label = "number of samples"
    f1.xaxis.axis_label = "number of samples"
    fl.add_layout(Legend(items=items, location="center", click_policy="hide"))
    fl.legend.click_policy="hide"
    fl.output_backend = OUTPUT_BACKEND
    f.output_backend = OUTPUT_BACKEND
    f0.output_backend = OUTPUT_BACKEND
    show(row([gridplot([[f, f1], [f0]]), fl]), notebook_handle=True)

def plot_corr_as_func_of_samples(key="ice", samples=NUM_SAMPLES_ICE):
    conditions_bin = [
        (
            data[(key, "GT", bin_size, "bin_vals")], 
            [(num_samples, data[(key, curr_win_size, num_samples, bin_size, "bin_vals")],
            data[(key, curr_win_size, num_samples, bin_size, "runtimes")]) for num_samples in samples], 
            key + " bin=" + str(bin_size//KBP) + "k win=" + str(curr_win_size//KBP) + "k",
            idxa,
            idxa
        ) 
        for idxa, bin_size in enumerate(BIN_SIZES)
    ]
    
    conditions_window = [
        (
            data[(key, "GT", curr_bin_size, "bin_vals")], 
            [(num_samples, data[("ice", window_size, num_samples, curr_bin_size, "bin_vals")],
            data[(key, window_size, num_samples, curr_bin_size, "runtimes")]) for num_samples in samples], 
            key + " bin=" + str(curr_bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
            idxb + len(BIN_SIZES),
            idxb + len(BIN_SIZES)
        ) 
        for idxb, window_size in enumerate(WINDOW_SIZES) if window_size != curr_win_size
    ]

    corr_as_func_of_samples(conditions_bin + conditions_window)
    #corr_as_func_of_samples(conditions_window)

plot_corr_as_func_of_samples("ice", NUM_SAMPLES_ICE)
#plot_corr_as_func_of_samples("grid_seq", NUM_SAMPLES_GRID_SEQ)
#plot_corr_as_func_of_samples("radicl_seq", NUM_SAMPLES_RADICL_SEQ)
#plot_corr_as_func_of_samples("ddd", NUM_SAMPLES_DDD)


- ICE:
    - window size does not affect results
    - bin size does
- Grid-Seq:
    - actually reaches zero


## Plot some of the heatmaps for visual verification

In [None]:
div_hic = 10000
div_radicl = 50000

DEFAULT_RANGE=(0, 3*curr_win_size // div_hic)
def plot_heatmap(datas, bg_color, x_range=DEFAULT_RANGE, y_range=DEFAULT_RANGE):
    fl = []
    for data, title, w_size in datas:
        if len(fl) == 0:
            f = figure(title=title, width=300, height=300)
        else:
            f = figure(title=title, x_range=fl[0].x_range, y_range=fl[0].y_range, width=300, height=300)
        d_filtered = {}
        for key, vals in data.items():
            d_filtered[key] = []
            for idx, v in enumerate(vals):
                if data["screen_left"][idx] >= x_range[0] and data["screen_right"][idx] <= x_range[1] and \
                data["screen_bottom"][idx] >= y_range[0] and data["screen_top"][idx] <= y_range[1]:
                    d_filtered[key].append(v)
        f.quad(
            left="screen_left",
            bottom="screen_bottom",
            right="screen_right",
            top="screen_top",
            fill_color="color",
            line_color=None,
            source=ColumnDataSource(data=d_filtered),
        )
        if not w_size is None:
            x_windows = range(0, x_range[1] + 1, w_size)
            y_windows = range(0, y_range[1] + 1, w_size)
            f.multi_line(
                xs=[[x, x] for x in x_windows],
                ys=[[y_range[0], y_range[1]] for x in x_windows],
                color="white",
            )
            f.multi_line(
                xs=[[x_range[0], x_range[1]] for y in y_windows],
                ys=[[y, y] for y in y_windows],
                color="white",
            )
        f.background_fill_color = bg_color
        f.grid.grid_line_color = None
        f.axis.ticker = []

        f.add_tools(
            HoverTool(
                tooltips=[
                    ("score", "@score_total"),
                    ("color", "@color"),
                    ("reads by group", "A: @score_a, B: @score_b"),
                    ("x", "@screen_left - @screen_right"),
                    ("y", "@screen_bottom - @screen_top"),
                ]
            )
        )
        #f.output_backend = OUTPUT_BACKEND
        fl.append(f)
    show(gridplot([fl]), notebook_handle=True)


RADICL_RANGE=(0, 3*curr_win_size // div_radicl)
plot_heatmap([
    (data[("raw_hic", "GT", curr_bin_size, "heatmap")], "raw data", None),
    (data[("ice", curr_win_size, NUM_SAMPLES_ICE[0], curr_bin_size, "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[0]), curr_win_size // div_hic),
    (data[("ice", curr_win_size, NUM_SAMPLES_ICE[1], curr_bin_size, "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[1]), curr_win_size // div_hic),
    (data[("ice", curr_win_size, NUM_SAMPLES_ICE[2], curr_bin_size, "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[2]), curr_win_size // div_hic),
    (data[("ice", curr_win_size, NUM_SAMPLES_ICE[3], curr_bin_size, "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[3]), curr_win_size // div_hic),
    (data[("ice", "GT", curr_bin_size, "heatmap")], "global", None), 
    ], "#440154")

plot_heatmap([
    (data[("raw_radicl", "GT", curr_bin_size, "heatmap")], "raw data", None),
    *[(data[("grid_seq", curr_win_size, num_sample, curr_bin_size, "heatmap")], "grid-seq - num samples = " + str(num_sample), curr_win_size // div_radicl) for num_sample in NUM_SAMPLES_GRID_SEQ],
    (data[("grid_seq", "GT", curr_bin_size, "heatmap")], "global", None), 
    ], "#440154",
    x_range=RADICL_RANGE,
    y_range=RADICL_RANGE)

plot_heatmap([
    (data[("raw_radicl", "GT", curr_bin_size, "heatmap")], "raw data", None),
    (data[("radicl_seq", curr_win_size, NUM_SAMPLES_RADICL_SEQ[0], curr_bin_size, "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[0]), curr_win_size // div_radicl),
    (data[("radicl_seq", curr_win_size, NUM_SAMPLES_RADICL_SEQ[1], curr_bin_size, "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[1]), curr_win_size // div_radicl),
    (data[("radicl_seq", curr_win_size, NUM_SAMPLES_RADICL_SEQ[2], curr_bin_size, "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[2]), curr_win_size // div_radicl),
    (data[("radicl_seq", curr_win_size, NUM_SAMPLES_RADICL_SEQ[3], curr_bin_size, "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[3]), curr_win_size // div_radicl),
    (data[("radicl_seq", "GT", curr_bin_size, "heatmap")], "global", None), 
    ], "#440154",
    x_range=RADICL_RANGE,
    y_range=RADICL_RANGE)

plot_heatmap([
    (data[("raw_hic", "GT", curr_bin_size, "heatmap")], "raw data", None),
    *[(data[("ddd", curr_win_size, num_sample, curr_bin_size, "heatmap")], "ddd - num samples = " + str(num_sample), curr_win_size // div_hic)  for num_sample in NUM_SAMPLES_GRID_SEQ],
    (data[("ddd", "GT", curr_bin_size, "heatmap")], "global", None), 
    ], "#440154")