#### Code to train classification models
**TODO**:
- verif dist fts
- GroupKFold ?
- inspect cat features
- catboost + xgboost + lgbm
- NaNs ?

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ../src

/home/theo/kaggle/foursquare/src


## Imports

In [3]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

import torch
torch.cuda.get_device_name(0)

'NVIDIA RTX A6000'

In [4]:
import os
import gc
import re
import glob
import json
import lofo
import torch
import pickle
import optuna
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from collections import Counter
from pandarallel import pandarallel
from inference.main import k_fold_inf
from numerize.numerize import numerize
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import GroupKFold
from cuml.feature_extraction.text import TfidfVectorizer

pandarallel.initialize(progress_bar=False, use_memory_fs=False)
pd.options.display.max_columns = 500
warnings.simplefilter("ignore")

INFO: Pandarallel will run on 12 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [5]:
from params import *

from data.features import *
from data.preparation import *
from data.post_processing import *

from utils.logger import prepare_log_folder, create_logger, save_config
from utils.metrics import *

from model_zoo.xgb import objective_xgb, lofo_xgb
from model_zoo.catboost import objective_catboost
from training.main_boosting import k_fold
from utils.plot import *

from matching import get_CV, load_cleaned_data
from pp import get_improved_CV
from dtypes import *

### Params

In [6]:
LEVEL = 2

N_FOLDS = 5  # 10 if LEVEL == 2 else 5

## Model


In [7]:
df = pd.read_csv(DATA_PATH + "train.csv")[["id", "point_of_interest"]]

In [8]:
path = f'../output/folds_{N_FOLDS}.csv'

if os.path.exists(path):
    df_split = pd.read_csv(path)
else:
    from sklearn.model_selection import GroupKFold

    gkf = GroupKFold(n_splits=N_FOLDS)
    splits = list(gkf.split(df["id"], groups=df['point_of_interest']))


    df_split = df[["id", "point_of_interest"]].copy()
    df_split['fold'] = -1

    for i, (_, val_idx) in enumerate(splits):
        df_split.loc[val_idx, 'fold'] = i

    df_split.to_csv(path, index=False)

In [9]:
if LEVEL == 1:
    df_p = pd.read_csv(OUT_PATH + f"features_train_1.csv", dtype=DTYPES_1)
    THRESHOLD = 0
else:
    THRESHOLD = 0.005
    df_p = pd.read_csv(OUT_PATH + f"features_train_2_{THRESHOLD}.csv", dtype=DTYPES_2)

In [10]:
FEATURES = list(df_p.columns[2:])

FEATURES = [
    f for f in FEATURES if f not in ['point_of_interest_1', 'fold_1', 'point_of_interest_2', 'fold_2', 'match']
]

In [11]:
if "fold_1" not in df_p.columns:
    df_p = df_p.merge(df_split, left_on="id_1", right_on="id", how="left").drop('id', axis=1)
    df_p = df_p.merge(df_split, left_on="id_2", right_on="id", how="left", suffixes=('_1', '_2')).drop('id', axis=1)

In [12]:
if "match" not in df_p.columns:
    df_p['match'] = (df_p['point_of_interest_1'] == df_p['point_of_interest_2']).astype(int)

In [13]:
df_p = df_p.sort_values(['id_1', 'id_2']).reset_index(drop=True)

In [14]:
df_p.head()

