#### Code to train models

In [34]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
cd ../src

/home/theo/kaggle/foursquare/src


## Imports

In [36]:
import os
import torch

os.environ['CUDA_VISIBLE_DEVICES'] = "1"
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 2080 Ti'

In [37]:
import os
import ast
import glob
import json
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm
from numerize.numerize import numerize

In [38]:
from params import *

from data.preparation import prepare_train_data, prepare_triplet_data
from data.dataset import SingleDataset
from data.tokenization import get_tokenizer

from model_zoo.models import SingleTransformer

from utils.logger import Config
from utils.torch import load_model_weights
from utils.metrics import *

from inference.predict import predict
from inference.knn import *

## Data

In [40]:
df = prepare_train_data(root=DATA_PATH)

In [41]:
folds = pd.read_csv(DATA_PATH + "folds_2.csv")[['id', 'fold']]
df = df.merge(folds, how="left", on="id").set_index("id")

In [42]:
FOLD = 0

In [43]:
df = df[df['fold'] == FOLD]

In [44]:
gt_matches = json.load(open(DATA_PATH + "gt.json", 'r'))

## Exp

In [54]:
# EXP_FOLDER = LOG_PATH + "2022-05-18/2/"
# EXP_FOLDER = LOG_PATH + "2022-05-18/3/"  # 10 ep
# EXP_FOLDER = LOG_PATH + "2022-05-19/0/"  # 2 ep, triplets_v2
EXP_FOLDER = LOG_PATH + "2022-05-19/1/"  # 1 ep, triplets_v2, d=64

In [55]:
config = Config(json.load(open(EXP_FOLDER + "config.json", 'r')))

In [56]:
tokenizer = get_tokenizer(config.name)

In [57]:
dataset = SingleDataset(df, tokenizer, config.max_len)

In [58]:
weights = sorted(glob.glob(EXP_FOLDER + "*.pt"))

In [59]:
model = SingleTransformer(
    config.name,
    nb_layers=config.nb_layers,
    no_dropout=config.no_dropout,
    embed_dim=config.embed_dim,
    nb_features=config.nb_features,
).cuda()
model.zero_grad()

model = load_model_weights(model, weights[FOLD])


 -> Loading weights from ../logs/2022-05-19/1/xlm-roberta-base_0.pt



In [60]:
if os.path.exists(EXP_FOLDER + f"fts_val_{FOLD}.npy"):
    preds = np.load(EXP_FOLDER + f"fts_val_{FOLD}.npy")
else:
    preds = predict(model, dataset, config.data_config)
    np.save(EXP_FOLDER + f"fts_val_{FOLD}.npy", preds)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=35588.0), HTML(value='')))




### Matches

In [None]:
for n_neighbors in [20, 30, 40, 50]:
    print(f'\n- -> n_neighbors={n_neighbors}\n')

    preds_matches = find_matches(preds, df, n_neighbors)
    found_prop, missed_nn = compute_found_prop(preds_matches, gt_matches)
    n_matches = sum([len(preds_matches[k]) - 1 for k in preds_matches])
    print(f' - Roberta :\t Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')

#     position_matches = json.load(open(OUT_PATH + f"knn_preds_{n_neighbors}_0.json", 'r'))
#     position_matches = json.load(open(OUT_PATH + f"dist_matches_{n_neighbors}_0.json", 'r'))
    position_matches = json.load(open(OUT_PATH + f"dist-phone-url_matches_{n_neighbors}_0.json", 'r'))

    found_prop, missed_pos = compute_found_prop(position_matches, gt_matches)
    n_matches = sum([len(position_matches[k]) - 1 for k in position_matches])
    print(f' - Position :\t Found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')

    merged_matches = {k : list(set(position_matches[k] + preds_matches[k])) for k in preds_matches}
    found_prop, missed = compute_found_prop(merged_matches, gt_matches)
    n_matches = sum([len(merged_matches[k]) - 1 for k in merged_matches])
    print(f' - Merged :\t found {found_prop * 100 :.2f}% of matches with {numerize(n_matches)} candidates.')


- -> n_neighbors=20

 - Roberta :	 Found 39.98% of matches with 10.82M candidates.
 - Position :	 Found 81.25% of matches with 11.07M candidates.
 - Merged :	 found 87.41% of matches with 21.62M candidates.

