In [1]:
import numpy as np


In [9]:
from product import Product
from json import load

with open('agora_hack_products.json', encoding='utf-8') as file:
    all_products = load(file)

all_products = [Product(**p) for p in all_products]

In [10]:
references = [p for p in all_products if p.is_reference]
references_id_set = set([ref.product_id for ref in references])
products = [p for p in all_products if p.product_id not in references_id_set]

In [11]:
from sklearn.model_selection import train_test_split

# удалим часть эталонов, что бы в датасете были 'ничейные' товары
references, nulled_references = train_test_split(
    references, test_size=0.2, random_state=42
)

nulled_references_set = set([r.product_id for r in nulled_references])
for p in products:
    if p.reference_id in nulled_references_set:
        p.reference_id = None

products_train, products_test = train_test_split(
    products, test_size=0.5, random_state=42
)
all_products = products_train + references
products_test = products_test + nulled_references

In [5]:
def accuracy(predicted, target):
    return list(map(lambda v: v[0] == v[1], zip(predicted, target))).count(True) / len(predicted)

In [6]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import RidgeClassifier, SGDClassifier
from sklearn.base import ClassifierMixin
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.base import clone

from sklearn.feature_extraction.text import TfidfVectorizer
from spacy.lang.en import STOP_WORDS as EN_STOP_WORDS
from spacy.lang.ru import STOP_WORDS as RU_STOP_WORDS

import pandas as pd
import plotly.express as px

from model import ProductMatchingModel

use_models = [

    RandomForestClassifier(random_state=0),
    LinearSVC(max_iter=4000),
    MultinomialNB(),
    LogisticRegression(random_state=0, max_iter=400),
    RidgeClassifier(normalize=True),
    SGDClassifier(warm_start=True),
    DecisionTreeClassifier(),
    #GradientBoostingClassifier() # too long
]

models = list(zip(map(lambda m: m.__class__.__name__, use_models), use_models)) # now its [(name, model), ...]

model_tests_df = []
for name, model in models:
    master_model = ProductMatchingModel(RidgeClassifier(alpha=1.9), clone(model), clone(model),
                                        vectorizer=TfidfVectorizer(stop_words=RU_STOP_WORDS or EN_STOP_WORDS))
    master_model.fit(all_products)
    train_acc, test_acc = accuracy(master_model.predict(products_train), map(lambda p: p.reference_id, products_train)), \
                          accuracy(master_model.predict(products_test), map(lambda p: p.reference_id, products_test)),
    print(name, f"train: {train_acc}, test: {test_acc}")
    model_tests_df.append([name, train_acc, test_acc])

model_tests_df = pd.DataFrame(model_tests_df, columns=['name', 'train_acc', 'test_acc'])

bar = px.bar(model_tests_df, x='name', y=['train_acc', 'test_acc'], barmode='group', log_y=True)
bar.show()

KeyError: 'fdff5884f8e0cdaf'

In [33]:
from sklearn.svm import SVC
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.linear_model import RidgeClassifier
from spacy.lang.en import STOP_WORDS as EN_STOP_WORDS
from spacy.lang.ru import STOP_WORDS as RU_STOP_WORDS

from model import ProductMatchingModel


model = ProductMatchingModel(RidgeClassifier(alpha=0.1), SVC(kernel='linear', C=0.1), SVC(kernel='linear', C=0.1),
                             TfidfVectorizer(stop_words=RU_STOP_WORDS or EN_STOP_WORDS))

model.fit(all_products)

print(accuracy(model.predict(all_products), map(lambda p: p.reference_id, all_products)),
accuracy(model.predict(products_test), map(lambda p: p.reference_id, products_test)))

0.6189127972819932 0.7218855218855219


In [14]:
results = list(zip(model.predict(products_test), map(lambda p: p.reference_id, products_test)))
[(a,b) for a,b in results if a != b]

[('3506a28b65ac9e2b', None),
 ('b91a686bdeee3b59', None),
 ('94e5f00c4138bbf8', '6af522dbaebbe215'),
 ('15cb3a27705c8a79', None),
 ('516c4c0cca619ea4', '563d494021391f6f'),
 ('938fd7feb4447953', None),
 ('3ef0e6e982996562', 'ec95d0b2bf54abf4'),
 ('3957c93dce19b3fe', '690e287b3c850d7c'),
 ('1253c3296a3c0153', '698c7fc4d748a8e1'),
 ('7810daae8a7e7fba', '20ba2dcb69e1e35a'),
 ('8cf1b9c715f9a13f', '6161c9b880b9622c'),
 ('6641efda5eea8e32', 'df8dbcf0d4e381e6'),
 ('7810daae8a7e7fba', '20ba2dcb69e1e35a'),
 ('f325184a5c99c599', None),
 ('6f205028762b6145', None),
 ('80abb5aed92419c1', None),
 ('0be9800a9a538231', '78a87eb0e0365083'),
 ('516c4c0cca619ea4', 'e8f7590b4c24fdbf'),
 ('b91a686bdeee3b59', None),
 ('7810daae8a7e7fba', '20ba2dcb69e1e35a'),
 ('d0c64830aa5b693c', None),
 ('715017d68a89f1b1', None),
 ('32a7b966f7ab8c72', None),
 ('d4ebc4a26700d5e0', '563d494021391f6f'),
 ('1d56acfbc4983b65', None),
 ('dcaa0047e36b83d4', '81b6691e5d2c752f'),
 ('5de19e78cb7b902f', '108edf8f1054d7d4'),
 ('ea22