I am referring to [@ragnar123](https://www.kaggle.com/ragnar123)'s awesome [notebook](https://www.kaggle.com/code/ragnar123/unsupervised-baseline-arcface) on [shopee competition](https://www.kaggle.com/competitions/shopee-product-matching).

Since the design of this competition was similar to that of the shopee competition, I tried metric learning, which is the top solution in 
the shopee competition.

However, unlike the shopee competition, the most important information in this competition is `latitude` and `longitude`, not textual information, so I tried metric learning model in the 1st stage and GBDT model with 1st stage features in the 2nd stage.

**about this notebook**

On both Kaggle and Colab, training and inference can be run on this single notebook!

1. Training
    
    Set `CFG.train = True` and run.

2. Inference

    Set `CFG.train = False` and run.




In [1]:
!nvidia-smi

Sat Jul  9 18:46:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# Library

In [2]:
# ====================================================
# import libraries1
# ====================================================

import warnings
warnings.filterwarnings('ignore')

import os
import sys
import math
import random
import time
import numpy as np
import pandas as pd
import gc
import json
import joblib
from tqdm import tqdm
from pathlib import Path
import itertools
import collections
from collections import Counter

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler
from torch.nn import Parameter

import datetime
from datetime import timedelta
import hashlib
import difflib
import seaborn as sns
import matplotlib.pyplot as plt
from requests import get
from PIL import Image
import pickle
from contextlib import contextmanager
import multiprocessing

from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn.neighbors import KNeighborsRegressor, NearestNeighbors
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.linear_model import RidgeCV
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from gensim.models import word2vec

import lightgbm as lgb
import typing as tp

from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter

tqdm.pandas()
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(f'Using device: {device}')

Using device: cuda


# Config

In [3]:
class CFG:
    # ====================================================
    # Basic Setting
    # ====================================================
    colab = "google.colab" in sys.modules
    exp = "130"
    train = False
    api_path = '/content/drive/My Drive/kaggle.json'
    seed = 42
    n_neighbors = 50
    threshold = 0.15
    # ====================================================
    # Model
    # ====================================================
    model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
    epochs = 15
    num_workers = 8
    batch_size = 32
    max_length = 32
    lr = 1e-5
    scheduler= 'linear' # ['linear', 'cosine', 'ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'CosineAnnealingWarmupRestarts']

if not CFG.colab:
    CFG.model = "../input/sbert-models/paraphrase-multilingual-mpnet-base-v2"

In [4]:
if CFG.colab:
    print("==============================================")
    print("This environment is Google Colab")
    print("==============================================")

    # Google Drive
    from google.colab import drive, files
    drive.mount('/content/drive')
    %cd "drive/My Drive/foursquare/"

    # Kaggle API
    f = open(CFG.api_path, 'r')
    json_data = json.load(f) 
    os.environ["KAGGLE_USERNAME"] = json_data["username"]
    os.environ["KAGGLE_KEY"] = json_data["key"]

    # Directory Setting
    if not os.path.exists(f"output/exp{CFG.exp}/"):
        os.makedirs(f"output/exp{CFG.exp}/")
    
    DATA_DIR = "input/"
    OUTPUT_DIR = f"output/exp{CFG.exp}/"
    MODEL_DIR = OUTPUT_DIR
    
    # Data Loading
    if not os.path.isfile(os.path.join(DATA_DIR, "foursquare-location-matching.zip")):
        !kaggle competitions download -c foursquare-location-matching -p $DATA_DIR

    # Libraries
    #!pip install -q catboost
    #!pip install -q Levenshtein
    #!pip install -q textdistance==4.2.2
    #!pip install -q pylcs==0.0.6
    #!pip install -q fasttext
    !pip install -q reverse_geocode
    !pip install -q transformers
    !pip install -q sentence_transformers==2.2.0

else:
    print("==============================================")
    print(" This environment is Kaggle Notebook")
    print("==============================================")

    # Directory Setting
    DATA_DIR = "../input/foursquare-location-matching/"
    OUTPUT_DIR = "./"
    MODEL_DIR = f"../input/foursquare-stage1-exp{CFG.exp}/"

    # Libraries
    !pip install /kaggle/input/reversegeocode/reverse_geocode-1.4.1-py3-none-any.whl
    #!pip install ../input/textdistance-install/textdistance-4.2.2-py3-none-any.whl
    #!pip install --force-reinstall ../input/pylcs-install/pybind11-2.9.2-py2.py3-none-any.whl

    #!rm -r mypip
    #!mkdir mypip
    #!tar -czvf mypip/pylcs-0.0.6.tar.gz -C ../input/pylcs-install/pylcs-0.0.6/pylcs-0.0.6 .
    #!ls -l mypip

    #!pip install --no-index mypip/pylcs-0.0.6.tar.gz
    
    sys.path.append("../input/sentencetransformersinstall/sentence-transformers-2.2.0")

# ====================================================
# import libraries2
# ====================================================

#from catboost import CatBoost, Pool
#import Levenshtein
#import pylcs
#import textdistance
import reverse_geocode
#from fasttext import load_model
from transformers import DistilBertModel, DistilBertTokenizer, AutoTokenizer, AutoModel, AutoConfig
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup

 This environment is Kaggle Notebook
Processing /kaggle/input/reversegeocode/reverse_geocode-1.4.1-py3-none-any.whl
Installing collected packages: reverse-geocode
Successfully installed reverse-geocode-1.4.1
[0m

# Helper Functions

In [5]:
def reduce_mem_usage(df, verbose=True):
    numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
    start_mem = df.memory_usage().sum() / 1024**2    
    for col in df.columns:
        col_type = df[col].dtypes
        if col_type in numerics:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)    
    end_mem = df.memory_usage().sum() / 1024**2
    if verbose: print('Memory usage decreased to {:5.2f} Mb ({:.1f}% reduction)'.format(end_mem, 100 * (start_mem - end_mem) / start_mem))
    return df

In [6]:
@contextmanager
def timer(name: str):
    t0 = time.time()
    print(f"[{name}] start")
    yield
    msg = f"[{name}] done in {time.time() - t0:.0f} s"
    print(msg)

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

seed_everything(CFG.seed)

In [8]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(CFG.seed)

<torch._C.Generator at 0x7f04eeb74b90>

In [9]:
def init_logger(log_file=OUTPUT_DIR+'train.log'):
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    logger.hasHandlers()
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

In [10]:
def get_id2poi(input_df: pd.DataFrame) -> dict:
    return dict(zip(input_df['id'], input_df['point_of_interest']))

def get_poi2ids(input_df: pd.DataFrame) -> dict:
    return input_df.groupby('point_of_interest')['id'].apply(set).to_dict()

def get_score(input_df: pd.DataFrame):
    scores = []
    for id_str, matches in zip(input_df['id'].to_numpy(), input_df['matches'].to_numpy()):
        targets = poi2ids[id2poi[id_str]]
        preds = set(matches.split())
        score = len((targets & preds)) / len((targets | preds))
        scores.append(score)
    scores = np.array(scores)
    return scores.mean()

def analysis(df):
    print('Num of data: %s' % len(df))
    print('Num of unique id: %s' % df['id'].nunique())
    print('Num of unique poi: %s' % df['point_of_interest'].nunique())
    
    poi_grouped = df.groupby('point_of_interest')['id'].count().reset_index()
    print('Mean num of unique poi: %s' % poi_grouped['id'].mean())

In [11]:
def cos_sim(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

# Data Loading

In [12]:
with timer("Data Loading"):
    if CFG.train:
        original_df = pd.read_csv(DATA_DIR + "train.csv")
    else:
        original_df = pd.read_csv(DATA_DIR + "test.csv")
        original_df["point_of_interest"] = "match"
display(original_df)

[Data Loading] start
[Data Loading] done in 0 s


Unnamed: 0,id,name,latitude,longitude,address,city,state,zip,country,url,phone,categories,point_of_interest
0,E_00001118ad0191,Jamu Petani Bagan Serai,5.012169,100.535805,,,,,MY,,,Cafés,match
1,E_000020eb6fed40,Johnny's Bar,40.434209,-80.56416,497 N 12th St,Weirton,WV,26062.0,US,,,Bars,match
2,E_00002f98667edf,QIWI,47.215134,39.686088,"Межевая улица, 60",Ростов-на-Дону,,,RU,https://qiwi.com,78003010000.0,ATMs,match
3,E_001b6bad66eb98,"Gelora Sriwijaya, Jaka Baring Sport City",-3.014675,104.794374,,,,,ID,,,Stadiums,match
4,E_0283d9f61e569d,Stadion Gelora Sriwijaya,-3.021727,104.788628,Jalan Gubernur Hasan Bastari,Palembang,South Sumatra,11480.0,ID,,,Soccer Stadiums,match


In [13]:
id2poi = get_id2poi(original_df)
poi2ids = get_poi2ids(original_df)

In [14]:
le = LabelEncoder()
original_df['point_of_interest'] = le.fit_transform(original_df['point_of_interest'])
original_df['point_of_interest'] = original_df['point_of_interest'].astype("int32")

n_classes = original_df["point_of_interest"].nunique()
print(f"n_classes: {n_classes}")

n_classes: 1


# Preprocess

`reverse_geocode` library is used to determine city information from `latitude` and `longitude`.

By including this city information in the model, it is intended that `latitude` and `longitude` information is also taken into account for BERT model.

In [15]:
def get_geo_info(coords):
    data = reverse_geocode.search(coords)
    return [v['country_code'] for v in data], [v['city'] for v in data]

original_df['city2'] = get_geo_info(original_df[['latitude', 'longitude']])[1]

In [16]:
original_df["text"] = original_df["name"].fillna("") + " " +\
                        original_df["city2"].fillna("") + " " +\
                        original_df["address"].fillna("") + " " +\
                        original_df["categories"].fillna("")

#original_df["text"] = original_df["name"].fillna("") + "[SEP]" +\
#                        original_df["city2"].fillna("") + "[SEP]" +\
#                        original_df["address"].fillna("") + "[SEP]" +\
#                        original_df["categories"].fillna("")

In [17]:
# ==============================================
# Degree to radian
# ==============================================

original_df["latitude"] = original_df["latitude"] * np.pi / 180
original_df["longitude"] = original_df["longitude"] * np.pi / 180

# Dataset

In [18]:
# ====================================================
#  Dataset
# ====================================================

class FoursquareDataset(Dataset):
    def __init__(self, df, include_labels=True):
        tokenizer = AutoTokenizer.from_pretrained(CFG.model)

        self.df = df
        self.include_labels = include_labels

        self.text = df['text'].tolist()
        self.lat = df['latitude'].values
        self.lon = df['longitude'].values
        self.labels = df['point_of_interest'].values

        self.encoded = tokenizer.batch_encode_plus(
            self.text,
            padding = 'max_length',            
            max_length = CFG.max_length,
            truncation = True,
            return_attention_mask=True
        )
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):

        input_ids = torch.tensor(self.encoded['input_ids'][idx], dtype=torch.long)
        attention_mask = torch.tensor(self.encoded['attention_mask'][idx], dtype=torch.long)
        lat = torch.tensor(self.lat[idx], dtype=torch.float)
        lon = torch.tensor(self.lon[idx], dtype=torch.float)

        if self.include_labels:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return input_ids, attention_mask, lat, lon, label

        return input_ids, attention_mask, lat, lon

# Metric Learning

In [19]:
# ====================================================
#  CurricularFace
# ====================================================   

def l2_norm(input, axis = 1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)

    return output

class CurricularFace(nn.Module):
    def __init__(self, in_features, out_features, s = 5, m = 0.050):
        super(CurricularFace, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.m = m
        self.s = s
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.threshold = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        self.kernel = nn.Parameter(torch.Tensor(in_features, out_features))
        self.register_buffer('t', torch.zeros(1))
        nn.init.normal_(self.kernel, std=0.01)

    def forward(self, embbedings, label):
        embbedings = l2_norm(embbedings, axis = 1)
        kernel_norm = l2_norm(self.kernel, axis = 0)
        cos_theta = torch.mm(embbedings, kernel_norm)
        cos_theta = cos_theta.clamp(-1, 1)
        with torch.no_grad():
            origin_cos = cos_theta.clone()
        target_logit = cos_theta[torch.arange(0, embbedings.size(0)), label].view(-1, 1)

        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m
        mask = cos_theta > cos_theta_m
        final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)

        hard_example = cos_theta[mask]
        with torch.no_grad():
            self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
        cos_theta[mask] = hard_example * (self.t + hard_example)
        cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit)
        output = cos_theta * self.s
        return output

