In [116]:
import numpy as np
from tensorflow.keras.preprocessing.text import one_hot
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Embedding

In [117]:
reviews = ['nice food',
        'amazing restaurant',
        'too good',
        'just loved it!',
        'will go again',
        'horrible food',
        'never go there',
        'poor service',
        'poor quality',
        'needs improvement']

sentiment = np.array([1,1,1,1,1,0,0,0,0,0])

In [118]:
one_hot('amazing restaurant',30)

[6, 14]

In [119]:
vocab_size = 30
encoded_reviews = [one_hot(d, vocab_size) for d in reviews]
print(encoded_reviews)

[[21, 22], [6, 14], [12, 8], [27, 25, 12], [24, 5, 24], [24, 22], [4, 5, 12], [7, 1], [7, 17], [8, 18]]


In [120]:
max_length = 3
padded_reviews = pad_sequences(encoded_reviews, maxlen = max_length, padding = 'post')
print(padded_reviews)

[[21 22  0]
 [ 6 14  0]
 [12  8  0]
 [27 25 12]
 [24  5 24]
 [24 22  0]
 [ 4  5 12]
 [ 7  1  0]
 [ 7 17  0]
 [ 8 18  0]]


In [121]:
embeded_vector_size = 4

model = Sequential()
model.add(Embedding(vocab_size, embeded_vector_size, input_length = max_length, name = 'embedding'))
model.add(Flatten())
model.add(Dense(1, activation = 'sigmoid'))

In [122]:
X = padded_reviews
y = sentiment

In [123]:
model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

In [145]:
model.fit(X, y, epochs = 10, verbose = 0)

<keras.src.callbacks.history.History at 0x2149d38dbe0>

In [147]:
model.summary()

In [151]:
loss, accuracy = model.evaluate(X, y)
accuracy

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.8000 - loss: 0.6389


0.800000011920929

In [153]:
weights = model.get_layer('embedding').get_weights()[0]
len(weights)

30

In [155]:
weights[21]

array([-0.04502627, -0.04857473, -0.07578903, -0.03594675], dtype=float32)

In [157]:
weights[22]

array([-0.02628585, -0.01714459,  0.01360437, -0.00998985], dtype=float32)