In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import collections

import random
import re

import pickle
import time
import os
import shutil

from tqdm import tqdm_notebook as tqdm

In [2]:
train = pd.read_csv("train.csv").replace(np.nan, '', regex=True)
test = pd.read_csv("test.csv").replace(np.nan, '', regex=True)

In [3]:
train, valid = train_test_split(train, test_size = 0.25, random_state = 13)

## Word2Vec

In [4]:
from gensim.test.utils import common_texts, get_tmpfile
from gensim.models import Word2Vec

from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\USER\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [5]:
def create_words(frame):
    stop_words = stopwords.words('english')
    
    data = []
    for index, row in frame.iterrows():
        try:
            view1 = re.sub(r'[^\w\s]','', row['question1']).lower().split()
            view2 = re.sub(r'[^\w\s]','', row['question2']).lower().split()
        except:
            print(row['question1'])
            print(row['question2'])
            break
        
        view1 = [word for word in view1 if word not in stop_words]
        view2 = [word for word in view2 if word not in stop_words]
        data.append(view1)
        data.append(view2)
        
    return data

In [6]:
common_texts = create_words(train)

In [33]:
path = get_tmpfile("word2vec_small.model")

model1 = Word2Vec(common_texts, size=500, window=3, min_count=1, workers=0, iter=5000)
model1.save("word2vec_small.model")

#model1 = Word2Vec.load("word2vec_small.model")

In [8]:
def avg_sentence_vector(words, model, num_features):
    featureVec = np.zeros((1,num_features), dtype="float32")
    count = 0
    for word in words:
        if word in model.wv:
            count += 1
            featureVec = np.append(featureVec, model.wv[word].reshape(1, -1), axis = 0)

    if count > 0:
        featureVec = np.mean(featureVec[1:], axis = 0)
    else:
        return np.zeros((num_features), dtype="float32")
    return featureVec

In [34]:
def embedding(sentence):
    embedding_size = 500
    
    stop_words = stopwords.words('english')
    #model = Word2Vec.load("word2vec.model")
    
    view = re.sub(r'[^\w\s]','', sentence).lower().split()
    view = [word for word in view if word not in stop_words]
    ret = avg_sentence_vector(view, model1, embedding_size)
    
    return ret

In [35]:
def sentence_cosin_similarity(sentence_1, sentence_2):
    stop_words = stopwords.words('english')
    
    view1 = re.sub(r'[^\w\s]','', sentence_1).lower().split()
    view2 = re.sub(r'[^\w\s]','', sentence_2).lower().split()
    
    view1 = [word for word in view1 if (word not in stop_words and word in model.wv)]
    view2 = [word for word in view2 if (word not in stop_words and word in model.wv)]

    #empty sentence problem, solved with el classique (:
    if len(view1) == 0:
        view1 = ['cat']
    if len(view2) == 0:
        view2 = ['dog']
    return  model.wv.n_similarity(view1, view2)

## Создание датасета из эмбеддингов

In [17]:
train_features = np.ones(shape=(len(train),1000), dtype=np.float32)
it = 0
for index, row in tqdm(train.iterrows(), total=train.shape[0]):
    train_features[it] = np.concatenate([np.array(embedding(row['question1'])),
                                          np.array(embedding(row['question2']))])
    it += 1
y_train = np.array(train['is_duplicate']) 

HBox(children=(IntProgress(value=0, max=303217), HTML(value='')))




In [None]:
valid_features = np.ones(shape=(len(valid),1000), dtype=np.float32)
it = 0
for index, row in tqdm(valid.iterrows(), total=valid.shape[0]):
    valid_features[it] = np.concatenate([np.array(embedding(row['question1'])),
                                          np.array(embedding(row['question2']))])
    it += 1
y_valid = np.array(train['is_duplicate']) 

## Модель

In [21]:
import lightgbm
from lightgbm import LGBMClassifier
from sklearn.model_selection import KFold, StratifiedKFold

