In [0]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from kaggle_datasets import KaggleDatasets
import transformers
from tqdm import tqdm

from tokenizers import BertWordPieceTokenizer



# Настройка TPU

In [0]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

Running on TPU  grpc://10.0.0.2:8470
REPLICAS:  8


# Загрузка данных

In [0]:
train = pd.read_csv("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv")
test = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv')
valid = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/validation.csv')
sub = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv')

# Tokenizer для BERT

In [0]:
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
tokenizer.save_pretrained('.')
fast_tokenizer = BertWordPieceTokenizer('vocab.txt', lowercase=False)
fast_tokenizer

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




Tokenizer(vocabulary_size=119547, model=BertWordPiece, unk_token=[UNK], sep_token=[SEP], cls_token=[CLS], pad_token=[PAD], mask_token=[MASK], clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=False, wordpieces_prefix=##)

# Fast Encoder для BERT

In [0]:
def fast_encode(texts, tokenizer, chunk_size=256, maxlen=512):
    tokenizer.enable_truncation(max_length=maxlen)
    tokenizer.enable_padding(max_length=maxlen)
    all_ids = []
    for i in tqdm(range(0, len(texts), chunk_size)):
        text_chunk = texts[i:i+chunk_size].tolist()
        encoded_chunks = tokenizer.encode_batch(text_chunk)
        all_ids.extend([encoded_chunk.ids for encoded_chunk in encoded_chunks])
    return np.array(all_ids)

# Предобработка и кодирование исходных данных

In [0]:
MAX_LEN = 192
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

x_train = fast_encode(train.comment_text.astype(str), fast_tokenizer, maxlen=MAX_LEN)
x_valid = fast_encode(valid.comment_text.astype(str), fast_tokenizer, maxlen=MAX_LEN)
x_test = fast_encode(test.content.astype(str), fast_tokenizer, maxlen=MAX_LEN)
y_train = train.toxic.values
y_valid = valid.toxic.values

100%|██████████| 874/874 [00:28<00:00, 30.53it/s]
100%|██████████| 32/32 [00:01<00:00, 28.16it/s]
100%|██████████| 250/250 [00:11<00:00, 22.50it/s]


In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).repeat().shuffle(2048).batch(BATCH_SIZE).prefetch(AUTO)
valid_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)).batch(BATCH_SIZE).cache().prefetch(AUTO)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)

# Training

In [0]:
def build_model(transformer, max_len=512):
    input_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
    sequence_output = transformer(input_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='relu')(cls_token)
    model = Model(inputs=input_ids, outputs=out)
    model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [0]:
%%time
with strategy.scope():
    transformer_layer = transformers.TFElectraModel.from_pretrained('google/electra-small-discriminator')
    model = build_model(transformer_layer, max_len=MAX_LEN)
model.summary()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=54466044.0, style=ProgressStyle(descrip…


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_word_ids (InputLayer)  [(None, 192)]             0         
_________________________________________________________________
tf_electra_model (TFElectraM ((None, 192, 256),)       13483008  
_________________________________________________________________
tf_op_layer_strided_slice (T [(None, 256)]             0         
_________________________________________________________________
dense (Dense)                (None, 1)                 257       
Total params: 13,483,265
Trainable params: 13,483,265
Non-trainable params: 0
_________________________________________________________________
CPU times: user 13.3 s, sys: 2.22 s, total: 15.5 s
Wall time: 16.7 s


In [0]:
n_steps = x_train.shape[0] // BATCH_SIZE
train_history = model.fit(train_dataset, steps_per_epoch=n_steps, validation_data=valid_dataset, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3

In [0]:
n_steps = x_valid.shape[0] // BATCH_SIZE
train_history_2 = model.fit(valid_dataset.repeat(), steps_per_epoch=n_steps, epochs=6)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [0]:
sub['toxic'] = model.predict(test_dataset, verbose=1)
sub.to_csv('submission.csv', index=False)