- -> n_neighbors=30

 - Roberta :	 Found 42.77% of matches with 16.51M candidates.
 - Position :	 Found 83.96% of matches with 16.51M candidates.
 - Merged :	 found 89.69% of matches with 32.69M candidates.

- -> n_neighbors=40



- -> n_neighbors=20

 - Roberta :	 Found 56.88% of matches with 10.82M candidates.
 - Position :	 Found 81.25% of matches with 11.07M candidates.
 - Merged :	 found 90.59% of matches with 21.52M candidates.

- -> n_neighbors=30

 - Roberta :	 Found 59.69% of matches with 16.51M candidates.
 - Position :	 Found 83.96% of matches with 16.51M candidates.
 - Merged :	 found 92.44% of matches with 32.57M candidates.

- -> n_neighbors=40

 - Roberta :	 Found 61.66% of matches with 22.21M candidates.
 - Position :	 Found 85.63% of matches with 21.87M candidates.
 - Merged :	 found 93.53% of matches with 43.56M candidates.

- -> n_neighbors=50

 - Roberta :	 Found 63.20% of matches with 27.9M candidates.
 - Position :	 Found 86.83% of matches with 27.13M candidates.
 - Merged :	 found 94.30% of matches with 54.47M candidates.

In [28]:
for i in range(100):
    if len(gt_matches[df.index[i]]) <= 1:
        continue
    if not len(df.loc[list(missed[i])]):
        continue

    print('Query')
    display(df.loc[[df.index[i]]])

    print('Target')
    display(df.loc[gt_matches[df.index[i]]])

    print('Missed')
    display(df.loc[list(missed[i])])

#     print('Preds')
#     display(df.loc[preds_matches[df.index[i]]].head(5))

#     break
    print('-' * 50)

Query


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_0005d223c299ab,гб 26костюшко 2,59.848975,30.294264,"Ул. Костюшко 2, Санкт-Петербург, RU",RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0


Target


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_13bbf0a27cfb9a,Хирургическое Отделение. Больница 26,59.851817,30.298295,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_e7f5381a2425e3,Больничка,59.848934,30.295753,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_9ab5a9ee6069ab,Палата 18,59.849037,30.297071,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_eb824f0b14ac67,Городская больница №26,59.849424,30.295438,"ул. Костюшко, 2, Санкт-Петербург, Санкт-Петерб...",RU,hospital26.ru,8124151888.0,Hospitals,P_b9ff4f6a365ad0,205706,0
E_811d75955cb441,Палата 1. 4 этаж,59.849085,30.294058,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_93d80b2c50d8fa,Городская больница №26,59.83811,30.237221,"ул. Костюшко, 2, St.-Petersburg, RU",RU,,78124151888.0,Hospitals,P_b9ff4f6a365ad0,205706,0
E_3db8de706c6caa,Палата Тусэ N 5,59.848853,30.294709,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_e9d3ce80ac24ec,Кабинет УЗИ,59.8493,30.294802,RU,RU,,,,P_b9ff4f6a365ad0,205706,0
E_1603a71e9bcdfc,Палата 9,59.848981,30.295432,RU,RU,,,Hospitals,P_b9ff4f6a365ad0,205706,0
E_ea2a92a9306798,Городская больница №26,59.904487,30.342419,"ул. Костюшко, 2, St.-Petersburg, RU",RU,,78124151888.0,Hospitals,P_b9ff4f6a365ad0,205706,0


Missed


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_ea2a92a9306798,Городская больница №26,59.904487,30.342419,"ул. Костюшко, 2, St.-Petersburg, RU",RU,,78124151888,Hospitals,P_b9ff4f6a365ad0,205706,0
E_93d80b2c50d8fa,Городская больница №26,59.83811,30.237221,"ул. Костюшко, 2, St.-Petersburg, RU",RU,,78124151888,Hospitals,P_b9ff4f6a365ad0,205706,0


--------------------------------------------------
Query


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_000ac5c1692f17,Costa Coffee,55.633228,37.633801,"Каширское ш., 26, Москва, RU",RU,,,Coffee Shops,P_cb130155215509,178484,0


Target


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_02123b385ee6c1,Costa coffee,55.653243,37.646748,"ТРК «Москворечье», Москва, Россия, RU",RU,costacoffee.ru,,Coffee Shops,P_cb130155215509,178484,0
E_000ac5c1692f17,Costa Coffee,55.633228,37.633801,"Каширское ш., 26, Москва, RU",RU,,,Coffee Shops,P_cb130155215509,178484,0


