In [None]:
import zarr
from scipy.spatial.distance import squareform
import numpy as np
import anjl
import pandas as pd
import plotly.express as px
import sys
import time
from contextlib import contextmanager
from tqdm.auto import tqdm


@contextmanager
def section(*msg):
    print(*msg, file=sys.stdout, end="... ")
    sys.stdout.flush()
    before = time.time()
    try:
        yield
    finally:
        after = time.time()
        duration = after - before
        print(f"{duration:.4f}s", file=sys.stdout)
        sys.stdout.flush()


def run_diagnostics(D, template="plotly"):
    with section("canonical"):
        _, timings_canonical, searched_canonical, visited_canonical = anjl.canonical_nj(
            D,
            diagnostics=True,
        )

    with section("rapid_gc100"):
        _, timings_rapid_gc100, searched_rapid_gc100, visited_rapid_gc100 = (
            anjl.rapid_nj(
                D,
                diagnostics=True,
                gc=100,
            )
        )

    with section("rapid_gc10"):
        _, timings_rapid_gc10, searched_rapid_gc10, visited_rapid_gc10 = anjl.rapid_nj(
            D,
            diagnostics=True,
            gc=10,
        )

    with section("rapid_nogc"):
        _, timings_rapid_nogc, searched_rapid_nogc, visited_rapid_nogc = anjl.rapid_nj(
            D,
            diagnostics=True,
            gc=None,
        )

    df_canonical = pd.DataFrame(
        {
            "time": timings_canonical,
            "searched": searched_canonical,
            "visited": visited_canonical,
            "iteration": np.arange(len(timings_canonical)),
            "algorithm": "canonical",
        }
    )
    df_rapid_gc100 = pd.DataFrame(
        {
            "time": timings_rapid_gc100,
            "searched": searched_rapid_gc100,
            "visited": visited_rapid_gc100,
            "iteration": np.arange(len(timings_rapid_gc100)),
            "algorithm": "rapid_gc100",
        }
    )
    df_rapid_gc10 = pd.DataFrame(
        {
            "time": timings_rapid_gc10,
            "searched": searched_rapid_gc10,
            "visited": visited_rapid_gc10,
            "iteration": np.arange(len(timings_rapid_gc10)),
            "algorithm": "rapid_gc10",
        }
    )
    df_rapid_nogc = pd.DataFrame(
        {
            "time": timings_rapid_nogc,
            "searched": searched_rapid_nogc,
            "visited": visited_rapid_nogc,
            "iteration": np.arange(len(timings_rapid_nogc)),
            "algorithm": "rapid_nogc",
        }
    )

    df_diagnostics = pd.concat(
        [
            df_rapid_gc100,
            df_rapid_gc10,
            df_rapid_nogc,
            df_canonical,
        ],
        axis=0,
    )
    df_diagnostics["time_per_search"] = df_diagnostics.eval("time / searched")
    df_diagnostics["time_per_visit"] = df_diagnostics.eval("time / visited")
    display(
        df_diagnostics.groupby("algorithm").agg(
            {
                "time": "sum",
                "visited": "sum",
                "searched": "sum",
                "time_per_visit": "median",
                "time_per_search": "median",
            }
        )
    )

    fig = px.line(
        df_diagnostics,
        x="iteration",
        y="time",
        color="algorithm",
        template=template,
        render_mode="svg",
    )
    fig.update_yaxes(range=[0, None])
    fig.show()

    fig = px.line(
        df_diagnostics,
        x="iteration",
        y="searched",
        color="algorithm",
        template=template,
        render_mode="svg",
    )
    fig.update_yaxes(range=[0, None])
    fig.show()

    fig = px.line(
        df_diagnostics,
        x="iteration",
        y="visited",
        color="algorithm",
        template=template,
        render_mode="svg",
    )
    fig.update_yaxes(range=[0, None])
    fig.show()

    fig = px.line(
        df_diagnostics,
        x="iteration",
        y="time_per_search",
        color="algorithm",
        template=template,
        render_mode="svg",
    )
    fig.update_yaxes(range=[0, 30e-9])
    fig.show()

    fig = px.line(
        df_diagnostics,
        x="iteration",
        y="time_per_visit",
        color="algorithm",
        template=template,
        render_mode="svg",
    )
    fig.update_yaxes(range=[0, 30e-9])
    fig.show()

## Small

In [None]:
small = zarr.load("../data/small/dist.zarr.zip")
small_D = squareform(small)
small_D.shape

In [None]:
small_Z = anjl.canonical_nj(small_D)

In [None]:
small_Z_r = anjl.rapid_nj(small_D, gc=None)

In [None]:
small_Z_r = anjl.rapid_nj(small_D, gc=1)

In [None]:
%%timeit -r500 -n1
anjl.canonical_nj(small_D)

In [None]:
%%timeit -r500 -n1
anjl.rapid_nj(small_D, gc=None)

In [None]:
%%timeit -r500 -n1
anjl.rapid_nj(small_D, gc=100)

In [None]:
run_diagnostics(small_D)

## Medium

In [None]:
medium = zarr.load("../data/medium/dist.zarr.zip")
medium_D = squareform(medium)
medium_D.shape

In [None]:
%%time
medium_Z = anjl.canonical_nj(medium_D)

In [None]:
%%time
medium_Z_r = anjl.rapid_nj(medium_D)

In [None]:
%%timeit -r200 -n1
anjl.canonical_nj(medium_D)

In [None]:
%%timeit -r100 -n1
anjl.rapid_nj(medium_D, gc=None)

In [None]:
%%timeit -r100 -n1
anjl.rapid_nj(medium_D, gc=10)

In [None]:
%%timeit -r100 -n1
anjl.rapid_nj(medium_D, gc=100)

In [None]:
%%timeit -r100 -n1
anjl.rapid_nj(medium_D, gc=1)

In [None]:
run_diagnostics(medium_D)

In [None]:
run_diagnostics(medium_D)

## Large

In [None]:
large = zarr.load("../data/large/dist.zarr.zip")
large_D = squareform(large)
shuffle = np.random.choice(large_D.shape[0], size=2000, replace=False)
large_D_shuffled = large_D.take(shuffle, axis=0).take(shuffle, axis=1)
run_diagnostics(large_D_shuffled)

In [None]:
%%timeit -r5 -n1
anjl.canonical_nj(large_D_shuffled)

In [None]:
%%timeit -r5 -n1
anjl.rapid_nj(large_D_shuffled, gc=10)

In [None]:
%%timeit -r5 -n1
anjl.rapid_nj(large_D_shuffled, gc=100)

In [None]:
%%timeit -r5 -n1
anjl.rapid_nj(large_D_shuffled, gc=200)

In [None]:
%%timeit -r5 -n1
anjl.rapid_nj(large_D_shuffled, gc=None)

In [None]:
%%timeit -r3 -n1
np.argsort(large_D_shuffled, axis=1)

## XXL

In [None]:
xxl = zarr.load("../data/xxl/dist.zarr.zip")
xxl_D = squareform(xxl)
# xxl_shuffle = np.random.choice(xxl_D.shape[0], size=7000, replace=False)
# xxl_D_shuffled = xxl_D.take(xxl_shuffle, axis=0).take(xxl_shuffle, axis=1)
# run_diagnostics(xxl_D_shuffled)

In [None]:
%%time
anjl.rapid_nj(xxl_D, gc=100, progress=tqdm)

In [None]:
%%time
anjl.canonical_nj(xxl_D, progress=tqdm)