In [20]:
# ====================================================
#  ArcFace
# ====================================================

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=10.0, m=0.050, easy_margin=True, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

# Loss

In [21]:
# ====================================================
#  Focal Loss
# ====================================================    

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()

    def forward(self, inputs, targets, alpha=0.8, gamma=2, smooth=1):
        
        inputs = F.sigmoid(inputs)       
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
                       
        return focal_loss

# Model

In [22]:
# ====================================================
#  Model
# ==================================================== 

class CustomModel(nn.Module):
    def __init__(self, model_name, embedding_size=128):                 
        super(CustomModel, self).__init__()

        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name, 
                                               config=self.config)
        #self.fc = ArcMarginProduct(embedding_size, n_classes)
        self.fc = CurricularFace(embedding_size, 739972)
        self.head = nn.Sequential(
            nn.Linear(self.config.hidden_size + 2, embedding_size),
            nn.BatchNorm1d(embedding_size),
        )

    def forward(self, ids, mask, lat, lon, labels):
        embedding = self.extract(ids=ids, mask=mask, lat=lat, lon=lon)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, ids, mask, lat, lon):
        lat, lon = lat.view(-1, 1), lon.view(-1, 1)
        out = self.model(input_ids=ids, attention_mask=mask)
        embedding = out[0][:, 0, :] # CLS Token
        embedding = torch.cat([embedding, lat, lon], axis=1)
        embedding = self.head(embedding)
        return embedding
    
