#### Code to infer pairs from NN models
**TODO**:
- Intersection
- Filter using dist_nn < th_1 and dist_pos < th_2
- https://github.com/facebookresearch/faiss
- https://www.kaggle.com/c/shopee-product-matching/discussion/238022
- https://www.kaggle.com/c/shopee-product-matching/discussion/238515
- https://www.kaggle.com/c/shopee-product-matching/discussion/238136

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

## Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

import torch
torch.cuda.get_device_name(0)

In [None]:
import os
import ast
import glob
import json
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm
from numerize.numerize import numerize

In [None]:
from params import *

from data.preparation import prepare_train_data, prepare_triplet_data
from data.dataset import SingleDataset
from data.tokenization import get_tokenizer

from model_zoo.models import SingleTransformer

from utils.logger import Config
from utils.torch import load_model_weights
from utils.metrics import *

from inference.predict import predict
from inference.knn import *

## Data

In [None]:
df = prepare_train_data(root=DATA_PATH)
# build_gt(df.reset_index(), save=True)

In [None]:
df.head(1)

In [None]:
folds = pd.read_csv(DATA_PATH + "folds_2.csv")[['id', 'fold']]
df = df.merge(folds, how="left", on="id").set_index("id")

In [None]:
FOLD = 0

In [None]:
df = df[df['fold'] == FOLD]

In [None]:
gt_matches = json.load(open(DATA_PATH + "gt.json", 'r'))

## Exp

In [None]:
# EXP_FOLDER = LOG_PATH + "2022-05-18/2/"
# EXP_FOLDER = LOG_PATH + "2022-05-18/3/"  # 10 ep
# EXP_FOLDER = LOG_PATH + "2022-05-19/0/"  # 2 ep, triplets_v2
# EXP_FOLDER = LOG_PATH + "2022-05-19/1/"  # 1 ep, triplets_v2, d=64

EXP_FOLDER = LOG_PATH + "2022-05-19/2/"  # 1 ep, d=256
# EXP_FOLDER = LOG_PATH + "2022-05-19/4/"  # 1 ep, d=256, large

# EXP_FOLDER = LOG_PATH + "2022-05-20/0/"  # 1 ep, d=256, triplets_v2
# EXP_FOLDER = LOG_PATH + "2022-05-20/1/"  # roberta-large
# EXP_FOLDER = LOG_PATH + "2022-05-20/2/"  # base + url
# EXP_FOLDER = LOG_PATH + "2022-05-20/3/"  # large + no address
# EXP_FOLDER = LOG_PATH + "2022-05-22/0/"  # 1 ep, d=256, large lower
# EXP_FOLDER = LOG_PATH + "2022-05-23/0/"  # 1 ep, d=384, large

In [None]:
config = Config(json.load(open(EXP_FOLDER + "config.json", 'r')))

In [None]:
tokenizer = get_tokenizer(config.name)

In [None]:
dataset = SingleDataset(
    df,
    tokenizer,
    config.max_len,
    use_url=True  ## !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
)

In [None]:
weights = sorted(glob.glob(EXP_FOLDER + "*.pt"))

In [None]:
model = SingleTransformer(
    config.name,
    nb_layers=config.nb_layers,
    no_dropout=config.no_dropout,
    embed_dim=config.embed_dim,
    nb_features=config.nb_features,
).cuda()
model.zero_grad()

model = load_model_weights(model, weights[FOLD])

In [None]:
if os.path.exists(EXP_FOLDER + f"fts_val_{FOLD}.npy"):
    preds = np.load(EXP_FOLDER + f"fts_val_{FOLD}.npy")
else:
    preds = predict(model, dataset, config.data_config)
    np.save(EXP_FOLDER + f"fts_val_{FOLD}.npy", preds)

### Matches

In [None]:
SAVE = False

In [None]:
for n_neighbors in [200]:
    print(f'\n- -> n_neighbors={n_neighbors}\n')

    nn_matches = find_matches(preds, df, n_neighbors)
    found_prop, missed_nn = compute_found_prop(nn_matches, gt_matches)
    n_matches = sum([len(nn_matches[k]) for k in nn_matches])
    print('NN matches :')
    print(f'Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')
    print(f'Best reachable IoU : {compute_best_iou(nn_matches, gt_matches) :.3f}\n')

#     naive_matches = json.load(open(OUT_PATH + f"knn_preds_{n_neighbors}_0.json", 'r'))
    naive_matches = json.load(open(OUT_PATH + f"dist_matches_{n_neighbors}_0.json", 'r'))
