# Test normalizations
## basic settings and imports first

In [28]:
import libbiosmoother

from bokeh.plotting import figure
from bokeh.palettes import viridis
from bokeh.io import show, output_notebook
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.layouts import column, row, gridplot

import json
import sys
import shelve
import random
from bokeh.models import Legend


output_notebook()

KBP = 1000
MBP = 1000 * KBP

WINDOW_SIZES = [1 * MBP, 5 * MBP, 10 * MBP]
BIN_SIZES = [50 * KBP, 100 * KBP, 500 * KBP]
OUTPUT_BACKEND = "svg"
if False: # fast
    WINDOW_SIZES = [10 * MBP]
    BIN_SIZES = [500 * KBP]
    OUTPUT_BACKEND = "canvas"
NUM_SAMPLES_ICE = [2**x for x in range(9)]
NUM_SAMPLES_GRID_SEQ = [2**x for x in range(9)]
NUM_SAMPLES_RADICL_SEQ = [4**x for x in range(8)]
NUM_SAMPLES_DDD = [4**x for x in range(6)]

CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME = False

## define basic evaluation functions

In [29]:
def lib_sps_print(s):
    pass

def conf_quarry_basic(quarry):
    #warnings.filterwarnings('ignore')
    with libbiosmoother.open_default_json() as default_file:
        default_json = json.load(default_file)
        quarry.set_value(["settings"], default_json)
    quarry.set_value(["settings", "filters", "cut_off_bin"], "fit_chrom_smaller")
    quarry.set_value(["settings", "filters", "show_contig_smaller_than_bin"], True)
    quarry.set_value(["settings", "interface", "fixed_bin_size"], True)
    quarry.set_value(["settings", "interface", "add_draw_area", "val"], 0)
    quarry.set_value(["settings", "normalization", "scale"], "dont")
    quarry.set_value(["settings", "filters", "symmetry"], "all")


def conf_quarry_data(quarry):
    quarry.set_value(["settings", "normalization", "log_base", "val"], 0)

def conf_quarry_heatmap(quarry, max_val):
    quarry.set_value(["settings", "normalization", "color_range", "val_max"], max_val)
    quarry.set_value(["settings", "normalization", "log_base", "val"], 10)


