## From https://keras.io/examples/nlp/text_classification_with_transformer/

In [1]:
import sys
sys.path.append('../input/challenge2021')

import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

from keras_transformers import MultiHeadSelfAttention, TransformerBlock, TokenAndPositionEmbedding

from sklearn.utils.class_weight import compute_class_weight

In [3]:
trainval = keras.preprocessing.text_dataset_from_directory(
    '../input/challenge2021/train/train',
    batch_size=8,
    validation_split=0.25,
    subset="training",
    seed=1234
)
test = keras.preprocessing.text_dataset_from_directory(
    '../input/challenge2021/train/train',
    batch_size=8,
    validation_split=0.25,
    subset="validation",
    seed=1234
)
print(
    "Number of batches in trainval: %d"
    % tf.data.experimental.cardinality(trainval)
)
print(
    "Number of batches in test: %d" % tf.data.experimental.cardinality(test)
)

Found 217197 files belonging to 28 classes.
Using 162898 files for training.
Found 217197 files belonging to 28 classes.
Using 54299 files for validation.
Number of batches in trainval: 20363
Number of batches in test: 6788


In [4]:
trainval = trainval.shuffle(220000)
train = trainval.skip(6666)
val = trainval.take(6666)
print(
    "Number of batches in train: %d"
    % tf.data.experimental.cardinality(train)
)
print(
    "Number of batches in val: %d" % tf.data.experimental.cardinality(val)
)

Number of batches in train: 13697
Number of batches in val: 6666


In [5]:
for text_batch, label_batch in train.take(1):
    for i in range(5):
        print(text_batch.numpy()[i])
        print(label_batch.numpy()[i])

KeyboardInterrupt: 

In [6]:
from keras.layers.experimental.preprocessing import TextVectorization

# Model constants.
max_features = 20000
embedding_dim = 128
sequence_length = 500

vectorize_layer = TextVectorization(
    standardize='lower_and_strip_punctuation',
    max_tokens=max_features,
    output_mode="int",
    output_sequence_length=sequence_length,
)

In [7]:
# Let's make a text-only dataset (no labels):
text_ds = train.map(lambda x, y: x)
# Let's call `adapt`:
vectorize_layer.adapt(text_ds)

In [8]:
def vectorize_text(text, label):
    text = tf.expand_dims(text, -1)
    return vectorize_layer(text), label

# Vectorize the data.
train_ds = train.map(vectorize_text)
val_ds = val.map(vectorize_text)
test_ds = test.map(vectorize_text)

# Do async prefetching / buffering of the data for best performance on GPU.
train_ds = train_ds.cache().prefetch(buffer_size=10)
val_ds = val_ds.cache().prefetch(buffer_size=10)
test_ds = test_ds.cache().prefetch(buffer_size=10)

In [22]:
from keras import layers, regularizers
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

inputs = layers.Input(shape=(sequence_length,))
embedding_layer = TokenAndPositionEmbedding(sequence_length, max_features, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu", kernel_regularizer=regularizers.l2(0.001))(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(28, activation="softmax", kernel_regularizer=regularizers.l2(0.001))(x)

model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()

Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 500)]             0         
_________________________________________________________________
token_and_position_embedding (None, 500, 32)           656000    
_________________________________________________________________
transformer_block_3 (Transfo (None, 500, 32)           6464      
_________________________________________________________________
global_average_pooling1d_3 ( (None, 32)                0         
_________________________________________________________________
dropout_13 (Dropout)         (None, 32)                0         
_________________________________________________________________
dense_28 (Dense)             (None, 20)                660       
_________________________________________________________________
dropout_14 (Dropout)         (None, 20)               

In [23]:
callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True), 
             keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=1, min_lr=1e-5)]
model.compile("adam", 
              "sparse_categorical_crossentropy", 
              metrics=["accuracy"])
history = model.fit(
    train_ds, epochs=10, validation_data=val_ds, callbacks=callbacks)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10


In [24]:
from sklearn.metrics import classification_report
y_pred = model.predict(test_ds)
y_pred = tf.argmax(y_pred, axis=1).numpy()
y_pred

array([13, 11, 24, ...,  3, 11, 15])

In [25]:
y_true = np.concatenate([y for x, y in test_ds], axis=0)
y_true

array([13, 11, 24, ...,  3, 11, 15], dtype=int32)

In [26]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.51      0.53      0.52       346
           1       0.65      0.70      0.67      1042
           2       0.75      0.47      0.58       212
           3       0.69      0.68      0.69      2927
           4       0.75      0.58      0.65       424
           5       0.67      0.54      0.60      1003
           6       0.84      0.81      0.83      3174
           7       0.67      0.70      0.68      1142
           8       0.89      0.93      0.91      1319
           9       0.71      0.61      0.66       358
          10       0.77      0.66      0.71      1066
          11       0.86      0.86      0.86     17460
          12       0.66      0.68      0.67       220
          13       0.78      0.83      0.80      3642
          14       0.53      0.61      0.56       183
          15       0.69      0.71      0.70      2536
          16       0.70      0.47      0.56       256
          17       0.56    