In [1]:
# -*- coding:utf-8 -*-
"""
Author:
    zanshuxun, zanshuxun@aliyun.com
    songwei, magic_24k@163.com

Reference:
    [1] [Jiaqi Ma, Zhe Zhao, Xinyang Yi, et al. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts[C]](https://dl.acm.org/doi/10.1145/3219819.3220007)
"""
import torch
import torch.nn as nn
import pickle
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from deepctr_torch.models.basemodel import BaseModel
from deepctr_torch.inputs import combined_dnn_input, embedding_lookup, maxlen_lookup
from deepctr_torch.layers import DNN, PredictionLayer, CIN, concat_fun, InteractingLayer
from deepctr_torch.layers.sequence import AttentionSequencePoolingLayer
import pandas as pd

class MMOELayer(nn.Module):
    """
    The Multi-gate Mixture-of-Experts layer in MMOE model
      Input shape
        - 2D tensor with shape: ``(batch_size,units)``.

      Output shape
        - A list with **num_tasks** elements, which is a 2D tensor with shape: ``(batch_size, output_dim)`` .

      Arguments
        - **input_dim** : Positive integer, dimensionality of input features.
        - **num_tasks**: integer, the number of tasks, equal to the number of outputs.
        - **num_experts**: integer, the number of experts.
        - **output_dim**: integer, the dimension of each output of MMOELayer.

    References
      - [Jiaqi Ma, Zhe Zhao, Xinyang Yi, et al. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts[C]](https://dl.acm.org/doi/10.1145/3219819.3220007)
    """

    def __init__(self, input_dim, num_tasks, num_experts, output_dim):
        super(MMOELayer, self).__init__()
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.num_tasks = num_tasks
        self.output_dim = output_dim
        self.expert_network = nn.Linear(self.input_dim, self.num_experts * self.output_dim, bias=True)
        self.gating_networks = nn.ModuleList(
            [nn.Linear(self.input_dim, self.num_experts, bias=False) for _ in range(self.num_tasks)])
        # initial model
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight)

    def forward(self, inputs):
        outputs = []
        expert_out = self.expert_network(inputs)
        expert_out = expert_out.reshape([-1, self.output_dim, self.num_experts])
        for i in range(self.num_tasks):
            gate_out = self.gating_networks[i](inputs)
            gate_out = gate_out.softmax(1).unsqueeze(-1)
            output = torch.bmm(expert_out, gate_out).squeeze()
            outputs.append(output)
        return outputs


class MMOE(BaseModel):
    """Instantiates the Multi-gate Mixture-of-Experts architecture.

    :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
    :param num_tasks: integer, number of tasks, equal to number of outputs, must be greater than 1.
    :param tasks: list of str, indicating the loss of each tasks, ``"binary"`` for  binary logloss, ``"regression"`` for regression loss. e.g. ['binary', 'regression']
    :param num_experts: integer, number of experts.
    :param expert_dim: integer, the hidden units of each expert.
    :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of shared-bottom DNN
    :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
    :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
    :param init_std: float,to use as the initialize std of embedding vector
    :param task_dnn_units: list,list of positive integer or empty list, the layer number and units in each layer of task-specific DNN
    :param seed: integer ,to use as random seed.
    :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
    :param dnn_activation: Activation function to use in DNN
    :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
    :param device: str, ``"cpu"`` or ``"cuda:0"``

    :return: A PyTorch model instance.
    """

    def __init__(self, dnn_feature_columns, history_feature_list, num_tasks, tasks, num_experts=4, expert_dim=8, dnn_hidden_units=(128, 128),
                 l2_reg_embedding=1e-5, l2_reg_dnn=1e-5, init_std=0.0001, task_dnn_units=None, seed=1024, dnn_dropout=0,
                 dnn_activation='relu', dnn_use_bn=False, device='cpu', gpus=[0, 1]):
        
        super(MMOE, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns,
                                   l2_reg_embedding=l2_reg_embedding, seed=seed, device=device)
        if num_tasks <= 1:
            raise ValueError("num_tasks must be greater than 1")
        if len(tasks) != num_tasks:
            raise ValueError("num_tasks must be equal to the length of tasks")
        for task in tasks:
            if task not in ['binary', 'regression']:
                raise ValueError("task must be binary or regression, {} is illegal".format(task))
        self.sparse_feature_columns = list(
            filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
        self.varlen_sparse_feature_columns = list(
            filter(lambda x: isinstance(x, VarLenSparseFeat), dnn_feature_columns)) if dnn_feature_columns else []
        self.dense_feature_columns = list(
            filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns)) if dnn_feature_columns else []
        
        # atten tag key
        self.history_feature_list = history_feature_list
        self.item_features = history_feature_list

        self.history_feature_columns = []
        self.sparse_varlen_feature_columns = []
        self.history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list))
        for fc in self.varlen_sparse_feature_columns:
            feature_name = fc.name
            if feature_name in self.history_fc_names:
                self.history_feature_columns.append(fc)
            else:
                self.sparse_varlen_feature_columns.append(fc)
        
        # din component
#         att_emb_dim = self._compute_interest_dim()
#         use_negsampling = False
#         gru_type="GRU"
#         att_hidden_size=(128, 64)
#         self.use_negsampling = use_negsampling
#         self.gru_type = gru_type
#         att_activation='relu'
#         att_weight_normalization=False
#         self.attention = AttentionSequencePoolingLayer(att_hidden_units=att_hidden_size,
#                                                        embedding_dim=att_emb_dim,
#                                                        att_activation=att_activation,
#                                                        return_score=False,
#                                                        supports_masking=False,
#                                                        weight_normalization=att_weight_normalization)
        # DIEN
