In [1]:
import tensorflow as tf
import numpy as np


num_users = 1000
num_items = 500
num_samples = 10000
embedding_size = 50

user_ids_train = np.random.randint(0, num_users, num_samples)
item_ids_train = np.random.randint(0, num_items, num_samples)
ratings_train = np.random.randint(1, 6, num_samples)

user_ids_val = np.random.randint(0, num_users, num_samples)
item_ids_val = np.random.randint(0, num_items, num_samples)
ratings_val = np.random.randint(1, 6, num_samples)

user_ids_test = np.random.randint(0, num_users, num_samples)
item_ids_test = np.random.randint(0, num_items, num_samples)
ratings_test = np.random.randint(1, 6, num_samples)


class CollaborativeFilteringModel(tf.keras.Model):
    def __init__(self, num_users, num_items, embedding_size):
        super(CollaborativeFilteringModel, self).__init__()
        self.user_embedding = tf.keras.layers.Embedding(num_users, embedding_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, embedding_size)
        self.dot = tf.keras.layers.Dot(axes=1)

    def call(self, inputs):
        user_id, item_id = inputs
        user_embedding = self.user_embedding(user_id)
        item_embedding = self.item_embedding(item_id)
        return self.dot([user_embedding, item_embedding])


model = CollaborativeFilteringModel(num_users, num_items, embedding_size)
model.compile(optimizer='adam', loss='mean_squared_error')


history = model.fit([user_ids_train, item_ids_train], ratings_train,
                    validation_data=([user_ids_val, item_ids_val], ratings_val),
                    epochs=10, batch_size=64)


loss = model.evaluate([user_ids_test, item_ids_test], ratings_test)
print("Test Loss:", loss)


Epoch 1/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - loss: 10.8765 - val_loss: 11.0112
Epoch 2/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 10.9314 - val_loss: 10.9912
Epoch 3/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 10.7394 - val_loss: 10.8475
Epoch 4/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 10.2898 - val_loss: 10.2612
Epoch 5/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 9.1155 - val_loss: 8.8835
Epoch 6/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 7.0420 - val_loss: 6.8778
Epoch 7/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - loss: 4.9461 - val_loss: 4.9611
Epoch 8/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 3.1231 - val_loss: 3.6747
Epoch 9/10
[1m157/157[0m [32m