In [1]:
import pickle
from detectors import *
import matplotlib.pyplot as plt
import numpy as np
import time
import os

In [2]:
def time_fit(det, data, labels):
    time_begin = time.time()
        
    det = det.fit(np.array(data), targets=np.array(labels))
    
    return det, time.time() - time_begin

def time_test(det, data):
    time_begin = time.time()
        
    result = det.predict_proba(np.array(data))
    
    return det, time.time() - time_begin, result


# load data

In [3]:
modes = ['bert_768', 'bow_50', 'bow_768']
subsets = {}

result_pickle = 'data/results/twitter_same_dist.pickle'

for mode in modes:
    with open('data/twitter/twitter_{mode}_same_dist.pickle'.format(mode=mode), 'rb') as handle:
        subsets[mode] = permutations_embs, permutation_keys = pickle.load(handle)['data']

# initiate detectors

In [4]:
detectors = {
    'csdd': CosineSimilarityDriftDetector(),
    'kts' : KernelTwoSampleDriftDetector(),
    'aks' : AlibiKSDetector(),
    'ammd': AlibiMMDDetector(),
    'lsdd': AlibiLSDDDetector(),
    'cdbd': CDBDDetector()
}

# tests

In [5]:

if os.path.isfile(result_pickle):  # Do not overwrite
    print('Loading result pickle: ', result_pickle)
    with open(result_pickle, 'rb') as handle:
        results = pickle.load(handle)
else:
    results = {mode: {detector: {} for detector in detectors} for mode in modes}

for detector in detectors:
    for mode in modes:
        if not detector in results[mode]:
            results[mode][detector] = {}
        if 'predictions' in results[mode][detector]: # skip already computed
            continue
        
        results[mode][detector]['predictions'] = []
        
        det, t = time_fit(detectors[detector],
                          subsets[mode][0][0],
                          [int(x//(len(subsets[mode][0][0])/2)) for x in range(len(subsets[mode][0][0]))]
                         )
        results[mode][detector]['time_fit'] = t
        results[mode][detector]['time_detect'] = []
        
        for permutation in subsets[mode][0][1:]:
            det_2, t, res = time_test(det, permutation)
            results[mode][detector]['predictions'].append(res)
            results[mode][detector]['time_detect'].append(t)

with open(result_pickle, 'wb') as handle:
    pickle.dump(results, handle)

Loading result pickle:  data/results/twitter_same_dist.pickle


In [6]:
print(results)

{'bert_768': {'csdd': {'predictions': [0.9995134, 0.9995103, 0.9995152, 0.9995761, 0.9996051, 0.9994821, 0.99953544, 0.9995222, 0.9993547, 0.99950904, 0.9992883, 0.99955773, 0.99939084, 0.99962527, 0.9994445, 0.9992125, 0.9995525, 0.9996222, 0.9995817, 0.9994236], 'time_fit': 0.003317117691040039, 'time_detect': [0.022496700286865234, 0.0243532657623291, 0.024049043655395508, 0.02386188507080078, 0.028882741928100586, 0.02582693099975586, 0.015207052230834961, 0.015634536743164062, 0.015811920166015625, 0.015898704528808594, 0.015833377838134766, 0.015721797943115234, 0.015622138977050781, 0.015715599060058594, 0.01577138900756836, 0.01553201675415039, 0.015683412551879883, 0.015762805938720703, 0.01572132110595703, 0.015883922576904297]}, 'kts': {'predictions': [0.39, 0.396, 0.278, 0.606, 0.87, 0.282, 0.44, 0.254, 0.126, 0.472, 0.03, 0.394, 0.076, 0.796, 0.11, 0.002, 0.46, 0.784, 0.636, 0.18], 'time_fit': 0.0009024143218994141, 'time_detect': [26.428621768951416, 29.22772717475891, 29