print(CustomModel(CFG.model))

CustomModel(
  (model): XLMRobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((76

# Helper functions

In [23]:
def get_accuracy(preds, targets):
    preds = preds.argmax(dim=1)
    acc = (preds == targets).float().mean()
    return acc

In [24]:
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    
    start = end = time.time()
    losses = AverageMeter()
    scores = AverageMeter()

    model.train()

    for step, (input_ids, attention_mask, lat, lon, label) in enumerate(train_loader):
        input_ids = input_ids.to(device, dtype=torch.long)
        attention_mask = attention_mask.to(device, dtype=torch.long)
        lat = lat.to(device)
        lon = lon.to(device)
        label = label.to(device, dtype=torch.long)

        batch_size = label.size(0)

        output = model(input_ids, attention_mask, lat, lon, label)
        loss = criterion(output, label)

        # record loss
        losses.update(loss.item(), batch_size)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # score
        score = get_accuracy(output.detach(), label)
        scores.update(score.item(), batch_size)

        # step
        scheduler.step()
                
        if CFG.scheduler=='ReduceLROnPlateau':
            lr = optimizer.param_groups[0]['lr']
        else:
            lr = scheduler.get_lr()[0]

        if step % 1000 == 0 or step == (len(train_loader) - 1):
            LOGGER.info(
                f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] "
                f"Elapsed {timeSince(start, float(step + 1) / len(train_loader)):s} "
                f"Loss: {losses.avg:.6f} "
                f"Score: {scores.avg:.6f} "
                f"LR: {lr:.8f} "
            )

    return losses.avg

In [25]:
# ====================================================
# helper function
# ====================================================

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))

