In [1]:
import pickle
from detectors import *
import matplotlib.pyplot as plt
import numpy as np
import time
import os

In [2]:
def time_fit(det, data):
    time_begin = time.time()
        
    det = det.fit(np.array(data))
    
    return det, time.time() - time_begin

def time_test(det, data):
    time_begin = time.time()
        
    result = det.predict_proba(np.array(data))
    
    return det, time.time() - time_begin, result


# load data

In [3]:
permutations = 10

modes = ['bert_768', 'bow_50', 'bow_768']
subsets = {}

result_pickle = 'data/results/amazon_same_dist.pickle'

for mode in modes:
    with open('data/movies/embeddings/amazon_{mode}_same_dist.pickle'.format(mode=mode), 'rb') as handle:
        subsets[mode] = permutations_embs, permutation_keys = pickle.load(handle)['data']

# initiate detectors

In [4]:
detectors = {
    'csdd': CosineSimilarityDriftDetector(),
    'kts' : KernelTwoSampleDriftDetector(),
    'aks' : AlibiKSDetector(),
    'ammd': AlibiMMDDetector()
}

# tests

In [5]:

if os.path.isfile(result_pickle):  # Do not overwrite
    print('Loading result pickle: ', result_pickle)
    with open(result_pickle, 'rb') as handle:
        results = pickle.load(handle)
else:
    results = {mode: {detector: {} for detector in detectors} for mode in modes}

for detector in detectors:
    for mode in modes:
        if not detector in results[mode]:
            results[mode][detector] = {}
        if 'predictions' in results[mode][detector]: # skip already computed
            continue
        
        results[mode][detector]['predictions'] = []
        
        det, t = time_fit(detectors[detector], subsets[mode][0][0])
        results[mode][detector]['time_fit'] = t
        results[mode][detector]['time_detect'] = []
        
        for permutation in subsets[mode][0][1:]:
            det_2, t, res = time_test(det, permutation)
            results[mode][detector]['predictions'].append(res)
            results[mode][detector]['time_detect'].append(t)

with open(result_pickle, 'wb') as handle:
    pickle.dump(results, handle)

In [6]:
print(results)

{'bert_768': {'csdd': {'predictions': [0.9992856, 0.9995605, 0.99944663, 0.99965155, 0.99958223, 0.99963987, 0.9996462, 0.99952936, 0.99948895, 0.9995949, 0.9993532, 0.99952286, 0.99970675, 0.9994862, 0.9995035, 0.9997282, 0.999526, 0.99963486, 0.99963284, 0.9995447], 'time_fit': 0.004500150680541992, 'time_detect': [0.03791451454162598, 0.03746151924133301, 0.037461042404174805, 0.03737640380859375, 0.03734922409057617, 0.03723430633544922, 0.03693675994873047, 0.03695797920227051, 0.03699207305908203, 0.03790092468261719, 0.037894248962402344, 0.0373685359954834, 0.037180185317993164, 0.03728008270263672, 0.0371699333190918, 0.037137746810913086, 0.03725481033325195, 0.03724098205566406, 0.03738212585449219, 0.03695559501647949]}, 'kts': {'predictions': [0.83, 0.742, 0.974, 0.91, 0.998, 0.686, 0.596, 0.978, 0.958, 0.986, 0.712, 0.956, 0.952, 0.754, 0.956, 0.838, 0.908, 0.746, 0.982, 0.838], 'time_fit': 0.002387523651123047, 'time_detect': [230.7145323753357, 230.70965576171875, 231.7