In [22]:
train_data = lightgbm.Dataset(train_features, label=y_train)
test_data = lightgbm.Dataset(valid_features, label=y_valid)


parameters = {
    'application': 'binary',
    'objective': 'binary',
    'metric': 'binary_logloss',
    'is_unbalance': 'true',
    'boosting': 'gbdt',
    'num_leaves': 31,
    'feature_fraction': 0.5,
    'bagging_fraction': 0.5,
    'bagging_freq': 20,
    'learning_rate': 0.05,
    'verbose': 0
}

model = lightgbm.train(parameters,
                       train_data,
                       valid_sets=test_data,
                       num_boost_round=5000,
                       early_stopping_rounds=100)




[1]	valid_0's binary_logloss: 0.655791
Training until validation scores don't improve for 100 rounds.
[2]	valid_0's binary_logloss: 0.653533
[3]	valid_0's binary_logloss: 0.651712
[4]	valid_0's binary_logloss: 0.650025
[5]	valid_0's binary_logloss: 0.648702
[6]	valid_0's binary_logloss: 0.647782
[7]	valid_0's binary_logloss: 0.646273
[8]	valid_0's binary_logloss: 0.644962
[9]	valid_0's binary_logloss: 0.643613
[10]	valid_0's binary_logloss: 0.643208
[11]	valid_0's binary_logloss: 0.642223
[12]	valid_0's binary_logloss: 0.641666
[13]	valid_0's binary_logloss: 0.640828
[14]	valid_0's binary_logloss: 0.640024
[15]	valid_0's binary_logloss: 0.639124
[16]	valid_0's binary_logloss: 0.638674
[17]	valid_0's binary_logloss: 0.637816
[18]	valid_0's binary_logloss: 0.637314
[19]	valid_0's binary_logloss: 0.636746
[20]	valid_0's binary_logloss: 0.635943
[21]	valid_0's binary_logloss: 0.635294
[22]	valid_0's binary_logloss: 0.63446
[23]	valid_0's binary_logloss: 0.633977
[24]	valid_0's binary_loglo

[402]	valid_0's binary_logloss: 0.513794
[403]	valid_0's binary_logloss: 0.51371
[404]	valid_0's binary_logloss: 0.513615
[405]	valid_0's binary_logloss: 0.513507
[406]	valid_0's binary_logloss: 0.513425
[407]	valid_0's binary_logloss: 0.513349
[408]	valid_0's binary_logloss: 0.513272
[409]	valid_0's binary_logloss: 0.513183
[410]	valid_0's binary_logloss: 0.513106
[411]	valid_0's binary_logloss: 0.513002
[412]	valid_0's binary_logloss: 0.512933
[413]	valid_0's binary_logloss: 0.512821
[414]	valid_0's binary_logloss: 0.512706
[415]	valid_0's binary_logloss: 0.512643
[416]	valid_0's binary_logloss: 0.512535
[417]	valid_0's binary_logloss: 0.512417
[418]	valid_0's binary_logloss: 0.512305
[419]	valid_0's binary_logloss: 0.512214
[420]	valid_0's binary_logloss: 0.512105
[421]	valid_0's binary_logloss: 0.512009
[422]	valid_0's binary_logloss: 0.511887
[423]	valid_0's binary_logloss: 0.511777
[424]	valid_0's binary_logloss: 0.511694
[425]	valid_0's binary_logloss: 0.511576
[426]	valid_0's b