Missed


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_02123b385ee6c1,Costa coffee,55.653243,37.646748,"ТРК «Москворечье», Москва, Россия, RU",RU,costacoffee.ru,,Coffee Shops,P_cb130155215509,178484,0


--------------------------------------------------
Query


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_000b6fee493351,Atlantic Ocean,33.502066,-79.065192,US,US,,,Beaches,P_dc6fc88a91a408,116314,0


Target


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_6727ec3ecf88ae,The atlantic Ocean,47.543993,-52.575458,CA,CA,,,,P_dc6fc88a91a408,116314,0
E_000b6fee493351,Atlantic Ocean,33.502066,-79.065192,US,US,,,Beaches,P_dc6fc88a91a408,116314,0
E_fb179fb773551a,Atlantic Ocean,39.532596,-74.262788,"Beach Haven, NJ, 08008 US",US,,,Other Great Outdoors,P_dc6fc88a91a408,116314,0
E_f74c8a0744d3f9,Atlantic Ocean,39.201227,-14.895984,"Trazo, Galiçya, 15687 ES",ES,,,,P_dc6fc88a91a408,116314,0
E_76cbba4a082131,on a boat somewhere In atlantic ocean,33.778001,-78.536479,US,US,,,Other Great Outdoors,P_dc6fc88a91a408,116314,0
E_d6d2369199339e,где то в атлантическом океане,27.997034,-16.657917,ES,ES,,,,P_dc6fc88a91a408,116314,0
E_2f5181918a188c,Atlantic Ocean,28.066065,-16.729744,ES,ES,,,Other Great Outdoors,P_dc6fc88a91a408,116314,0
E_da4966c09d1280,Где то в океане,28.046823,-16.764508,ES,ES,,,Fishing Spots,P_dc6fc88a91a408,116314,0
E_afedc006ea1570,Somewhere over Atlantic Ocean,60.785767,-39.67418,GL,GL,,,,P_dc6fc88a91a408,116314,0
E_f716e71e181501,В Атлантике На Яхте,28.023055,-16.573253,ES,ES,,,,P_dc6fc88a91a408,116314,0


Missed


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_6727ec3ecf88ae,The atlantic Ocean,47.543993,-52.575458,CA,CA,,,,P_dc6fc88a91a408,116314,0
E_d6d2369199339e,где то в атлантическом океане,27.997034,-16.657917,ES,ES,,,,P_dc6fc88a91a408,116314,0
E_f716e71e181501,В Атлантике На Яхте,28.023055,-16.573253,ES,ES,,,,P_dc6fc88a91a408,116314,0
E_da4966c09d1280,Где то в океане,28.046823,-16.764508,ES,ES,,,Fishing Spots,P_dc6fc88a91a408,116314,0


--------------------------------------------------
Query


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_000c1e53567646,Esbaş,38.336512,27.133751,TR,TR,,,Coworking Spaces,P_9ef6227fc7603a,159240,0


Target


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_df02c68e9725ef,Gaziemir Serbest Bölge,38.335986,27.133831,"İzmir, TR",TR,,,,P_9ef6227fc7603a,159240,0
E_55709c6beefb17,Agean Free Zone,38.337053,27.133114,TR,TR,,,,P_9ef6227fc7603a,159240,0
E_000c1e53567646,Esbaş,38.336512,27.133751,TR,TR,,,Coworking Spaces,P_9ef6227fc7603a,159240,0
E_74713f5c0acbe5,Ege Serbest Bölgesi,38.336693,27.129885,"Akçay Cad. No:144/1, Gaziemir, İzmir, 35410 TR",TR,esbas.com.tr,2322513851.0,Industrial Estates,P_9ef6227fc7603a,159240,0
E_c84e5082be0ad6,Ege serbest bölge,38.335944,27.123388,TR,TR,,,Factories,P_9ef6227fc7603a,159240,0


Missed


Unnamed: 0_level_0,name,latitude,longitude,address,country,url,phone,categories,point_of_interest,clust,fold
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
E_c84e5082be0ad6,Ege serbest bölge,38.335944,27.123388,TR,TR,,,Factories,P_9ef6227fc7603a,159240,0


--------------------------------------------------
