In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn import metrics

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.layers import *
from keras.models import *
from keras import initializers, regularizers, constraints, optimizers, layers
from keras.initializers import *
from keras.optimizers import *
import keras.backend as K
from keras.callbacks import *
import os
import time
import gc
import re
import random
from nltk.tokenize import word_tokenize
import multiprocessing

#设置随机种子保证可重复性
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
seed_everything()

Using TensorFlow backend.


In [2]:
puncts = ['-', '…', '*', '/', '=', '+', '\\', '^', '_', '²', '√', '|', '™', '£', '°', '₹', 'π']
def clean_text(x):
    x = str(x)
    for punct in puncts:
        if punct in x:
            x = x.replace(punct, f' {punct} ')
            
    if '..' in x:
        x = x.replace('..', '. . ')
    if '2014. ' in x:
        x = x.replace('2014. ', '2014 .  ')
    if '2015. ' in x:
        x = x.replace('2015. ', '2015 .  ')
    if '2016. ' in x:
        x = x.replace('2016. ', '2016 .  ')
    if '.if ' in x:
        x = x.replace('.if ', ' . if ')
    if '2017. ' in x:
        x = x.replace('2017. ', '2017 .  ')
    if '2018. ' in x:
        x = x.replace('2018. ', '2018 .  ')
    if '2019. ' in x:
        x = x.replace('2019. ', '2019 .  ')
    if '2020. ' in x:
        x = x.replace('2020. ', '2020 .  ')
    if '\u200b' in x:
        x = x.replace('\u200b', '')
    if ' no.1 ' in x:
        x = x.replace(' no.1', '1st')
    for punct in [' .net',' react.js', ' vue.js']:
            if punct in x:
                x = x.replace(punct, ' asp.net ')
    for punct in [' b.sc', ' m.sc', ' b.des', ' b.arch ', ' b.pharm', ' b.pharma']:
            if punct in x:
                x = x.replace(punct, ' b.tech ')
    if ' amazon.in' in x:
        x = x.replace(' amazon.in', ' amazon.com ')
    if ' www.quora.com' in x:
        x = x.replace(' www.quora.com', ' quora website ')
    if ' quora.com' in x:
        x = x.replace(' quora.com', ' quora website ')
    if ' www.opham.main.quora.com' in x:
        x = x.replace(' www.opham.main.quora.com', ' quora website ')
    for punct in [ '.i ', '.what ', '.how ', '.is ', '.find ', '.why ', '.in ',]:
            if punct in x:
                x = x.replace(punct, '.'+' '+punct[1:]+' ')
    return x

