In [1]:
import numpy as np
from tensorflow.keras.preprocessing.text import one_hot
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Embedding

reviews = ['nice food',
        'amazing restaurant',
        'too good',
        'just loved it!',
        'will go again',
        'horrible food',
        'never go there',
        'poor service',
        'poor quality',
        'needs improvement']

sentiment = np.array([1,1,1,1,1,0,0,0,0,0])

In [2]:
one_hot("amazing restaurant",30)

[15, 12]

In [3]:
vocab_size = 30
encoded_reviews = [one_hot(d, vocab_size) for d in reviews]
print(encoded_reviews)

[[11, 16], [15, 12], [23, 6], [17, 22, 5], [25, 23, 27], [23, 16], [28, 23, 17], [18, 21], [18, 8], [20, 5]]


In [4]:
max_length = 4
padded_reviews = pad_sequences(encoded_reviews, maxlen = max_length, padding="post")
print(padded_reviews)

[[11 16  0  0]
 [15 12  0  0]
 [23  6  0  0]
 [17 22  5  0]
 [25 23 27  0]
 [23 16  0  0]
 [28 23 17  0]
 [18 21  0  0]
 [18  8  0  0]
 [20  5  0  0]]


In [5]:
embedded_vector_size = 5

model = Sequential()
model.add(Embedding(vocab_size, embedded_vector_size, input_length=max_length, name="embedding"))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))

Metal device set to: Apple M1


2023-02-10 15:18:55.842498: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-02-10 15:18:55.842977: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [6]:
X = padded_reviews
y = sentiment

In [7]:
model.compile(optimizer="adam", loss = "binary_crossentropy", metrics=["accuracy"])
print(model.summary())

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 4, 5)              150       
                                                                 
 flatten (Flatten)           (None, 20)                0         
                                                                 
 dense (Dense)               (None, 1)                 21        
                                                                 
Total params: 171
Trainable params: 171
Non-trainable params: 0
_________________________________________________________________
None


In [8]:
model.fit(X, y, epochs=50, verbose=0)

2023-02-10 15:20:08.425828: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-02-10 15:20:08.711326: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


<keras.callbacks.History at 0x28cf0a820>

In [9]:
# evaluate the model

loss, accuracy = model.evaluate(X,y)
accuracy



2023-02-10 15:20:37.932660: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


1.0

In [10]:
weights = model.get_layer("embedding").get_weights()[0]
len(weights)

30

In [11]:
weights[13]

array([-0.00194472, -0.0338334 ,  0.02405417,  0.02805239,  0.00816403],
      dtype=float32)

In [12]:
weights[4]

array([ 0.03067657, -0.01464423, -0.00392935, -0.04838952,  0.0167723 ],
      dtype=float32)

In [13]:
weights[16]

array([-0.01610599,  0.04544757, -0.00028907, -0.01158243,  0.01714518],
      dtype=float32)

In [14]:
model.predict(X)



2023-02-10 15:29:43.127053: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


array([[0.5118633 ],
       [0.5145117 ],
       [0.5143169 ],
       [0.5753499 ],
       [0.5419775 ],
       [0.4887682 ],
       [0.4625665 ],
       [0.43578967],
       [0.4495697 ],
       [0.44996572]], dtype=float32)