In [1]:
import sys
sys.path.append('../')

import numpy as np
import itertools
import time
from knock_off.association_measures import distance_corr, projection_corr, tr, HSIC, MMD, pearson_correlation
from tqdm import trange, tqdm
from memory_profiler import memory_usage
import plotly.graph_objects as go


In [3]:

n_list = [100, 200, 300, 400, 500]
p = 500 # in a way, the number of repetitions

AM = [
    ("DC", distance_corr),
    ("PC", projection_corr),
    ("TR", tr),
    ("HSIC", HSIC),
    ("MMD", MMD),
    ("Pearson", pearson_correlation)
]

recordings = {name: [] for name, func in AM}
memory = {name: [] for name, func in AM}

In [4]:

for n in tqdm(n_list):
    X = np.random.randn(n, p)
    Y = np.random.randn(n, 1)
    for name, am in AM:
        start = time.time()
        mem_usage = memory_usage((am, (X, Y)))
        elapse_time = time.time() - start
        recordings[name].append(elapse_time / p)
        memory[name].append(max(mem_usage))


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [2:01:33<00:00, 1458.72s/it]


In [16]:

fig_cpu = go.Figure()
fig_mem = go.Figure()

colors = {"MMD": "rgb(0, 186, 56)",
        "TR": "rgb(245, 100, 227)",
        "HSIC": "rgb(183, 159, 0)",
        "Pearson": "rgb(97, 156, 255)",
        "PC": "rgb(0, 191, 196)",
        "DC": "rgb(248, 118, 109)"}

for name, col in colors.items():
    fig_cpu.add_trace(
        go.Scatter(x=n_list,
                   y=recordings[name],
                   marker=dict(color=col),
                   name=name
                  )
    )
    fig_mem.add_trace(
        go.Scatter(x=n_list,
                   y=memory[name],
                   marker=dict(color=col),
                   name=name
                  )
    )

fig_cpu.update_layout(template="ggplot2", legend_title_text='Algorithm',
                title={
                    'text': "Monitoring cpu time",
                    'x': 0.85,
                    'y': 0.88},
                font=dict(
                    size=22
                ))

fig_cpu.update_xaxes(title="n")
fig_cpu.update_yaxes(title="time (s)")

fig_mem.update_layout(template="ggplot2", legend_title_text='Algorithm',
                title={
                    'text': "Monitoring RAM consumption",
                    'x': 0.85,
                    'y': 0.88},
                font=dict(
                    size=22
                ))

fig_mem.update_xaxes(title="n")
fig_mem.update_yaxes(title="RAM consumption (Mb)")

fig_cpu.show()
fig_mem.show()