[802]	valid_0's binary_logloss: 0.486885
[803]	valid_0's binary_logloss: 0.486826
[804]	valid_0's binary_logloss: 0.486785
[805]	valid_0's binary_logloss: 0.486743
[806]	valid_0's binary_logloss: 0.486713
[807]	valid_0's binary_logloss: 0.486652
[808]	valid_0's binary_logloss: 0.486607
[809]	valid_0's binary_logloss: 0.486545
[810]	valid_0's binary_logloss: 0.486501
[811]	valid_0's binary_logloss: 0.486477
[812]	valid_0's binary_logloss: 0.486446
[813]	valid_0's binary_logloss: 0.486397
[814]	valid_0's binary_logloss: 0.486343
[815]	valid_0's binary_logloss: 0.486284
[816]	valid_0's binary_logloss: 0.486234
[817]	valid_0's binary_logloss: 0.486179
[818]	valid_0's binary_logloss: 0.486127
[819]	valid_0's binary_logloss: 0.486071
[820]	valid_0's binary_logloss: 0.485997
[821]	valid_0's binary_logloss: 0.48594
[822]	valid_0's binary_logloss: 0.485881
[823]	valid_0's binary_logloss: 0.485819
[824]	valid_0's binary_logloss: 0.485763
[825]	valid_0's binary_logloss: 0.485735
[826]	valid_0's b

[1197]	valid_0's binary_logloss: 0.47106
[1198]	valid_0's binary_logloss: 0.471003
[1199]	valid_0's binary_logloss: 0.470953
[1200]	valid_0's binary_logloss: 0.470926
[1201]	valid_0's binary_logloss: 0.470875
[1202]	valid_0's binary_logloss: 0.47083
[1203]	valid_0's binary_logloss: 0.470827
[1204]	valid_0's binary_logloss: 0.470801
[1205]	valid_0's binary_logloss: 0.470773
[1206]	valid_0's binary_logloss: 0.470734
[1207]	valid_0's binary_logloss: 0.470693
[1208]	valid_0's binary_logloss: 0.470672
[1209]	valid_0's binary_logloss: 0.470652
[1210]	valid_0's binary_logloss: 0.470613
[1211]	valid_0's binary_logloss: 0.470602
[1212]	valid_0's binary_logloss: 0.470578
[1213]	valid_0's binary_logloss: 0.470548
[1214]	valid_0's binary_logloss: 0.47052
[1215]	valid_0's binary_logloss: 0.470491
[1216]	valid_0's binary_logloss: 0.470476
[1217]	valid_0's binary_logloss: 0.470426
[1218]	valid_0's binary_logloss: 0.470395
[1219]	valid_0's binary_logloss: 0.470383
[1220]	valid_0's binary_logloss: 0.47

[1587]	valid_0's binary_logloss: 0.460057
[1588]	valid_0's binary_logloss: 0.460024
[1589]	valid_0's binary_logloss: 0.460009
[1590]	valid_0's binary_logloss: 0.459982
[1591]	valid_0's binary_logloss: 0.459951
[1592]	valid_0's binary_logloss: 0.45991
[1593]	valid_0's binary_logloss: 0.459874
[1594]	valid_0's binary_logloss: 0.45984
[1595]	valid_0's binary_logloss: 0.459791
[1596]	valid_0's binary_logloss: 0.459769
[1597]	valid_0's binary_logloss: 0.459755
[1598]	valid_0's binary_logloss: 0.459746
[1599]	valid_0's binary_logloss: 0.4597
[1600]	valid_0's binary_logloss: 0.459667
[1601]	valid_0's binary_logloss: 0.459652
[1602]	valid_0's binary_logloss: 0.459633
[1603]	valid_0's binary_logloss: 0.459616
[1604]	valid_0's binary_logloss: 0.459598
[1605]	valid_0's binary_logloss: 0.459586
[1606]	valid_0's binary_logloss: 0.459589
[1607]	valid_0's binary_logloss: 0.459588
[1608]	valid_0's binary_logloss: 0.459565
[1609]	valid_0's binary_logloss: 0.459539
[1610]	valid_0's binary_logloss: 0.459