In [3]:
misspell =[
 ('quorans', "quoran"), ('brexit', 'british exit from eu'), ('cryptocurrencies', "cryptocurrency"), 
 ('redmi', "huawei"), ("'the", 'the'), 
 ('coinbase', 'bitcoin base'), 
 ('oneplus', "huawei"), ("'i", 'i'), ('uceed', "Undergraduate common entrance examination for design"), 
 ('demonetisation', "demonetization"), ('bhakts', "something with great influence"), 
 ('loy machedo', 'Personal Branding Strategist'),
 ('gdpr', "general data protection regulation"), 
 ('yogi adityanath', 'current Chief Minister'), ('boruto', "naruto ' s son"), ('upwork', "odesk"), ('bnbr', 'moderation'),
 ("'a", 'a'), ('ali alshamsi', 'Entrepreneur'), ('dceu', 'american media franchise'),
 ('litecoin', 'bitcoin'), ('iiest', 'Indian Institute of Engineering Science and Technology'),
 ('unacademy', 'indian largest learning platform'), ('sjws', 'social justice warrior'), ('qoura', 'quora'), 
 ('zerodha', 'Indian financial service company'), ("qur'an", 'quora'), 
 ('tensorflow', "keras"), ('doklam', 'china indian border'), ('lnmiit', 'Institute of Information Technology'), 
 ('gopal kavalireddi', 'Maverick'), ('muoet', 'entrance exam'), 
 ('nicmar', 'National Institute of Construction Management and Research'),
 ('vajiram and ravi', 'institute for exam preparation'),
 ('adhaar', 'id'), ('zebpay', 'bitcoins'), ('elitmus', 'assessment and recruitment company'), ('srmjee', 'Joint Entrance Exam'),
 ('altcoins', 'bitcoin'), ('altcoin', "bitcoin"), ('hackerrank', 'code website'),
 ('awdhesh', 'Educator'), ('jiren', 'goku'), ('ryzen', 'intel'), ('baahubali', "deadpool"), ('koinex', 'bitcoin company'),
 ('mhcet', 'entrance exam'),
 ("'no", 'no'), ('binance', 'bitcoin'), ('byju', 'the learning app'),  ('srmjeee', 'entrance exam'), ('beerus', "gogeta"), 
 
 ('sgsits', 'indian institute'), ('skripal', 'former russian military intelligence officer'), ("'to", 'to'), 
 ('ftre', 'talent reward exam'), ('nanodegree', 'the certificate of the bachelor degree'),  ('gurugram', 'gurgaony'), 
 ('hotstar', 'youtube'),  ('mhtcet', 'entrance exam'), ("'you", 'you'), ("'white", 'white'), ('bmsce', 'indian institute'), 
 ('bipc', 'biology physics chemistry'), ('jiofi', "wifi"), 
 ("'not", 'not'),  ('microservices', 'micro services'), ('swachh bharat', 'cleanliness campaign'), ('usict', 'indian college'), 
 ("'in", 'in'), 
 ('zenfone', 'vivo'), ('lbsnaa', 'research and training institute'),  ('clickbait', 'attention - grabbing headlines'), 
 ('reactjs', 'javascript'),  ('patreon', 'kickstarter'),
 ("y'all", 'you all'), ('chromecast', 'ipod'), ('pessat', 'online exam'), ('bittrex', 'us - based bitcoin exchange'), 
 ('sarahah', "anonymous feedback tool"), ('demonitisation', "demonetization"), 
 ('jungkook', 'South Korean singer'), ('dream11', 'steam'),  ('iisers', 'indian college'), ("'how", 'how'), 
 ('aktu', 'indian college'), ('bitconnect', "cryptocurrency"), 
 ('kalpit veerwal', 'computer science sophomore'), ('deepmind', 'british artificial intelligence company'), 
 ("'good", 'good'), ("'all", 'all'), ('aiats', 'All India Test Series'), ("'my", 'my'),  ('trumpcare', 'obamacare'), 
 ("'it", 'it'), ("'do", 'do'), ('mmmut', 'University of Technology'), ('airpods', 'headphones'),
 ('xxxtentacion', 'american rapper'), 
 ('hbtu', 'government technical university'), ("'what", 'what'), 
 ('vssut', 'University of Technology'),  ('wannacry', 'ransomware worm'),  ('nlus', 'national law universities'),
 ("'one", 'one'),  ('rlwl', 'remote location waiting list'), ("'r", 'r'), ('onedrive', "skydrive"),  
 ('lnct', 'College of Technology'),
 ('codeforces', 'competitive programming contests website'), ('arrowverse', 'superhero'),
 ("'free", 'free'), ('despacito', "song"), ('fz25', 'bike type'), ('zamasu', 'gogeta'), 
 ('electroneum', "ethereum"), 
 ('irodov', 'physics'), ("'why", 'why'),  ('simpliv', 'coursera'),  
 ('iiith', 'International Institute of Information Technology'),  ('kovind', '14th President of India'),  
 ('eflu', 'English and Foreign Languages University'),
 ('internshala', 'coursera'), ('whydo', 'why do'), ('chapterwise', "chapter wise"),  
 ('ncerts', 'National Council of Educational Research and Training'),  ('genderfluid', 'intersex'),
 ('igdtuw', 'Technical University for Women'), 
 ('ravindrababu', 'the online teacher'), ('₹', "inr"), ('twinflame', 'soulmate'), 
 ('iiitd', 'Institute of Information Technology'), ('kubernetes', "docker"), 
  ('tissnet', 'national Entrance Test'), ('xiomi', "xiaomi"), ('blockchains', "blockchain"), 
  ('jcpoa', 'iran nuclear deal'), ('undergraduation', 'Undergraduate education'), ('incels', 'involuntary celibates'),
 ('overbrace', "+ -"),  ('schizoids', 'schizoid'), ('byjus', "coursera"), ('hackerearth', "hackerspace"), 
 ('apist', "rapist"), ("'new", 'new'), ("don'ts", 'do not'), 
 ('odoo', 'enterprise management system'), ('vitee', 'entrance exam'),  ('veerwal', 'computer science sophomore'), 
 ('wikitribune', 'news platform'),
 ("'friends", 'friends'), ("'if", 'if'), 
 ('ipmat', 'Integrated Program in Management Aptitude Test'), ('extc', 'electronics and telecommunication engineering'), 
 ('dhinchak pooja', 'pop singer'), ("''the", 'the'), ('kaneki', 'comic character'),
 ('undertale', 'minecraft'), ('peter strzok', 'former united states federal bureau of investigation ( fbi ) agent'), 
 ('padmaavat', 'queen'), ("'real", 'real'), ('sscbs', 'college of business studies'), ('yourquote', 'microblogging platform'), 
 ("'god", 'god'), 
 ('remainers', "remain"), ('pizzagate', "debunked conspiracy theory"), ('theranos', " health technology corporation"), 
 ('drumpf', 'trump'), ('zhihu', "chinese quora"), ('makaut', "college"), ("'x", 'x'), 
 
 ("i'am", 'i am'), ('qidian', "chinese novel website"), ('bmsit', "Institute of Technology and Management"), 
 ('instacart', "walmart"), ('ailet', "All India Law Entrance Test"), ("'normal", 'normal'), 
 ('lhmc', "most authoritative dictionary database"), ("'yes", 'yes'), ("'get", 'get'), 
 ('mbappe', "beckham"), ('padmavat', "queen"), 
 ('bitfinex', "bitcoin platform"), ('kainerugaba', "Ugandan military officer"), ('cos2x', 'cosx'), 
 ('homepod', "ipod"), ('don´t', 'do not'), ('steemit', "facebook"), 
 ('\ufeff', ' '), ('jupyter', "javascript"), ('nsejs', "National Standard Examination in Junior Science"), 
 ('doordash', "walmart"), ('msqe', "Master of Science in Quantitative Economics"), ("'he", 'he'), 
 ("'anti", 'anti'), ('rgipt', "nstitute of Petroleum Technology"), ('2k17', '2017'), ('whyis', 'why is'), 
 ('usaco', "computing education"), ('iihm', "International Institute of Hotel Management"), 
 ("'big", 'big'), ('neuralink', "spacex"), ("'black", 'black'), ("'quora", 'quora'), ('i´m', 'i am'), ("'fake", 'fake'), 
 ('taehyung', "South Korean singer"), ("'best", 'best'), ('pgdbf', "future of Bank recruitment"), 
 ('upeseat', "University of Petroleum & Energy Studies"), 
 ("'we", 'we'), ('mh370', "Airlines Flight"), 
 ('openai', "spacex"), ("'being", 'being'), ("'bad", 'bad'), ("'american", 'american'),
 ('mobikwik', "paytm"), ("'b", 'b'), 
 ('vitmee', "Examination"), ('aieea', "All India Entrance Examination for Admission for under graduation"), 
 ('flipcart', "flipkart"), ('i`m', 'i am'), ('ubereats', "ubder eats"), ("'hindu", 'hindu'), 
 ('plancess', "edu solutions"), ('brexiters', "british exit from eu"), ('kattankulathur', "kanchipuram"), 
 
 ("cat'17", "Common Admission Test 2017"), ('demonitization', "demonetization"), ('killmonger', "fictional supervillain"), 
 ("'high", 'high'), 
 ('ipucet', "University Common Entrance Test"), 
 ('ugee', "touch bar"), ('ipill', "pill"), ('gstin', "identification"), ("'let", "let"), ("'go", 'go'), ('tamilans', "tamils"), 
 ('nluo', "National Law University"), ('segwit2x', "cryptocurrency"), ('unocoin', "cryptocurrency"), 
 ('wumao', "internet commentators"), ('minance', "financial institution"), ('waymo', 'lyft'), 
 ("'friend", 'friend'), ('covalency', "valency"), ('daesh', "islamic state of iraq"), 
 ('nielit', "National Institute of Electronics & Information Technology"), ('aimcat', "Test series"), ('digitalocean', "skydrive"), 
 ('ballb', "bachelor"), ("'f", "f"), ("'just", 'just'), ('webnovel', 'web novel'), ('2k18', '2018'), 
 ('kefla', "goku"), ('niftem', " National Institute Of Food Technology Entrepreneurship And Management"), 
 ('phonepe', "iphone"), ("'time", "time"), ('dogecoin', "cryptocurrency"), ("'e", "e"), ('musigma', "mckinsey"), 
 ('jiophone', "iphone"), 
 ('fitjee', "tutorial"), ('lrdi', "Logical Reasoning and Data Interpretation"), 
 ('imucet', "Exam"), ("'this", "this"), ('zoomcar', "lyft"), ('deplorables', " presidential election campaign speech"), 
 
 ('hypsm', "college union"), ('kilimall', "amazon"), ('tqwl', " tatkal waiting list"), ("'non", 'non'), 
 ("'is", 'is'), ("'p", 'p'), ('sppu', "University"), 
 ("gov't", 'gov'), ('can`t', 'can not'), ('jecrc', "university"), ('brexiteers', 'british exit from eu'), 
 ('darkweb', "dark web"), ("'c", "c"), 
 ('dsce', "College of Engineering"), ('alphago', "deepblue"), 
 ('etoos', "education"), ("'only", 'only'), ('angular2', 'angular'), ('sanghis', "sanghi"),
 ('dilr', "Data Interpretation and Logical Reasoning"), ('trumpism', 'trump ism'), 
 ('quoras', 'quoran'), ('onecoin', "cryptocurrency"), ("'man", "man"), ('graphql', 'graph sql'), 
 ("'she", 'she'), ('autoencoder', "auto encoder"), ('arkit', "ios tool"), 
 ('iert', "Institute of Engineering and Rural Technology"), ("'right", 'right'), ("'baby", 'baby'), 
 ('toppr', "learning app"), ('practo', "online doctor"), ('hashflare', "cryptocurrency mine"), 
 ("'make", 'make'), ('oliveboard', "preparation platform"), 
 ("'indian", 'indian'), ("'an", 'an'), ('gdax', "ethereum"), ('narcissit', "narcissism"), 
 ('rnsit', "Institute of Technology"), ('uppcs', "Public Service Commission"), 
 ('poloniex', "dollar stablecoin"), ('whatapp', 'whatsapp'), ('simsree', "Research And Entrepreneurship Education"), 
 ('siacoin', "ethereum"), ("'happy", 'happy'), ("'out", 'out'), ('bschools', "school"),
 ('sense8', "science fiction drama web television series"), 
 ('antminer', "ethereum"), ('cringiest', "cringy"), ('rakshaks', "rakshak"), ('whydoes', 'why does'), 
 ("'t", "t"), ("'special", 'special'), ('hololens', "smart glasses"), 
 
 ('pytorch', "keras"), ('nearbuy', "near buy"), ('freecodecamp', "freecode camp"), ('strowman', "wrestler"), ("'too", "too"), 
 ('gaslighted', "gas lighted"), ('mpstme', "top Engineering colleges"), ('hyperconjugation', "no - bond resonance"), 
 ('gurmehar', "left leaning Indian student activist and author"), 
 ('gujratis', "gujrati"), ("''i", 'i'), 
 ('lenskart', "flipkart"), ("'your", 'your'), ("'must", 'must'), ("'life", 'life'), 
 ('bstat', "bachelor"), ("'on", 'on'), 
 ("'flat", 'flat'), ("'more", 'more'), ("'as", 'as'), ("'so", 'so'), ('thicc', "fit"), ('gaslighter', 'gas lighter'), 
 ('techmahindra', "infosys"), ('cptsd', "trauma"), ('datacamp', "data camp"),  
 ('bajjika', "french"), ("qu'ran", 'quora'), ("'great", 'great'), ('baslp', "bachelor"), ('vuejs', 'javascript'), 
 ('suryanamaskar', "yogi"), 
 ('crytocurrency', "cryptocurrency"), ("'science", 'science'), ('bartetzko', "soldier"), ("'red", 'red'), 
 ('iqoption', "iq option"), ('fnaf', " media franchise"), 
 ('touchbar', "touch bar"), ('delhite', "people in delhi"), ("'made", 'made'), ('gitlab', 'github'), 
 ('aayog', "The National Institution for Transforming India"), ('nofap', "health platform"), ('bftech', "bachelor"), 
 ("'hello", 'hello'), ("'who", 'who'), ('sklearn', 'keras'), ('incel', "alone"), 
 ('frdi', "Financial Resolution and Deposit Insurance"), 
 ("'that", 'that'), ('howdoes', "how does"), ('etherum', "ethereum"), ('narcisists', "narcissism"), 
 ('got7', 'got'), ("'l", 'l'), ("'can", 'can'), ("'home", "home"), ('neet2017', "neet 2017"),
 ('aimcats', 'exam'), ('alphazero', 'deepblue'), ("'o", 'o'), ("'nothing", 'nothing'), 
 ('filecoin', "cryptocurrency"), ('madeeasy', "made easy"),
("5'1", "5 feet 1 inch"),
("5'2", "5 feet 2 inch"),
("5'3", "5 feet 3 inch"),
("5'4", "5 feet 4 inch"),
("5'5", "5 feet 5 inch"),
("5'6", "5 feet 6 inch"),
("5'7", "5 feet 7 inch"),
("5'8", "5 feet 8 inch"),
("5'9", "5 feet 9 inch"),
("5'0", "5 feet 0 inch"),
("5'10", "5 feet 10 inch"),
("5'11", "5 feet 11 inch"),
("5'12", "5 feet 12 inch"),
("6'1", "6 feet 1 inch"),
("6'2", "6 feet 2 inch"),
("6'3", "6 feet 3 inch"),
("6'4", "6 feet 4 inch"),
("6'5", "6 feet 5 inch"),
("6'0", "6 feet 0 inch"),]
mispell_dict = {}
for words in misspell:
    if 'why' in words[0] or 'how' in words[0]:
        mispell_dict[words[0]] = words[1]+' '
    mispell_dict[' '+words[0]] = ' '+words[1]+' '
    
