In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
from alibi_detect.cd import CVMDriftOnline, MMDDriftOnline, FETDriftOnline

Importing plotly failed. Interactive plots will not work.


# Continuous data

In [2]:
ntimes = 1
times = np.empty([ntimes])

ert = 150  # Desired expected runtime
N = 100 # Size of reference set
B = 10000 # Number of bootstrap simulations for threshold estimation. 100k more realistic?
window_sizes = [10,20]
n_repeats = 200  # Number of instances in stream

# Note how these are diff to the Gaussian used for configuration.
x_ref = np.random.uniform(0, 1, N)
stream_h0 = iter(lambda: np.random.uniform(), 1)
stream_h1 = iter(lambda: np.random.exponential(), 1)

for t in range(ntimes):
    t0 = time.time()
    ddcvm = CVMDriftOnline(x_ref, ert=ert, window_size=window_sizes, n_bootstraps=B, 
                           verbose=True, device='parallel',t_max=None)
    times[t] = time.time() - t0
    #ddmmd = MMDDriftOnline(x_ref, ert=ert, window_size=10)

print('Wall time = %.3fs +- %.2gs' % (times.mean(),times.std()))

Using 10000 bootstrap simulations to configure thresholds...
Wall time = 0.542s +- 0s


In [3]:
ddcvm.thresholds

array([       nan,        nan,        nan,        nan,        nan,
              nan,        nan,        nan,        nan,        nan,
       4.38037657, 3.73416277, 3.53091781, 3.64742953, 3.42106905,
       3.39730146, 3.30861734, 3.35120192, 3.06926882, 3.12988989,
       3.61416821, 3.4927938 , 3.31976305, 3.33969306, 3.32919234,
       3.38712203, 3.43497215, 3.48692816, 3.36199559, 3.36018666,
       3.46058551, 3.2875086 , 3.55132104, 3.45358981, 3.49452452,
       3.55947306, 3.44803205, 3.58408395, 3.52806717])

In [4]:
n = 0
while n < 1000:
    a = np.array([np.random.uniform(0,1.1)])
    #a = np.array([np.random.randn()])
    ddcvm_result = ddcvm.predict(a)
    #dmmd_result = ddmmd.predict(a)
    #print(ddcvm_result,ddmmd_result)
    if ddcvm_result['data']['is_drift'] == 1:
        print(n)
        break
    n += 1

40


In [5]:
ddcvm.predict(np.array([20]))

{'data': {'is_drift': 0,
  'distance': None,
  'p_val': None,
  'threshold': 3.528067165199367,
  'time': 42,
  'ert': 150,
  'test_stat': array([1.2821587, 1.962883 ], dtype=float32)},
 'meta': {'name': 'CVMDriftOnline',
  'detector_type': 'online',
  'data_type': None,
  'version': '0.7.3dev'}}

# Categorical data

In [6]:
ert = 100
N = 1000
B = 10000
window_sizes = [10,20]
n_repeats = 50
p_h0 = 0.5
p_h1 = 0.3
lam = 0.99

x_ref = np.random.choice(2, N, p=[1-p_h0, p_h0])
stream_h0 = (np.random.choice(2, p=[1-p_h0, p_h0]) for _ in range(int(1e8)))
stream_h1 = (np.random.choice(2, p=[1-p_h1, p_h1]) for _ in range(int(1e8)))

ddfet = FETDriftOnline(x_ref, ert=ert, window_size=window_sizes, n_bootstraps=B, lam=lam)

Using 10000 bootstrap simulations to configure thresholds...


100%|██████████| 2/2 [00:11<00:00,  5.85s/it]


In [19]:
n = 0
while n < 20:
    ddfet_result = ddfet.predict(next(stream_h0))
    print(n)
    if ddfet_result['data']['is_drift'] == 1:
        print(ddfet_result)
        break
    n += 1

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [20]:
ddfet_result

{'data': {'is_drift': 0,
  'distance': None,
  'p_val': None,
  'threshold': 0.951964262940023,
  'time': 120,
  'ert': 100,
  'test_stat': array([0.13585795, 0.3299013 ], dtype=float32)},
 'meta': {'name': 'FETDriftOnline',
  'detector_type': 'online',
  'data_type': None,
  'version': '0.7.3dev'}}

In [24]:
times_h0 = []
while len(times_h0) < n_repeats:
    pred = ddfet.predict(next(stream_h0))
    if pred['data']['is_drift']:
        print(ddfet.t)
        times_h0.append(ddfet.t)
        ddfet.reset()

times_h1 = []
while len(times_h1) < n_repeats:
    pred = ddfet.predict(next(stream_h1))
    if pred['data']['is_drift']:
        print(ddfet.t)
        times_h1.append(ddfet.t)
        ddfet.reset()
        
art = np.array(times_h0).mean() - np.min(window_sizes) + 1
add = np.array(times_h1).mean() - np.min(window_sizes)
# Note art won't be super close to ert as we haven't averaged over initialisations.
print(f'ERT: {ert}')
print(f'ART: {art}')
print(f'ADD: {add}')

27084
91
44
182
121
514
234
260
64
153
99
245
86
39
45
92
23
249
20
139
296
66
21
592
316
205
52
86
261
145
253
65
326
122
84
160
256
19
68
267
28
329
52
120
250
84
49
147
22
75
4391
22410
11898
3071
34953
8384
7810
8005


KeyboardInterrupt: 