In [1]:
from sklearn.metrics.pairwise import rbf_kernel
import numpy as np
import matplotlib.pyplot as plt
from mmdew.fast_rbf_kernel import est_gamma
from mmdew.mmdew import MMDEW
from tqdm import tqdm
import pickle
import pandas as pd
from notebooks.data import MixedNormal, Uniform, Laplace

In [2]:
d=20
ref_size=1000
rng = np.random.default_rng()

In [3]:
with open('mmdew-statistics.pickle', 'rb') as handle:
    statistics = pickle.load(handle)

In [4]:
target_arls_log = np.arange(3,5.1,.25)
arl2thresh = { i : np.quantile(statistics, 1-(1/10**i)) for i in target_arls_log}

In [5]:
def edd(arl2thresh, statistics):
    arl2edd = {}
    for arl, thresh in arl2thresh.items():
        edd = [np.argmax(s + [np.inf]>thresh) for s in h1_stats]
        arl2edd[arl] = np.mean(edd) + 1 # account for counting from 0
    return arl2edd

In [6]:
d = 20
n_q = 500
qs = {
    "MixedNormal0.3" : MixedNormal(n_q,d,0.3),
    "MixedNormal0.7" : MixedNormal(n_q,d,0.7),
    "Laplace"        : Laplace(n_q,d),
    "Uniform"        : Uniform(n_q,d)
}

In [7]:
df = pd.DataFrame()

In [8]:
for name, q in qs.items():
    h1_stats = []
    for _ in tqdm(range(100)):
        ref = rng.normal(size=(10000,d))
        gamma = est_gamma(ref)
        detector = MMDEW(gamma=gamma)
        
        for elem in ref[:64]:
            detector.insert(elem.reshape(1,-1))
            
        for elem in q.draw():
            detector.insert(elem.reshape(1,-1))
        h1_stats += [detector.stats[64:]]
    df = pd.concat((df, pd.DataFrame(edd(arl2thresh=arl2thresh, statistics=h1_stats), index=[name])))

100%|█████████████████████████████████████████| 100/100 [01:51<00:00,  1.11s/it]
100%|█████████████████████████████████████████| 100/100 [01:59<00:00,  1.20s/it]
100%|█████████████████████████████████████████| 100/100 [01:58<00:00,  1.19s/it]
100%|█████████████████████████████████████████| 100/100 [02:01<00:00,  1.22s/it]


In [9]:
df = df.reset_index(names="data")

In [10]:
df

Unnamed: 0,data,3.0,3.25,3.5,3.75,4.0,4.25,4.5,4.75,5.0
0,MixedNormal0.3,1.62,1.62,1.64,1.65,1.68,1.71,1.79,1.82,1.82
1,MixedNormal0.7,3.06,3.14,3.15,3.22,3.25,3.35,3.38,3.46,3.46
2,Laplace,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.03
3,Uniform,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [11]:
df = df.melt(id_vars="data",var_name="logARL",value_name="EDD")

In [12]:
df["algorithm"] = "MMDEW"

In [13]:
df.to_csv("../results_rebuttal/arl-vs-edd/mmdew.csv")