def set_bin_size(quarry, bin_size):
    div = quarry.get_value(["dividend"])
    if bin_size % div != 0:
        print("WARNING: uneven division by index dividend", file=sys.stderr)
    if bin_size < div:
        print("WARNING: dividend larger than value", file=sys.stderr)
    bin_size = max(1, bin_size // div)
    quarry.set_value(["settings", "interface", "fixed_bin_size_x", "val"], bin_size)
    quarry.set_value(["settings", "interface", "fixed_bin_size_y", "val"], bin_size)

def tsv_to_ret(data, key, tsv):
    ret = [(x[:-1], x[-1]) for x in tsv]
    ret.sort()
    if CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME:
        data[key + ("bin_coords", )] = [a for a, _ in ret]
    data[key + ("bin_vals", )] = [b for _, b in ret]
    return max([b for _, b in ret])

def quarry_whole_window(data, key, quarry):
    canvas_size_x, canvas_size_y = quarry.get_canvas_size(lib_sps_print)
    quarry.set_value(["area"], {"x_start": 0, "x_end": canvas_size_x, "y_start": 0, "y_end": canvas_size_y})
    
    conf_quarry_data(quarry)
    max_v = tsv_to_ret(data, key, quarry.get_heatmap_export(lib_sps_print))
    
    conf_quarry_heatmap(quarry, max_v)
    data[key + ("heatmap", )] = quarry.get_heatmap(lib_sps_print)

def print_window_amount(quarry, window_size):
    canvas_size_x, canvas_size_y = quarry.get_canvas_size(lambda s: None)
    div = quarry.get_value(["dividend"])
    print("num windows on genome:", canvas_size_x / (window_size // div), "x", canvas_size_y / (window_size // div))

def quarry_chunked_window(data, data_key, quarry, window_size):
    canvas_size_x, canvas_size_y = quarry.get_canvas_size(lambda s: None)
    div = quarry.get_value(["dividend"])
    tsv = []
    runtimes = {}
    heatmap = None
    for x_start in range(0, canvas_size_x, window_size // div):
        for y_start in range(0, canvas_size_y, window_size // div):
            quarry.set_value(["area"], {"x_start": x_start, "x_end": min(canvas_size_x, x_start + window_size // div), 
                                        "y_start": y_start, "y_end":  min(canvas_size_y, y_start + window_size // div)})
            conf_quarry_data(quarry)
            quarry.clear_cache()
            tsv.extend(quarry.get_heatmap_export(lib_sps_print))
            for k, v in quarry.get_runtimes():
                if not k in runtimes:
                    runtimes[k] = []
                runtimes[k].append(v)
            
    max_v = tsv_to_ret(data, data_key, tsv)
    for x_start in range(0, canvas_size_x, window_size // div):
        for y_start in range(0, canvas_size_y, window_size // div):
            quarry.set_value(["area"], {"x_start": x_start, "x_end": min(canvas_size_x, x_start + window_size // div), 
                                        "y_start": y_start, "y_end":  min(canvas_size_y, y_start + window_size // div)})
            conf_quarry_heatmap(quarry, max_v)
            heatmap_local = quarry.get_heatmap(lib_sps_print)
            if heatmap is None:
                heatmap = heatmap_local
            else:
                for key, val in heatmap_local.items():
                    heatmap[key].extend(val)
    data[data_key + ("heatmap", )] = heatmap
    data[data_key + ("runtimes", )] = runtimes


def get_mean_dev(ground_truth, points):
    return sum(abs(a-b) for a, b in zip(ground_truth, points)) / len(points)

class ShelveTupleDict:
    def __init__(self, filename, flag='c'):
        self.data = shelve.open(filename, flag=flag)
        
    def __to_to_key(self, tup):
        return ".".join(str(x) for x in tup)
    
    def __getitem__(self, tup):
        return self.data[self.__to_to_key(tup)]
    
    def __setitem__(self, tup, val):
        self.data[self.__to_to_key(tup)] = val

    def close(self):
        self.data.close()

## Load index and compute data

In [30]:
print("loading file")
quarry = libbiosmoother.Quarry("../../smoother_out/radicl.smoother_index")
conf_quarry_basic(quarry)

data = ShelveTupleDict("norm_corelation.shelf")

for bin_size in BIN_SIZES:
    print("bin_size", bin_size // KBP, "k")

    set_bin_size(quarry, bin_size)
    print("grid-seq")
    quarry.set_value(["settings", "normalization", "normalize_by"], "grid-seq")
    quarry.set_value(["settings", "normalization", "grid_seq_global"], True)
    quarry_whole_window(data, ("grid-seq", "GT", bin_size), quarry)

    print("radicl-seq")
    quarry.set_value(["settings", "normalization", "normalize_by"], "radicl-seq")
    quarry.set_value(["settings", "normalization", "radicl_seq_global"], True)
    quarry_whole_window(data, ("radicl-seq", "GT", bin_size), quarry)

    print("ddd")
    quarry.set_value(["settings", "normalization", "normalize_by"], "dont")
    quarry.set_value(["settings", "normalization", "ddd"], True)
    # max out num samples for the default setting
    quarry.set_value(["settings", "normalization", "ddd_samples", "val_min"], 0)
    quarry.set_value(["settings", "normalization", "ddd_all_samples"], True)
    quarry_whole_window(data, ("ddd", "GT", bin_size), quarry)
    quarry.set_value(["settings", "normalization", "ddd"], False)

    print("cooler")
    quarry.set_value(["settings", "normalization", "normalize_by"], "cool-ice")
    quarry_whole_window(data, ("ICE", "cooler", bin_size), quarry)
    
    print("local-ice")
    quarry.set_value(["settings", "normalization", "ice_local"], True)
    quarry.set_value(["settings", "normalization", "normalize_by"], "ice")
    quarry_whole_window(data, ("ICE", "local", bin_size), quarry)

    print("raw data")
    quarry.set_value(["settings", "normalization", "normalize_by"], "dont")
    quarry_whole_window(data, ("raw", bin_size), quarry)


    for window_size in WINDOW_SIZES:
        print("window_size", window_size // KBP, "k")
        print_window_amount(quarry, window_size)

        for num_samples in NUM_SAMPLES_ICE:
            print("ice", num_samples, "samples")
            quarry.set_value(["settings", "normalization", "ice_local"], False)
            quarry.set_value(["settings", "normalization", "normalize_by"], "ice")
            quarry.set_value(["settings", "normalization", "num_ice_bins", "val"], num_samples)
            quarry_chunked_window(data, ("ICE", "global", bin_size, window_size, num_samples), quarry, window_size)

            if CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME:
                assert data[("ICE", "cooler", bin_size, "bin_coords")] == \
                    data[("ICE", "global", bin_size, window_size, num_samples, "bin_coords")]

        for num_samples in NUM_SAMPLES_GRID_SEQ:
            print("grid-seq", num_samples, "samples")
            quarry.set_value(["settings", "normalization", "normalize_by"], "grid-seq")
            quarry.set_value(["settings", "normalization", "grid_seq_global"], False)
            quarry.set_value(["settings", "normalization", "grid_seq_samples", "val"], num_samples)
            quarry_chunked_window(data, ("grid-seq", "sampled", bin_size, window_size, num_samples), quarry, 
                                  window_size)

        for num_samples in NUM_SAMPLES_RADICL_SEQ:
            print("radicl-seq", num_samples, "samples")
            quarry.set_value(["settings", "normalization", "normalize_by"], "radicl-seq")
            quarry.set_value(["settings", "normalization", "radicl_seq_global"], False)
            quarry.set_value(["settings", "normalization", "radicl_seq_samples", "val"], num_samples)
            quarry_chunked_window(data, ("radicl-seq", "sampled", bin_size, window_size, num_samples), quarry, 
                                  window_size)

            if CHECK_THAT_HEATMAPS_ARE_EXACTLY_THE_SAME:
                assert data[("radicl-seq", "GT", bin_size, "bin_coords")] == \
                    data[("radicl-seq", "sampled", bin_size, window_size, num_samples, "bin_coords")]

        for num_samples in NUM_SAMPLES_DDD:
            print("ddd", num_samples, "samples")
            quarry.set_value(["settings", "normalization", "normalize_by"], "dont")
            quarry.set_value(["settings", "normalization", "ddd_all_samples"], False)
            quarry.set_value(["settings", "normalization", "ddd_samples", "val_min"], 0)
            quarry.set_value(["settings", "normalization", "ddd_samples", "val_max"], num_samples)
            quarry.set_value(["settings", "normalization", "ddd"], True)
            quarry_chunked_window(data, ("ddd", "sampled", bin_size, window_size, num_samples), quarry, window_size)
            quarry.set_value(["settings", "normalization", "ddd"], False)

data.close()

loading file
bin_size 50 k
grid-seq
radicl-seq
ddd
cooler
local-ice
raw data
window_size 1000 k
num windows on genome: 42.45 x 42.45
ice 1 samples
ice 2 samples
ice 4 samples


## Checkpoint

In [None]:
data = ShelveTupleDict("norm_corelation.shelf", 'r')

## Investigate the scatter plot for one bin and window size

Expect a bad correleation for a low number of samples, it should then gradually improve with the number of samples.

In [None]:
ALMOST_ZERO = 10**-5
def plot_scatter_points(ground_truth, data, title, x_range=(ALMOST_ZERO, 10**0), y_range=(ALMOST_ZERO, 10**5)):
    palette = viridis(sum(1 if c is None else 0 for _1, _2, c in data))
    f = figure(
            title=title, 
            x_axis_type="log", 
            y_axis_type="log", 
            x_range=x_range, 
            y_range=y_range,
            height=300,
            width=500
        )
    f.line(x=[ALMOST_ZERO,100], y=[ALMOST_ZERO,100], color="black")
    items = []
    idx = 0
    for name, points, color in data:
        xs = []
        ys = []
        for (x, y) in random.sample(list(zip(ground_truth, points)), min(len(points), 250)):
            xs.append(x)
            ys.append(y)
        mean_dev = round(get_mean_dev(ground_truth, points), 5)
        d = f.dot(x=xs, y=ys, color=palette[idx] if color is None else color, size=25, alpha=0.5)
        idx += 1 if color is None else 0
        items.append((name + " dev: " + str(mean_dev), [d]))

    f.xaxis.axis_label = "ground truth"
    f.yaxis.axis_label = "sample"
    f.add_layout(Legend(items=items, location="center", click_policy="hide"), "right")
    f.output_backend = OUTPUT_BACKEND
    show(f)


plot_scatter_points(
    data[("ICE", "local", BIN_SIZES[0], "bin_vals")], 
    [("unnormalized", data[("raw", BIN_SIZES[0], "bin_vals")], "red")] +
    [("num samples = " + str(num_samples), 
      data[("ICE", "global", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_ICE[:-1]], 
    "icing - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k" )

plot_scatter_points(
    data[("grid-seq", "GT", BIN_SIZES[0], "bin_vals")], 
    [("unnormalized", data[("raw", BIN_SIZES[0], "bin_vals")], "red")] +
    [("num samples = " + str(num_samples), 
      data[("grid-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_GRID_SEQ#[1, 4, 16, 32, 64, 128]
      ], 
    "grid-seq - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k",
    x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

plot_scatter_points(
    data[("radicl-seq", "GT", BIN_SIZES[0], "bin_vals")], 
    [("unnormalized", data[("raw", BIN_SIZES[0], "bin_vals")], "red")] +
    [("num samples = " + str(num_samples), 
      data[("radicl-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_RADICL_SEQ], 
    "radicl-seq - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k",
    x_range=(10**-3,10**2), y_range=(10**-3, 10**4))

plot_scatter_points(
    data[("ddd", "GT", BIN_SIZES[0], "bin_vals")], 
    [("unnormalized", data[("raw", BIN_SIZES[0], "bin_vals")], "red")] +
    [("num samples = " + str(num_samples), 
      data[("ddd", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], num_samples, "bin_vals")], None) for num_samples in NUM_SAMPLES_DDD], 
    "distance dependent decay - stability num samples; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k" )


plot_scatter_points(
    data[("ICE", "cooler", BIN_SIZES[0], "bin_vals")], 
    [("", data[("ICE", "local", BIN_SIZES[0], "bin_vals")], "blue")], 
    "icing - cooler vs my implementation; bin-size= " + str(BIN_SIZES[0] // KBP) + "k window-size= " + str(WINDOW_SIZES[0]//KBP) + "k" )

mean deviation becomes smaller with increasing number of samples and approaches the origin diagonal for all normalization methods

for now ice does never reach perfect equality when it its windowed

## Plot mean deviation as a function of the number of samples for all bin and window sizes

In [None]:
COLOR_PALETTE = ["#0072B2", "#D55E00", "#009E73", "#E69F00", "#CC79A7", "#56B4E9", "#F0E442"]
SCATTER_PALETTE = ["x", "cross", "circle", "dash"]
def corr_as_func_of_samples(conditions):
    width = 200
    f = figure(
            y_axis_type="log", 
            x_axis_type="log",
            width=width,
            height=180,
        )
    f1 = figure(
            y_axis_type="log", 
            x_axis_type="log",
            width=width,
            height=180,
        )
    f0 = figure(
            x_range=f.x_range,
            y_range=["0"],
            height=75,
            x_axis_type="log",
            width=width,
        )
    fl = figure(width=400, height=400)
    items = []
    for ground_truth, sample, name, idxa, idxb in conditions:
        xs = []
        ys = []
        ys1 = []
        ys1_average = []
        xs1 = []
        xs2 = []
        ys2 = []
        for num_samples, points, runtimes in sample:
            y = get_mean_dev(ground_truth, points)
            xs.append(num_samples)
            ys.append(y)
            a = [sum(tmp) for tmp in zip(*runtimes.values())]
            ys1.extend(a)
            xs1.extend([num_samples] * len(a))
            ys1_average.append(sum(a) / len(a))
            if y == 0:
                xs2.append(num_samples)
                ys2.append("0")
        a = fl.line(x=[0], y=[0], color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        b = fl.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=[0], y=[0], 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)
        f.line(x=xs, y=ys, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        f.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=xs, y=ys, 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)
        f1.line(x=xs, y=ys1_average, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        f1.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=xs1, y=ys1, 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10, line_alpha=0.5)
        items.append((name, [a, b]))
        f0.line(x=xs2, y=ys2, color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)])
        f0.scatter(marker=SCATTER_PALETTE[idxb % len(SCATTER_PALETTE)], x=xs2, y=ys2, 
                    fill_color=None, line_color=COLOR_PALETTE[idxa % len(COLOR_PALETTE)], size=10)
    f.yaxis.axis_label = "mean deviation"
    f1.yaxis.axis_label = "runtime [ms]"
    f.xaxis.major_label_text_color = None
    f0.xaxis.axis_label = "number of samples"
    f1.xaxis.axis_label = "number of samples"
    fl.add_layout(Legend(items=items, location="center", click_policy="hide"))
    fl.legend.click_policy="hide"
    fl.output_backend = OUTPUT_BACKEND
    f.output_backend = OUTPUT_BACKEND
    f0.output_backend = OUTPUT_BACKEND
    show(row([gridplot([[f, f1], [f0]]), fl]), notebook_handle=True)

conditions_ice = [
    (
        data[("ICE", "local", bin_size, "bin_vals")], 
        [(num_samples, data[("ICE", "global", bin_size, window_size, num_samples, "bin_vals")],
          data[("ICE", "global", bin_size, window_size, num_samples, "runtimes")]) for num_samples in NUM_SAMPLES_ICE], 
        "ICE bin=" + str(bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
        idxa,
        idxb
     ) 
     for idxa, bin_size in enumerate(BIN_SIZES) for idxb, window_size in enumerate(WINDOW_SIZES)
]

conditions_grid_seq = [
    (
        data[("grid-seq", "GT", bin_size, "bin_vals")], 
        [(num_samples, data[("grid-seq", "sampled", bin_size, window_size, num_samples, "bin_vals")],
          data[("grid-seq", "sampled", bin_size, window_size, num_samples, "runtimes")]) for num_samples in NUM_SAMPLES_GRID_SEQ], 
        "grid-s bin=" + str(bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
        idxa,
        idxb
     ) 
     for idxa, bin_size in enumerate(BIN_SIZES) for idxb, window_size in enumerate(WINDOW_SIZES)
]
conditions_radicl_seq = [
    (
        data[("radicl-seq", "GT", bin_size, "bin_vals")], 
        [(num_samples, data[("radicl-seq", "sampled", bin_size, window_size, num_samples, "bin_vals")],
          data[("radicl-seq", "sampled", bin_size, window_size, num_samples, "runtimes")]) for num_samples in NUM_SAMPLES_RADICL_SEQ], 
        "radicl-s bin=" + str(bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
        idxa,
        idxb
     ) 
     for idxa, bin_size in enumerate(BIN_SIZES) for idxb, window_size in enumerate(WINDOW_SIZES)
]

conditions_ddd = [
    (
        data[("ddd", "GT", bin_size, "bin_vals")], 
        [(num_samples, data[("ddd", "sampled", bin_size, window_size, num_samples, "bin_vals")],
          data[("ddd", "sampled", bin_size, window_size, num_samples, "runtimes")]) for num_samples in NUM_SAMPLES_DDD], 
        "ddd bin=" + str(bin_size//KBP) + "k win=" + str(window_size//KBP) + "k",
        idxa,
        idxb
     ) 
     for idxa, bin_size in enumerate(BIN_SIZES) for idxb, window_size in enumerate(WINDOW_SIZES)
]

corr_as_func_of_samples(conditions_ice)
corr_as_func_of_samples(conditions_grid_seq)
corr_as_func_of_samples(conditions_radicl_seq)
corr_as_func_of_samples(conditions_ddd)


- ICE:
    - window size does not affect results
    - bin size does
- Grid-Seq:
    - actually reaches zero


## Plot some of the heatmaps for visual verification

In [None]:
div = quarry.get_value(["dividend"])
DEFAULT_RANGE=(0, 3*WINDOW_SIZES[0] // div)
def plot_heatmap(datas, bg_color, x_range=DEFAULT_RANGE, y_range=DEFAULT_RANGE):
    fl = []
    for data, title, w_size in datas:
        if len(fl) == 0:
            f = figure(title=title, width=300, height=300)
        else:
            f = figure(title=title, x_range=fl[0].x_range, y_range=fl[0].y_range, width=300, height=300)
        d_filtered = {}
        for key, vals in data.items():
            d_filtered[key] = []
            for idx, v in enumerate(vals):
                if data["screen_left"][idx] >= x_range[0] and data["screen_right"][idx] < x_range[1] and \
                data["screen_bottom"][idx] >= y_range[0] and data["screen_top"][idx] < y_range[1]:
                    d_filtered[key].append(v)
        f.quad(
            left="screen_left",
            bottom="screen_bottom",
            right="screen_right",
            top="screen_top",
            fill_color="color",
            line_color=None,
            source=ColumnDataSource(data=d_filtered),
        )
        if not w_size is None:
            x_windows = range(0, x_range[1] + 1, w_size)
            y_windows = range(0, y_range[1] + 1, w_size)
            f.multi_line(
                xs=[[x, x] for x in x_windows],
                ys=[[y_range[0], y_range[1]] for x in x_windows],
                color="white",
            )
            f.multi_line(
                xs=[[x_range[0], x_range[1]] for y in y_windows],
                ys=[[y, y] for y in y_windows],
                color="white",
            )
        f.background_fill_color = bg_color
        f.grid.grid_line_color = None
        f.axis.ticker = []

        f.add_tools(
            HoverTool(
                tooltips=[
                    (
                        "(x, y)",
                        "(@chr_x @index_left .. @index_right, @chr_y @index_bottom .. @index_top)",
                    ),
                    ("score", "@score_total"),
                    ("color", "@color"),
                    ("reads by group", "A: @score_a, B: @score_b"),
                ]
            )
        )
        #f.output_backend = OUTPUT_BACKEND
        fl.append(f)
    show(gridplot([fl]), notebook_handle=True)

if len(NUM_SAMPLES_ICE) > 0:
    plot_heatmap([
        (data[("ICE", "local", BIN_SIZES[0], "heatmap")], "global", None), 
        (data[("ICE", "global", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_ICE[-1], "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[-1]), WINDOW_SIZES[0] // div),
        (data[("ICE", "global", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_ICE[3], "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[3]), WINDOW_SIZES[0] // div),
        (data[("ICE", "global", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_ICE[0], "heatmap")], "ice - num samples = " + str(NUM_SAMPLES_ICE[0]), WINDOW_SIZES[0] // div),
        (data[("raw", BIN_SIZES[0], "heatmap")], "raw data", None),
        ], "#440154")

if len(NUM_SAMPLES_GRID_SEQ) > 0:
    plot_heatmap([
        (data[("grid-seq", "GT", BIN_SIZES[0], "heatmap")], "global", None), 
        (data[("grid-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_GRID_SEQ[-1], "heatmap")], "grid-seq - num samples = " + str(NUM_SAMPLES_GRID_SEQ[-1]), WINDOW_SIZES[0] // div),
        (data[("grid-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_GRID_SEQ[0], "heatmap")], "grid-seq - num samples = " + str(NUM_SAMPLES_GRID_SEQ[0]), WINDOW_SIZES[0] // div),
        (data[("raw", BIN_SIZES[0], "heatmap")], "raw data", None),
        ], "#440154")

if len(NUM_SAMPLES_RADICL_SEQ) > 0:
    plot_heatmap([
        (data[("radicl-seq", "GT", BIN_SIZES[0], "heatmap")], "global", None), 
        (data[("radicl-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_RADICL_SEQ[-1], "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[-1]), WINDOW_SIZES[0] // div),
        (data[("radicl-seq", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_RADICL_SEQ[0], "heatmap")], "radicl-seq - num samples = " + str(NUM_SAMPLES_RADICL_SEQ[0]), WINDOW_SIZES[0] // div),
        ], "#440154")

if len(NUM_SAMPLES_DDD) > 0:
    plot_heatmap([
        (data[("ddd", "GT", BIN_SIZES[0], "heatmap")], "global", None), 
        (data[("ddd", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_DDD[-1], "heatmap")], "ddd - num samples = " + str(NUM_SAMPLES_DDD[-1]), WINDOW_SIZES[0] // div),
        (data[("ddd", "sampled", BIN_SIZES[0], WINDOW_SIZES[0], NUM_SAMPLES_DDD[0], "heatmap")], "ddd - num samples = " + str(NUM_SAMPLES_DDD[0]), WINDOW_SIZES[0] // div),
        (data[("raw", BIN_SIZES[0], "heatmap")], "raw data", None),
        ], "#440154")