#     naive_matches = json.load(open(OUT_PATH + f"dist-phone_matches_{n_neighbors}_0.json", 'r'))
#     naive_matches = json.load(open(OUT_PATH + f"dist-phone-url_matches_{n_neighbors}_0.json", 'r'))

    found_prop, missed_pos = compute_found_prop(naive_matches, gt_matches)
    n_matches = sum([len(naive_matches[k]) for k in naive_matches])
    print('Naive matches :')
    print(f'Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')
    print(f'Best reachable IoU : {compute_best_iou(naive_matches, gt_matches) :.3f}\n')

    # UNION
    merged_matches = {k : list(set(naive_matches[k] + nn_matches[k])) for k in nn_matches}
    found_prop, missed = compute_found_prop(merged_matches, gt_matches)
    n_matches = sum([len(merged_matches[k]) for k in merged_matches])
    print('Merged matches - Union :')
    print(f'Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')
    print(f'Best reachable IoU : {compute_best_iou(merged_matches, gt_matches) :.3f}')
    
    df_pairs = create_pairs(nn_matches, naive_matches, n_neighbors, gt_matches)
    prop = df_pairs['match'].sum() / len(df_pairs) * 100
    save_path = EXP_FOLDER + f'df_pairs_{n_neighbors}.csv'
    
    if SAVE:
        df_pairs.to_csv(save_path, index=False)
        print(f'-> Saved pairs to {save_path} - Positive proportion  {prop:.2f}%\n')
    else:
        print(f'Positive proportion  {prop:.2f}%\n')

    # INTERSECTION
    merged_matches = {k : list(set(naive_matches[k]).intersection(nn_matches[k])) for k in nn_matches}
    found_prop, missed = compute_found_prop(merged_matches, gt_matches)
    n_matches = sum([len(merged_matches[k]) for k in merged_matches])
    print('Merged matches - Intersection :')
    print(f'Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')
    print(f'Best reachable IoU : {compute_best_iou(merged_matches, gt_matches) :.3f}')
    
    df_pairs_i = df_pairs[(df_pairs['rank'] >= -0.5) &  (df_pairs['rank_nn'] >= -0.5)].reset_index(drop=True)
    prop = df_pairs_i['match'].sum() / len(df_pairs_i) * 100

    if SAVE:
        df_pairs_i.to_csv(save_path, index=False)
        print(f'-> Saved pairs to {save_path} - Positive proportion  {prop:.2f}%\n')
    else:
        print(f'Positive proportion  {prop:.2f}%\n')

#     break


- -> n_neighbors=10

NN matches :
Found 65.78% of matches with 5.12M candidates.
Best reachable IoU : 0.870

Naive matches :
Found 75.53% of matches with 5.45M candidates.
Best reachable IoU : 0.902

Merged matches :
Found 89.25% of matches with 10.26M candidates.
Best reachable IoU : 0.953

-> Saved pairs to ../logs/2022-05-19/4/df_pairs_10.csv - Positive proportion  5.01%

- -> n_neighbors=20

NN matches :
Found 70.42% of matches with 10.82M candidates.
Best reachable IoU : 0.888

Naive matches :
Found 81.19% of matches with 10.98M candidates.
Best reachable IoU : 0.923

Merged matches :
Found 92.87% of matches with 21.39M candidates.
Best reachable IoU : 0.968

-> Saved pairs to ../logs/2022-05-19/4/df_pairs_20.csv - Positive proportion  2.81%

- -> n_neighbors=30

NN matches :
Found 72.90% of matches with 16.51M candidates.
Best reachable IoU : 0.897

Naive matches :
Found 83.91% of matches with 16.42M candidates.
Best reachable IoU : 0.933

Merged matches :
Found 94.42% of matches with 32.45M candidates.
Best reachable IoU : 0.974

-> Saved pairs to ../logs/2022-05-19/4/df_pairs_30.csv - Positive proportion  2.01%

- -> n_neighbors=40

NN matches :
Found 74.59% of matches with 22.21M candidates.
Best reachable IoU : 0.903

Naive matches :
Found 85.59% of matches with 21.77M candidates.
Best reachable IoU : 0.940

Merged matches :
Found 95.32% of matches with 43.44M candidates.
Best reachable IoU : 0.978

-> Saved pairs to ../logs/2022-05-19/4/df_pairs_40.csv - Positive proportion  1.58%

In [None]:
for i in range(1, 100):
#     print(i)
    idx = df.index[i]

    if len(gt_matches[idx]) <= 1:
        continue
        
    if len(gt_matches[idx]) > 10:
        continue

    found = [idx] + list(set(merged_matches[idx]).intersection(set(gt_matches[idx])))
    all_found = sorted(found) == sorted(gt_matches[idx])

    if all_found:
        continue
    
#     if merged_matches[k]
    found_naive = [idx] + list(set(naive_matches[idx]).intersection(set(gt_matches[idx])))
    all_found_naive = sorted(found_naive) == sorted(gt_matches[idx])
    
    if all_found:
        continue

    print('Query')
    display(df.loc[[idx]])

    print('Target')
    display(df.loc[gt_matches[idx]])

    print('Found naive')
    display(df.loc[list(found_naive)])
    
    print('Found NN')
    found_nn = [idx] + list(set(nn_matches[idx]).intersection(set(gt_matches[idx])))
    display(df.loc[list(found_nn)])
    
    display(df.loc[nn_matches[idx]])

    break
    print('-' * 50)