In [61]:
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer

"""
Get SBERT model.
It's used when embedding sentence or word before calculating distance.
"""


class SBERT:
    def __init__(self):
        self.model = SentenceTransformer("all-MiniLM-L6-v2")

    def encode(self, text_list: list):
        embedded_text_list = self.model.encode(text_list)
        return embedded_text_list


model = SBERT()

"""
Our objective is to find the closest product among the product list y.
Moreover, we want to do batch execution, which means we are not going to
use for loop to look through every row in y_hat.

step 1: embedding y
step 2: embedding y_hat and pad y_hat according to the max_length
step 3: matrix multiplication between y and y_hat
step 4: get distance_matrix using argmax
"""

y = [
    "fresh banana",
    "fresh peach",
    "fresh apple",
    "dried potato",
    "fresh avocado",
    "fresh peach",
]
embedded_y = torch.from_numpy(model.encode(y))

y_hat = pd.DataFrame(
    {
        "prod": [
            ["fresh apple fuji", "africa peach"],
            ["avocado"],
            ["potato green", "peach fluffy", "fresh avocado"],
        ]
    }
)

In [62]:
MAX_LENGTH = y_hat["prod"].apply(len).max()

_y_hat = []

for idx, series in y_hat.iterrows():
    prod_lst = series["prod"]
    embedded_prod_lst = torch.from_numpy(model.encode(prod_lst))
    # |embedded_prod_lst| = (num_prod, dimension)

    if len(prod_lst) < MAX_LENGTH:
        embedded_prod_lst = torch.cat(
            [
                embedded_prod_lst,
                torch.zeros(MAX_LENGTH - len(prod_lst), embedded_prod_lst.size(-1))
                # |torch.zeros| = (MAX_LENGTH - num_prod, dimension)
            ],
            dim=0,
        ).unsqueeze(0)
        # |embedded_prod_lst| = (1, MAX_LENGTH, dimension)

        _y_hat += [embedded_prod_lst]
    else:
        embedded_prod_lst = embedded_prod_lst.unsqueeze(0)
        _y_hat += [embedded_prod_lst]

_y_hat = torch.cat(_y_hat, dim=0)
# |embedded_prod_lst| = (num_rows, MAX_LENGTH, dimension)

In [63]:
print(_y_hat, _y_hat.size())

tensor([[[-0.0700, -0.0240,  0.0509,  ..., -0.0357,  0.0216,  0.0532],
         [-0.0461,  0.0610, -0.0767,  ..., -0.0712, -0.0075, -0.0423],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0188,  0.0214, -0.0350,  ..., -0.0140,  0.0417,  0.1026],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0119, -0.0041, -0.0858,  ...,  0.0007,  0.0545,  0.0318],
         [-0.0401, -0.0086,  0.0142,  ..., -0.0410,  0.0812, -0.0254],
         [-0.0338, -0.0230, -0.0058,  ...,  0.0132,  0.0123,  0.0533]]]) torch.Size([3, 3, 384])


In [64]:
print("Before transpose: ", embedded_y.size())
embedded_y.transpose_(0, 1)
print("After transpose: ", embedded_y.size())

Before transpose:  torch.Size([6, 384])
After transpose:  torch.Size([384, 6])


In [74]:
"""
SBERT outputs the vector of which norm value are all 1.
Therefore, we don't have to divide norm for calculating the cosine distance.
Only we have to do is matrix multiplication.
"""

distance_matrix = torch.matmul(_y_hat, embedded_y)
distance_matrix.argmax(dim=-1)
# y = ["fresh banana", "fresh peach", "fresh apple", "dried potato", "fresh avocado", "fresh peach"]
# result = [[2, 0], [4, 0]]
# result = [[fresh apple, fresh banana], [fresh avocado, fresh banana]]

"""
At last, we have to mask the padded position.
"""
lengths = y_hat["prod"].apply(len).tolist()
masks = []
for length in lengths:
    mask = torch.zeros(MAX_LENGTH)
    if length < MAX_LENGTH:
        mask[MAX_LENGTH - length :] = 1

    masks += [mask]
masks = torch.stack(masks, dim=0)

distance_matrix.argmax(dim=-1).float().masked_fill_(masks, -float("inf"))

tensor([[2., 1., 0.],
        [4., 0., 0.],
        [3., 1., 4.]])