In [10]:
import pandas as pd
from sentence_transformers import SentenceTransformer, LoggingHandler, losses
from torch.utils.data import DataLoader, RandomSampler
from torch.optim import AdamW
from tqdm import tqdm

import sys
sys.path.append('../')

In [2]:
from src.utils.data_constractor import CompanyDatasetSentBert

### Learning

In [4]:
PATH_DATA = '../data/preprocess_train.csv'
BATCH_SIZE = 32
MODEL_INIT = 'all-MiniLM-L6-v2'
DEVICE = 'cuda:1'
EPOCHS = 3
LR = 2e-5

In [4]:
train_dataset = CompanyDatasetSentBert(PATH_DATA)
val_dataset = CompanyDatasetSentBert(PATH_DATA, train=False)
len(train_dataset), len(val_dataset)

100%|████████████████████████████████████████████████████████████████████████| 472928/472928 [00:08<00:00, 53674.44it/s]
100%|██████████████████████████████████████████████████████████████████████████| 24891/24891 [00:00<00:00, 64114.59it/s]


(472928, 24891)

In [7]:
train_dataset.samples[0], val_dataset.samples[0]

(<sentence_transformers.readers.InputExample.InputExample at 0x7f4acd44fd60>,
 <sentence_transformers.readers.InputExample.InputExample at 0x7f4ac021efa0>)

In [33]:
trainDataLoader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=32
)

In [34]:
model = SentenceTransformer('all-MiniLM-L6-v2', device=DEVICE)
model.max_seq_length = 60
trainLoss = losses.ContrastiveLoss(model=model)
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 60, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [35]:
warmup_steps = int(len(trainDataLoader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(trainDataLoader, train_loss)], 
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='../weights/sbert',
    optimizer_params={'lr': LR},
    save_best_model=True,
    show_progress_bar=True
)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/14779 [00:00<?, ?it/s]

Iteration:   0%|          | 0/14779 [00:00<?, ?it/s]

Iteration:   0%|          | 0/14779 [00:00<?, ?it/s]

### Eval

In [46]:
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import f1_score, confusion_matrix, classification_report, precision_recall_curve
import numpy as np
from tqdm import tqdm

import sys
sys.path.append('../')

In [16]:
from src.utils.data_constractor import CompanyDatasetSentBert

In [17]:
DEVICE = 'cuda:2'
PATH_DATA = '../data/preprocess_train.csv'

In [18]:
model = SentenceTransformer('../weights/sbert', device=DEVICE)
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 60, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [19]:
val_dataset = CompanyDatasetSentBert(PATH_DATA, train=False)
val_dataset.samples

Unnamed: 0,name_1,name_2,is_duplicate
304006,divekar wallstabe &schneider precision seals p.l.,e&a scheer bv,0
271233,first world asian trading,stp ltd.,0
185141,dipanjali sharees,i b international,0
396834,qsr inc.,onguard fire protection services,0
151609,aa synthetic fibers ltd.,indian synthetic rubber private ltd.,0
...,...,...,...
186356,"h & v advanced materials india pvt., ltd.",welspun india ltd.,0
297374,focha trading l l c,corp. tadzai s.a. de c.v.,0
424530,psa corporation,shams alzohor trading,0
317898,"ncm hersbit chemical co., ltd.",lidorr chemicals ltd.,0


In [20]:
val_dataset.samples['is_duplicate'].value_counts()

0    24708
1      183
Name: is_duplicate, dtype: int64

In [26]:
cosine_score = list()
for idx in tqdm(range(len(val_dataset.samples))):
    curr_list_cmp = val_dataset.samples[['name_1', 'name_2']].iloc[idx].tolist()
    embeddings = model.encode(curr_list_cmp, convert_to_tensor=True, device=DEVICE)
    cosine_score.append(util.cos_sim(embeddings[0], embeddings[1])[0][0].item())

val_dataset.samples['cosine_score'] = cosine_score

In [28]:
val_dataset.samples.head()

Unnamed: 0,name_1,name_2,is_duplicate,cosine_score
304006,divekar wallstabe &schneider precision seals p.l.,e&a scheer bv,0,0.162562
271233,first world asian trading,stp ltd.,0,0.101912
185141,dipanjali sharees,i b international,0,0.063827
396834,qsr inc.,onguard fire protection services,0,0.139751
151609,aa synthetic fibers ltd.,indian synthetic rubber private ltd.,0,0.547664