Unnamed: 0,id_1,id_2,dist,dist1,dist2,country,cat2a,cat2b,name_pi1,name_lcs2,name_lcs,name_pi1_r1,name_lcs2_r1,name_lcs2_r2,name_lcs_r1,name_lcs_r2,name_r3,name_lcs_r4,categories_pi1,categories_lcs2,categories_lcs,categories_pi1_r1,categories_lcs2_r1,categories_lcs2_r2,categories_lcs_r1,categories_lcs_r2,categories_r3,categories_lcs_r4,address_pi1,address_pi1_r1,address_lcs2_r1,address_lcs2_r2,address_lcs_r1,address_lcs_r2,address_r3,city_NA,address_NA,phone_m10,dist_r1,dist_r2,id_cc_min,id_cc_max,name_cc_min,name_cc_max,angular_distance_min,angular_distance_l2_min,same_state,same_zip,same_city,name_len_diff,name_levenshtein,address_len_diff,address_levenshtein,url_len_diff,url_levenshtein,Nb_multiPoi_1,Nb_multiPoi_2,Nb_connect1,Nb_connect2,ratio_connect_multipoi1,ratio_connect_multipoi2,Nb_strong_connect,ratio_strong_connect_multipoi1,ratio_strong_connect_multipoi2,cat_link_score,cat_link_score_all,mean_1,mean_2,q25_1,q25_2,q50_1,q50_2,q75_1,q75_2,q90_1,q90_2,q99_1,q99_2,cat_solo_score_1,cat_solo_score_2,freq_pairing_with_other_groupedcat_1,freq_pairing_with_other_groupedcat_2,mean_ratiodist_1,mean_ratiodist_2,mean_ratiodist_pair,q25_ratiodist_1,q25_ratiodist_2,q25_ratiodist_pair,q50_ratiodist_1,q50_ratiodist_2,q50_ratiodist_pair,q75_ratiodist_1,q75_ratiodist_2,q75_ratiodist_pair,q90_ratiodist_1,q90_ratiodist_2,q90_ratiodist_pair,q99_ratiodist_1,q99_ratiodist_2,q99_ratiodist_pair,grouped_cat_link_score,grouped_cat_link_score_all,name_initial_cclcs,name_initial_lllcs,name_initial_lcs2,name_initial_lcs,name_initial_pi1,name_initial_pi2,name_initial_ld,name_initial_ljw,name_initial_dsm1,name_initial_ll1,name_initial_pi1_r1,name_initial_pi2_r1,name_initial_lcs2_r1,name_initial_lcs2_r2,name_initial_lcs_r1,name_initial_lcs_r2,name_initial_lllcs_r1,name_initial_lllcs_r2,name_initial_r3,name_initial_lcs_r4,name_initial_decode_cclcs,name_initial_decode_lllcs,name_initial_decode_lcs2,name_initial_decode_lcs,name_initial_decode_pi1,name_initial_decode_pi2,name_initial_decode_ld,name_initial_decode_ljw,name_initial_decode_dsm1,name_initial_decode_ll1,name_initial_decode_pi1_r1,name_initial_decode_pi2_r1,name_initial_decode_lcs2_r1,name_initial_decode_lcs2_r2,name_initial_decode_lcs_r1,name_initial_decode_lcs_r2,name_initial_decode_lllcs_r1,name_initial_decode_lllcs_r2,name_initial_decode_r3,name_initial_decode_lcs_r4,name_initial_decode_m5,nameC_lllcs,nameC_lcs2,nameC_lcs,nameC_pi1,nameC_pi2,nameC_ld,nameC_ljw,nameC_dsm1,nameC_ll1,nameC_pi1_r1,nameC_pi2_r1,nameC_lcs2_r1,nameC_lcs2_r2,nameC_lcs_r1,nameC_lcs_r2,nameC_lllcs_r1,nameC_lllcs_r2,nameC_r3,name_cclcs,name_lllcs,name_pi2,name_ld,name_ljw,name_dsm1,name_ll1,name_pi2_r1,name_lllcs_r1,name_lllcs_r2,name_m5,categories_pi2,categories_ld,categories_ljw,categories_dsm1,categories_ll1,categories_pi2_r1,categories_lllcs_r1,categories_lllcs_r2,address_pi2,address_ld,address_ljw,address_dsm1,address_ll1,address_pi2_r1,address_lllcs_r1,url_lcs,url_ld,url_dsm1,url_ll1,url_lcs_r2,url_r3,city_lcs,city_pi1,city_ld,city_ljw,city_dsm1,city_ll1,city_lcs2_r1,city_lcs2_r2,city_r3,state_ld,state_ljw,state_lcs2_r1,state_lcs_r2,state_r3,state_lcs_r4,state_NA,zip_pi2,zip_ljw,zip_pi1_r1,zip_lcs_r1,zip_NA,phone_lcs2,phone_pi2,phone_ljw,phone_dsm1,address_cc_min,address_cc_max,categories_cc_min,city_cc_min,city_cc_max,state_cc_min,zip_cc_min,zip_cc_max,phone_cc_min,phone_cc_max,city_group_cc_min,city_group_cc_max,state_group_cc_min,state_group_cc_max,word_c_cs,word_n_cs,id_cc_2K,id_cc_1K,id_cc_500,id_cc_200,id_cc_100,id_cc_50,id_cc_5K,id_cc_cat_2K,id_cc_cat_1K,id_cc_cat_500,id_cc_cat_200,id_cc_cat_100,id_cc_cat_50,id_cc_cat_5K,id_cc_simplcat_2K,id_cc_simplcat_1K,id_cc_simplcat_500,id_cc_simplcat_200,id_cc_simplcat_100,id_cc_simplcat_50,id_cc_simplcat_5K,name_num,address_num,langs,cat_simpl,num_in_name,nb_in_name,ratio_in_name,address_both_nan,address_any_nan,city_both_nan,city_any_nan,state_both_nan,state_any_nan,zip_both_nan,zip_any_nan,url_both_nan,url_any_nan,phone_both_nan,phone_any_nan,info_power_1,info_power_2,info_diff,name_tf_idf_33_char_wb_sim,address_tf_idf_33_char_wb_sim,url_tf_idf_33_char_wb_sim,name_wratio,name_partial_ratio,address_wratio,address_partial_ratio,url_wratio,url_partial_ratio,point_of_interest_1,fold_1,point_of_interest_2,fold_2,match
0,E_000001272c6c5d,E_da7fa3963561f8,660,600,50,11,1,1,4,10,14,0.290039,0.709961,0.560059,1.0,0.779785,0.779785,0.709961,4,4,4,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0,0.0,0.0,0.0,0.0,0.0,0.090027,1,1,0,11.179688,17.171875,2,2,2,2,0.010391,0.009438,,,,4,0.222168,11,,0,,24757.0,24757.0,16,1,0.000646,4e-05,200,0.00808,0.00808,0.21106,0.171997,6.289062,6.289062,0.015221,0.015221,0.052246,0.052246,0.527832,0.527832,3.681641,3.681641,92.1875,92.1875,0.817383,0.817383,0.316406,0.316406,109.947144,109.947144,4.953125,44355.847656,44355.847656,492.748932,12088.375977,12088.375977,148.375,1211.967041,1211.967041,54.59375,181.272186,181.272186,4.953125,8.166169,8.166169,1.105469,0.0,0.0,2,11,11,16,5,11,5,0.899902,0.859863,15,0.330078,0.72998,0.72998,0.549805,1.070312,0.799805,0.72998,0.549805,0.75,0.689941,2,11,11,15,5,11,5,0.899902,0.859863,15,0.330078,0.72998,0.72998,0.549805,1.0,0.75,0.72998,0.549805,0.75,0.72998,1,2,1,2,1,1,1,0.609863,0.799805,2,0.5,0.5,0.5,0.330078,1.0,0.669922,1.0,0.669922,0.669922,2,10,10,4,0.899902,0.879883,14,0.709961,0.709961,0.560059,0,4,0,1.0,1.0,4,1.0,0.0,0.0,0,11,0.0,0.0,1,0.0,0.0,0,0,1.0,1,0.0,1.0,0,0,10,0.0,0.0,1,0.0,0.0,0.099976,14,0.0,0.0,0.0,0.070007,0.0,1,0,0.0,0.0,0.0,1,0,0,0.0,1.0,2,9896,9896,9,9896,8954,298,9896,9896,9896,9,9896,8954,9896,1.0,1.0,24,10,4,0,0,0,180,2,2,0,0,0,0,14,2,2,0,0,0,0,8,0,0,8,19,0,0,0.0,0,1,0,1,0,1,0,1,1,1,1,1,0.666504,0.0,0.666504,0.755859,,,88.0,79.0,,,,,P_677e840bb6fc7e,2,P_677e840bb6fc7e,2,1
1,E_000002eae2a589,E_1ba37c68d3b314,180,100,150,8,1,2,7,7,10,0.540039,0.540039,0.320068,0.77002,0.449951,0.589844,0.700195,20,20,20,1.0,1.0,0.830078,1.0,0.830078,0.830078,1.0,0,0.0,0.0,0.0,0.0,0.0,0.119995,1,1,0,3.054688,1.225586,7,9,9,12,0.003901,0.002851,,,,9,0.54541,8,,0,,24757.0,1310.0,15,15,0.011452,0.000606,200,0.15271,0.00808,0.192993,0.134033,6.289062,2.960938,0.015869,0.015869,0.052246,0.046631,0.527832,0.235474,3.681641,1.240234,92.1875,85.375,0.817383,0.694336,0.316406,0.22937,29.964104,60.340286,2.013672,10938.021484,10938.021484,134.289795,3294.469482,4023.872803,44.6875,330.299652,735.095154,14.882812,49.402454,148.413162,2.013672,3.004166,3.004166,1.0,0.321045,0.101013,1,8,8,11,8,1,14,0.830078,0.560059,14,0.569824,0.070007,0.569824,0.320068,0.790039,0.439941,0.569824,0.320068,0.560059,0.72998,1,8,8,11,8,1,14,0.830078,0.560059,14,0.569824,0.070007,0.569824,0.320068,0.790039,0.439941,0.569824,0.320068,0.560059,0.72998,1,1,1,1,1,0,2,0.609863,0.399902,2,0.5,0.0,0.5,0.330078,0.5,0.330078,0.5,0.330078,0.669922,1,7,1,12,0.850098,0.569824,13,0.080017,0.540039,0.320068,1,20,4,0.970215,0.910156,20,1.0,1.0,0.830078,0,8,0.0,0.0,1,0.0,0.0,0,0,1.0,1,0.0,1.0,0,0,12,0.0,0.0,1,0.0,0.0,0.080017,12,0.0,0.0,0.0,0.080017,0.0,1,0,0.0,0.0,0.0,1,0,0,0.0,0.0,9,9896,48,9896,9896,4446,44,9896,9,9896,9896,9896,9896,9896,0.707031,1.0,664,601,329,89,24,9,1635,163,133,59,17,4,1,364,39,29,15,5,1,0,59,0,1,4,0,0,0,0.0,0,1,0,1,0,1,0,1,1,1,0,1,0.0,0.833496,0.833496,0.354736,,,69.0,77.0,,,,,P_d82910d8382a83,1,P_aaf926a3437875,0,0
2,E_000002eae2a589,E_45e0cc554a703c,20,20,10,8,1,5,5,5,6,0.709961,0.709961,0.379883,0.859863,0.459961,0.540039,0.830078,0,2,11,0.0,0.099976,0.059998,0.549805,0.310059,0.560059,0.180054,0,0.0,0.0,0.0,0.0,0.0,0.099976,1,1,0,0.349854,0.105164,6,7,11,12,0.000483,0.000381,,,,6,0.538574,10,,18,,11115.0,1310.0,15,11,0.011452,0.00099,200,0.15271,0.01799,0.0,0.0,461.0,2.960938,0.044922,0.015869,0.342529,0.046631,304.5,0.235474,1637.0,1.240234,4792.0,85.375,0.865723,0.694336,0.22937,0.22937,1.0,7.389056,1.105469,445.857697,1211.967041,16.444645,60.340286,445.857697,5.472656,1.105171,90.017128,2.71875,1.0,16.444645,1.105469,1.0,1.221403,1.0,0.0,0.0,1,5,5,6,5,1,8,0.560059,0.569824,7,0.709961,0.140015,0.709961,0.360107,0.859863,0.429932,0.709961,0.360107,0.5,0.830078,1,5,5,6,5,1,8,0.560059,0.569824,7,0.709961,0.140015,0.709961,0.360107,0.859863,0.429932,0.709961,0.360107,0.5,0.830078,0,1,1,1,1,1,2,0.0,0.5,2,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,1.0,1,5,1,7,0.569824,0.600098,7,0.140015,0.709961,0.379883,0,1,27,0.640137,0.209961,20,0.049988,0.0,0.0,0,10,0.0,0.0,1,0.0,0.0,0,18,0.0,1,0.0,0.059998,0,0,12,0.0,0.0,1,0.0,0.0,0.080017,2,0.0,0.0,0.0,0.5,0.0,1,0,0.0,0.0,0.0,1,0,0,0.0,0.0,6,9896,109,9896,9896,8954,9,9896,11,9896,9896,9896,9896,9896,0.0,1.0,664,601,329,66,21,9,1635,220,199,98,21,7,3,402,21,17,10,3,2,2,32,0,1,4,0,0,0,0.0,0,1,0,1,0,1,0,1,0,1,0,1,0.0,1.0,1.0,0.419434,,,64.0,71.0,,,,,P_d82910d8382a83,1,P_380e258e825d46,1,0
3,E_000002eae2a589,E_e80db432029aea,20,10,10,8,1,2,13,13,13,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1,2,7,0.070007,0.130005,0.099976,0.469971,0.350098,0.75,0.290039,0,0.0,0.0,0.0,0.0,0.0,0.059998,1,1,0,0.349854,0.105164,5,7,12,12,0.000374,0.000264,,,,0,0.0,18,,0,,24757.0,1310.0,15,15,0.011452,0.000606,200,0.15271,0.00808,0.104004,0.072021,6.289062,2.960938,0.019989,0.015869,0.079895,0.046631,0.527832,0.235474,3.681641,1.240234,92.1875,85.375,0.817383,0.694336,0.316406,0.22937,4.0552,7.389056,1.105469,992.274841,1211.967041,16.444645,244.691925,445.857697,5.472656,40.447308,90.017128,2.71875,6.685894,16.444645,1.105469,1.221403,1.221403,1.0,0.321045,0.101013,1,14,14,14,14,14,0,1.0,1.0,14,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1,14,14,14,14,14,0,1.0,1.0,14,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1,2,2,2,2,2,0,1.0,1.0,2,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1,13,13,0,1.0,1.0,13,1.0,1.0,1.0,1,2,15,0.589844,0.399902,15,0.130005,0.0,0.0,0,18,0.0,0.0,1,0.0,0.0,0,0,1.0,1,0.0,1.0,0,0,12,0.0,0.0,1,0.0,0.0,0.080017,2,0.0,0.0,0.0,0.5,0.0,1,0,0.0,0.0,0.0,1,0,0,0.0,1.0,15,9896,48,9896,9896,8954,36,9896,9896,9896,9896,9896,9896,9896,0.0,1.0,664,601,298,59,21,9,1635,163,133,66,12,6,2,364,39,32,15,3,2,2,59,0,0,4,0,0,0,0.0,0,1,0,1,0,1,0,1,1,1,1,1,0.0,0.666504,0.666504,1.0,,,100.0,100.0,,,,,P_d82910d8382a83,1,P_d82910d8382a83,1,1
4,E_000007f24ebc95,E_39ed217394b6d1,360,200,330,5,2,3,4,4,6,0.290039,0.290039,0.209961,0.429932,0.320068,0.740234,0.669922,1,5,5,0.109985,0.560059,0.290039,0.560059,0.290039,0.529785,1.0,0,0.0,0.0,0.0,0.0,0.0,1.0,1,2,0,5.6875,2.669922,2,4,2,4,0.00795,0.005791,,,,5,0.684082,0,,0,,13664.0,3895.0,16,16,0.001171,0.004108,200,0.014633,0.051361,0.001,0.001,7.253906,3.1875,0.023819,0.019913,0.074158,0.068481,0.369385,0.362305,2.216797,1.446289,30.40625,26.46875,0.685547,0.609375,0.22937,0.146362,49.402454,109.947144,3.003906,14764.788086,18033.748047,270.426361,4914.768555,5431.662109,81.4375,992.274841,992.274841,29.96875,164.021896,244.691925,3.320312,13.463737,14.879733,1.105469,0.303955,0.045013,0,0,0,0,0,0,16,0.0,0.0,15,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.939941,0.0,1,0,4,6,4,0,13,0.529785,0.360107,14,0.290039,0.0,0.290039,0.209961,0.429932,0.320068,0.0,0.0,0.740234,0.669922,0,0,0,0,0,0,4,0.0,0.0,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.25,1,0,0,13,0.529785,0.360107,14,0.0,0.0,0.0,0,5,12,0.5,0.379883,9,0.560059,0.560059,0.290039,0,0,0.0,1.0,1,0.0,0.0,0,0,1.0,1,0.0,1.0,0,0,7,0.0,0.0,1,0.0,0.0,0.140015,8,0.0,0.0,0.0,0.119995,0.0,1,0,0.0,0.0,0.0,2,0,0,0.0,1.0,9896,9896,6633,9896,9896,6002,9896,9896,9896,9896,9896,9896,9896,9896,0.5,0.0,1479,734,329,12,3,3,9896,298,180,121,7,2,2,2440,66,59,48,3,0,0,364,0,0,10,0,0,0,0.0,1,1,0,1,0,1,1,1,1,1,1,1,0.0,0.333252,0.333252,0.116638,,,36.0,36.0,,,,,P_b1066599e78477,0,P_10b38113cb4853,4,0