#         self.alpha=1.0
#         self.interest_extractor = InterestExtractor(input_size=att_emb_dim, use_neg=use_negsampling, init_std=init_std)
#         self.interest_evolution = InterestEvolving(
#             input_size=att_emb_dim,
#             gru_type=gru_type,
#             use_neg=use_negsampling,
#             init_std=init_std,
#             att_hidden_size=att_hidden_size,
#             att_activation=att_activation,
#             att_weight_normalization=att_weight_normalization)
#         self.din_linear = nn.Linear(80, 1, bias=False).to(device)
        
        # 加入cin
        cin_layer_size=(256, 128, 64)
        self.cin_layer_size=cin_layer_size
        cin_split_half=True
        cin_activation='relu'
        l2_reg_cin=1e-5
        self.use_cin = len(self.cin_layer_size) > 0 and len(dnn_feature_columns) > 0
        if self.use_cin:
#             field_num = len(self.embedding_dict)
            field_num = len(self.sparse_feature_columns)

            if cin_split_half == True:
                self.featuremap_num = sum(
                    cin_layer_size[:-1]) // 2 + cin_layer_size[-1]
            else:
                self.featuremap_num = sum(cin_layer_size)
            self.cin = CIN(field_num, cin_layer_size,
                           cin_activation, cin_split_half, l2_reg_cin, seed, device=device)
            self.cin_linear = nn.Linear(self.featuremap_num, 1, bias=False).to(device)
            self.add_regularization_weight(filter(lambda x: 'weight' in x[0], self.cin.named_parameters()),
                                           l2=l2_reg_cin)
        self.add_regularization_weight(self.embedding_dict.parameters(), l2=l2_reg_embedding)
        # multi-head atten
        att_embedding_size=30
        att_head_num=10
        att_layer_num=3
        att_res=True
        self.int_layers = nn.ModuleList(
            [InteractingLayer(self.embedding_size if i == 0 else att_embedding_size * att_head_num,
                              att_embedding_size, att_head_num, att_res, device=device) for i in range(att_layer_num)])
        
        if len(dnn_hidden_units) and att_layer_num > 0:
            dnn_linear_in_feature = dnn_hidden_units[-1] + \
                                    field_num * att_embedding_size * att_head_num
        elif len(dnn_hidden_units) > 0:
            dnn_linear_in_feature = dnn_hidden_units[-1]
        elif att_layer_num > 0:
            dnn_linear_in_feature = field_num * att_embedding_size * att_head_num
        else:
            raise NotImplementedError
        
        print('sparse_features: ', self.sparse_feature_columns)
        print('varlen_features: ', self.varlen_sparse_feature_columns)
        
         # MMOE
        self.tasks = tasks
        self.task_dnn_units = task_dnn_units
        self.dnn = DNN(
#             self.compute_input_dim(self.sparse_feature_columns + self.dense_feature_columns) +16 + dnn_linear_in_feature + self.featuremap_num, 
            self.compute_input_dim(self.sparse_feature_columns + self.dense_feature_columns),   
            dnn_hidden_units,
            activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
                       init_std=init_std, device=device)
        self.mmoe_layer = MMOELayer(dnn_hidden_units[-1]+1600+800, num_tasks, num_experts, expert_dim)
        if task_dnn_units is not None:
            # the last layer of task_dnn should be expert_dim
            self.task_dnn = nn.ModuleList([DNN(expert_dim, task_dnn_units+(expert_dim,)) for _ in range(num_tasks)])
        self.tower_network = nn.ModuleList([nn.Linear(expert_dim, 1, bias=False) for _ in range(num_tasks)])
        self.out = nn.ModuleList([PredictionLayer(task) for task in self.tasks])
        self.to(device)
    
    def _compute_interest_dim(self):
        interest_dim = 0
        for feat in self.sparse_feature_columns:
            if feat.name in self.history_feature_list:
                interest_dim += feat.embedding_dim
        return interest_dim
    def _get_emb(self, X):
        # history feature columns : pos, neg
        
        history_feature_columns = []
        neg_history_feature_columns = []
        sparse_varlen_feature_columns = []
        history_fc_names = list(map(lambda x: "hist_" + x, self.item_features))
        neg_history_fc_names = list(map(lambda x: "neg_" + x, history_fc_names))
        for fc in self.varlen_sparse_feature_columns:
            feature_name = fc.name
            if feature_name in history_fc_names:
                history_feature_columns.append(fc)
            elif feature_name in neg_history_fc_names:
                neg_history_feature_columns.append(fc)
            else:
                sparse_varlen_feature_columns.append(fc)

        # convert input to emb
        features = self.feature_index
        query_emb_list = embedding_lookup(X, self.embedding_dict, features, self.sparse_feature_columns,
                                          return_feat_list=self.item_features, to_list=True)
        # [batch_size, dim]
        query_emb = torch.squeeze(concat_fun(query_emb_list), 1)

        keys_emb_list = embedding_lookup(X, self.embedding_dict, features, history_feature_columns,
                                         return_feat_list=history_fc_names, to_list=True)
        # [batch_size, max_len, dim]
        keys_emb = concat_fun(keys_emb_list)

        keys_length_feature_name = [feat.length_name for feat in self.varlen_sparse_feature_columns if
                                    feat.length_name is not None]
        # [batch_size]
        keys_length = torch.squeeze(maxlen_lookup(X, features, keys_length_feature_name), 1)

        if self.use_negsampling:
            neg_keys_emb_list = embedding_lookup(X, self.embedding_dict, features, neg_history_feature_columns,
                                                 return_feat_list=neg_history_fc_names, to_list=True)
            neg_keys_emb = concat_fun(neg_keys_emb_list)
        else:
            neg_keys_emb = None

        return query_emb, keys_emb, neg_keys_emb, keys_length
    
    
    def forward(self, X):
#         print(self.embedding_dict)
        _, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
                                                                           self.embedding_dict)
        sparse_embedding_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns,
                                              to_list=True)
        #DIN

#         query_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns,
#                                           return_feat_list=self.history_feature_list, to_list=True)
#         keys_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.history_feature_columns,
#                                          return_feat_list=self.history_fc_names, to_list=True)
#         query_emb = torch.cat(query_emb_list, dim=-1)                     # [B, 1, E]
#         keys_emb = torch.cat(keys_emb_list, dim=-1)                       # [B, T, E]
#         keys_length_feature_name = [feat.length_name for feat in self.varlen_sparse_feature_columns if
#                                     feat.length_name is not None]
#         keys_length = torch.squeeze(maxlen_lookup(X, self.feature_index, keys_length_feature_name), 1)  # [B, 1]
#         hist = self.attention(query_emb, keys_emb, keys_length)           # [B, 1, E]
        #DIEN
