In [3]:
import time, math, os
from tqdm import tqdm
import gc
import pickle
import random
from datetime import datetime
from operator import itemgetter
import numpy as np
import pandas as pd
import warnings
from collections import defaultdict
import collections
warnings.filterwarnings('ignore')

In [4]:
data_path = '../data/data_raw/'
save_path = '../data/tmp_results/'

In [5]:
# 节约内存的一个标配函数
def reduce_mem(df):
    starttime = time.time()
    numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
    start_mem = df.memory_usage().sum() / 1024**2
    for col in df.columns:
        col_type = df[col].dtypes
        if col_type in numerics:
            c_min = df[col].min()
            c_max = df[col].max()
            if pd.isnull(c_min) or pd.isnull(c_max):
                continue
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
    end_mem = df.memory_usage().sum() / 1024**2
    print('-- Mem. usage decreased to {:5.2f} Mb ({:.1f}% reduction),time spend:{:2.2f} min'.format(end_mem,
                                                                                                           100*(start_mem-end_mem)/start_mem,
                                                                                                           (time.time()-starttime)/60))
    return df

In [6]:
# debug模式：从训练集中划出一部分数据来调试代码
def get_all_click_sample(data_path, sample_nums=10000):
    """
        训练集中采样一部分数据调试
        data_path: 原数据的存储路径
        sample_nums: 采样数目（这里由于机器的内存限制，可以采样用户做）
    """
    file_path = os.path.join(data_path, 'train_click_log.csv')
    all_click = pd.read_csv(file_path)
    all_user_ids = all_click.user_id.unique()

    sample_user_ids = np.random.choice(all_user_ids, size=sample_nums, replace=False) 
    all_click = all_click[all_click['user_id'].isin(sample_user_ids)]
    
    all_click = all_click.drop_duplicates((['user_id', 'click_article_id', 'click_timestamp']))
    return all_click

# 读取点击数据，这里分成线上和线下，如果是为了获取线上提交结果应该讲测试集中的点击数据合并到总的数据中
# 如果是为了线下验证模型的有效性或者特征的有效性，可以只使用训练集
def get_all_click_df(data_path, offline=True):
    if offline:
        file_path = os.path.join(data_path, 'train_click_log.csv')
        all_click = pd.read_csv(file_path)
    else:
        file_path = os.path.join(data_path, 'train_click_log.csv')
        trn_click = pd.read_csv(file_path)
        file_path = os.path.join(data_path, 'testA_click_log.csv')
        tst_click = pd.read_csv(file_path)

        all_click = pd.concat([trn_click, tst_click], axis=0)
    
    all_click = all_click.drop_duplicates((['user_id', 'click_article_id', 'click_timestamp']))
    return all_click

In [7]:
# 全量训练集
all_click_df = get_all_click_df(data_path, offline=False)

In [8]:
emb_df = pd.read_csv(data_path + 'articles_emb.csv')
emb_df.head()

Unnamed: 0,article_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,...,emb_240,emb_241,emb_242,emb_243,emb_244,emb_245,emb_246,emb_247,emb_248,emb_249
0,0,-0.161183,-0.957233,-0.137944,0.050855,0.830055,0.901365,-0.335148,-0.559561,-0.500603,...,0.321248,0.313999,0.636412,0.169179,0.540524,-0.813182,0.28687,-0.231686,0.597416,0.409623
1,1,-0.523216,-0.974058,0.738608,0.155234,0.626294,0.485297,-0.715657,-0.897996,-0.359747,...,-0.487843,0.823124,0.412688,-0.338654,0.320786,0.588643,-0.594137,0.182828,0.39709,-0.834364
2,2,-0.619619,-0.97296,-0.20736,-0.128861,0.044748,-0.387535,-0.730477,-0.066126,-0.754899,...,0.454756,0.473184,0.377866,-0.863887,-0.383365,0.137721,-0.810877,-0.44758,0.805932,-0.285284
3,3,-0.740843,-0.975749,0.391698,0.641738,-0.268645,0.191745,-0.825593,-0.710591,-0.040099,...,0.271535,0.03604,0.480029,-0.763173,0.022627,0.565165,-0.910286,-0.537838,0.243541,-0.885329
4,4,-0.279052,-0.972315,0.685374,0.113056,0.238315,0.271913,-0.568816,0.341194,-0.600554,...,0.238286,0.809268,0.427521,-0.615932,-0.503697,0.61445,-0.91776,-0.424061,0.185484,-0.580292


In [11]:
# convert to dict
emb_df['emb'] = emb_df.loc[:,[f'emb_{i}' for i in range(250)]].apply(lambda row: row.values, axis=1)
emb_dict = dict(zip(emb_df['article_id'], emb_df['emb']))
embeddings = np.array(list(emb_dict.values()))
# emb_dict

In [None]:
# brute-force appoach to find most sim items
most_sim_10_items = {}
for item, emb in tqdm(emb_dict.items()):
    if item not in most_sim_10_items:
        most_sim_10_items[item] = {}
    item_index = []
    sim_values = []
    for relate_item, sim in emb_dict.items():
        if item != relate_item and relate_item not in emb_dict:
            continue
        item_index.append(relate_item)
        sim_values.append(np.dot(emb, emb_dict[relate_item]))
    top_10_index = np.argsort(sim_values)[-11:-2]   # ignore the item itself
    # print(top_10_index)
    for i in top_10_index:
        # print(item_index[i], sim_values[i])
        most_sim_10_items[item][item_index[i]] = sim_values[i]

# save most sim items
pickle.dump(most_sim_10_items, open(save_path + 'emb_most_sim_10_item.pkl', 'wb'))

  0%|          | 1/364047 [00:00<63:27:26,  1.59it/s]


