## Twitter ' s who-to-follow 推薦系統部份(candidate generation)
- [參考](https://blog.twitter.com/engineering/en_us/topics/insights/2022/model-based-candidate-generation-for-account-recommendations)

In [22]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [19]:
# 簡單以各兩項為例

class Tower(layers.Layer):
    def __init__(self, feature_dict, **kwargs):
        super(Tower, self).__init__(**kwargs)
        self.nums_features = len(feature_dict)
        self.embeddings = []
        self.denses = []
        self.concatenate = layers.Concatenate()
        self.dense_512 = layers.Dense(units=512, activation='relu')
        self.dense_last = layers.Dense(units=256, activation='relu', name='output')

        for i in range(len(feature_dict)):
            input_dim = feature_dict[i]['input_dim']
            output_dim = feature_dict[i]['output_dim']
            name = feature_dict[i]['name']
            self.embeddings.append(layers.Embedding(input_dim=input_dim, output_dim=output_dim, name=name, input_length=1))
            denses = []
            denses.append(layers.Dense(units=512, activation='relu'))
            denses.append(layers.Dense(units=512, activation='relu'))
            denses.append(layers.Dense(units=256, activation='relu'))
            denses.append(layers.Dense(units=256, activation='relu'))
            self.denses.append(denses)
        
    
    def call(self, X):
        embeddings = []
        for i in range(len(self.embeddings)):
            embedding = self.embeddings[i](X[:, i])
            for j in range(len(self.denses[i])):
                embedding = self.denses[i][j](embedding)
            embeddings.append(embedding)
        embeddings = self.dense_512(self.concatenate(embeddings))
        
        return self.dense_last(embeddings)
    
    
class TwitterWhoToFollow(keras.Model):
    def __init__(self, consumer_feature_dict, producer_feature_dict, **kwargs):
        super(TwitterWhoToFollow, self).__init__(**kwargs)
        self.consumer_tower = Tower(consumer_feature_dict)
        self.producer_tower = Tower(producer_feature_dict)
        self.dot_layer = layers.Dot(axes=1)
        
    def call(self, X):
        consumer_tower_embedding = self.consumer_tower(X)
        producer_tower_embedding = self.producer_tower(X)
        
        return self.dot_layer([consumer_tower_embedding, producer_tower_embedding])

In [21]:
embed_dim = 512


# 參數我先縮小了
consumer_feature_dict = [
    {'name': 'user_id', 'input_dim': 10000, 'output_dim': embed_dim},
    {'name': 'interested_in_follow', 'input_dim': 1450, 'output_dim': embed_dim}
]
producer_feature_dict = [
    {'name': 'geo_counts', 'input_dim': 10000, 'output_dim': embed_dim},
    {'name': 'known_for', 'input_dim': 1450, 'output_dim': embed_dim}
]

twitter_model = TwitterWhoToFollow(consumer_feature_dict, producer_feature_dict)
twitter_model.build((None, 2))   # 其實這邊不太合理，應該是4，然後call那邊需要改
twitter_model.summary()

Model: "twitter_who_to_follow_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
tower_10 (Tower)             multiple                  7701248   
_________________________________________________________________
tower_11 (Tower)             multiple                  7701248   
_________________________________________________________________
dot_4 (Dot)                  multiple                  0         
Total params: 15,402,496
Trainable params: 15,402,496
Non-trainable params: 0
_________________________________________________________________


In [23]:
# 測試
batch_data = np.random.randint(low=1, high=1000, size=(64, 2))
twitter_model(batch_data)

<tf.Tensor: shape=(64, 1), dtype=float32, numpy=
array([[0.00225225],
       [0.00303309],
       [0.00257653],
       [0.00285879],
       [0.00229791],
       [0.00218818],
       [0.00242005],
       [0.00237509],
       [0.0033    ],
       [0.00308561],
       [0.00275155],
       [0.00302385],
       [0.00226284],
       [0.0026786 ],
       [0.00232036],
       [0.00324307],
       [0.00184285],
       [0.0024542 ],
       [0.00322019],
       [0.00325187],
       [0.00245779],
       [0.00310049],
       [0.00264306],
       [0.0030272 ],
       [0.00239984],
       [0.00329036],
       [0.00287322],
       [0.00343102],
       [0.00277571],
       [0.00311583],
       [0.00291544],
       [0.0026752 ],
       [0.00229575],
       [0.00258374],
       [0.00251001],
       [0.00266387],
       [0.00348713],
       [0.00205453],
       [0.00292226],
       [0.00316013],
       [0.00201679],
       [0.00187615],
       [0.00347805],
       [0.00293047],
       [0.00265237],
      