In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from tqdm import tqdm
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from IPython.display import display


from utils.metric import roc_auc
from utils.model import define_rnn_model, define_cnn_model, define_lstm_model, define_gru_model, define_bi_model, define_cnn_rnn_model

if not tf.__version__.startswith('2'):
    raise ValueError('This code requires TensorFlow V2.x')

In [None]:
# Data
o_train = pd.read_csv('Data/jigsaw-toxic-comment-train.csv')

# Pre processing
o_train.drop(['severe_toxic','obscene','threat','insult','identity_hate'],axis=1,inplace=True) # Drop other columns

# Get input
type_input = widgets.Dropdown(
    options=['Phần trăm', 'Số lượng'],
    value='Số lượng',
    description='Dữ liệu vào',
    disabled=False,
)
value_input = widgets.IntSlider(
    value=20000,
    min=0,
    max=o_train.shape[0],
    step=1000,
    description='Số lượng',
    readout=True
)

def update_value_input(*args):
    if type_input.value == 'Số lượng':
        value_input.value=20000
        value_input.max=o_train.shape[0]
        value_input.step=1000
        value_input.description='Số lượng'
    else:
        value_input.value=50
        value_input.max=100
        value_input.step=1
        value_input.description='Phần trăm'
type_input.observe(update_value_input, 'value')

display(type_input)
display(value_input)


In [None]:
model_array = ['cnn', 'rnn', 'lstm', 'gru', 'bi_directional', 'cnn + rnn']

model_input = widgets.Dropdown(
    options=model_array,
    value='cnn',
    description='Loại mô hình',
    disabled=False,
)

display(model_input)

In [None]:
# # Embedding
# embeddings_index = {}
# f = open('glove.840B.300d.txt','r',encoding='utf-8')
# for line in tqdm(f):
#     values = line.split(' ')
#     word = values[0]
#     coefs = np.asarray([float(val) for val in values[1:]])
#     embeddings_index[word] = coefs
# f.close()

# print('Found %s word vectors.' % len(embeddings_index))

In [None]:
if type_input.value == 'Số lượng':
    train = o_train.loc[:value_input.value,:]
else:
    train = o_train.loc[:value_input.value * o_train.shape[0] / 100,:]
max_test = train['comment_text'].apply(lambda x:len(str(x).split())).max() # Max test's length

xtrain, xvalid, ytrain, yvalid = train_test_split(train.comment_text.values, train.toxic.values, 
                                                  stratify=train.toxic.values, 
                                                  random_state=42, 
                                                  test_size=0.2, shuffle=True)

In [None]:
token = keras.preprocessing.text.Tokenizer(num_words=None)

token.fit_on_texts(list(xtrain) + list(xvalid))
xtrain_seq = token.texts_to_sequences(xtrain)
xvalid_seq = token.texts_to_sequences(xvalid)

#zero pad the sequences
xtrain_pad = keras.preprocessing.sequence.pad_sequences(xtrain_seq, maxlen=max_test)
xvalid_pad = keras.preprocessing.sequence.pad_sequences(xvalid_seq, maxlen=max_test)

word_index = token.word_index

In [None]:
max_vocab = len(word_index) + 1
model_type_array = {
    'cnn': define_cnn_model(max_vocab, max_test),
    'rnn': define_rnn_model(max_vocab, max_test), 
    'lstm': define_lstm_model(max_vocab, max_test), 
    'gru': define_gru_model(max_vocab, max_test),
    'bi_directional': define_bi_model(max_vocab, max_test),
    'cnn + rnn': define_cnn_rnn_model(max_vocab, max_test),
}
model = model_type_array[model_input.value]

In [21]:
if model_input.value == 'cnn':
    history = model.fit([xtrain_pad, xtrain_pad, xtrain_pad], ytrain, epochs=10)
    model.save('cnn.h5')
    scores = model.predict([xvalid_pad, xvalid_pad, xvalid_pad])
    print("Auc: %.2f%%" % (roc_auc(scores, yvalid)))
else:
    history = model.fit(xtrain_pad, ytrain, epochs=10)
    model.save(model_input.value + '.h5')
    scores = model.predict(xvalid_pad)
    print("Auc: %.2f%%" % (roc_auc(scores, yvalid)))

KeyboardInterrupt: 