### Params

In [17]:
if LEVEL == 1:
    PARAMS = {
        "xgb":
        {
            "learning_rate": 0.05,
            'max_depth': 10,
            'min_child_weight': 0.01,
            'reg_alpha': 0.01,
            'reg_lambda': 0.1,
            "colsample_bytree": 0.95,
            "subsample": 0.75,
        },    
        "catboost":
        {
            "learning_rate": 0.1,
            'depth': 12,
            "l2_leaf_reg": 0.1,
            "min_data_in_leaf": 2000,
#             "subsample": 0.75,
#             "bootstrap_type": "Poisson",
        },
        "lgbm": {
            "learning_rate": 0.05,
            "num_leaves": 500,
            "reg_alpha": 1,
            "reg_lambda": 10,
            "min_child_samples": 1000,
            "min_split_gain": 0.01,
            "min_child_weight": 0.01,
            "path_smooth": 0.1,
#             "min_data_in_bin": 320,
        }
    }
else:
    PARAMS = {
        "xgb":
        {
            "learning_rate": 0.05,
            'max_depth': 15,
            'colsample_bytree': 0.5,
            'reg_alpha': 1,
            'reg_lambda': 10,
            "min_child_weight": 0.1,
            "gamma": 0.1,
        },
        "catboost":
            {
            'depth': 12,
            "l2_leaf_reg": 0.1,
            "min_data_in_leaf": 2000,
#             'reg_lambda': 0.1,
#             "model_size_reg": 0.5,
#             "border_count": 256,
            },
        "lgbm": {
            "learning_rate": 0.05,
            "num_leaves": 511,
            "colsample_bytree": 0.5,
            "reg_alpha": 1,
            "reg_lambda": 70,
            "min_child_samples": 2000,  # MODIF  # 2000
            "min_split_gain": 0.02,
            "min_child_weight": 0.03,
            "path_smooth": 0.2,
#             "min_data_in_bin": 32,
        }
    }