def _get_mispell(mispell_dict):
    mispell_re = re.compile('(%s)' % '|'.join(mispell_dict.keys()))
    return mispell_dict, mispell_re

mispellings, mispellings_re = _get_mispell(mispell_dict)

def replace_typical_misspell(text):
    def replace(match):
        return mispellings[match.group(0)]

    return mispellings_re.sub(replace, text)

In [None]:
embed_size = 300 #嵌入词向量维度
max_features = None #词汇量
maxlen = 72 #样本长度

def preapre_data():
    
    #加载数据
    train = pd.read_csv("../input/train.csv")
    test = pd.read_csv("../input/test.csv")
    
    #文本处理（链接符分隔，词汇修正）
    train["question_text"] = train["question_text"].str.lower()
    test["question_text"] = test["question_text"].str.lower()
    train["question_text"] = train["question_text"].apply(lambda x: clean_text(x))
    test["question_text"] = test["question_text"].apply(lambda x: clean_text(x))
    train["question_text"] = train["question_text"].apply(lambda x: replace_typical_misspell(x))
    test["question_text"] = test["question_text"].apply(lambda x: replace_typical_misspell(x))
    
    
    #缺失值处理
    train["question_text"] = train["question_text"].fillna("_####_")
    test["question_text"] = test["question_text"].fillna("_####_")
    
    with multiprocessing.Pool(2) as pool:
        docs_tokenized_train = pool.map(word_tokenize, train['question_text'].values)
        docs_tokenized_test = pool.map(word_tokenize, test['question_text'].values)

    #句子分词
    tokenizer = Tokenizer(num_words=max_features, filters='')
    tokenizer.fit_on_texts(docs_tokenized_train + docs_tokenized_test)

    #句子编码
    X_all = tokenizer.texts_to_sequences(docs_tokenized_train)
    X_test = tokenizer.texts_to_sequences(docs_tokenized_test)

    #填充至一定长度，默认值为0
    X_all = pad_sequences(X_all, maxlen=maxlen)
    X_test = pad_sequences(X_test, maxlen=maxlen)

    #样本标签
    Y = train['target'].values

    submission = test[['qid']]
    return X_all, X_test, Y, tokenizer.word_index, submission

