In [None]:
# groundtruth starting point

import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
from sketch.api import data
import json
import numpy as np
import time
import os

run_name = "220910_2ary_groundtruth"
base_path = "/home/jawaugh"

groundtruth_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}.parquet')
knn_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_knn.parquet')
sketchpad_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_sketchpad.parquet')

prebuilt_index =  os.path.join(base_path, f'sketch/sketch/trained.index')
database = os.path.join(base_path, f'sketch/sketch/test.db')
database_path = f'sqlite+aiosqlite:///{database}'
username = 'justin'

In [None]:
database = data.Database(database_path)
await database.connect()
model = SentenceTransformer("all-MiniLM-L6-v2")
index = faiss.read_index(prebuilt_index)

In [None]:
# extract = {}

# # Build index... (since it seems like there's possibly an erro
# async def build_index():
#     # iterate over references, and add to index
#     short_ids, references = zip(*[x async for x in data.get_references(database)])
#     print("references_gathered")
#     embeddings = model.encode([r.to_searchable_string() for r in references])
#     print("searchable_strings_gathered")
    
#     index = faiss.IndexFlatL2(embeddings.shape[1])
#     # index = faiss.IndexHNSWFlat(embeddings.shape[1], 32)
#     # index.hnsw.efConstruction = 40
#     index2 = faiss.IndexIDMap(index)
#     index2.add_with_ids(
#         embeddings,
#         np.array(short_ids, dtype=np.int64),
#     )
#     index = index2
#     faiss.write_index(index2, prebuilt_index)

# await build_index()

In [None]:
# for each left, right, get the knn sketchpads and store them in line
import asyncio
import time

cache = {}

def memoize(func):
    async def memoized_async_func(*args, **kwargs):
        key = (args, frozenset(sorted(kwargs.items())))
        if key in cache:
            return cache[key]
        result = await func(*args, **kwargs)
        cache[key] = result
        return result
    return memoized_async_func

@memoize
async def get_knn(q, k=5):
    model = SentenceTransformer("all-MiniLM-L6-v2")
    query_vector = model.encode([q])
    D, I = index.search(query_vector, k)
    indexes = list(I[0])
    sketchpads = {
            i: x
            async for i, x in data.get_most_recent_sketchpads_by_reference_short_ids(
                database, indexes, username
            )
        }
    sketchpads = [sketchpads[i] for i in indexes]
    return list(zip(D[0], sketchpads))

In [None]:
groundtruth = pd.read_parquet(groundtruth_path)

In [None]:
new_results = []
sketchpads = {}
st = time.time()
for i, row in groundtruth.iterrows():
    left = row['left_string']
    right = row['right_string']
    left_knn = await get_knn(left)
    right_knn = await get_knn(right)
    res = {}
    for j, (score, sketchpad) in enumerate(left_knn):
        res[f"left_knn_{j}_score"] = score
        res[f"left_knn_{j}_sketchpad"] = sketchpad.id
        sketchpads[sketchpad.id] = json.dumps(sketchpad.to_dict())
    for j, (score, sketchpad) in enumerate(right_knn):
        res[f"right_knn_{j}_score"] = score
        res[f"right_knn_{j}_sketchpad"] = sketchpad.id
        sketchpads[sketchpad.id] = json.dumps(sketchpad.to_dict())
    new_results.append(res)
    if i % 200 == 0:
        cache = {}
        d = time.time() - st
        print(i, d, d * len(groundtruth) / (i + 1))

In [None]:
output = pd.DataFrame(new_results)
output.to_parquet(knn_path)

In [None]:
pd.Series(sketchpads).to_frame().reset_index().rename(columns={'index': 'sketchpad_id', 0: 'sketchpad'}).to_parquet(sketchpad_path)