In [18]:
OPTIMIZE = False
TRAIN = True
DEBUG = False

### Config

In [19]:
if LEVEL == 1:
    LOW_IMP = []
else:
    LOW_IMP = [
        'url_cc_max', 'url_cc_min', 'address_any_nan', 'address_both_nan', 'city_any_nan', 'zip_both_nan', 'phone_both_nan'
    ]

In [20]:
class Config:
    level = LEVEL
    threshold = THRESHOLD

    split = "gkf" # if LEVEL == 1 else "kf"
    n_folds = N_FOLDS

    features = FEATURES
    features = [f for f in FEATURES if f not in LOW_IMP]

    cat_features = ['country', 'cat2a', 'cat2b', "langs", "cat_simpl", "name_num", "address_num"]
    cat_features = [c for c in cat_features if c in FEATURES]

    target = "match"
    model = "lgbm"
    params = PARAMS[model]
    selected_folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    
    use_es = split == "gkf" 

In [21]:
if len(Config.cat_features):
    df_p[Config.cat_features] = df_p[Config.cat_features].astype("category")

### Training

In [None]:
%%time

if TRAIN:
    log_folder = None
    if not DEBUG:
        log_folder = prepare_log_folder(LOG_PATH + f"lvl_{LEVEL}/")
        print(f'Logging results to {log_folder}')
        save_config(Config, log_folder + 'config')
        create_logger(directory=log_folder, name="logs.txt")

    pred_oof, models, ft_imp = k_fold(df_p, Config, log_folder=log_folder)

