In [65]:
import torch, os
from sentence_transformers import SentenceTransformer, util
from pandas import read_csv as rcsv, DataFrame as df
from time import time as t

In [55]:
def symmetric_semantic_search(
    query : str,
    corpus : list,
    corpus_emb : torch.Tensor,
    top_k : int = 5,
    kg_enhanced : bool = False,
    kgraph : list = None,
    kgraph_emb : torch.Tensor = None,
    *args, **kwargs
) -> dict:
    
    kg_entity = None
    query_embedding = embedder.encode(query, convert_to_tensor=True)
    
    if kg_enhanced:
        if kgraph is None or kgraph_emb is None:
            raise ValueError("If `kg_enhanced` is True, you must pass both `kgraph` and `kgraph_emb`.")
        else:
            # find the most relevant from knowledge graph
            x = symmetric_semantic_search(query=query, corpus=kgraph, corpus_emb=kgraph_emb, top_k=1)
            kg_entity = list(x["scores"].keys())[0]
            x_emb = embedder.encode(kg_entity, convert_to_tensor=True)
            query_embedding += x_emb
    
    cos_scores = util.cos_sim(query_embedding, corpus_emb)[0]
    top_results = torch.topk(cos_scores, k=top_k)
    
    return { 
        "scores" : {corpus[idx] : score.item() for score, idx in zip(top_results[0], top_results[1])},
        "kg_entity" : kg_entity,
        "query_embedding" : query_embedding
    }

In [3]:
def get_embeddings(corpus : list, filename : str, *args, **kwargs) -> torch.Tensor:
    savepath_corpus_emb = "results/"
    if [p for p in os.listdir("results") if p.endswith("{}.pt".format(filename))]:
        corpus_embeddings = torch.load(savepath_corpus_emb + "{}.pt".format(filename))
    else:
        corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
        torch.save(corpus_embeddings, "results/{}.pt".format(filename))
    
    return corpus_embeddings

In [82]:
if __name__ == "__main__":
    # use all-MiniLM-L6-v2 as the embedding model
    embedder = SentenceTransformer("all-MiniLM-L6-v2")

    # prepare corpus
    kg = rcsv("results/knowledge_graph_wiki_v2.csv")
    kg_corpus = [n.source + " " + n.edge + " " + n.target for _, n in kg.iterrows()]

    q_corpus = [
        "Which movie was credited as a social drama feature film?",
        "What movie was released on April?",
        "What was the purpose of filming a movie scene on paper film?",
        "What was the movie that won the Film Institute Award?",
        "Which musical composition by Benjamin Wallfisch was released in 2004?",
        "Which music released by Aditya Music Company was the most popular?",
        "How did the music received positive reviews?",
        "What was the soundtrack released digitally June?",
        "What is the name of the soundtracks composed by Anirudh Ravichander?",
        "Which song was added to the Soundtrack of the United States National Recording Registry?"
    ]

    # get the embedding
    kg_emb = get_embeddings(kg_corpus, "embedding_kgraph")
    q_emb = get_embeddings(q_corpus, "embedding_query")
    
    # alternative user prompts
    alt_user_prompts = [
        "Tell me the social drama movie.",
        "Which movie was published on April?",
        "Is there any specific purpose why a movie scene shot with paper film?",
        "Give me a list of musics that won the Film Institute Award.",
        "Wallfisch released a music in 2004. What is it?",
        "List for me the most popular music by Aditya Music.",
        "Why some movies can received no negative reviews?",
        "Give me the music released in June.",
        "I want to know what are soundtracks created by A. Ravichander.",
        "What are songs registered in the US National Recording?",
        "What are movies registered in the US National Recording?", # test: replace "songs" with "movies"
        "What are sports registered in the US National Recording?", # test: replace "songs" with "sports" (wrong context)
        "songs US National registered Recording in?" # test: scramble words
    ]

In [83]:
cache_words = []
prec = []
enh_cache_words = []
enh_prec = []

for user_prompt in alt_user_prompts:
    resp1 = symmetric_semantic_search(query=user_prompt, corpus=q_corpus, corpus_emb=q_emb, top_k=1)
    
    cache_words.append(list(resp1["scores"].keys())[0]) # cache hit (words)
    prec.append(list(resp1["scores"].values())[0]) # score
    
    # enhanced with knowledge graph
    resp2 = symmetric_semantic_search(query=user_prompt, corpus=q_corpus, corpus_emb=q_emb, \
                                       kg_enhanced=True, kgraph=kg_corpus, kgraph_emb=kg_emb, top_k=1)
    
    enh_cache_words.append(resp2["kg_entity"]) # kgraph entity (words)
    enh_prec.append(list(resp2["scores"].values())[0]) # enhanced score
    
df({
    "user_prompt" : alt_user_prompts,
    "cache" : cache_words,
    "precision" : prec,
    "kg_cache" : enh_cache_words,
    "precision_with_kg" : enh_prec
})

Unnamed: 0,user_prompt,cache,precision,kg_cache,precision_with_kg
0,Tell me the social drama movie.,Which movie was credited as a social drama fea...,0.788848,movie credited as social drama feature film,0.933651
1,Which movie was published on April?,What movie was released on April?,0.854984,film released on april,0.917247
2,Is there any specific purpose why a movie scen...,What was the purpose of filming a movie scene ...,0.882757,movie shot on paper film,0.882551
3,Give me a list of musics that won the Film Ins...,What was the movie that won the Film Institute...,0.751611,movie won film institute award,0.898739
4,Wallfisch released a music in 2004. What is it?,Which musical composition by Benjamin Wallfisc...,0.698875,music composed by benjamin wallfisch,0.88095
5,List for me the most popular music by Aditya M...,Which music released by Aditya Music Company w...,0.792708,music released by aditya music company,0.923691
6,Why some movies can received no negative reviews?,How did the music received positive reviews?,0.564766,film met with negative negative reviews,0.573312
7,Give me the music released in June.,What was the soundtrack released digitally June?,0.690917,soundtrack released digitally june,0.885071
8,I want to know what are soundtracks created by...,What is the name of the soundtracks composed b...,0.828747,soundtrack composed by anirudh ravichander,0.938921
9,What are songs registered in the US National R...,Which song was added to the Soundtrack of the ...,0.738616,soundtrack added to united states national rec...,0.888538


In [36]:
score_threshold = .85

### Performance Metrics

1. Precision (Cache Accuracy)

2. Recall (Cache Hit Rate)

In [None]:
# from transformers import BertTokenizer, BertModel, GPT2TokenizerFast, GPT2Model

# def get_embedding(text : str, model, tokenizer, *args, **kwargs) -> Tensor:
#     encoded_inp = tokenizer(text, return_tensors="pt", padding="max_length", max_length=64, truncation=True)
#     output_emb = model(**encoded_inp)
#     return output_emb.last_hidden_state

# bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
# bert_model = BertModel.from_pretrained("bert-base-cased")

# gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
# gpt2_model = GPT2Model.from_pretrained("openai-community/gpt2")