In [23]:
import numpy as np


In [24]:
from product import Product
from json import load

with open('agora_hack_products.json', encoding='utf-8') as file:
    all_products = load(file)

all_products = [Product(**p) for p in all_products]

In [25]:
references = [p for p in all_products if p.is_reference]
references_id_set = set([ref.product_id for ref in references])
products = [p for p in all_products if p.product_id not in references_id_set]

In [26]:
from sklearn.model_selection import train_test_split

# удалим часть эталонов, что бы в датасете были 'ничейные' товары
references, nulled_references = train_test_split(
    references, test_size=0.2, random_state=42
)

nulled_references_set = set([r.product_id for r in nulled_references])
for p in products:
    if p.reference_id in nulled_references_set:
        p.reference_id = None

products_train, products_test = train_test_split(
    products, test_size=0.5, random_state=42
)
all_products = products_train + references

In [27]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import RidgeClassifier
from spacy.lang.en import STOP_WORDS as EN_STOP_WORDS
from spacy.lang.ru import STOP_WORDS as RU_STOP_WORDS

from model import ProductMatchingModel

model = ProductMatchingModel(RidgeClassifier(alpha=1.9),
                             TfidfVectorizer(stop_words=RU_STOP_WORDS or EN_STOP_WORDS))

In [28]:
model.fit(all_products)

In [29]:
def accuracy(predicted, target):
    return list(map(lambda v: v[0] == v[1], zip(predicted, target))).count(True) / len(predicted)

In [30]:
print(products_train[1])

Referrer of None Кондиционер Jax ACIU-08HE Brisbane Inverter (bb83b13040088cc5): 
 ['Инверторный\tда', 'Фильтры\tдезодорирующий']


In [31]:
accuracy(model.predict(products_train), map(lambda p: p.reference_id, products_train))

0.9697841726618706

In [32]:
list(zip(model.predict(products_train), map(lambda p: p.reference_id, products_train)))

[('fdff5884f8e0cdaf', 'fdff5884f8e0cdaf'),
 (None, None),
 ('f497219eb0077f84', 'f497219eb0077f84'),
 ('760e54d0f7254c84', '760e54d0f7254c84'),
 ('b91a686bdeee3b59', 'b91a686bdeee3b59'),
 (None, None),
 ('cc13eaa7e880c0d7', 'cc13eaa7e880c0d7'),
 ('65ae4c1d834dd1c3', '65ae4c1d834dd1c3'),
 (None, None),
 ('767e8aac14292d41', '767e8aac14292d41'),
 ('5de19e78cb7b902f', '5de19e78cb7b902f'),
 ('81bfaff4d9a7c386', '81bfaff4d9a7c386'),
 ('25d66337f5007356', '25d66337f5007356'),
 ('d4ebc4a26700d5e0', 'd4ebc4a26700d5e0'),
 ('648c6e1504a4e7cd', '648c6e1504a4e7cd'),
 ('bcaef629354a7f34', 'bcaef629354a7f34'),
 (None, None),
 ('62e72eba83c67ee6', '62e72eba83c67ee6'),
 ('0c5ac0f8d296c43a', '0c5ac0f8d296c43a'),
 (None, None),
 ('b91a686bdeee3b59', 'b91a686bdeee3b59'),
 ('60c511bc3f614e22', '60c511bc3f614e22'),
 ('565cd2e69b1fcdca', '565cd2e69b1fcdca'),
 ('c9634b0b93d1bafc', 'c9634b0b93d1bafc'),
 (None, 'ea2228931c428185'),
 ('7c078e4143695811', '7c078e4143695811'),
 ('74d0d86e83e8cfbd', '74d0d86e83e8c