#         query_emb, keys_emb, neg_keys_emb, keys_length = self._get_emb(X)
#         masked_interest, aux_loss = self.interest_extractor(keys_emb, keys_length, neg_keys_emb)
#         self.add_auxiliary_loss(aux_loss, self.alpha)
#         hist = self.interest_evolution(query_emb, masked_interest, keys_length)

#         din_logit = self.din_linear(hist).squeeze(1)
#         din_logit = self.din_out(din_out)
        
#         加入cin模块
        if self.use_cin:
            cin_input = torch.cat(sparse_embedding_list, dim=1)
            cin_output = self.cin(cin_input) # 1024, 256
            cin_logit = self.cin_linear(cin_output)
            
#         print('cin_logit: ', cin_logit)  
#         print('din_logit: ', din_logit, din_logit.shape)
#         muti-head
        att_input = concat_fun(sparse_embedding_list, axis=1)
        for layer in self.int_layers:
            att_input = layer(att_input)
        att_output = torch.flatten(att_input, start_dim=1)

        # dnn
        dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) # 1024, 101
#         print('dnn_input: ', dnn_input.shape)
#         print('hist: ', hist.shape)
#         dnn_input = torch.cat((dnn_input, hist.squeeze(1)), dim=-1)
#         dnn_input = torch.cat((dnn_input, att_output), dim=-1)
#         dnn_input = torch.cat((dnn_input, cin_output), dim=-1)
        
        dnn_output = self.dnn(dnn_input)
        dnn_output = concat_fun([att_output, dnn_output])
        mmoe_outs = self.mmoe_layer(dnn_output)
        if self.task_dnn_units is not None:
            mmoe_outs = [self.task_dnn[i](mmoe_out) for i, mmoe_out in enumerate(mmoe_outs)]

        task_outputs = []
        for i, mmoe_out in enumerate(mmoe_outs):
            logit = self.tower_network[i](mmoe_out) + cin_logit
#             logit = self.tower_network[i](mmoe_out)
            output = self.out[i](logit)
            task_outputs.append(output)
#         print(cin_logit.shape, din_logit.shape)
#         print(task_outputs.shape)
        task_outputs = torch.cat(task_outputs, -1)
        return task_outputs


class InterestExtractor(nn.Module):
    def __init__(self, input_size, use_neg=False, init_std=0.001, device='cpu'):
        super(InterestExtractor, self).__init__()
        self.use_neg = use_neg
        self.gru = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        if self.use_neg:
            self.auxiliary_net = DNN(input_size * 2, [100, 50, 1], 'sigmoid', init_std=init_std, device=device)
        for name, tensor in self.gru.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)
        self.to(device)

    def forward(self, keys, keys_length, neg_keys=None):
        """
        Parameters
        ----------
        keys: 3D tensor, [B, T, H]
        keys_length: 1D tensor, [B]
        neg_keys: 3D tensor, [B, T, H]

        Returns
        -------
        masked_interests: 2D tensor, [b, H]
        aux_loss: [1]
        """
        batch_size, max_length, dim = keys.size()
        zero_outputs = torch.zeros(batch_size, dim, device=keys.device)
        aux_loss = torch.zeros((1,), device=keys.device)

        # create zero mask for keys_length, to make sure 'pack_padded_sequence' safe
        mask = keys_length > 0
        masked_keys_length = keys_length[mask]

        # batch_size validation check
        if masked_keys_length.shape[0] == 0:
            return zero_outputs,

        masked_keys = torch.masked_select(keys, mask.view(-1, 1, 1)).view(-1, max_length, dim)

        packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length, batch_first=True,
                                           enforce_sorted=False)
        packed_interests, _ = self.gru(packed_keys)
#         print('keys_len: ', keys_length.max(), keys_length)
#         print(max_length, masked_keys_length, masked_keys.shape)
#         print(packed_interests, packed_interests.batch_sizes, packed_interests.batch_sizes.size(0))
        interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0,
                                           total_length=max_length)

        if self.use_neg and neg_keys is not None:
            masked_neg_keys = torch.masked_select(neg_keys, mask.view(-1, 1, 1)).view(-1, max_length, dim)
            aux_loss = self._cal_auxiliary_loss(
                interests[:, :-1, :],
                masked_keys[:, 1:, :],
                masked_neg_keys[:, 1:, :],
                masked_keys_length - 1)

        return interests, aux_loss

    def _cal_auxiliary_loss(self, states, click_seq, noclick_seq, keys_length):
        # keys_length >= 1
        mask_shape = keys_length > 0
        keys_length = keys_length[mask_shape]
        if keys_length.shape[0] == 0:
            return torch.zeros((1,), device=states.device)

        _, max_seq_length, embedding_size = states.size()
        states = torch.masked_select(states, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length, embedding_size)
        click_seq = torch.masked_select(click_seq, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length, embedding_size)
        noclick_seq = torch.masked_select(noclick_seq, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length,
                                                                                       embedding_size)
        batch_size = states.size()[0]

        mask = (torch.arange(max_seq_length, device=states.device).repeat(
            batch_size, 1) < keys_length.view(-1, 1)).float()

        click_input = torch.cat([states, click_seq], dim=-1)
        noclick_input = torch.cat([states, noclick_seq], dim=-1)
        embedding_size = embedding_size * 2

        click_p = self.auxiliary_net(click_input.view(
            batch_size * max_seq_length, embedding_size)).view(
            batch_size, max_seq_length)[mask > 0].view(-1, 1)
        click_target = torch.ones(
            click_p.size(), dtype=torch.float, device=click_p.device)

        noclick_p = self.auxiliary_net(noclick_input.view(
            batch_size * max_seq_length, embedding_size)).view(
            batch_size, max_seq_length)[mask > 0].view(-1, 1)
        noclick_target = torch.zeros(
            noclick_p.size(), dtype=torch.float, device=noclick_p.device)

        loss = F.binary_cross_entropy(
            torch.cat([click_p, noclick_p], dim=0),
            torch.cat([click_target, noclick_target], dim=0))

        return loss