In [29]:
label_zero = val_dataset.samples[val_dataset.samples['is_duplicate'] == 0]['cosine_score']
label_one = val_dataset.samples[val_dataset.samples['is_duplicate'] == 1]['cosine_score']

In [30]:
label_one.describe()

count    183.000000
mean       0.634709
std        0.176863
min        0.186636
25%        0.523047
50%        0.642733
75%        0.744277
max        1.000000
Name: cosine_score, dtype: float64

In [31]:
label_zero.describe()

count    24708.000000
mean         0.338060
std          0.170048
min         -0.093053
25%          0.203308
50%          0.328723
75%          0.459828
max          0.977624
Name: cosine_score, dtype: float64

In [34]:
val_dataset.samples[
    (val_dataset.samples['cosine_score'] > 0.9) & (val_dataset.samples['is_duplicate'] == 0)
]

Unnamed: 0,name_1,name_2,is_duplicate,cosine_score
65404,"qingdao xinye international logistics co., ltd.",qingdao yida world logistics ltd.,0,0.90259
389708,arp materials inc.200,arp materials inc.,0,0.951161
219607,al estagamah flexiflo factory,al estagmah flexiflo factory,0,0.977624
477899,ana de mexico sa de cv,s i de mexico sa cv,0,0.911499


In [35]:
val_dataset.samples[
    (val_dataset.samples['cosine_score'] > 0.8) & (val_dataset.samples['is_duplicate'] == 0)
]

