This notebook collects query time and recall in the baseline solution

In [1]:
import pickle
import collections
import matplotlib.pyplot as plt
import faiss
import numpy as np
import itertools
import statistics
import matplotlib.cm as cm
import time
import os
from time import perf_counter
from time import perf_counter_ns

In [2]:
# read in embeddings and cluster label info
with open('embedding64.pickle', 'rb') as fp:
    embedding64 = pickle.load(fp)
with open('label_info.pickle', 'rb') as fp:
    label_info = pickle.load(fp)

print(embedding64.keys())
print(label_info.keys())

# create helpful dicts for the cluster labels
# also look at n counts for the clusters
major_labels = list(set(label_info['cluster label']))
minor_labels = list(set(label_info['cluster label minor']))

major = dict(zip(major_labels,[[] for item in major_labels]))
minor = dict(zip(minor_labels,[[] for item in minor_labels]))

true_major = label_info['cluster label']
true_minor = label_info['cluster label minor']

for j in range(len(label_info['cluster label'])):
    maj = label_info['cluster label'][j]
    min = label_info['cluster label minor'][j]

    if maj in major.keys():
        major[maj].append(j)
    if min in minor.keys():
        minor[min].append(j)

dict_keys(['embed_all', 'embed_raw', 'embed_l2_norm', 'restore_order', 'embed_correct_coverage_fh', 'embed_l2_norm_correct_coverage_fh'])
dict_keys(['batch id', 'age', 'total_cg', 'average_cg_rate', 'total_ch', 'average_ch_rate', 'hic_counts', 'cell_name_higashi', 'major', 'minor', 'cluster label', 'cluster label minor'])


In [3]:
# function to get memory footprint for index
# source: https://www.pinecone.io/learn/series/faiss/product-quantization/
def get_memory(index):
    faiss.write_index(index,'./temp.index')
    file_size = os.path.getsize('./temp.index')
    os.remove('./temp.index')
    return file_size
    
# function to calculate recall based on search results
def get_recall_min(i, result_min):
    # recall: TP / cluster size
    return len(set(minor[true_minor[i]]).intersection(result_min)) / len(minor[true_minor[i]])

def get_recall_maj(i, result_maj):
    # recall: TP / cluster size
    return len(set(major[true_major[i]]).intersection(result_maj)) / len(major[true_major[i]])

In [4]:
# create input database
database = np.array(embedding64["embed_l2_norm"]) 

In [5]:
def query_rep(index, query, k):
    start1 = perf_counter_ns()
    for x in range(100):
        D, I = index.search(query, k)
    end1 = perf_counter_ns()
        
    start2 = perf_counter_ns()
    for x in range(100):
        D, I = index.search(query, k)
    end2 = perf_counter_ns()
        
    start3 = perf_counter_ns()
    for x in range(100):
        D, I = index.search(query, k)
    end3 = perf_counter_ns()

    times = [(end1-start1),(end2-start2),(end3-start3)]
    times.sort()

    return I, times

In [6]:
# define experiment function
def baseline_experiment():
    # initialize empty arrays for results
    recall_min = np.zeros([4238])
    recall_maj = np.zeros([4238])
    speed_min = np.zeros([4238])
    speed_maj = np.zeros([4238])

    # create flat index
    index = faiss.IndexFlatL2(64)
    index.add(database)
    
    # how much memory is used?
    memory = np.array(get_memory(index))

    for i in range(4238):
        query = np.array([database[i]])
        k_major = len(major[true_major[i]]) # size of true cluster (how many neighbors to return)
        k_minor = len(minor[true_minor[i]]) # size of true cluster (how many neighbors to return)

        # major cluster query
        I, time = query_rep(index, query, k_major)
        speed_maj[i] = time[0]
        recall_maj[i] = get_recall_maj(i, I[0])

        # minor cluster query
        I, time = query_rep(index, query, k_minor)
        recall_min[i] = get_recall_min(i, I[0])
        speed_min[i] = time[0]

    # return all results
    return recall_min, recall_maj, speed_min, speed_maj, memory

In [7]:
recall_min_BASELINE, recall_maj_BASELINE, speed_min_BASELINE, speed_maj_BASELINE, memory_BASELINE = baseline_experiment()

with open('FLAT/Baseline.npy', 'wb') as f:
    np.save(f, recall_min_BASELINE)
    np.save(f, recall_maj_BASELINE)
    np.save(f, speed_min_BASELINE)    
    np.save(f, speed_maj_BASELINE)
    np.save(f, memory_BASELINE)

# with open('FLAT/Baseline.npy', 'rb') as f:
#     recall_min_BASELINE = np.load(f)
#     recall_maj_BASELINE = np.load(f)
#     speed_min_BASELINE = np.load(f)
#     speed_maj_BASELINE = np.load(f)
#     memory_BASELINE = np.load(f)

In [8]:
print(statistics.mean(recall_maj_BASELINE))
print(statistics.mean(recall_min_BASELINE))
print(statistics.mean(speed_maj_BASELINE))
print(statistics.mean(speed_min_BASELINE))

0.7984170792592759
0.6215873936903249
11708813.992449269
8936997.569608307