class InterestEvolving(nn.Module):
    __SUPPORTED_GRU_TYPE__ = ['GRU', 'AIGRU', 'AGRU', 'AUGRU']

    def __init__(self,
                 input_size,
                 gru_type='GRU',
                 use_neg=False,
                 init_std=0.001,
                 att_hidden_size=(64, 16),
                 att_activation='sigmoid',
                 att_weight_normalization=False):
        super(InterestEvolving, self).__init__()
        if gru_type not in InterestEvolving.__SUPPORTED_GRU_TYPE__:
            raise NotImplementedError("gru_type: {gru_type} is not supported")
        self.gru_type = gru_type
        self.use_neg = use_neg

        if gru_type == 'GRU':
            self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size,
                                                           att_hidden_units=att_hidden_size,
                                                           att_activation=att_activation,
                                                           weight_normalization=att_weight_normalization,
                                                           return_score=False)
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AIGRU':
            self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size,
                                                           att_hidden_units=att_hidden_size,
                                                           att_activation=att_activation,
                                                           weight_normalization=att_weight_normalization,
                                                           return_score=True)
            self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True)
        elif gru_type == 'AGRU' or gru_type == 'AUGRU':
            self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size,
                                                           att_hidden_units=att_hidden_size,
                                                           att_activation=att_activation,
                                                           weight_normalization=att_weight_normalization,
                                                           return_score=True)
            self.interest_evolution = DynamicGRU(input_size=input_size, hidden_size=input_size,
                                                 gru_type=gru_type)
        for name, tensor in self.interest_evolution.named_parameters():
            if 'weight' in name:
                nn.init.normal_(tensor, mean=0, std=init_std)

    @staticmethod
    def _get_last_state(states, keys_length):
        # states [B, T, H]
        batch_size, max_seq_length, hidden_size = states.size()

        mask = (torch.arange(max_seq_length, device=keys_length.device).repeat(
            batch_size, 1) == (keys_length.view(-1, 1) - 1))

        return states[mask]

    def forward(self, query, keys, keys_length, mask=None):
        """
        Parameters
        ----------
        query: 2D tensor, [B, H]
        keys: (masked_interests), 3D tensor, [b, T, H]
        keys_length: 1D tensor, [B]

        Returns
        -------
        outputs: 2D tensor, [B, H]
        """
        batch_size, dim = query.size()
        max_length = keys.size()[1]

        # check batch validation
        zero_outputs = torch.zeros(batch_size, dim, device=query.device)
        mask = keys_length > 0
        # [B] -> [b]
        keys_length = keys_length[mask]
        if keys_length.shape[0] == 0:
            return zero_outputs

        # [B, H] -> [b, 1, H]
        query = torch.masked_select(query, mask.view(-1, 1)).view(-1, dim).unsqueeze(1)

        if self.gru_type == 'GRU':
            packed_keys = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, enforce_sorted=False)
            packed_interests, _ = self.interest_evolution(packed_keys)
            interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0,
                                               total_length=max_length)
            outputs = self.attention(query, interests, keys_length.unsqueeze(1))  # [b, 1, H]
            outputs = outputs.squeeze(1)  # [b, H]
        elif self.gru_type == 'AIGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1))  # [b, 1, T]
            interests = keys * att_scores.transpose(1, 2)  # [b, T, H]
            packed_interests = pack_padded_sequence(interests, lengths=keys_length, batch_first=True,
                                                    enforce_sorted=False)
            _, outputs = self.interest_evolution(packed_interests)
            outputs = outputs.squeeze(0) # [b, H]
        elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU':
            att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1)  # [b, T]
            packed_interests = pack_padded_sequence(keys, lengths=keys_length, batch_first=True,
                                                    enforce_sorted=False)
            packed_scores = pack_padded_sequence(att_scores, lengths=keys_length, batch_first=True,
                                                 enforce_sorted=False)
            outputs = self.interest_evolution(packed_interests, packed_scores)
            outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=max_length)
            # pick last state
            outputs = InterestEvolving._get_last_state(outputs, keys_length) # [b, H]
        # [b, H] -> [B, H]
        zero_outputs[mask] = outputs
        return zero_outputs


