# ESSM

In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import numpy as np

In [2]:
class ESMM(nn.Module):
    def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, hidden_dim=[128, 64], dropouts=[0.5, 0.5],
                 output_size=1, num_task=2):
        """
        esmm model input parameters
        :param user_feature_dict: user feature dict include: {feature_name: (feature_unique_num, feature_index)}
        :param item_feature_dict: item feature dict include: {feature_name: (feature_unique_num, feature_index)}
        :param emb_dim: int, embedding size
        :param hidden_dim: list of ctr and ctcvr dnn hidden sizes
        :param dropouts: list of ctr and ctcvr dnn drop out probability
        :param output_size: int out put size
        :param num_task: int default 2 multitask numbers
        """
        super(ESMM, self).__init__()
        
        # check input parameters
        if user_feature_dict is None or item_feature_dict is None:
            raise Exception("input parameter user_feature_dict and item_feature_dict must be not None")
            
        if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict) is False:
            raise Exception("input parameter user_feature_dict and item_feature_dict must be dict")
        
        self.user_feature_dict = user_feature_dict
        self.item_feature_dict = item_feature_dict
        self.num_task = num_task
        
        # embedding初始化
        user_cate_feature_nums, item_cate_feature_nums = 0, 0
        # 用户特征
        for user_cate, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_cate_feature_nums += 1
                setattr(self, user_cate, nn.Embedding(num[0], emb_dim))
        # 物品特征
        for item_cate, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_cate_feature_nums += 1
                setattr(self, item_cate, nn.Embedding(num[0], emb_dim))
                
        # user embedding + item embedding
        hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) + \
                      (len(user_feature_dict) - user_cate_feature_nums) + (len(item_feature_dict) - item_cate_feature_nums)
        
        # esmm 独立任务的DNN结构
        for i in range(self.num_task):
            setattr(self, 'task_{}_dnn'.format(i + 1), nn.ModuleList())
            hid_dim = [hidden_size] + hidden_dim
            for j in range(len(hid_dim) - 1):
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_hidden_{}'.format(j),
                                                                      nn.Linear(hid_dim[j], hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_batchnorm_{}'.format(j),
                                                                      nn.BatchNorm1d(hid_dim[j + 1]))
                getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_dropout_{}'.format(j),
                                                                      nn.Dropout(dropouts[j]))
            getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('task_last_layer',
                                                                  nn.Linear(hid_dim[-1], output_size))

    def forward(self, x):
#         assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict)
        # embedding
        user_embed_list, item_embed_list = list(), list()
        for user_feature, num in self.user_feature_dict.items():
            if num[0] > 1:
                user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long()))
            else:
                user_embed_list.append(x[:, num[1]].unsqueeze(1))
                
        for item_feature, num in self.item_feature_dict.items():
            if num[0] > 1:
                item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long()))
            else:
                item_embed_list.append(x[:, num[1]].unsqueeze(1))
            
        # embedding 融合
        user_embed = torch.cat(user_embed_list, axis=1)
        item_embed = torch.cat(item_embed_list, axis=1)
        
        # hidden layer
        hidden = torch.cat([user_embed, item_embed], axis=1).float()

        # task tower
        task_outputs = list()
        for i in range(self.num_task):
            x = hidden
            for mod in getattr(self, 'task_{}_dnn'.format(i + 1)):
                x = mod(x)
            task_outputs.append(x)

        return task_outputs
 

In [3]:
a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1],
                               [4, 5, 3, 8, 0.6, 0.43],
                               [6, 3, 2, 9, 0.12, 0.32],
                               [9, 1, 1, 1, 0.12, 0.45],
                               [8, 3, 1, 4, 0.21, 0.67]]))
a.shape

torch.Size([5, 6])

In [4]:
a[:, 1], a[:, 1].unsqueeze(1)

(tensor([2., 5., 3., 1., 3.], dtype=torch.float64),
 tensor([[2.],
         [5.],
         [3.],
         [1.],
         [3.]], dtype=torch.float64))

In [5]:
# Spase featue 和 Dese feature ()考虑使用name_tuple
user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)}
item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)}
esmm = ESMM(user_cate_dict, item_cate_dict)
tasks = esmm(a)
print(tasks)

[tensor([[-1.6505],
        [ 1.0325],
        [-0.0688],
        [ 0.1978],
        [-0.2421]], grad_fn=<AddmmBackward>), tensor([[ 0.3390],
        [ 0.3304],
        [-1.8689],
        [ 0.0599],
        [ 0.5291]], grad_fn=<AddmmBackward>)]


In [8]:
tasks[1].squeeze()

tensor([ 0.3390,  0.3304, -1.8689,  0.0599,  0.5291],
       grad_fn=<SqueezeBackward0>)

In [6]:
esmm

ESMM(
  (user_id): Embedding(11, 128)
  (user_list): Embedding(12, 128)
  (item_id): Embedding(8, 128)
  (item_cate): Embedding(6, 128)
  (task_1_dnn): ModuleList(
    (ctr_hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (ctr_batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ctr_dropout_0): Dropout(p=0.5, inplace=False)
    (ctr_hidden_1): Linear(in_features=128, out_features=64, bias=True)
    (ctr_batchnorm_1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ctr_dropout_1): Dropout(p=0.5, inplace=False)
    (task_last_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (task_2_dnn): ModuleList(
    (ctr_hidden_0): Linear(in_features=514, out_features=128, bias=True)
    (ctr_batchnorm_0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ctr_dropout_0): Dropout(p=0.5, inplace=False)
    (ctr_hidden_1): Linear(in_features=128, out_f

In [7]:
w = SummaryWriter()
w.add_graph(esmm, a)

# 参考

[论文地址](https://arxiv.org/abs/1804.07931)  
[阿里CVR预估模型之ESMM](https://zhuanlan.zhihu.com/p/57481330)  
[CVR预估的新思路：完整空间多任务模型](https://zhuanlan.zhihu.com/p/37562283)