Logging results to ../logs/lvl_2/2022-07-03/2/

-------------   Fold 1 / 5  -------------

    -> 2784784 training pairs
    -> 1395932 validation pairs

[LightGBM] [Info] Number of positive: 550725, number of negative: 2234059
[LightGBM] [Info] This is the GPU trainer!!
[LightGBM] [Info] Total Bins 24563
[LightGBM] [Info] Number of data points in the train set: 2784784, number of used features: 279
[LightGBM] [Info] Using GPU Device: NVIDIA RTX A6000, Vendor: NVIDIA Corporation
[LightGBM] [Info] Compiling OpenCL Kernel with 256 bins...
[LightGBM] [Info] GPU programs have been built
[LightGBM] [Info] Size of histogram bin entry: 8
[LightGBM] [Info] 211 dense feature groups (563.02 MB) transferred to GPU in 0.270612 secs. 1 sparse feature groups
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.197762 -> initscore=-1.400340
[LightGBM] [Info] Start training from score -1.400340
Training until validation scores don't improve for 100 rounds
[100]	valid_0's auc: 0.94834	valid_0's binary_log

### Train 2

In [None]:
# class Config:
#     split = "gkf" if LEVEL == 1 else "kf"
#     n_folds = N_FOLDS

#     features = FEATURES
#     features = [f for f in FEATURES if f not in LOW_IMP]