In [2]:
dense_features = [
 'videoplayseconds',
#  'userid_5day_count',
#  'userid_5day_play_times_mean',
#  'userid_5day_play_mean',
#  'userid_5day_stay_mean',
#  'userid_5day_read_comment_sum',
#  'userid_5day_read_comment_mean',
#  'userid_5day_read_comment_count',
#  'userid_5day_like_sum',
#  'userid_5day_like_mean',
#  'userid_5day_like_count',
#  'userid_5day_click_avatar_sum',
#  'userid_5day_click_avatar_mean',
#  'userid_5day_click_avatar_count',
#  'userid_5day_forward_sum',
#  'userid_5day_forward_mean',
#  'userid_5day_forward_count',
#  'userid_5day_favorite_sum',
#  'userid_5day_favorite_mean',
#  'userid_5day_favorite_count',
#  'userid_5day_comment_sum',
#  'userid_5day_comment_mean',
#  'userid_5day_comment_count',
#  'userid_5day_follow_sum',
#  'userid_5day_follow_mean',
#  'userid_5day_follow_count',
#  'feedid_5day_count',
#  'feedid_5day_play_times_mean',
#  'feedid_5day_play_mean',
#  'feedid_5day_stay_mean',
#  'feedid_5day_read_comment_sum',
#  'feedid_5day_read_comment_mean',
#  'feedid_5day_read_comment_count',
#  'feedid_5day_like_sum',
#  'feedid_5day_like_mean',
#  'feedid_5day_like_count',
#  'feedid_5day_click_avatar_sum',
#  'feedid_5day_click_avatar_mean',
#  'feedid_5day_click_avatar_count',
#  'feedid_5day_forward_sum',
#  'feedid_5day_forward_mean',
#  'feedid_5day_forward_count',
#  'feedid_5day_favorite_sum',
#  'feedid_5day_favorite_mean',
#  'feedid_5day_favorite_count',
#  'feedid_5day_comment_sum',
#  'feedid_5day_comment_mean',
#  'feedid_5day_comment_count',
#  'feedid_5day_follow_sum',
#  'feedid_5day_follow_mean',
#  'feedid_5day_follow_count',
#  'authorid_5day_count',
#  'authorid_5day_play_times_mean',
#  'authorid_5day_play_mean',
#  'authorid_5day_stay_mean',
#  'authorid_5day_read_comment_sum',
#  'authorid_5day_read_comment_mean',
#  'authorid_5day_read_comment_count',
#  'authorid_5day_like_sum',
#  'authorid_5day_like_mean',
#  'authorid_5day_like_count',
#  'authorid_5day_click_avatar_sum',
#  'authorid_5day_click_avatar_mean',
#  'authorid_5day_click_avatar_count',
#  'authorid_5day_forward_sum',
#  'authorid_5day_forward_mean',
#  'authorid_5day_forward_count',
#  'authorid_5day_favorite_sum',
#  'authorid_5day_favorite_mean',
#  'authorid_5day_favorite_count',
#  'authorid_5day_comment_sum',
#  'authorid_5day_comment_mean',
#  'authorid_5day_comment_count',
#  'authorid_5day_follow_sum',
#  'authorid_5day_follow_mean',
#  'authorid_5day_follow_count',
#  'userid_count',
#  'feedid_count',
#  'authorid_count',
#  'userid_in_feedid_nunique',
#  'feedid_in_userid_nunique',
#  'userid_in_authorid_nunique',
#  'authorid_in_userid_nunique',
#  'userid_authorid_count',
#  'userid_in_authorid_count_prop',
#  'authorid_in_userid_count_prop',
#  'videoplayseconds_in_userid_mean',
#  'feedid_in_authorid_nunique',
#  'bgm_song_id',
#  'bgm_singer_id'
]

In [3]:
1

1

In [4]:
import os
import torch
import pandas as pd
import numpy as np
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import sys
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = '.'
sys.path.append(os.path.join(BASE_DIR, '../../config'))
sys.path.append(os.path.join(BASE_DIR, '../model'))
sys.path.append(os.path.join(BASE_DIR, '../utils'))
from config import *
from time import time
from deepctr_torch.inputs import SparseFeat, DenseFeat, get_feature_names, VarLenSparseFeat
from sklearn.preprocessing import MinMaxScaler
import datatable as dt
# from mmoe import MMOE
from evaluation import evaluate_deepctr
import pickle
import gc

# 训练相关参数设置
ONLINE_FLAG = True # 是否准备线上提交

# 指定GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

vocab_dict = {
    'bgm_song_id': 25158+1,
    'bgm_singer_id': 17499+1,
    'userid': 199999,
    'feedid': 112871+1,
    'authorid': 18788+1,
    'device' : 3
}

# if __name__ == "__main__":
epochs = 1
batch_size = 1024
embedding_dim = 50
max_hist_seq_len = 100

target = ['read_comment', 'like', 'click_avatar', 'forward', 'comment', 'favorite', 'follow']
tagids = ['manual_tag_' + str(tagid) for tagid in range(11)] # 11
keyids = ['manual_key_' + str(keyid) for keyid in range(11)] # 18
sparse_features = ['userid', 'feedid', 'authorid', 'bgm_song_id', 'bgm_singer_id']
# dense_features += ['videoplayseconds', ]

feed = dt.fread(FEED_INFO)
feed = feed.to_pandas()
# tag = dt.fread(FEATURE_PATH + '/feed_info_tags_keys_des_seq_len.csv')
# tag = tag.to_pandas()[tagids + keyids + ['feedid', 'tag_seq_len', 'key_seq_len']]

pkl = open(FEATURE_PATH + '/user_encoder.pkl', 'rb')
userid_map = pickle.load(pkl)
pkl.close()
mms = MinMaxScaler(feature_range=(0, 1))
# feed_embs = pd.read_csv(FEATURE_PATH + '/feed_embeddings_PCA64.csv')
# feed_embs[['feed_embed_' + str(col) for col in range(64)]] = mms.fit_transform(feed_embs[['feed_embed_' + str(col) for col in range(64)]])

user_emb1 = np.load(FEATURE_PATH + '/user_emb_normal_50.npy')
user_emb1 = torch.from_numpy(user_emb1).float().to(device)
user_emb2 = np.load(FEATURE_PATH + '/user_emb_adjust_50.npy')
user_emb2 = torch.from_numpy(user_emb2).float().to(device)

hist_file = open(FEATURE_PATH + '/hist_data_action_begin_100.pkl', 'rb')
hist_data = pickle.load(hist_file)
hist_file.close()

# feed[["bgm_song_id", "bgm_singer_id"]] += 1  # 0 用于填未知
# feed[["bgm_song_id", "bgm_singer_id", "videoplayseconds"]] = \
#     feed[["bgm_song_id", "bgm_singer_id", "videoplayseconds"]].fillna(0)
# feed['bgm_song_id'] = feed['bgm_song_id'].astype('int64')
# feed['bgm_singer_id'] = feed['bgm_singer_id'].astype('int64')

# data = pd.read_csv(FEATURE_PATH + '/data_128.csv')

if ONLINE_FLAG:
#     data = pd.read_csv(USER_ACTION, iterator=True)
    data = pd.read_csv(FEATURE_PATH + '/online_train_100.csv', iterator=True)
    test = pd.read_csv(FEATURE_PATH + '/online_test_100.csv')
else:
    val = pd.read_csv(FEATURE_PATH + '/offline_val_100.csv')
    data = pd.read_csv(FEATURE_PATH + '/offline_train_100.csv', iterator=True)

# if ONLINE_FLAG:
# #     data = pd.read_csv(USER_ACTION, iterator=True)
# else:
#     val = pd.read_csv(FEATURE_PATH + '/val_data.csv')
#     data = pd.read_csv(FEATURE_PATH + '/train_data.csv', iterator=True)