Unnamed: 0,name_1,name_2,is_duplicate,cosine_score
85000,"jianda rubber (tianjin) co., ltd.","donghai rubber (tianjin) co., ltd.",0,0.857441
196906,htp de mexico sa de cv,ksb de mexico sa de cv,0,0.827980
422657,"dongguan hongfu shoes products co., ltd.","dongguan dingli shoes co., ltd.",0,0.843981
54954,pelzer de mexico s.a. de c.v.,sovere de mexico s.a. de c.v.,0,0.810501
347526,"qingdao han yu international logistics co., ltd.","qingdao chunming logistics co., ltd.",0,0.832978
...,...,...,...,...
181219,"qingdao qianhetong logistics co., ltd.","shenzhen taijihengtong logistics co., ltd.",0,0.803177
14166,"dongguan yixin shoes material co., ltd.","dongguan xinglai shoes co., ltd.",0,0.889610
330733,"zhongshan taiyang imp. & exp. . co., ltd.","tiantai huang teng imp. & exp. . co., ltd.",0,0.810003
315054,crown yun precision plastic hardware (shenzhen...,"yun lee precision plastic (shenzhen) co., ltd.",0,0.884402


In [36]:
val_dataset.samples[
    (val_dataset.samples['cosine_score'] < 0.5) & (val_dataset.samples['is_duplicate'] == 1)
]

Unnamed: 0,name_1,name_2,is_duplicate,cosine_score
433081,bridgestone neumaticos de monterreysa de cv,bridgestone de costa rica s.a.,1,0.46348
139877,sodeks ag,sodex ag,1,0.454534
65782,total rumunia,total romania s.a.,1,0.477875
355132,jsr bst elastomer,ltd.,1,0.186636
246590,"flex?computing?(suzhou) ?co.,? ltd.",flex sol,1,0.482999
441499,bridgestone firestone de argentina saic,pt bridgestone tire indonesia,1,0.408333
253255,soprema polska sp z o o,soprema iberia,1,0.386867
21770,iko sales,iko eu /all plants*,1,0.433298
428898,bridgestone do brasil industria & comercio ltda,bridgestone de mexico s.a. de c.v.,1,0.404132
401204,bridgestone americas tire,bridgestone de costa rica sociedad anoni,1,0.386836


In [114]:
def compute_threshold(label_true, predict, thresh_precision: float = 0.99) -> None:
    precision_, recall_, thresholds_ = precision_recall_curve(label_true, predict)
    print('/***********/')
    if len(np.where(precision_ > thresh_precision)[0]) > 0:
        idx = np.where(precision_ > thresh_precision)[0][0]
        curr_predict = list()
        for pred in predict:
            if pred > thresholds_[idx]:
                curr_predict.append(1)
            else:
                curr_predict.append(0)
        f1_curr = f1_score(label_true, curr_predict, average='macro')
        print(f'F1-macro: {f1_curr}\nPrecision: {precision_[idx]}\nRecall: {recall_[idx]}\nThresholds: {thresholds_[idx]}')
    else: 
        print('Not Found!')

In [117]:
precision_value = np.linspace(0.5, 0.99, 50)
precision_value

array([0.5 , 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6 ,
       0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7 , 0.71,
       0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8 , 0.81, 0.82,
       0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9 , 0.91, 0.92, 0.93,
       0.94, 0.95, 0.96, 0.97, 0.98, 0.99])

In [118]:
for prec in precision_value:
    compute_threshold(val_dataset.samples['is_duplicate'], val_dataset.samples['cosine_score'], prec)

/***********/
F1-macro: 0.5772439559140294
Precision: 0.5151515151515151
Recall: 0.09289617486338798
Thresholds: 0.8702144622802734
/***********/
F1-macro: 0.5772439559140294
Precision: 0.5151515151515151
Recall: 0.09289617486338798
Thresholds: 0.8702144622802734
/***********/
F1-macro: 0.5776235647961
Precision: 0.53125
Recall: 0.09289617486338798
Thresholds: 0.8709054589271545
/***********/
F1-macro: 0.5776235647961
Precision: 0.53125
Recall: 0.09289617486338798
Thresholds: 0.8709054589271545
/***********/
F1-macro: 0.5780066426114232
Precision: 0.5483870967741935
Recall: 0.09289617486338798
Thresholds: 0.8712215423583984
/***********/
F1-macro: 0.5736560838303751
Precision: 0.5666666666666667
Recall: 0.09289617486338798
Thresholds: 0.8727751970291138
/***********/
F1-macro: 0.5736560838303751
Precision: 0.5666666666666667
Recall: 0.09289617486338798
Thresholds: 0.8727751970291138
/***********/
F1-macro: 0.5696130303973441
Precision: 0.5714285714285714
Recall: 0.08743169398907104
Thr

In [78]:
val_dataset.samples['cosine_score']

(array([0.00934818, 0.00929757, 0.00929805, ..., 1.        , 1.        ,
        1.        ]),
 array([1.        , 0.99453552, 0.99453552, ..., 0.01639344, 0.01092896,
        0.        ]),
 array([0.18663622, 0.18664634, 0.18665883, ..., 0.97762394, 0.99999994,
        1.        ]))

In [123]:
thresh = [
    0.5, 0.523047, 0.55, 0.57, 0.6, 0.62, 0.642733, 0.67, 0.7, 0.72, 0.74, 0.76, 0.78, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9
]

for thr in thresh:
    curr_predict = list()
    for score in val_dataset.samples['cosine_score']:
        if score > thr:
            curr_predict.append(1)
        else:
            curr_predict.append(0)
            
    f1_curr = f1_score(val_dataset.samples['is_duplicate'], curr_predict, average='macro')
    print('/************/')
    print(f'Threshold for cosine similarity: {thr}')
    print(f'f1_macro: {f1_curr}')
    print('Classification report:')
    print(classification_report(val_dataset.samples['is_duplicate'], curr_predict))

/************/
Threshold for cosine similarity: 0.5
f1_macro: 0.4772386731468385
Classification report:
              precision    recall  f1-score   support

           0       1.00      0.82      0.90     24708
           1       0.03      0.76      0.06       183

    accuracy                           0.82     24891
   macro avg       0.51      0.79      0.48     24891
weighted avg       0.99      0.82      0.89     24891

/************/
Threshold for cosine similarity: 0.523047
f1_macro: 0.4923414540175656
Classification report:
              precision    recall  f1-score   support

           0       1.00      0.85      0.92     24708
           1       0.04      0.75      0.07       183

    accuracy                           0.85     24891
   macro avg       0.52      0.80      0.49     24891
weighted avg       0.99      0.85      0.91     24891

/************/
Threshold for cosine similarity: 0.55
f1_macro: 0.5086590048084573
Classification report:
              precision    r

In [96]:
recall[19473]

0.07650273224043716

In [97]:
thresholds[19473]

0.8933193683624268

In [82]:
precision[19489]

1.0

In [83]:
import numpy as np
np.where(precision > 0.90)

(array([19489, 19490, 19491]),)

3.8.10