#     cat_features = ['country', 'cat2a', 'cat2b', "langs", "cat_simpl", "name_num", "address_num"]
#     cat_features = [c for c in cat_features if c in FEATURES]

#     target = "match"
#     model = "xgb"
#     params = PARAMS[model]
#     selected_folds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    
#     use_es = split == "gkf"

In [None]:
# %%time

# if TRAIN:
#     log_folder = None
#     if not DEBUG:
#         log_folder = prepare_log_folder(LOG_PATH + f"lvl_{LEVEL}/")
#         print(f'Logging results to {log_folder}')
#         save_config(Config, log_folder + 'config')
#         create_logger(directory=log_folder, name="logs.txt")

#     pred_oof, models, ft_imp = k_fold(df_p, Config, log_folder=log_folder)

### Retrieve

In [None]:
if LEVEL == 1:
    EXP_FOLDER = LOG_PATH + "lvl_1/" + "2022-07-02/2/"  # lgb gkf

    EXP_FOLDERS = [
        LOG_PATH + "lvl_1/" + "2022-07-02/2/",  # lgb gkf
        LOG_PATH + "lvl_1/" + "2022-07-02/0/",  # xgb gkf 
#         LOG_PATH + "lvl_1/" + "2022-07-02/4/",  # catboost gkf
    ]
    WEIGHTS = [
        0.75,
        0.25, 
#         0.15
    ]