fixlen_feature_columns = [SparseFeat(feat, vocabulary_size=vocab_dict[feat] + 1, embedding_dim=embedding_dim)
                      for feat in vocab_dict.keys()] + [DenseFeat(feat, 1) for feat in dense_features]

fixlen_feature_columns += [SparseFeat('user_embedding_normal', 200000, embedding_dim=embedding_dim)]
fixlen_feature_columns += [SparseFeat('user_embedding_adjust', 200000, embedding_dim=embedding_dim)]

if ONLINE_FLAG:
    # 加入test
#     test = dt.fread(TEST_FILE)
#     test = test.to_pandas()
#     test = test.merge(feed[['feedid', 'authorid', 'videoplayseconds', 'bgm_song_id', 'bgm_singer_id']], how='left',
#                       on='feedid')
#     test = test.merge(tag, how='left', on='feedid')
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_feedid', vocabulary_size=vocab_dict['feedid']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='test_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_authorid', vocabulary_size=vocab_dict['authorid']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='test_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_bgm_song_id', vocabulary_size=vocab_dict['bgm_song_id']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='test_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_bgm_singer_id', vocabulary_size=vocab_dict['bgm_singer_id']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='test_seq_len')]

    dnn_feature_columns = fixlen_feature_columns
    feature_names = get_feature_names(dnn_feature_columns)

#     test[dense_features] = test[dense_features].fillna(0, )
#     test[dense_features] = mms.fit_transform(test[dense_features])
#     test['userid'] = userid_map.transform(test['userid'])

    test_model_input = {name: test[name] for name in feature_names if 'hist' not in name and 'seq_len'not in name and name not in ['seq_len', 'feed_emb', 'feed_embedding', 'user_embedding_normal', 'user_embedding_adjust', 'hist_tagids', 'hist_keyids', 'tagids', 'keyids']}
    test_model_input['user_embedding_normal'] = test_model_input['userid']
    test_model_input['user_embedding_adjust'] = test_model_input['userid']
#     test_model_input['hist_feedid'] = hist_data['hist_feedid'][test_model_input['userid']][:, :max_hist_seq_len]
#     test_model_input['hist_authorid'] = hist_data['hist_authorid'][test_model_input['userid']][:, :max_hist_seq_len]
#     test_model_input['hist_bgm_song_id'] = hist_data['hist_bgm_song_id'][test_model_input['userid']][:, :max_hist_seq_len]
#     test_model_input['hist_bgm_singer_id'] = hist_data['hist_bgm_singer_id'][test_model_input['userid']][:, :max_hist_seq_len]
#     test_model_input['test_seq_len'] = test['real_seq_len'].values
#     test_model_input['test_seq_len'] = hist_data['test_seq_len'][test_model_input['userid']]
else:
#     val = val.merge(feed[['feedid', 'authorid', 'videoplayseconds', 'bgm_song_id', 'bgm_singer_id']], how='left',
#                   on='feedid')
    
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_feedid', vocabulary_size=vocab_dict['feedid']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='val_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_authorid', vocabulary_size=vocab_dict['authorid']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='val_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_bgm_song_id', vocabulary_size=vocab_dict['bgm_song_id']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='val_seq_len')]
#     fixlen_feature_columns += [VarLenSparseFeat(SparseFeat('hist_bgm_singer_id', vocabulary_size=vocab_dict['bgm_singer_id']+1, embedding_dim=embedding_dim), max_hist_seq_len, length_name='val_seq_len')]

    dnn_feature_columns = fixlen_feature_columns
    feature_names = get_feature_names(dnn_feature_columns)
    
#     val[dense_features] = val[dense_features].fillna(0, )
#     val[dense_features] = mms.fit_transform(val[dense_features])
#     val['userid'] = userid_map.transform(val['userid'])
    
    val_model_input = {name: val[name] for name in feature_names if 'hist' not in name and 'seq_len' not in name and name not in ['feed_emb', 'hist_tagids', 'hist_keyids', 'feed_embedding', 'user_embedding_normal', 'user_embedding_adjust', 'tagids', 'keyids']}
#     val_model_input['hist_feedid'] = hist_data['hist_feedid'][val_model_input['userid']][:, :max_hist_seq_len]
#     val_model_input['hist_authorid'] = hist_data['hist_authorid'][val_model_input['userid']][:, :max_hist_seq_len]
#     val_model_input['hist_bgm_song_id'] = hist_data['hist_bgm_song_id'][val_model_input['userid']][:, :max_hist_seq_len]
#     val_model_input['hist_bgm_singer_id'] = hist_data['hist_bgm_singer_id'][val_model_input['userid']][:, :max_hist_seq_len]
#     val_model_input['val_seq_len'] = hist_data['val_seq_len'][val_model_input['userid']]
#     val_model_input['val_seq_len'] = val['real_seq_len'].values

    val_model_input['user_embedding_normal'] = val_model_input['userid']
    val_model_input['user_embedding_adjust'] = val_model_input['userid']
    userid_list = val['userid'].astype(str).tolist()
    val_labels = [val[y].values for y in target]
    

print('use features: ', feature_names)


use features:  ['bgm_song_id', 'bgm_singer_id', 'userid', 'feedid', 'authorid', 'device', 'videoplayseconds', 'user_embedding_normal', 'user_embedding_adjust']


In [5]:
1

1

In [None]:
# hist_features = ['feedid', 'authorid', 'bgm_singer_id', 'bgm_song_id']
hist_features = []

train_model = MMOE(dnn_feature_columns, history_feature_list=hist_features, num_tasks=7, num_experts=6, expert_dim=128, dnn_hidden_units=(512, 256, 64),
                   task_dnn_units=(128, 128, 64),
                   tasks=['binary', 'binary', 'binary', 'binary', 'binary', 'binary', 'binary'], device=device)
train_model.compile("adagrad", loss='binary_crossentropy')
train_model.embedding_dict['user_embedding_normal'] = nn.Embedding.from_pretrained(user_emb1, freeze=False)
train_model.embedding_dict['user_embedding_adjust'] = nn.Embedding.from_pretrained(user_emb2, freeze=False)