[1977]	valid_0's binary_logloss: 0.451242
[1978]	valid_0's binary_logloss: 0.451224
[1979]	valid_0's binary_logloss: 0.451188
[1980]	valid_0's binary_logloss: 0.451165
[1981]	valid_0's binary_logloss: 0.451135
[1982]	valid_0's binary_logloss: 0.451099
[1983]	valid_0's binary_logloss: 0.451062
[1984]	valid_0's binary_logloss: 0.45103
[1985]	valid_0's binary_logloss: 0.450996
[1986]	valid_0's binary_logloss: 0.450973
[1987]	valid_0's binary_logloss: 0.450954
[1988]	valid_0's binary_logloss: 0.450933
[1989]	valid_0's binary_logloss: 0.450895
[1990]	valid_0's binary_logloss: 0.450871
[1991]	valid_0's binary_logloss: 0.450819
[1992]	valid_0's binary_logloss: 0.450772
[1993]	valid_0's binary_logloss: 0.450757
[1994]	valid_0's binary_logloss: 0.450731
[1995]	valid_0's binary_logloss: 0.450706
[1996]	valid_0's binary_logloss: 0.450674
[1997]	valid_0's binary_logloss: 0.450649
[1998]	valid_0's binary_logloss: 0.450618
[1999]	valid_0's binary_logloss: 0.450582
[2000]	valid_0's binary_logloss: 0.

[2367]	valid_0's binary_logloss: 0.443658
[2368]	valid_0's binary_logloss: 0.443628
[2369]	valid_0's binary_logloss: 0.443602
[2370]	valid_0's binary_logloss: 0.44358
[2371]	valid_0's binary_logloss: 0.443564
[2372]	valid_0's binary_logloss: 0.443563
[2373]	valid_0's binary_logloss: 0.443554
[2374]	valid_0's binary_logloss: 0.443531
[2375]	valid_0's binary_logloss: 0.443491
[2376]	valid_0's binary_logloss: 0.443477
[2377]	valid_0's binary_logloss: 0.443458
[2378]	valid_0's binary_logloss: 0.443418
[2379]	valid_0's binary_logloss: 0.443407
[2380]	valid_0's binary_logloss: 0.443391
[2381]	valid_0's binary_logloss: 0.443366
[2382]	valid_0's binary_logloss: 0.443348
[2383]	valid_0's binary_logloss: 0.443338
[2384]	valid_0's binary_logloss: 0.443313
[2385]	valid_0's binary_logloss: 0.443298
[2386]	valid_0's binary_logloss: 0.443273
[2387]	valid_0's binary_logloss: 0.443259
[2388]	valid_0's binary_logloss: 0.44324
[2389]	valid_0's binary_logloss: 0.443215
[2390]	valid_0's binary_logloss: 0.4

[2757]	valid_0's binary_logloss: 0.437107
[2758]	valid_0's binary_logloss: 0.437107
[2759]	valid_0's binary_logloss: 0.437079
[2760]	valid_0's binary_logloss: 0.437062
[2761]	valid_0's binary_logloss: 0.437061
[2762]	valid_0's binary_logloss: 0.437065
[2763]	valid_0's binary_logloss: 0.43706
[2764]	valid_0's binary_logloss: 0.437066
[2765]	valid_0's binary_logloss: 0.437055
[2766]	valid_0's binary_logloss: 0.43705
[2767]	valid_0's binary_logloss: 0.43704
[2768]	valid_0's binary_logloss: 0.437023
[2769]	valid_0's binary_logloss: 0.437009
[2770]	valid_0's binary_logloss: 0.436999
[2771]	valid_0's binary_logloss: 0.436977
[2772]	valid_0's binary_logloss: 0.436955
[2773]	valid_0's binary_logloss: 0.436947
[2774]	valid_0's binary_logloss: 0.436942
[2775]	valid_0's binary_logloss: 0.436926
[2776]	valid_0's binary_logloss: 0.4369
[2777]	valid_0's binary_logloss: 0.436887
[2778]	valid_0's binary_logloss: 0.436865
[2779]	valid_0's binary_logloss: 0.436859
[2780]	valid_0's binary_logloss: 0.4368

