In [1]:
import numpy as np
from tqdm import tqdm

from pytorch_metric_learning import samplers, distances
from pytorch_metric_learning.utils.inference import InferenceModel
from pytorch_metric_learning.utils.inference import MatchFinder

from data_processing.productset import ProductSet, ProductTestSet

from data_processing.utils import read_data
from models.encoder import Encoder
from models.lit import LitSiamese
from models.embedder import Embedder
from utils import seed_everything
from config import config

In [2]:
def f1_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    intersection = np.array([len(x[0] & x[1]) for x in zip(y_true, y_pred)])
    len_y_pred = y_pred.apply(lambda x: len(x)).values
    len_y_true = y_true.apply(lambda x: len(x)).values
    f1 = 2 * intersection / (len_y_pred + len_y_true)
    return f1

In [3]:
SEED = config["seed"]
MODEL = config["model"]
EMBEDDER = config["embedder"]
WANDB_ARGS = config["wandb"]
CONTINUE_FROM_CKPT = config["continue_from_ckpt"]
CKPT_PATH = config["ckpt_path"]
EPOCHS = config["epochs"]
LOG_EVERY_N_STEP = config["log_every_n_step"]
VAL_CHECK_INTERVAL = config["val_check_interval"]

seed_everything(SEED)

Global seed set to 42


In [4]:
embedder = Embedder(EMBEDDER)
encoder = Encoder(embedder.emb_size)
model = LitSiamese(encoder)

model = LitSiamese.load_from_checkpoint(CKPT_PATH)
model.eval();

  rank_zero_warn(


In [5]:
match_finder = MatchFinder(distance=distances.LpDistance(normalize_embeddings=True, p=2, power=1), threshold=0.2)
im = InferenceModel(model, data_device="cpu", match_finder=match_finder)

In [6]:
def get_predictions(data_path, test_data=False):
    df = read_data(data_path, embedder, is_train=False)
    # df = df.sample(frac=0.05)
    # df = df[-5000:].copy()

    if not test_data:
        tmp = df.groupby(['label_group'])['posting_id'].unique().to_dict()
        df['matches'] = df['label_group'].map(tmp)
        df['matches'] = df['matches'].apply(lambda x: ' '.join(x))

    dataset = ProductTestSet(df)

    im.train_knn(dataset)

    match_predictions = list()
    for embedding, label in tqdm(dataset):
        distances, indices = im.get_nearest_neighbors(embedding.unsqueeze(0), k=50)
        indices = indices[0]

        matches = list()
        for i in indices:
            if im.is_match(embedding.unsqueeze(0), dataset[i.item()][0].unsqueeze(0)):
                matches.append(df.iloc[i.item()]["posting_id"])
        match_predictions.append(" ".join(matches))

    df["predictions"] = match_predictions
    return df

In [7]:
df = get_predictions("train_data.csv")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20952/20952 [10:09<00:00, 34.39it/s]


In [8]:
print("F1 = ", f1_score(df['matches'], df['predictions']).mean())

F1 =  0.553785071810349


In [9]:
df[["posting_id", "title", "matches", "predictions"]].to_csv("train_predictions.csv", index=False)

In [10]:
df = get_predictions("test_data.csv", test_data=True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13494/13494 [05:53<00:00, 38.15it/s]


In [11]:
df[["posting_id", "title", "predictions"]].to_csv("submission.csv", index=False)