In [112]:
import json
import numpy as np

def cosine_similarity(query_embedding: np.ndarray, search_embedding: np.ndarray):
    """
    Compute the cosine similarity between 2 arrays
    """
    # first normalize the arrays to unit length
    query_embedding /= np.linalg.norm(query_embedding)
    search_embedding = (search_embedding.T / np.linalg.norm(search_embedding, axis = 1))

    # then compute the normalized dot product
    return np.dot(query_embedding, search_embedding)[0]

def check_celex_status(query_celex: str, retrieved_case_celex: str):
    """
    Check whether the query case has any CELEX;\n\n
    ---
    If not available, return True.\n
    If available check whether it matches with another given CELEX.
    """
    return (query_celex == None) or (query_celex == retrieved_case_celex)

def cosine_similarity_search(query_embedding: np.ndarray, search_space_embedding: np.ndarray, query_celex: str, search_space: list, top_k: int = 5) -> list:
    """
    Given the embedded representation of a query case and the embeddings of the cases through which to search,
    return the `top_k` most similar cases based on cosine similarity from most similar to least
    """
    cosine_scores: list = cosine_similarity(query_embedding, search_space_embedding)

    # sort the similarity scores from most to least similar and select the top 5 most similar cases
    paired_scores_search_data: list = [(data_entry, score) for data_entry, score in zip(search_space, cosine_scores)]
    best_matches: list = sorted(paired_scores_search_data, key = lambda x: x[1], reverse = True)[:top_k]

    # only show the user cases which do have the same CELEX number; if there are no CELEX IDs in the query case, then simply return the most semantically similar cases
    best_matches: list = [retrieved_case[0] for retrieved_case in best_matches if check_celex_status(query_celex, retrieved_case[0]["euProvisions"])]
    
    return best_matches

In [134]:
def read_json_data(json_case_query: str, json_search_embedding: str, json_search_text: str):
    """
    Load:
    * `json_case_query`: the json of the query case;
    * `json_search_embedding`: the json containing the entire search corpus of cases;
    * `json_search_text`: the json containing the embeddings of the search corpus.
    """
    query_data: dict = json.load(open(json_case_query, "r"))
    search_embeddings_data: list = json.load(open(json_search_embedding, "r"))
    search_text_data: list = json.load(open(json_search_text, "r"))

    # remove query case from the search corpus
    search_text_data = [item for item in search_text_data if item["uniqueId"] != query_data["uniqueId"]]

    return query_data, search_embeddings_data, search_text_data

def get_embeddings_from_json(search_embeddings_data, query_data_uid):
    """
    Get the embedding(s):
    * of the query;
    * of the search corpus, while ignoring the the embedding of the query.
    """
    query_embedding: np.ndarray = np.asarray([search_space_item["embedding"] for search_space_item in search_embeddings_data if search_space_item["uniqueId"] == query_data_uid])
    search_embedding: np.ndarray = np.asarray([np.asarray(search_space_item["embedding"]) for search_space_item in search_embeddings_data if search_space_item["uniqueId"] != query_data_uid])

    return query_embedding, search_embedding

if __name__ == "__main__":
    query_data, search_embeddings_data, search_text_data = read_json_data(json_case_query="input_query.json", json_search_embedding="corpus_embedded.json", json_search_text="corpus.json")

    query_embedding, search_embedding = get_embeddings_from_json(search_embeddings_data, query_data["uniqueId"])

    similar_cases: list = cosine_similarity_search(query_embedding, search_embedding, query_data["euProvisions"], search_text_data)

    json.dump(similar_cases, open("example_output.json", "w"), indent = 2)

85.2 ms ± 948 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