X, X_test, Y, word_index, sub = preapre_data()

In [None]:
#嵌入矩阵长度
max_features = len(word_index)+1
def load_glove(word_index):
    EMBEDDING_FILE = '../input/embeddings/glove.840B.300d/glove.840B.300d.txt'
    def get_coefs(word, *arr): return word.lower(), np.asarray(arr, dtype='float32')
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE) if o.split(" ")[0] in word_index)
    emb_mean, emb_std = -0.005838499, 0.48782197
    embedding_matrix = np.random.normal(emb_mean, emb_std, (max_features, embed_size))
    for word, i in word_index.items():
        if i >= max_features: continue
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None: embedding_matrix[i] = embedding_vector
            
    return embedding_matrix 
    
def load_para(word_index):
    EMBEDDING_FILE = '../input/embeddings/paragram_300_sl999/paragram_300_sl999.txt'
    def get_coefs(word, *arr): return word.lower(), np.asarray(arr, dtype='float32')
    embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE, encoding="utf8", errors='ignore') if len(o)>100 and o.split(" ")[0] in word_index)
    emb_mean, emb_std = -0.0053247833, 0.49346462
    embedding_matrix = np.random.normal(emb_mean, emb_std, (max_features, embed_size))
    for word, i in word_index.items():
        if i >= max_features: continue
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None: embedding_matrix[i] = embedding_vector
    
    return embedding_matrix