loop = True
cnt = 0
while loop:
    try:
        cnt += 1
        print('chunk: ', cnt)
        train = data.get_chunk(2000*10000)
        
#         train = train.merge(feed[['feedid', 'authorid', 'videoplayseconds', 'bgm_song_id', 'bgm_singer_id']], how='left',
#                   on='feedid')
#         train[dense_features] = train[dense_features].fillna(0, )
#         train[dense_features] = mms.fit_transform(train[dense_features])
#         train['userid'] = userid_map.transform(train['userid'])                

        train_model_input = {name: train[name] for name in feature_names  if 'hist' not in name and 'seq_len'not in name and name not in ['feed_emb', 'hist_tagids', 'hist_keyids', 'feed_embedding', 'user_embedding_normal', 'user_embedding_adjust', 'tagids', 'keyids']}
#         train_model_input['hist_feedid'] = hist_data['hist_feedid'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_authorid'] = hist_data['hist_authorid'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_bgm_song_id'] = hist_data['hist_bgm_song_id'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_bgm_singer_id'] = hist_data['hist_bgm_singer_id'][train_model_input['userid']][:, :max_hist_seq_len]
        
#         if ONLINE_FLAG:
#             train_model_input['test_seq_len'] = hist_data['test_seq_len'][train_model_input['userid']]    
# #             train_model_input['test_seq_len'] = train['real_seq_len'].values   

#         else:
#             train_model_input['val_seq_len'] = hist_data['val_seq_len'][train_model_input['userid']]    
# #             train_model_input['val_seq_len'] = train['real_seq_len'].values 

        train_model_input['user_embedding_normal'] = train_model_input['userid']
        train_model_input['user_embedding_adjust'] = train_model_input['userid']
        train_labels = train[target].values

        for epoch in range(epochs):
            history = train_model.fit(train_model_input, train_labels,
                              batch_size=batch_size, epochs=1, verbose=1, shuffle=True)
        if not ONLINE_FLAG:
            val_pred_ans = train_model.predict(val_model_input, batch_size=batch_size * 4)
            # 模型predict()返回值格式为(?, 4)，与tf版mmoe不同。因此下方用到了transpose()进行变化。
            evaluate_deepctr(val_labels, val_pred_ans.transpose(), userid_list, target)

    except StopIteration:
        loop=False
        print('Finished all train')

# 继续train一波 最近的数据
# del data, train
# data = pd.read_csv(FEATURE_PATH + '/online_train_100.csv', iterator=True)
# train = data.get_chunk(4000*10000)
# loop = True
# cnt = 0
# while loop:
#     try:
#         cnt += 1
#         print('chunk: ', cnt)
#         train = data.get_chunk(2000*10000)
        
# #         train = train.merge(feed[['feedid', 'authorid', 'videoplayseconds', 'bgm_song_id', 'bgm_singer_id']], how='left',
# #                   on='feedid')
# #         train[dense_features] = train[dense_features].fillna(0, )
# #         train[dense_features] = mms.fit_transform(train[dense_features])
# #         train['userid'] = userid_map.transform(train['userid'])                

#         train_model_input = {name: train[name] for name in feature_names  if 'hist' not in name and 'seq_len'not in name and name not in ['feed_emb', 'hist_tagids', 'hist_keyids', 'feed_embedding', 'user_embedding_normal', 'user_embedding_adjust', 'tagids', 'keyids']}
# #         train_model_input['hist_feedid'] = hist_data['hist_feedid'][train_model_input['userid']][:, :max_hist_seq_len]
# #         train_model_input['hist_authorid'] = hist_data['hist_authorid'][train_model_input['userid']][:, :max_hist_seq_len]
# #         train_model_input['hist_bgm_song_id'] = hist_data['hist_bgm_song_id'][train_model_input['userid']][:, :max_hist_seq_len]
# #         train_model_input['hist_bgm_singer_id'] = hist_data['hist_bgm_singer_id'][train_model_input['userid']][:, :max_hist_seq_len]
        
# #         if ONLINE_FLAG:
# #             train_model_input['test_seq_len'] = hist_data['test_seq_len'][train_model_input['userid']]    
# # #             train_model_input['test_seq_len'] = train['real_seq_len'].values   

# #         else:
# #             train_model_input['val_seq_len'] = hist_data['val_seq_len'][train_model_input['userid']]    
# # #             train_model_input['val_seq_len'] = train['real_seq_len'].values 

#         train_model_input['user_embedding_normal'] = train_model_input['userid']
#         train_model_input['user_embedding_adjust'] = train_model_input['userid']
#         train_labels = train[target].values

#         for epoch in range(epochs):
#             history = train_model.fit(train_model_input, train_labels,
#                               batch_size=batch_size, epochs=1, verbose=1, shuffle=True)
#         if not ONLINE_FLAG:
#             val_pred_ans = train_model.predict(val_model_input, batch_size=batch_size * 4)
#             # 模型predict()返回值格式为(?, 4)，与tf版mmoe不同。因此下方用到了transpose()进行变化。
#             evaluate_deepctr(val_labels, val_pred_ans.transpose(), userid_list, target)

#     except StopIteration:
#         loop=False
#         print('Finished all train')


if ONLINE_FLAG:
    t1 = time()
    pred_ans = train_model.predict(test_model_input, batch_size=batch_size * 20)
    pred_ans = pred_ans.transpose()
    t2 = time()
    print('7个目标行为%d条样本预测耗时（毫秒）：%.3f' % (len(test), (t2 - t1) * 1000.0))
    ts = (t2 - t1) * 1000.0* 2000.0 / (len(test)*7.0) 
    print('7个目标行为2000条样本平均预测耗时（毫秒）：%.3f' % ts)

    # # 5.生成提交文件
    for i, action in enumerate(target):
        test[action] = pred_ans[i]
    test['userid'] = userid_map.inverse_transform(test['userid'])
    test[['userid', 'feedid'] + target].to_csv(SUBMIT_DIR + '/mmoe_cin_multi_all_1024_1.csv', index=None, float_format='%.6f')
    print('to_csv ok')

