In [1]:
import tensorflow_hub as hub
import tensorflow_text as text
from pathlib import Path
import os
import pandas as pd
from collections import defaultdict
import re
import tensorflow as tf
import sklearn as sk
from sklearn import preprocessing

In [15]:
# bert_en
# preprocess_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
# encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"

# bert zh
preprocess_url = "https://tfhub.dev/tensorflow/bert_zh_preprocess/3"
encoder_url = "https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/4"


In [16]:
bert_preprocess = hub.KerasLayer(preprocess_url)

In [17]:
text_test = """Post-Brexit Britain should use its power to defend democracy around the world, Jeremy Hunt will say today.
The Foreign Secretary will insist the UK should not 'underestimate' its strength.
'We are not a superpower and we do not have an empire,' he will say in a speech in Singapore.
'But we do have the fifth-biggest economy in the world, the third-biggest overseas aid budget, the second-biggest military budget in Nato, one of the two biggest financial centres, the world's language, highly effective intelligence services and a world-class diplomatic network.
'We also have immense reserves of soft power, with three of the world's top ten universities, 450,000 international students, 39million visits by tourists in 2017, and a global audience for our media, especially the BBC, measured in the hundreds of millions.'
Foreign Secretary Jeremy Hunt will give a speech in Singapore today where he will say the UK should not underestimate its strength after Brexit.
He is pictured in Myanmar before Christmas
Mr Hunt will point to Britain's links with the Commonwealth, the United States and Europe, saying: 'Those connections are why Britain's post-Brexit role should be to act as an invisible chain linking together the democracies of the world ... who share our values and support our belief in free trade, the rule of law and open societies.
'We should begin by being realistic about our global position.
That means not overestimating our strength but not underestimating it either.'
Mr Hunt will highlight worrying figures showing that 71 countries saw reversals of political and civil liberties in 2017.
He will sign a 'strategic partnership' with Singapore covering areas including defence and education.
""".splitlines()

text_processed = bert_preprocess(text_test)
text_processed.keys()


dict_keys(['input_type_ids', 'input_mask', 'input_word_ids'])

In [18]:
bert_encoder = hub.KerasLayer(encoder_url)

In [27]:
def get_uid_from_file_path(path:Path):
    return os.path.splitext(path.name)[0][len('article'):]

def create_df(language):
    dr = Path(f'data/{language}')
    dr_train = dr/'train-articles-subtask-1'
    fl_labels = dr/'train-labels-subtask-1.txt'
    df = pd.DataFrame({'id':[], 'label':[], 'lines':[]})
    articles_files = [dr_train/fl for fl in os.listdir(dr_train)]
    articles_data = []
    mapping = defaultdict(dict)

    for fl in articles_files:
        try:
            with fl.open('r', encoding='utf-8') as f:
                id = get_uid_from_file_path(fl)
                mapping[id]['lines'] = ' '.join(f.readlines()[2:])
        except UnicodeDecodeError:
            print(fl)

    id_label_re = re.compile('(?P<id>\d+)\s*(?P<label>\w+)')
    with fl_labels.open('r', encoding='utf-8') as f:
        for line in f:
            mtch = id_label_re.match(line)
            if not mtch:
                print(f'bad line in labels file. line: {line}. file: {fl_labels}')
            mapping[mtch.group('id')]['label'] = mtch.group('label')

    for id, data in mapping.items():
        df.loc[len(df.index)] = [id, data['label'], data['lines']]
    return  df
        

df = create_df('ru')

le = sk.preprocessing.LabelEncoder()
le.fit(df['label'])



In [28]:
df.head(1)

Unnamed: 0,id,label,lines
0,24100,opinion,"Война догнала их и там, в доме, где они прятал..."


In [29]:
df['label'].value_counts()

opinion      93
reporting    41
satire        8
Name: label, dtype: int64

In [30]:
# bert layers
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)

# # neural network
l = tf.keras.layers.Dropout(0.1, name='dropout')(outputs['pooled_output'])
l = tf.keras.layers.Dense(3, activation='softmax', name='output')(l)

# model
model = tf.keras.Model(inputs=[text_input], outputs=[l])

In [31]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 text (InputLayer)              [(None,)]            0           []                               
                                                                                                  
 keras_layer_2 (KerasLayer)     {'input_type_ids':   0           ['text[0][0]']                   
                                (None, 128),                                                      
                                 'input_mask': (Non                                               
                                e, 128),                                                          
                                 'input_word_ids':                                                
                                (None, 128)}                                                

In [32]:
model.compile(optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics='accuracy')

In [33]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(df['lines'], df['label'])
y_train = le.transform(y_train)
y_test = le.transform(y_test)


In [34]:
model.fit(x_train, y_train)



<keras.callbacks.History at 0x1f49702bd60>

In [35]:
model.evaluate(x_test, y_test)



[1.111533284187317, 0.5833333134651184]