KeyboardInterrupt: 

In [None]:
# using faiss to find most sim items
import faiss

True

In [66]:
# basic search
d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)
index.is_trained
index.add(embeddings)
index.ntotal
k = 11
D, I = index.search(embeddings, k)
print(I)

[[ 84015  83447  84091 ...  84386  77411  77258]
 [     1 109065 301649 ...     13  10791 244380]
 [  3653      2   3424 ...   1899   3075   2546]
 ...
 [ 43158 364044  43179 ...  43216  41252  43213]
 [103291 105536 102522 ... 108368  89427  89391]
 [364046 364043 346554 ... 163794 292350 341338]]


In [43]:
# IVF
nlist = 100
quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
index.train(embeddings)

index.add(embeddings)
index.nprobe = 10
k = 10
D, I = index.search(embeddings, k)
print(I)

[[ 84015  83447  84091 ...  77974  84386  77411]
 [     1 109065 301649 ... 109650     13  10791]
 [  3653      2   3424 ...   2469   1899   3075]
 ...
 [ 43158 364044  43179 ...  43167  43216  41252]
 [103291 105536 102522 ...  89806 108368  89427]
 [364046 364043 346554 ... 346834 163794 292350]]


In [44]:
I[0][0]

84015

In [88]:
# IVF + PQ
nlist = 1000
m = 10 # number of subquantizers
nbits = 8 # number of bit per subquantizer index

quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)

index.train(embeddings)
index.add(embeddings)

index.nprobe = 50

k = 11
D, I = index.search(embeddings, k)
print(I)

[[     0  77610  77983 ...  78227  83772  81169]
 [     1 301649   7126 ... 117628 178117 169338]
 [     2   3782   2903 ...   3849   3414   2038]
 ...
 [364044  40071  38759 ... 363470  43208 363497]
 [364045  61563  78335 ... 278633  89340 105962]
 [364046 364043 347855 ... 292187 316253 291412]]


In [74]:
def save_dict(I, D, file):

    sim_dict = {}
    for i, items in enumerate(I):
        count = 0
        for item, sim in zip(items, D[i]):
            if i not in sim_dict:
                sim_dict[i] = {}
            if item == i:
                continue
            sim_dict[i][item] = sim
            count += 1
            if count >= 10:
                break
    
    with open(file, 'wb') as f:
        pickle.dump(sim_dict, f, pickle.HIGHEST_PROTOCOL)

    return sim_dict

In [None]:
file_path = save_path + 'basic_most_sim_10_item.pkl'
sim_dict = save_dict(I, D, file_path)
sim_dict

{0: {84015: 63.86139,
  83447: 63.69615,
  84091: 63.388916,
  84074: 63.286636,
  78194: 63.061684,
  84619: 62.772038,
  84918: 62.4016,
  77974: 62.35498,
  84386: 62.300327,
  77411: 62.003044},
 1: {109065: 62.639038,
  301649: 57.690563,
  204342: 57.394226,
  109078: 57.127018,
  5431: 56.29313,
  326327: 55.773808,
  109650: 55.75948,
  13: 55.724655,
  10791: 55.649628,
  244380: 54.51692},
 2: {3653: 59.431927,
  3424: 58.816536,
  3795: 58.46267,
  4090: 58.19672,
  2555: 58.090935,
  2038: 58.043365,
  2469: 57.9605,
  1899: 57.853043,
  3075: 57.847996,
  2546: 57.755383},
 3: {1415: 67.33089,
  1702: 65.61456,
  4424: 65.01054,
  1383: 64.34637,
  1385: 64.314186,
  1350: 63.842865,
  1338: 63.6408,
  533: 63.501076,
  1410: 63.41412,
  262: 63.190212},
 4: {15: 62.20402,
  2567: 62.16925,
  2220: 60.344036,
  4032: 59.223114,
  3597: 58.980804,
  783: 57.74648,
  13: 57.184135,
  19: 57.122868,
  1271: 56.79103,
  2469: 56.409058},
 5: {13: 61.209484,
  1728: 60.59404,
 

In [89]:
file_path = save_path + 'ivfpq_most_sim_10_item.pkl'
sim_dict = save_dict(I, D, file_path)
sim_dict

{0: {77610: 12.844373,
  77983: 13.769381,
  77608: 14.137655,
  84006: 14.246804,
  83909: 14.333242,
  77712: 14.352239,
  83905: 14.406701,
  78227: 14.463646,
  83772: 14.93431,
  81169: 15.026012},
 1: {301649: 14.213521,
  7126: 16.98061,
  240266: 17.045528,
  7210: 17.189247,
  303003: 17.52216,
  304597: 17.896921,
  301656: 18.095772,
  117628: 18.098618,
  178117: 18.56564,
  169338: 18.592495},
 2: {3782: 16.145462,
  2903: 16.391693,
  3193: 16.62664,
  83: 17.039663,
  3191: 17.085938,
  2033: 17.169102,
  2832: 17.856611,
  3849: 17.928402,
  3414: 17.954882,
  2038: 18.113832},
 3: {1030: 16.921066,
  111: 18.601122,
  865: 18.925396,
  527: 19.091545,
  373: 19.147224,
  890: 19.230827,
  1304: 19.406044,
  1415: 19.623549,
  559: 19.68724,
  1195: 19.741676},
 4: {2220: 17.506964,
  4309: 17.59946,
  7: 18.370438,
  25: 18.45097,
  4005: 18.504484,
  2081: 18.802725,
  4312: 19.467976,
  18: 19.472137,
  1939: 20.260403,
  5: 20.473534},
 5: {15: 23.96265,
  1728: 24.