In [33]:
from sentence_transformers import SentenceTransformer, util
import os
import pandas as pd
import numpy as np


model = SentenceTransformer('all-MiniLM-L6-v2')


# Load data
# tb1_path = os.path.join("small_files", "amazon.csv")
# tb2_path = os.path.join("small_files", "best_buy.csv")

tb1_path = os.path.join("files", "amazon.csv")
tb2_path = os.path.join("files", "best_buy.csv")

df1 = pd.read_csv(tb1_path)
df2 = pd.read_csv(tb2_path)


# Extract column values
table1_names = list(df1['Name'].astype(str))
table2_names = list(df2['Name'].astype(str))

table1_brand = list(df1['Brand'].astype(str))
table2_brand = list(df2['Brand'].astype(str))

table1_features = list(df1['Features'].astype(str))
table2_features = list(df2['Features'].astype(str))

# Encode them
emb1_names = model.encode(table1_names, convert_to_tensor=True)
emb2_names = model.encode(table2_names, convert_to_tensor=True)

emb1_feats = model.encode(table1_features, convert_to_tensor=True)
emb2_feats = model.encode(table2_features, convert_to_tensor=True)

emb1_brand = model.encode(table1_brand, convert_to_tensor=True)
emb2_brand = model.encode(table2_brand, convert_to_tensor=True)


In [35]:
emb1_brand.shape


torch.Size([4259, 384])

In [37]:
# Top-k most similar entries for each product in table1 vs table2
top_k = 30

name_matches = util.semantic_search(emb1_names, emb2_names, top_k=top_k)
feature_matches = util.semantic_search(emb1_feats, emb2_feats, top_k=top_k)
brand_matches = util.semantic_search(emb1_brand, emb2_brand, top_k=top_k)

In [44]:
from collections import defaultdict

final_candidates = defaultdict(set)

for i in range(len(df1)):
    name_ids = {match['corpus_id'] for match in name_matches[i]}
    feat_ids = {match['corpus_id'] for match in feature_matches[i]}
    brand_ids = {match['corpus_id'] for match in brand_matches[i]}
    # Combine: choose union or intersection

    # combined = name_ids & feat_ids & brand_ids  # intersection
    combined = name_ids | feat_ids | brand_ids # union
    final_candidates[i] = combined

In [46]:
final_candidates.get(1)

{10,
 28,
 29,
 131,
 151,
 152,
 153,
 154,
 161,
 279,
 282,
 465,
 579,
 610,
 639,
 685,
 728,
 729,
 854,
 868,
 929,
 953,
 965,
 973,
 977,
 980,
 981,
 982,
 985,
 986,
 1742,
 1762,
 1787,
 1793,
 1806,
 1807,
 1820,
 1889,
 2162,
 2201,
 2207,
 2209,
 2218,
 2221,
 2303,
 2322,
 2350,
 2411,
 2423,
 2427,
 2458,
 2460,
 2466,
 2476,
 2492,
 2525,
 2540,
 2541,
 2563,
 2585,
 2598,
 2670,
 2820,
 2826,
 2830,
 2876,
 2906,
 2935,
 2950,
 2951,
 2953,
 2956,
 2960,
 2983,
 3002,
 3010,
 3041,
 3051,
 3090,
 3602,
 3787,
 3788,
 3795,
 3836,
 3853,
 3875,
 3876,
 4872}