else:
    EXP_FOLDER = LOG_PATH + "lvl_2/" + "2022-06-27/4/"  # xgboost 5kf
    EXP_FOLDER = LOG_PATH + "lvl_2/" + "2022-06-27/7/"  # lgb 5kf
    
    EXP_FOLDER = LOG_PATH + "lvl_2/" + "2022-06-28/0/"  # lgb 10kf - 0.8982
    
    EXP_FOLDER = LOG_PATH + "lvl_2/" + "2022-06-30/5/"  # lgb gkf - 0.8837
    EXP_FOLDER = LOG_PATH + "lvl_2/" + "2022-06-30/6/"  # xgb gkf - 0.8782
    
#     EXP_FOLDERS = [
#         LOG_PATH + "lvl_2/" + "2022-06-30/2/",
#         LOG_PATH + "lvl_2/" + "2022-06-30/3/",
#     ]

In [None]:
if not TRAIN:
#     pred_oof = np.load(EXP_FOLDER + "pred_oof.npy")
    ft_imp = pd.read_csv(EXP_FOLDER + "ft_imp.csv").set_index('Unnamed: 0')
    
    pred_oof = np.average([np.load(f + "pred_oof.npy") for f in EXP_FOLDERS], weights=WEIGHTS, axis=0)

## Results

