In [3]:
from DeezyMatch import inference as dm_inference
from DeezyMatch import combine_vecs
from DeezyMatch import candidate_ranker

from pathlib import Path
import pandas as pd
import time

In [4]:
def findcandidates(candidates, queries, dm_model, inputfile):
    
    # --------------------------------------
    # GENERATE AND COMBINE CANDIDATE VECTORS
    
    # generate vectors for candidates (specified in dataset_path) 
    # using a model stored at pretrained_model_path and pretrained_vocab_path 
    if not Path("./candidates/" + candidates + "_" + dm_model + "/embeddings/").is_dir():
        start_time = time.time()
        dm_inference(input_file_path="./models/" + dm_model + "/" + inputfile + ".yaml",
                     dataset_path="./gazetteers/" + candidates + ".txt", 
                     pretrained_model_path="./models/" + dm_model + "/" + dm_model + ".model", 
                     pretrained_vocab_path="./models/" + dm_model + "/" + dm_model + ".vocab",
                     inference_mode="vect",
                     scenario="candidates/" + candidates + "_" + dm_model)
        elapsed = time.time() - start_time
        print("Generate candidate vectors: %s" % elapsed)

    # combine vectors stored in the scenario in candidates/ and save them in combined/
    if not Path("./combined/" + candidates + "_" + dm_model).is_dir():
        start_time = time.time()
        combine_vecs(rnn_passes=["fwd", "bwd"], 
                     input_scenario="candidates/" + candidates + "_" + dm_model, 
                     output_scenario="combined/" + candidates + "_" + dm_model, 
                     print_every=100)
        elapsed = time.time() - start_time
        print("Combine candidate vectors: %s" % elapsed)
    
    # --------------------------------------
    # GENERATE AND COMBINE QUERY VECTORS
    
    # generate vectors for queries (specified in dataset_path) 
    # using a model stored at pretrained_model_path and pretrained_vocab_path 
    if not Path("./queries/" + queries + "_" + dm_model + "/embeddings/").is_dir():
        start_time = time.time()
        dm_inference(input_file_path="./models/" + dm_model + "/" + inputfile + ".yaml",
                     dataset_path="./gazetteers/" + queries + ".txt", 
                     pretrained_model_path="./models/" + dm_model + "/" + dm_model + ".model", 
                     pretrained_vocab_path="./models/" + dm_model + "/" + dm_model + ".vocab",
                     inference_mode="vect",
                     scenario="queries/" + queries + "_" + dm_model)
        elapsed = time.time() - start_time
        print("Generate candidate vectors: %s" % elapsed)

    # combine vectors stored in the scenario in queries/ and save them in combined/
    if not Path("./combined/" + queries + "_" + dm_model).is_dir():
        start_time = time.time()
        combine_vecs(rnn_passes=["fwd", "bwd"], 
                     input_scenario="queries/" + queries + "_" + dm_model, 
                     output_scenario="combined/" + queries + "_" + dm_model, 
                     print_every=100)
        elapsed = time.time() - start_time
        print("Combine candidate vectors: %s" % elapsed)
        
    # Select candidates based on L2-norm distance (aka faiss distance):
    # find candidates from candidate_scenario 
    # for queries specified in query_scenario
    if not Path("ranker_results/" + queries + "_" + candidates + "_" + dm_model + ".pkl").is_file():
        start_time = time.time()
        candidates_pd = \
            candidate_ranker(query_scenario="./combined/" + queries + "_" + dm_model,
                             candidate_scenario="./combined/" + candidates + "_" + dm_model, 
                             ranking_metric="faiss", 
                             selection_threshold=100., 
                             num_candidates=20, 
                             search_size=20, 
                             output_path="ranker_results/" + queries + "_" + candidates + "_" + dm_model, 
                             pretrained_model_path="./models/" + dm_model + "/" + dm_model + ".model", 
                             pretrained_vocab_path="./models/" + dm_model + "/" + dm_model + ".vocab")
        elapsed = time.time() - start_time
        print("Rank candidates: %s" % elapsed)

