In [1]:
# pip install torch transformers pandas sklearn

In [3]:
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, get_linear_schedule_with_warmup, logging as tr_logging
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score

tr_logging.set_verbosity_error()

I propose solving this task using a BERT-based classifier because it's a SOTA solution. To conserve resources, let's select the distilled BERT version. I want to retrain the classifier on our data to make it more adapted to our task.

We need to write a Dataset class and train/evaluate loops. I propose using 10% of the data to speed up training.

In [4]:
# Load and preprocess the dataset
path_to_dataset = 'data.csv'
raw_data = pd.read_csv(path_to_dataset)

In [5]:
# Let's transform the data to lowercase and delete the leading and trailing space symbols.
raw_data['entity_1'] = raw_data['entity_1'].str.lower().str.strip()
raw_data['entity_2'] = raw_data['entity_2'].str.lower().str.strip()
print(len(raw_data))
raw_data.head(2)

7042846


Unnamed: 0.1,Unnamed: 0,entity_1,entity_2,tag
0,3137667,preciform a.b,preciform ab,1
1,5515816,degener staplertechnik vertriebs-gmbh,irshim,0


In [6]:
cleaned_data = raw_data[['entity_1', 'entity_2', 'tag']].drop_duplicates()
print(len(cleaned_data), len(cleaned_data[['entity_1', 'entity_2']].drop_duplicates()))

6759092 6759090


In [7]:
# Unfortunately, we have encountered two cases where the tags are different for a single pair of entities.
# Let's detect these cases and remove them from the dataset.
conflict_tags = cleaned_data.groupby(['entity_1', 'entity_2']).count().query('tag != 1')
conflict_tags

Unnamed: 0_level_0,Unnamed: 1_level_0,tag
entity_1,entity_2,Unnamed: 2_level_1
changzhou plastic development,changzhou plastic development factory,2
dutro ford lincoln-mercury inc,dutro ford lincoln-mercury inc,2


In [8]:
cleaned_data.merge(conflict_tags.reset_index().drop(columns='tag'), on=['entity_1', 'entity_2'])

Unnamed: 0,entity_1,entity_2,tag
0,changzhou plastic development,changzhou plastic development factory,1
1,changzhou plastic development,changzhou plastic development factory,0
2,dutro ford lincoln-mercury inc,dutro ford lincoln-mercury inc,0
3,dutro ford lincoln-mercury inc,dutro ford lincoln-mercury inc,1


In [9]:
data = cleaned_data[~((cleaned_data.entity_1 == "changzhou plastic development")
& (cleaned_data.entity_2 == "changzhou plastic development factory")
| (cleaned_data.entity_1 == "dutro ford lincoln-mercury inc")
& (cleaned_data.entity_2 == "dutro ford lincoln-mercury inc"))]

In [11]:
data = data.sample(frac=0.1)  # using only 10% of data

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Tokenizer - using DistilBert for faster training
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