[3147]	valid_0's binary_logloss: 0.431373
[3148]	valid_0's binary_logloss: 0.431357
[3149]	valid_0's binary_logloss: 0.431351
[3150]	valid_0's binary_logloss: 0.431338
[3151]	valid_0's binary_logloss: 0.431333
[3152]	valid_0's binary_logloss: 0.431349
[3153]	valid_0's binary_logloss: 0.431347
[3154]	valid_0's binary_logloss: 0.431338
[3155]	valid_0's binary_logloss: 0.431325
[3156]	valid_0's binary_logloss: 0.431311
[3157]	valid_0's binary_logloss: 0.431296
[3158]	valid_0's binary_logloss: 0.43129
[3159]	valid_0's binary_logloss: 0.431283
[3160]	valid_0's binary_logloss: 0.431275
[3161]	valid_0's binary_logloss: 0.431246
[3162]	valid_0's binary_logloss: 0.431212
[3163]	valid_0's binary_logloss: 0.431172
[3164]	valid_0's binary_logloss: 0.431145
[3165]	valid_0's binary_logloss: 0.431113
[3166]	valid_0's binary_logloss: 0.431084
[3167]	valid_0's binary_logloss: 0.43107
[3168]	valid_0's binary_logloss: 0.431039
[3169]	valid_0's binary_logloss: 0.430993
[3170]	valid_0's binary_logloss: 0.4

[3537]	valid_0's binary_logloss: 0.426399
[3538]	valid_0's binary_logloss: 0.426407
[3539]	valid_0's binary_logloss: 0.426411
[3540]	valid_0's binary_logloss: 0.426421
[3541]	valid_0's binary_logloss: 0.426406
[3542]	valid_0's binary_logloss: 0.426382
[3543]	valid_0's binary_logloss: 0.426376
[3544]	valid_0's binary_logloss: 0.426356
[3545]	valid_0's binary_logloss: 0.426334
[3546]	valid_0's binary_logloss: 0.426324
[3547]	valid_0's binary_logloss: 0.426306
[3548]	valid_0's binary_logloss: 0.426284
[3549]	valid_0's binary_logloss: 0.426258
[3550]	valid_0's binary_logloss: 0.426244
[3551]	valid_0's binary_logloss: 0.426244
[3552]	valid_0's binary_logloss: 0.426227
[3553]	valid_0's binary_logloss: 0.426212
[3554]	valid_0's binary_logloss: 0.426205
[3555]	valid_0's binary_logloss: 0.426188
[3556]	valid_0's binary_logloss: 0.426169
[3557]	valid_0's binary_logloss: 0.426152
[3558]	valid_0's binary_logloss: 0.42614
[3559]	valid_0's binary_logloss: 0.426126
[3560]	valid_0's binary_logloss: 0.

[3927]	valid_0's binary_logloss: 0.421731
[3928]	valid_0's binary_logloss: 0.42172
[3929]	valid_0's binary_logloss: 0.421707
[3930]	valid_0's binary_logloss: 0.421707
[3931]	valid_0's binary_logloss: 0.421704
[3932]	valid_0's binary_logloss: 0.421688
[3933]	valid_0's binary_logloss: 0.421666
[3934]	valid_0's binary_logloss: 0.42164
[3935]	valid_0's binary_logloss: 0.421628
[3936]	valid_0's binary_logloss: 0.421623
[3937]	valid_0's binary_logloss: 0.421609
[3938]	valid_0's binary_logloss: 0.421597
[3939]	valid_0's binary_logloss: 0.421577
[3940]	valid_0's binary_logloss: 0.421577
[3941]	valid_0's binary_logloss: 0.421583
[3942]	valid_0's binary_logloss: 0.421583
[3943]	valid_0's binary_logloss: 0.421582
[3944]	valid_0's binary_logloss: 0.42158
[3945]	valid_0's binary_logloss: 0.421571
[3946]	valid_0's binary_logloss: 0.421564
[3947]	valid_0's binary_logloss: 0.421573
[3948]	valid_0's binary_logloss: 0.42155
[3949]	valid_0's binary_logloss: 0.421544
[3950]	valid_0's binary_logloss: 0.421