In [5]:
candidates = "britwikidata_candidates"
queries = "bho_queries"
dm_model = "wikigaz_en_001"
inputfile = "input_dfm_001"

findcandidates(candidates, queries, dm_model, inputfile)

In [6]:
df = pd.read_pickle("ranker_results/" + queries + "_" + candidates + "_" + dm_model + ".pkl")

In [15]:
df.iloc[350:400]

Unnamed: 0_level_0,query,pred_score,faiss_distance,cosine_sim,candidate_original_ids,query_original_id,num_all_searches
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
350,Filleigh,"{'Filleigh': 0.979, 'Fifield House': 0.9531, '...","{'Filleigh': 0.0, 'Fifield House': 10.538, 'Fi...","{'Filleigh': 1.0, 'Fifield House': 0.823, 'Fif...","{'Filleigh': 90538, 'Fifield House': 152339, '...",350,20
351,Church Honeybourne,"{'Church Honeybourne': 0.9921, 'Church Honeybo...","{'Church Honeybourne': 0.0, 'Church Honeybourn...","{'Church Honeybourne': 1.0, 'Church Honeybourn...","{'Church Honeybourne': 162717, 'Church Honeybo...",351,20
352,West Meon,"{'West Meon': 0.9844, 'West FM': 0.9363, 'Lowe...","{'West Meon': 0.0, 'West FM': 10.1256, 'Lower ...","{'West Meon': 1.0, 'West FM': 0.7776, 'Lower M...","{'West Meon': 437943, 'West FM': 240061, 'Lowe...",352,20
353,Carlton-Colville,"{'Carlton Colville': 0.9772, 'Carlton Lawn': 0...","{'Carlton Colville': 2.8442, 'Carlton Lawn': 1...","{'Carlton Colville': 0.9647, 'Carlton Lawn': 0...","{'Carlton Colville': 368021, 'Carlton Lawn': 2...",353,20
354,Staunton-Upon-Wye,"{'Staunton on Wye': 0.9553, 'Staunton Way': 0....","{'Staunton on Wye': 2.9886, 'Staunton Way': 9....","{'Staunton on Wye': 0.9595, 'Staunton Way': 0....","{'Staunton on Wye': 537453, 'Staunton Way': 29...",354,20
355,Stafford,"{'Stafford': 0.9498, 'Sandford': 0.705, 'Staff...","{'Stafford': 0.0, 'Sandford': 8.0551, 'Staffor...","{'Stafford': 1.0, 'Sandford': 0.8705, 'Staffor...","{'Stafford': 82344, 'Sandford': 422623, 'Staff...",355,20
356,Trudox-Hill,"{'Trudoxhill': 0.8663, 'Trewhiddle': 0.6576, '...","{'Trudoxhill': 4.4563, 'Trewhiddle': 12.7921, ...","{'Trudoxhill': 0.9452, 'Trewhiddle': 0.8291, '...","{'Trudoxhill': 44213, 'Trewhiddle': 162481, 'T...",356,20
357,Preston-Capes,"{'Preston-Capes': 0.9754, 'Preston Capes': 0.9...","{'Preston-Capes': 0.0, 'Preston Capes': 1.5049...","{'Preston-Capes': 1.0, 'Preston Capes': 0.9808...","{'Preston-Capes': 88986, 'Preston Capes': 5585...",357,20
358,Burlescombe,"{'Burlescombe': 0.9853, 'Burcombe': 0.6445, 'B...","{'Burlescombe': 0.0, 'Burcombe': 12.7478, 'Bur...","{'Burlescombe': 1.0, 'Burcombe': 0.8204, 'Burv...","{'Burlescombe': 172886, 'Burcombe': 238840, 'B...",358,20
359,Beltingham,"{'Beltingham': 0.9766, 'Bellingham': 0.5878, '...","{'Beltingham': 0.0, 'Bellingham': 10.6686, 'Be...","{'Beltingham': 1.0, 'Bellingham': 0.8706, 'Bel...","{'Beltingham': 164126, 'Bellingham': 385247, '...",359,20