def load_embedding_matrix():
    seed_everything()
    embedding_matrix_g = load_glove(word_index)
    embedding_matrix_p = load_para(word_index)
    return 0.6*embedding_matrix_g + 0.4*embedding_matrix_p, np.concatenate((embedding_matrix_g, embedding_matrix_p), axis=1)

embedding_matrix_weight,  embedding_matrix_concat = load_embedding_matrix()

In [None]:
class AdamW(Optimizer):
    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, weight_decay=1e-4,  # decoupled weight decay (1/4)
                 epsilon=1e-8, decay=0., **kwargs):
        super(AdamW, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
            self.wd = K.variable(weight_decay, name='weight_decay') # decoupled weight decay (2/4)
        self.epsilon = epsilon
        self.initial_decay = decay

    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]
        wd = self.wd # decoupled weight decay (3/4)

        lr = self.lr
        if self.initial_decay > 0:
            lr *= (1. / (1. + self.decay * K.cast(self.iterations,
                                                  K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) - lr * wd * p # decoupled weight decay (4/4)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'beta_1': float(K.get_value(self.beta_1)),
                  'beta_2': float(K.get_value(self.beta_2)),
                  'decay': float(K.get_value(self.decay)),
                  'weight_decay': float(K.get_value(self.wd)),
                  'epsilon': self.epsilon}
        base_config = super(AdamW, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
#epoch=5
embed_size = 600
def parallelRNN():
    K.clear_session()
    recurrent_units = 128
    inp = Input(shape=(maxlen,))
    embedding_layer = Embedding(max_features,
                                embed_size,
                                weights=[embedding_matrix_concat],
                                input_length=maxlen,
                                trainable=False)(inp)
    embedding_layer = SpatialDropout1D(0.2, seed=10086)(embedding_layer)

    gru = Bidirectional(CuDNNGRU(64, return_sequences=True, 
                                   kernel_initializer=glorot_uniform(seed=1008600), 
                                   recurrent_initializer=Orthogonal(gain=1.0, seed=1008600)))(embedding_layer)
    lstm = Bidirectional(CuDNNLSTM(64, return_sequences=True,
                                  kernel_initializer=glorot_uniform(seed=111000), 
                                  recurrent_initializer=Orthogonal(gain=1.0, seed=1008600)))(embedding_layer)
    concat = concatenate([gru, lstm], axis=-1)
    concat = GlobalMaxPooling1D()(concat)

    output_layer = Dense(1, activation="sigmoid", kernel_initializer=glorot_uniform(seed=10086))(concat)
    model = Model(inputs=inp, outputs=output_layer)
    model.compile(loss='binary_crossentropy', optimizer=AdamW(weight_decay=0.06))
    return model

In [None]:
def f1_smart(y_true, y_pred):
    args = np.argsort(y_pred)
    tp = y_true.sum()
    fs = (tp - np.cumsum(y_true[args[:-1]])) / np.arange(y_true.shape[0] + tp - 1, tp, -1)
    res_idx = np.argmax(fs)
    return 2 * fs[res_idx], (y_pred[args[res_idx]] + y_pred[args[res_idx + 1]]) / 2

In [None]:
seed_everything()
kfold = StratifiedKFold(n_splits=7, random_state=10, shuffle=True)
y_test = np.zeros((X_test.shape[0], ))
thresholds = [] 

for i, (train_index, valid_index) in enumerate(kfold.split(X, Y)):
    if i != 2 :continue
    X_train, X_val, Y_train, Y_val = X[train_index], X[valid_index], Y[train_index], Y[valid_index]
    filepath="weights_best.h5"
    checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=2, save_best_only=True, mode='min')
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=1, min_lr=0.0, verbose=2)
    callbacks = [checkpoint, reduce_lr]
    model = parallelRNN()
    model.fit(X_train, Y_train, batch_size=512, epochs=7, validation_data=(X_val, Y_val), verbose=2, 
              callbacks=callbacks, shuffle=False, class_weight={0:1, 1:1.25}
             )
    model.load_weights(filepath)
    y_pred = model.predict([X_val], batch_size=1024, verbose=1)
    y_test = np.squeeze(model.predict([X_test], batch_size=1024, verbose=1))
    f1, threshold = f1_smart(np.squeeze(Y_val), np.squeeze(y_pred))
    thresholds.append(threshold)
    print('Optimal F1: {:.4f} at threshold: {:.4f}'.format(f1, threshold))

In [None]:
y_test = y_test.reshape((-1, 1))
pred_test_y = (y_test>np.mean(thresholds)).astype(int)
sub['prediction'] = pred_test_y
sub.to_csv("submission.csv", index=False)