sparse_features:  [SparseFeat(name='bgm_song_id', vocabulary_size=25160, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='bgm_song_id', group_name='default_group'), SparseFeat(name='bgm_singer_id', vocabulary_size=17501, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='bgm_singer_id', group_name='default_group'), SparseFeat(name='userid', vocabulary_size=200000, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='userid', group_name='default_group'), SparseFeat(name='feedid', vocabulary_size=112873, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='feedid', group_name='default_group'), SparseFeat(name='authorid', vocabulary_size=18790, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='authorid', group_name='default_group'), SparseFeat(name='device', vocabulary_size=4, embedding_dim=50, use_hash=False, dtype='int32', embedding_name='device', group_name='default_group'), SparseFeat(name='user_embedding_normal', vo

1086it [01:30, 13.07it/s]

Please check the latest version manually on https://pypi.org/project/deepctr-torch/#history


19532it [26:13, 12.41it/s]


Epoch 1/1
1573s - loss:  0.2592
chunk:  2
cuda
Train on 20000000 samples, validate on 0 samples, 19532 steps per epoch


10292it [13:50, 13.02it/s]

In [7]:
# ORI:0.2309 0.671505

#0.672874: cin 256 128 64
# att_embedding_size=30
#         att_head_num=10
#         att_layer_num=3
# 加上bi:674669 多train一些：0.676041
# 多训练一次第12-13天：0.674816


1

In [7]:
del data, train
data = pd.read_csv(FEATURE_PATH + '/online_train_100.csv', iterator=True)


In [8]:
train = data.get_chunk(1000*10000)
train = data.get_chunk(4000*10000)

In [9]:
# start: 3000
# del train, data
# data = pd.read_csv(FEATURE_PATH + '/offline_train_100.csv', iterator=True)
loop = True
cnt = 0
while loop:
    try:
        cnt += 1
        print('chunk: ', cnt)
        train = data.get_chunk(2000*10000)
        
#         train = train.merge(feed[['feedid', 'authorid', 'videoplayseconds', 'bgm_song_id', 'bgm_singer_id']], how='left',
#                   on='feedid')
#         train[dense_features] = train[dense_features].fillna(0, )
#         train[dense_features] = mms.fit_transform(train[dense_features])
#         train['userid'] = userid_map.transform(train['userid'])                

        train_model_input = {name: train[name] for name in feature_names  if 'hist' not in name and 'seq_len'not in name and name not in ['feed_emb', 'hist_tagids', 'hist_keyids', 'feed_embedding', 'user_embedding_normal', 'user_embedding_adjust', 'tagids', 'keyids']}
#         train_model_input['hist_feedid'] = hist_data['hist_feedid'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_authorid'] = hist_data['hist_authorid'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_bgm_song_id'] = hist_data['hist_bgm_song_id'][train_model_input['userid']][:, :max_hist_seq_len]
#         train_model_input['hist_bgm_singer_id'] = hist_data['hist_bgm_singer_id'][train_model_input['userid']][:, :max_hist_seq_len]
        
#         if ONLINE_FLAG:
#             train_model_input['test_seq_len'] = hist_data['test_seq_len'][train_model_input['userid']]    
# #             train_model_input['test_seq_len'] = train['real_seq_len'].values   

#         else:
#             train_model_input['val_seq_len'] = hist_data['val_seq_len'][train_model_input['userid']]    
# #             train_model_input['val_seq_len'] = train['real_seq_len'].values 

        train_model_input['user_embedding_normal'] = train_model_input['userid']
        train_model_input['user_embedding_adjust'] = train_model_input['userid']
        train_labels = train[target].values

        for epoch in range(epochs):
            history = train_model.fit(train_model_input, train_labels,
                              batch_size=batch_size, epochs=1, verbose=1, shuffle=True)
        if not ONLINE_FLAG:
            val_pred_ans = train_model.predict(val_model_input, batch_size=batch_size * 4)
            # 模型predict()返回值格式为(?, 4)，与tf版mmoe不同。因此下方用到了transpose()进行变化。
            evaluate_deepctr(val_labels, val_pred_ans.transpose(), userid_list, target)

    except StopIteration:
        loop=False
        print('Finished all train')

if ONLINE_FLAG:
    t1 = time()
    pred_ans = train_model.predict(test_model_input, batch_size=batch_size * 20)
    pred_ans = pred_ans.transpose()
    t2 = time()
    print('7个目标行为%d条样本预测耗时（毫秒）：%.3f' % (len(test), (t2 - t1) * 1000.0))
    ts = (t2 - t1) * 1000.0* 2000.0 / (len(test)*7.0) 
    print('7个目标行为2000条样本平均预测耗时（毫秒）：%.3f' % ts)

    # # 5.生成提交文件
    for i, action in enumerate(target):
        test[action] = pred_ans[i]
    test['userid'] = userid_map.inverse_transform(test['userid'])
    test[['userid', 'feedid'] + target].to_csv(SUBMIT_DIR + '/mmoe_cin_multi_all_1024_2.csv', index=None, float_format='%.6f')
    print('to_csv ok')

chunk:  1
cuda
Train on 20000000 samples, validate on 0 samples, 19532 steps per epoch


19532it [26:14, 12.40it/s]


Epoch 1/1
1574s - loss:  0.2139
chunk:  2
cuda
Train on 3175511 samples, validate on 0 samples, 3102 steps per epoch


3102it [04:03, 12.76it/s]


Epoch 1/1
243s - loss:  0.2131
chunk:  3
Finished all train
7个目标行为4252097条样本预测耗时（毫秒）：58330.128
7个目标行为2000条样本平均预测耗时（毫秒）：3.919


ValueError: y contains previously unseen labels: [200001 200004 200005 ... 250242 250243 250244]

In [10]:
test[['userid', 'feedid'] + target].to_csv(SUBMIT_DIR + '/mmoe_cin_multi_all_1024_2.csv', index=None, float_format='%.6f')


In [7]:
1

1