<a href="https://colab.research.google.com/github/Muzhi1920/awesome-models/blob/main/05-%E7%89%B9%E5%BE%81%E4%BA%A4%E4%BA%92/01_CAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Co-Action
1. 参考知乎：https://zhuanlan.zhihu.com/p/287898562
2. 论文：https://arxiv.org/abs/2011.05625


- 将Target item的embedding，reshape成MLP；
- 输入历史兴趣id的embedding去训练MLP；
- 得到中间交互的latent emb作为交互输出。

In [None]:
import tensorflow as tf
from sequence_feature_layer import SequenceFeatures
from tensorflow import feature_column as fc
from tensorflow.keras.layers import Layer, Dense, LayerNormalization, Dropout, Embedding, Conv1D

## 0.CAN模型配置

In [None]:
can_config = {
    'target_emb_w': [[16, 8], [8, 4]], #reshaped network
    'target_emb_b': [0, 0], # bias is none
    'order_indep': False,  # True
    'orders': 3,  # exp non_linear trans
}

## 1.CAN模型build与train

In [None]:
def build_mlp(target_emb):
    order_indep, orders, weight_emb_w, weight_emb_b = can_config['order_indep'], can_config['orders'], can_config['target_emb_w'], can_config['target_emb_b']
    target_emb = tf.reduce_sum(target_emb, axis=1)
    weight_orders, bias_orders = [], []
    idx = 0
    for i in range(orders):
        weight, bias = [], []
        for w, b in zip(weight_emb_w, weight_emb_b):
            weight.append(tf.reshape(target_emb[:, idx:idx + w[0] * w[1]], [-1, w[0], w[1]]))
            idx += w[0] * w[1]
            if b == 0:
                bias.append(None)
            else:
                bias.append(tf.reshape(target_emb[:, idx:idx + b], [-1, 1, b]))
                idx += b
        weight_orders.append(weight)
        bias_orders.append(bias)
        if not order_indep:
            break
    return weight_orders, bias_orders

def CAN(weight_orders, bias_orders, co_action_feature, mask=None):
    inputs = []
    for i in range(can_config['orders']):
        inputs.append(tf.math.pow(co_action_feature, i + 1.0))
    out_seq = []
    for i, h in enumerate(inputs):
        if can_config['order_indep']:
            weight, bias = weight_orders[i], bias_orders[i]
        else:
            weight, bias = weight_orders[0], bias_orders[0]
        for j, (w, b) in enumerate(zip(weight, bias)):
            h = tf.matmul(h, w)
            if b is not None:
                h = h + b
            if j != len(weight) - 1:
                h = tf.nn.tanh(h)
            out_seq.append(h)
    out_seq = tf.concat(out_seq, 2)
    if mask is not None:
        mask = tf.expand_dims(mask, axis=-1)
        out_seq = out_seq * mask
    out = tf.reduce_sum(out_seq, 1)
    return out


## 2.准备工作

### 2.1 input_layer

In [None]:
target_emb_size = sum([w[0] * w[1] for w in can_config['target_emb_w']]) + sum(can_config['target_emb_b'])
seq_emb_size = can_config['target_emb_w'][0][0]

seq = fc.sequence_categorical_column_with_hash_bucket('seq', hash_bucket_size=10, dtype=tf.int64)
target = fc.sequence_categorical_column_with_hash_bucket('target', hash_bucket_size=10, dtype=tf.int64)
seq_col = fc.embedding_column(seq, dimension=seq_emb_size)
target_col = fc.embedding_column(target, dimension=target_emb_size)
columns = [seq_col, target_col]
features={
  "seq": tf.sparse.SparseTensor(
      indices=[[0, 0], [0, 1], [1, 0], [1, 1], [2, 0]],
      values=[1100, 1101, 1102, 1101, 1103],
      dense_shape=[3, 2]),
  "target": tf.sparse.SparseTensor(
      indices=[[0, 0],[1,0],[2,0]],
      values=[1102,1103,1100],
      dense_shape=[3, 1]),

}
tf.sparse.to_dense(features['seq'])

<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1100, 1101],
       [1102, 1101],
       [1103,    0]], dtype=int32)>

### 2.2 序列与tatget的embedding获取

In [None]:
sequence_feature_layer = SequenceFeatures(columns, name='sequence_features_input_layer')
sequence_inputs, sequence_lengths = sequence_feature_layer(features)
target_input=sequence_inputs['target_embedding']
target_length=sequence_lengths['target_embedding']
sequence_input=sequence_inputs['seq_embedding']
sequence_length=sequence_lengths['seq_embedding']
tf.shape(sequence_input),tf.shape(target_input),sequence_length

(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 3,  2, 16], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([  3,   1, 160], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 2, 1])>)

### 2.3 build target mlp

In [None]:
weights, biases = build_mlp(target_emb=target_input)
weights

[[<tf.Tensor: shape=(3, 16, 8), dtype=float32, numpy=
  array([[[ 0.10137711, -0.06435817, -0.13045818, -0.07548644,
           -0.02454926,  0.00246648,  0.11344709, -0.09588721],
          [ 0.01856251, -0.06053852, -0.03210662, -0.10088634,
           -0.02419616,  0.08359312, -0.03179131,  0.11366984],
          [-0.13051239, -0.10455967,  0.03955878,  0.02983663,
           -0.03758518, -0.13609728,  0.01351203,  0.08538082],
          [-0.00140181,  0.10369221,  0.12698841, -0.09246264,
           -0.02612645,  0.01475628, -0.04308092,  0.08011487],
          [-0.01934592,  0.1343469 ,  0.01589933,  0.06382099,
            0.0210166 , -0.10394734, -0.00878059,  0.05687911],
          [ 0.03829439,  0.15743534,  0.08446892, -0.10956515,
            0.10264438, -0.01690628, -0.02500452,  0.05679527],
          [ 0.04558031, -0.0664928 , -0.0652179 ,  0.03596425,
           -0.01366829, -0.04426897,  0.07075365, -0.1099182 ],
          [ 0.00939228,  0.06575622, -0.02010042, -0.0225

### 2.4 train mlp

In [None]:
can_output = CAN(weight_orders=weights, bias_orders=biases, co_action_feature=sequence_input, mask=None)

## 3.CAN交互输出

In [None]:
can_output

<tf.Tensor: shape=(3, 36), dtype=float32, numpy=
array([[ 5.01778349e-02, -2.00935975e-01, -8.35971311e-02,
         1.93671770e-02, -3.70933488e-02,  5.56501448e-02,
         7.40525723e-02, -1.94012105e-01,  2.69804671e-02,
         1.94713827e-02, -1.02869282e-02, -2.38376167e-02,
         5.20518720e-02,  4.39009368e-02, -2.97014918e-02,
        -7.69861927e-03,  2.03943532e-02, -4.41536978e-02,
         3.47796306e-02, -1.21500455e-02, -9.49517824e-03,
         6.89352117e-03, -2.41043884e-03,  4.77566675e-04,
         1.42347012e-02, -1.06823416e-02, -1.42407846e-02,
        -2.41869153e-03,  4.46004909e-04,  2.34961393e-03,
         1.99747160e-02, -1.71742365e-02,  6.55179436e-04,
         2.31307140e-03, -1.60387601e-03, -2.70899595e-03],
       [-1.42529428e-01, -2.17842981e-02,  3.11489683e-03,
         6.14735559e-02, -1.96310915e-02,  1.94459707e-01,
         1.50961757e-01,  1.35005370e-01,  2.40996350e-02,
        -1.27256569e-02, -6.07332308e-03, -1.79368556e-02,
      