class EntityDataset(Dataset):
    def __init__(self, entities1, entities2, labels):
        self.entities1 = entities1
        self.entities2 = entities2
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        entity_1 = self.entities1[idx]
        entity_2 = self.entities2[idx]
        label = self.labels[idx]
        inputs = tokenizer(entity_1, entity_2, padding='max_length', truncation=True, max_length=64, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# DataLoaders
train_dataset = EntityDataset(train_data['entity_1'].tolist(), train_data['entity_2'].tolist(), train_data['tag'].tolist())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = EntityDataset(test_data['entity_1'].tolist(), test_data['entity_2'].tolist(), test_data['tag'].tolist())
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Model - using DistilBert
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
model.to(device)

# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
total_steps = len(train_loader) * 2  # Fewer epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Enable mixed precision training
scaler = torch.cuda.amp.GradScaler()
device

device(type='cuda')

In [12]:
# Training Loop
def train_model(epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.cuda.amp.autocast():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch: {epoch+1}, Loss: {avg_loss:.4f}')

# Train the model
for epoch in range(2):
    train_model(epoch)

# Save the model
model.save_pretrained('path_to_save_model')

  0%|          | 0/8449 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.0062


  0%|          | 0/8449 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.0029


In [14]:
# Evaluation
def evaluate_model():
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs.logits
            logits = logits.detach().cpu().numpy()
            label_ids = labels.to('cpu').numpy()
            predictions.extend(np.argmax(logits, axis=1).flatten())
            true_labels.extend(label_ids.flatten())

    return true_labels, predictions

# Evaluate the model
true_labels, predictions = evaluate_model()
diff = [idx for idx, (pred_label, true_label) in enumerate(zip(predictions, true_labels)) if pred_label != true_label]
print(classification_report(true_labels, predictions, digits=4))

  0%|          | 0/2113 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0     0.9996    0.9998    0.9997     83302
           1     0.9997    0.9993    0.9995     51880

    accuracy                         0.9996    135182
   macro avg     0.9996    0.9996    0.9996    135182
weighted avg     0.9996    0.9996    0.9996    135182



In [15]:
print(f"There're wrong predicted tags for {len(diff)} of {len(test_data)} items; it's {round(len(diff) / len(true_labels) * 100, 3)}%")
print(f"AUC ROC = {round(roc_auc_score(true_labels, predictions), 5)}")

There're wrong predicted tags for 50 of 135182 items; it's 0.037%
AUC ROC = 0.99957


Wow, these results look great!

In [16]:
# Function to predict the similarity of a pair of entities
def predict_similarity(entity_1, entity_2):
    model.eval()
    # Tokenize the pair of entities
    inputs = tokenizer(entity_1, entity_2, return_tensors="pt", padding='max_length', truncation=True, max_length=64)

    # Move tensors to the same device as the model
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    # Convert logits to probabilities
    probabilities = torch.softmax(logits, dim=1)
    prediction = torch.argmax(probabilities, dim=1).item()
    return prediction, probabilities[0][prediction].item()

# Example usage
entity_1 = "Example Company A"
entity_2 = "Example Co. A"
prediction, probability = predict_similarity(entity_1, entity_2)
print(f"Prediction: {prediction}, Probability: {probability:.4f}")

Prediction: 1, Probability: 1.0000


In [17]:
%timeit prediction, probability = predict_similarity(entity_1, entity_2)

7.03 ms ± 731 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Prediction is fast (about 7±0.7 ms). However, it may be necessary to make predictions for batches. In that case, it's easy to rewrite the function.

In [19]:
# Discover model's mistakes
test_diff = test_data.iloc[diff].reset_index().drop(columns=['index'])
test_diff[['predicted', 'prob']] = pd.DataFrame(
    test_diff.apply(lambda x: predict_similarity(x.entity_1, x.entity_2),
                    axis=1).tolist(), index=test_diff.index)
test_diff

Unnamed: 0,entity_1,entity_2,tag,predicted,prob
0,ano coil corp,a.n.o il rp,1,0,0.917542
1,ford rat les,ford sarat sales inc,1,0,0.886619
2,nmb sales kk,n.m.b sociedade anônimales,1,0,0.992891
3,condor blanco mines ltd,w.p colorado ltd,0,1,0.999724
4,m.o.r o.i.l,nmk ooo,0,1,0.547707
5,el sammak contacting co,e.l s.a.mmak contacting,1,0,0.568056
6,d co.urtney co.nstruction,d courtney construction inc,1,0,0.92686
7,cal consult gmbh,c.a.l nsult gmbh,1,0,0.96253
8,m.i.d sociedad por accionesce city,mid space city restuarant equipment,1,0,0.946231
9,intercontinental cargo enterprises,deccan enterprises ltd,0,1,0.999986


We see that the model can make mistakes in some difficult cases.

In some cases, it appears that the tag is incorrect. For example, the pair 'teledata technology solutions inc.' and 'teledata technology solutions' has a tag of 0, but it should be 1.

Having only a 0.04% rate of wrong classification is definitely excellent! It seems like the proposed solution is one of the best in terms of quality metrics. Maybe we would need faster inference, but considering my knowledge about the company, our 7±0.7 ms is sufficient.

We can also consider trying another model based on a more modern solution than classical BERT, or retraining the model on all the data.