In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
#parameters
vocab_size = 10000
maxlen = 200
embed_dim =32
num_heads = 2
ff_dim = 32

In [3]:
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)
x_train = pad_sequences(x_train,maxlen=maxlen)
x_test = pad_sequences(x_test,maxlen=maxlen)

In [4]:
inputs = layers.Input(shape=(maxlen,))

In [5]:
embedding_layer = layers.Embedding(vocab_size,embed_dim)(inputs)

In [6]:
positions = tf.range(start=0, limit=maxlen, delta=1)
position_embeddings = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)(positions)
x = embedding_layer + position_embeddings

In [7]:
attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)(x, x)

In [8]:
x = layers.LayerNormalization()(x+attention)

In [10]:
# feed forward network - thinking layer or processing brain
ffn = models.Sequential(
    [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
)

- currently each word is a vector for classification we need one sumarised vector 
- pooling will create a summary of all sentences 

In [11]:
# pooling layer 
x = layers.GlobalAveragePooling1D()(x)

In [12]:
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
outputs = layers.Dense(1, activation="sigmoid")(x)

In [13]:
model = models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="adam", 
            loss="binary_crossentropy", 
            metrics=["accuracy"])
model.summary()

In [15]:
model.fit(
    x_train,
    y_train,
    validation_data=(x_test, y_test),
    epochs=5,
    batch_size=64,
)

Epoch 1/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 111ms/step - accuracy: 0.8982 - loss: 0.2493 - val_accuracy: 0.8738 - val_loss: 0.2986
Epoch 2/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 107ms/step - accuracy: 0.9324 - loss: 0.1796 - val_accuracy: 0.8666 - val_loss: 0.3308
Epoch 3/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 108ms/step - accuracy: 0.9537 - loss: 0.1374 - val_accuracy: 0.8564 - val_loss: 0.3812
Epoch 4/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 95ms/step - accuracy: 0.9628 - loss: 0.1101 - val_accuracy: 0.8506 - val_loss: 0.4269
Epoch 5/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 92ms/step - accuracy: 0.9740 - loss: 0.0877 - val_accuracy: 0.8386 - val_loss: 0.5402


<keras.src.callbacks.history.History at 0x1a2826298d0>