In [1]:
import sys
sys.path.append(".")
sys.path.append("..")
import Config
from Bias_DNN import DNN
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import random
from utils import *

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(114514)

# 读取特征
test_user_feature = pd.read_csv('../feature/test_user_feature.csv')
item_feature = pd.read_csv('../feature/item_feature.csv')

# 读取数据
train = pd.read_csv('../data/train_data.csv')
test = pd.read_csv('../data/test_data.csv')

for fea in Config.count_feature:
    test_user_feature[fea] = list(map(eval, list(test_user_feature[fea])))
    train[fea] = list(map(eval, list(train[fea])))
    test[fea] = list(map(eval, list(test[fea])))



train_y, test_y = torch.Tensor(list(train['label'])), torch.Tensor(list(test['label']))

### 这里放弃用time_diff特征作为bias, 使用后auc大幅降低 

In [14]:
# 首先训练模型
model = DNN(
    Features = Config.item_feature + Config.count_feature + Config.user_feature + Config.match_feature, 
    hidden_unit = [], 
    Bias_Features = [],
    bias_unit = []
)

model.fit(train, train_y, test, test_y, epoch=5, batch_size=1024, lr=5e-2)

[2022-02-24 17:57:26.671328]start fit model


In [9]:
# 冻结模型参数
for param in model.parameters():
    param.requires_grad = False

print(model)

DNN(
  (embs): Normal_Embedding(
    (embs): ModuleDict(
      (click_article_id): Embedding(31116, 32)
      (category_id): Embedding(290, 4)
      (count_feature): Embedding(31117, 32, padding_idx=31116)
      (time_diff): Embedding(4, 1)
    )
  )
  (fc): ModuleList(
    (0): Linear(in_features=165, out_features=1, bias=True)
  )
  (bias_fc): ModuleList()
  (ce): BCEWithLogitsLoss()
)


In [10]:
# 解析json
import json
def read_json(path):
    f = open(path, 'r')
    return json.load(f)

# 读取召回字典, 这里用MF单路召回
# recall_data = read_json('../recall/recall_data/MF_data.json')
recall_data = read_json('../recall/recall_data/DSSM_data.json')
timelogger(f"all users:{len(recall_data)}, recall nums per user:{len(recall_data[list(recall_data.keys())[0]])}")

[2022-02-24 16:40:16.717094]all users:50000, recall nums per user:300


In [12]:
# 排序, 取分数最高的10个作为推荐序列
def get_rec_seq(user_id, top_k):
    items = sorted(recall_data[str(user_id)]) # 候选样本
    fs = merge_feature(
        user_id = user_id,
        user_df = test_user_feature,
        item_df = item_feature[item_feature['click_article_id'].isin(items)],
        use_match_feature = True
    )
    scores = model.get_y_pre(fs).detach().numpy()
    candidates = [(items[i], scores[i]) for i in range(len(scores))]
    candidates = sorted(candidates, key=lambda x:x[1], reverse = True)
    return candidates[:top_k]
    
from tqdm import tqdm

sort_data = {}
top_k = 300
for user in tqdm(list(test_user_feature['user_id'])):
    sort_data[user] = get_rec_seq(user, top_k)

100%|██████████| 50000/50000 [23:48<00:00, 34.99it/s]


In [6]:
# 获取测试集用户的最后一次点击

last_clk = pd.read_csv('../data/test.csv').sort_values(by=['user_id', "click_timestamp"]).groupby('user_id').tail(1)
last_clk

Unnamed: 0,user_id,click_article_id,click_timestamp,click_environment,click_deviceGroup,click_os,click_country,click_region,click_referrer_type
1,1,5881,1508211346889,2,0,5,0,24,5
3,2,15525,1508211468695,2,2,7,0,24,1
5,10,14538,1508211661144,2,0,5,0,24,1
7,11,20952,1508211104535,2,2,7,0,24,1
9,13,24266,1508211458226,2,2,7,0,24,0
...,...,...,...,...,...,...,...,...,...
278863,199978,20188,1507151558596,2,0,5,0,23,1
278886,199982,8743,1508112362342,2,3,7,0,20,0
278888,199988,14725,1507029612008,2,2,7,5,27,1
278893,199990,13205,1507313341592,2,0,5,0,6,0


In [13]:
# 看一下命中
sort_dict = {}
for user in sort_data:
    sort_dict[user] = [x[0] for x in sort_data[user]]

ks = [1,3,5,10,20,30,50,100]
for k in ks:    
    clk_hit = 0
    # for i, row in last_clk[last_clk['user_id'].isin(list(test_user_feature['user_id']))].iterrows():
    for i, row in last_clk.iterrows():
        user, item = row['user_id'], row['click_article_id']
        if item in sort_dict[user][:k]:
            clk_hit += 1
    message = 'k = {}, hit rate: {}/{} = {:.2f}%'.format(k, clk_hit, len(last_clk), 100*clk_hit/len(last_clk))
    timelogger(message)

[2022-02-24 17:35:46.591187]k = 1, hit rate: 572/50000 = 1.14%
[2022-02-24 17:35:54.955868]k = 3, hit rate: 1183/50000 = 2.37%
[2022-02-24 17:36:03.760403]k = 5, hit rate: 1629/50000 = 3.26%
[2022-02-24 17:36:12.586419]k = 10, hit rate: 2511/50000 = 5.02%
[2022-02-24 17:36:21.531428]k = 20, hit rate: 4000/50000 = 8.00%
[2022-02-24 17:36:30.324416]k = 30, hit rate: 5471/50000 = 10.94%
[2022-02-24 17:36:39.870582]k = 50, hit rate: 7699/50000 = 15.40%
[2022-02-24 17:36:50.830062]k = 100, hit rate: 12678/50000 = 25.36%