In [26]:
def run(train):

    LOGGER.info(f"==============================================")
    LOGGER.info(f"▶︎ Start Training")
    LOGGER.info(f"==============================================")

    # ====================================================
    #  Data Loader
    # ====================================================

    train_dataset = FoursquareDataset(train)

    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
        worker_init_fn=seed_worker,
        generator=g
    )

    # ====================================================
    #  Model
    # ====================================================
    model = CustomModel(CFG.model)
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=CFG.lr)
    num_train_steps = len(train_loader)* CFG.epochs

    def get_scheduler(optimizer):
        if CFG.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=num_train_steps*0.05, num_training_steps=num_train_steps
            )
        elif CFG.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=CFG.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=CFG.num_cycles
            )
        elif CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=CFG.T_mult, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmupRestarts':
            scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps=CFG.first_cycle_steps, cycle_mult=CFG.cycle_mult, max_lr=CFG.lr, min_lr=CFG.min_lr, gamma=CFG.gamma, last_epoch=-1)
        return scheduler

    scheduler = get_scheduler(optimizer)

    criterion = nn.CrossEntropyLoss()

    # ====================================================
    #  Loop
    # ====================================================

    for epoch in range(CFG.epochs):
        start_time = time.time()
        
        # train
        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)

        elapsed = time.time() - start_time
        LOGGER.info(
            f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.6f} time: {elapsed:.0f}s"
        )

        # Save model
        torch.save(
            model.state_dict(), OUTPUT_DIR + f"bert_epoch{epoch+1}.pth"
        )

        # ====================================================
        #  Judge if the best score or not
        # ====================================================
        best_loss = np.inf

        if avg_loss < best_loss:
            best_loss = avg_loss
            LOGGER.info(f"Epoch {epoch+1} - Best Loss Model")
    
    # ==============================================
    # Create Kaggle Dataset
    # ==============================================

    !kaggle datasets init -p $OUTPUT_DIR

    metadata = {"id": f"shkanda/foursquare-stage1-exp{CFG.exp}",
                    "title": f"foursquare-stage1-exp{CFG.exp}",
                    "licenses": [{"name": "CC0-1.0"}]}

    with open(OUTPUT_DIR+'dataset-metadata.json', 'w') as fp:
        json.dump(metadata, fp)

    !kaggle datasets create -p $OUTPUT_DIR