In [None]:
y = df_p[Config.target].values if isinstance(df_p, pd.DataFrame) else df_p[Config.target].get()

plot_confusion_matrix(
    pred_oof > 0.5,
    y,
    display_labels=['No Match', 'Match'],
#     normalize="pred"
)

plt.title(f"AUC = {roc_auc_score(y, pred_oof) :.4f}")
plt.show()

In [None]:
if "m_true" not in df.columns:
    df = df.reset_index()
    df = df.sort_values(by=["point_of_interest", "id"]).reset_index(drop=True)

    id_all = np.array(df["id"])
    poi_all = np.array(df["point_of_interest"])
    poi0 = poi_all[0]
    id0 = id_all[0]

    di_poi = {}
    for i in range(1, df.shape[0]):
        if poi_all[i] == poi0:
            id0 = str(id0) + " " + str(id_all[i])
        else:
            di_poi[poi0] = str(id0) + " "  # need to have trailing space in m_true
            poi0 = poi_all[i]
            id0 = id_all[i]

    di_poi[poi0] = str(id0) + " "  # need to have trailing space in m_true
    df["m_true"] = df["point_of_interest"].map(di_poi)

    df = df.sort_values(by="index").reset_index(
        drop=True
    )  # sort back to original order
    df.drop("index", axis=1, inplace=True)

In [None]:
_ = get_improved_CV(df_p, pred_oof, df.copy())

### Check several cut levels

In [None]:
if LEVEL == 1:
    df_p[Config.target] = y

    for thresh in [.0025, .005, .0075]:
        print(f'\nRemoving pairs with p < {thresh} : ')
        df_cut = df_p.loc[pred_oof > thresh].reset_index()
        y_cut = df_cut[Config.target].values

        try:
            print(f"- Number of candidates : {numerize(len(y_cut))}")
        except NameError:
            print(f"- Number of candidates : {len(y_cut)}")
        print(f"- Proportion of positive candidates: {y_cut.mean() * 100:.2f}%")

        get_CV(None, None, y_cut, y_cut, df.copy(), df_cut.copy())

In [None]:
if LEVEL == 1:
    THRESHOLD = 0.0075

    df_p_r = df_p[pred_oof > THRESHOLD].reset_index(drop=True)
    df_p_r.to_csv(OUT_PATH + f"features_train_1_filtered_{THRESHOLD}.csv", index=False)

### Feature importance

In [None]:
plot_importances(ft_imp)
# plt.xscale('log')
plt.show()