#### Code to infer pairs from NN models
**TODO**:
- Intersection
- Filter using dist_nn < th_1 and dist_pos < th_2

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

## Imports

In [None]:
import os
import torch

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
torch.cuda.get_device_name(0)

In [None]:
import os
import ast
import glob
import json
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm
from numerize.numerize import numerize

In [None]:
from params import *

from data.preparation import prepare_train_data, prepare_triplet_data
from data.dataset import SingleDataset
from data.tokenization import get_tokenizer

from model_zoo.models import SingleTransformer

from utils.logger import Config
from utils.torch import load_model_weights
from utils.metrics import *

from inference.predict import predict
from inference.knn import *

## Data

In [None]:
df = prepare_train_data(root=DATA_PATH)

In [None]:
folds = pd.read_csv(DATA_PATH + "folds_2.csv")[['id', 'fold']]
df = df.merge(folds, how="left", on="id").set_index("id")

In [None]:
FOLD = 0

In [None]:
df = df[df['fold'] == FOLD]

In [None]:
gt_matches = json.load(open(DATA_PATH + "gt.json", 'r'))

## Exp

In [None]:
# EXP_FOLDER = LOG_PATH + "2022-05-18/2/"
# EXP_FOLDER = LOG_PATH + "2022-05-18/3/"  # 10 ep
# EXP_FOLDER = LOG_PATH + "2022-05-19/0/"  # 2 ep, triplets_v2
# EXP_FOLDER = LOG_PATH + "2022-05-19/1/"  # 1 ep, triplets_v2, d=64
# EXP_FOLDER = LOG_PATH + "2022-05-19/2/"  # 1 ep, d=256
EXP_FOLDER = LOG_PATH + "2022-05-19/4/"  # 1 ep, d=256, large
# EXP_FOLDER = LOG_PATH + "2022-05-20/0/"  # 1 ep, d=256, triplets_v2
# EXP_FOLDER = LOG_PATH + "2022-05-20/1/"  # robert-large
# EXP_FOLDER = LOG_PATH + "2022-05-20/2/"  # base + url
# EXP_FOLDER = LOG_PATH + "2022-05-20/3/"  # large + no address
# EXP_FOLDER = LOG_PATH + "2022-05-22/0/"  # 1 ep, d=256, large lower
# EXP_FOLDER = LOG_PATH + "2022-05-23/0/"  # 1 ep, d=384, large

In [None]:
config = Config(json.load(open(EXP_FOLDER + "config.json", 'r')))

In [None]:
tokenizer = get_tokenizer(config.name)

In [None]:
dataset = SingleDataset(df, tokenizer, config.max_len)

In [None]:
weights = sorted(glob.glob(EXP_FOLDER + "*.pt"))

In [None]:
model = SingleTransformer(
    config.name,
    nb_layers=config.nb_layers,
    no_dropout=config.no_dropout,
    embed_dim=config.embed_dim,
    nb_features=config.nb_features,
).cuda()
model.zero_grad()

model = load_model_weights(model, weights[FOLD])

In [None]:
if os.path.exists(EXP_FOLDER + f"fts_val_{FOLD}.npy"):
    preds = np.load(EXP_FOLDER + f"fts_val_{FOLD}.npy")
else:
    preds = predict(model, dataset, config.data_config)
    np.save(EXP_FOLDER + f"fts_val_{FOLD}.npy", preds)

### Matches

In [None]:
SAVE = True

In [None]:
for n_neighbors in [20, 30, 40, 50]:
    print(f'\n- -> n_neighbors={n_neighbors}\n')

    preds_matches = find_matches(preds, df, n_neighbors)
    found_prop, missed_nn = compute_found_prop(preds_matches, gt_matches)
    n_matches = sum([len(preds_matches[k]) for k in preds_matches])
    print(f' - Roberta :\t Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')

#     position_matches = json.load(open(OUT_PATH + f"knn_preds_{n_neighbors}_0.json", 'r'))
#     position_matches = json.load(open(OUT_PATH + f"dist_matches_{n_neighbors}_0.json", 'r'))
    position_matches = json.load(open(OUT_PATH + f"dist-phone-url_matches_{n_neighbors}_0.json", 'r'))

    found_prop, missed_pos = compute_found_prop(position_matches, gt_matches)
    n_matches = sum([len(position_matches[k]) for k in position_matches])
    print(f' - Position :\t Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')

    merged_matches = {k : list(set(position_matches[k] + preds_matches[k])) for k in preds_matches}
    found_prop, missed = compute_found_prop(merged_matches, gt_matches)
    n_matches = sum([len(merged_matches[k]) for k in merged_matches])
    print(f' - Merged :\t found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')

    if SAVE:
        df_pairs = create_pairs(merged_matches, gt_matches=gt_matches)
        prop = df_pairs['match'].sum() / len(df_pairs) * 100
        save_path = EXP_FOLDER + f'df_pairs_{n_neighbors}.csv'
        df_pairs.to_csv(save_path, index=False)
        print(f'-> Saved pairs to {save_path} - Positive proportion  {prop:.2f}%')
        
#     break

Base:
- -> n_neighbors=20

 - Roberta :	 Found 61.96% of matches with 10.82M candidates.
 - Position :	 Found 81.25% of matches with 11.07M candidates.
 - Merged :	 found 91.57% of matches with 21.48M candidates.

- -> n_neighbors=30

 - Roberta :	 Found 64.76% of matches with 16.51M candidates.
 - Position :	 Found 83.96% of matches with 16.51M candidates.
 - Merged :	 found 93.31% of matches with 32.53M candidates.

- -> n_neighbors=40

 - Roberta :	 Found 66.67% of matches with 22.21M candidates.
 - Position :	 Found 85.63% of matches with 21.87M candidates.
 - Merged :	 found 94.32% of matches with 43.51M candidates.

- -> n_neighbors=50

 - Roberta :	 Found 68.15% of matches with 27.9M candidates.
 - Position :	 Found 86.83% of matches with 27.13M candidates.
 - Merged :	 found 95.00% of matches with 54.41M candidates.
 
Large:
- -> n_neighbors=20

 - Roberta :	 Found 70.42% of matches with 10.82M candidates.
 - Position :	 Found 81.25% of matches with 11.07M candidates.
 - Merged :	 found 92.88% of matches with 21.46M candidates.

- -> n_neighbors=30

 - Roberta :	 Found 72.90% of matches with 16.51M candidates.
 - Position :	 Found 83.96% of matches with 16.51M candidates.
 - Merged :	 found 94.42% of matches with 32.52M candidates.

- -> n_neighbors=40

 - Roberta :	 Found 74.59% of matches with 22.21M candidates.
 - Position :	 Found 85.63% of matches with 21.87M candidates.
 - Merged :	 found 95.32% of matches with 43.5M candidates.

- -> n_neighbors=50

 - Roberta :	 Found 75.87% of matches with 27.9M candidates.
 - Position :	 Found 86.83% of matches with 27.13M candidates.
 - Merged :	 found 95.92% of matches with 54.41M candidates.

In [None]:
for i in range(100):
    if len(gt_matches[df.index[i]]) <= 1:
        continue
    if not len(df.loc[list(missed[i])]):
        continue

    print('Query')
    display(df.loc[[df.index[i]]])

    print('Target')
    display(df.loc[gt_matches[df.index[i]]])

    print('Missed')
    display(df.loc[list(missed[i])])

#     print('Preds')
#     display(df.loc[preds_matches[df.index[i]]].head(5))

#     break
    print('-' * 50)