# Train

In [27]:
if CFG.train:
    run(original_df)

# Inference

In [28]:
def inference_fn(test):

    test_dataset = FoursquareDataset(test, include_labels=False)

    test_loader = DataLoader(
        test_dataset,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
        worker_init_fn=seed_worker,
        generator=g
    )

    model = CustomModel(CFG.model)
    path = MODEL_DIR + "bert_epoch15.pth"
    state = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(state)
    model.to(device)
    model.eval()

    preds = []
    for step, (input_ids, attention_mask, lat, lon) in tqdm(enumerate(test_loader), total=len(test_loader)):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        lat = lat.to(device)
        lon = lon.to(device)
        
        with torch.no_grad():
            pred = model.extract(input_ids, attention_mask, lat, lon)
        preds.append(pred.detach().cpu().numpy())
        
    preds = np.concatenate(preds)
    
    return preds

In [29]:
def predict(embedding):

    df = []

    for country, country_df in tqdm(original_df.groupby('country')):
        
        country_embedding = embedding[country_df.index]
        country_df = country_df.reset_index(drop=True)
        country_df["index"] = country_df.index
        
        neighbors = min(len(country_df), CFG.n_neighbors)

        knn = NearestNeighbors(n_neighbors = neighbors,
                                             metric = 'cosine',
                                             algorithm='brute',
                                             n_jobs = -1)
        
        knn.fit(country_embedding, country_df.index)

        dists, nears = knn.kneighbors(country_embedding, return_distance = True)

        for k in range(neighbors):            
            cur_df = country_df[['id']]
            cur_df['match_id'] = country_df['id'].values[nears[:, k]]
            cur_df['cos_dist'] = dists[:, k]

            cur_df = cur_df[cur_df["cos_dist"]<CFG.threshold]

            df.append(cur_df)

    df = pd.concat(df).reset_index(drop = True)

    return df

In [30]:
def post_process(df):

    id2match = dict(zip(df['id'].values, df['matches'].str.split()))

    for base, match in df[['id', 'matches']].values:
        match = match.split()
        if len(match) == 1:        
            continue

        for m in match:
            if base not in id2match[m]:
                id2match[m].append(base)
    df['matches'] = df['id'].map(id2match).map(' '.join)
    
    return df 

In [31]:
if not CFG.train:
    
    embedding = inference_fn(original_df)
    df = predict(embedding)
    
    # post process1
    tmp_df = original_df[["id"]]
    tmp_df["match_id"] = tmp_df["id"]
    df = pd.concat([df, tmp_df]).drop_duplicates(["id", "match_id"]).reset_index(drop=True)
    
    sub = df.groupby(["id"])["match_id"].apply(list).reset_index()
    sub.columns = ["id", "matches"]
    sub["matches"] = sub["matches"].map(" ".join)

    # post process2
    sub = post_process(sub)

    sub.to_csv("submission.csv", index=False)
    display(sub)

100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
100%|██████████| 4/4 [00:00<00:00,  9.30it/s]


Unnamed: 0,id,matches
0,E_00001118ad0191,E_00001118ad0191
1,E_000020eb6fed40,E_000020eb6fed40
2,E_00002f98667edf,E_00002f98667edf
3,E_001b6bad66eb98,E_001b6bad66eb98 E_0283d9f61e569d
4,E_0283d9f61e569d,E_0283d9f61e569d E_001b6bad66eb98