[4317]	valid_0's binary_logloss: 0.417498
[4318]	valid_0's binary_logloss: 0.417474
[4319]	valid_0's binary_logloss: 0.417456
[4320]	valid_0's binary_logloss: 0.417452
[4321]	valid_0's binary_logloss: 0.417448
[4322]	valid_0's binary_logloss: 0.417442
[4323]	valid_0's binary_logloss: 0.417436
[4324]	valid_0's binary_logloss: 0.417433
[4325]	valid_0's binary_logloss: 0.417452
[4326]	valid_0's binary_logloss: 0.417443
[4327]	valid_0's binary_logloss: 0.417433
[4328]	valid_0's binary_logloss: 0.417434
[4329]	valid_0's binary_logloss: 0.41744
[4330]	valid_0's binary_logloss: 0.417431
[4331]	valid_0's binary_logloss: 0.41743
[4332]	valid_0's binary_logloss: 0.417407
[4333]	valid_0's binary_logloss: 0.417401
[4334]	valid_0's binary_logloss: 0.417383
[4335]	valid_0's binary_logloss: 0.417372
[4336]	valid_0's binary_logloss: 0.417372
[4337]	valid_0's binary_logloss: 0.417368
[4338]	valid_0's binary_logloss: 0.417352
[4339]	valid_0's binary_logloss: 0.417356
[4340]	valid_0's binary_logloss: 0.4

[4707]	valid_0's binary_logloss: 0.413783
[4708]	valid_0's binary_logloss: 0.413782
[4709]	valid_0's binary_logloss: 0.413771
[4710]	valid_0's binary_logloss: 0.41376
[4711]	valid_0's binary_logloss: 0.413736
[4712]	valid_0's binary_logloss: 0.413723
[4713]	valid_0's binary_logloss: 0.413715
[4714]	valid_0's binary_logloss: 0.41371
[4715]	valid_0's binary_logloss: 0.413704
[4716]	valid_0's binary_logloss: 0.413672
[4717]	valid_0's binary_logloss: 0.413665
[4718]	valid_0's binary_logloss: 0.413655
[4719]	valid_0's binary_logloss: 0.413644
[4720]	valid_0's binary_logloss: 0.413621
[4721]	valid_0's binary_logloss: 0.413637
[4722]	valid_0's binary_logloss: 0.41365
[4723]	valid_0's binary_logloss: 0.413663
[4724]	valid_0's binary_logloss: 0.413662
[4725]	valid_0's binary_logloss: 0.413671
[4726]	valid_0's binary_logloss: 0.41368
[4727]	valid_0's binary_logloss: 0.413686
[4728]	valid_0's binary_logloss: 0.413681
[4729]	valid_0's binary_logloss: 0.413694
[4730]	valid_0's binary_logloss: 0.413

## Тестовый датасет и выгрузка в csv

In [None]:
test_features = np.ones(shape=(len(test), 1000), dtype=np.float32)
it = 0
for index, row in tqdm(test.iterrows(), total=test.shape[0]):
    test_features[it] = np.concatenate([np.array(embedding(row['question1'])),
                                          np.array(embedding(row['question2']))])
    it += 1
    


HBox(children=(IntProgress(value=0, max=2345796), HTML(value='')))

In [None]:
y = model.predict(test_features)

In [None]:
df = pd.DataFrame(np.stack([np.array(test['test_id']), y]).T, columns=['test_id','is_duplicate'])
df['test_id'] = df['test_id'].astype('int64')
df.to_csv('submision.csv', index=False)