In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
# 多行输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all" 

数据与预处理参考 [yooChooseRec.ipynb](./yooChooseRec.ipynb)

这里解决用户在某个session中会发生购买行为，预测用户购买什么东西

In [2]:
import os, gc
import torch
import pickle
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
np.random.seed(123)

In [3]:
root = Path('../data/yoochoose-data/')

## 数据准备

In [4]:
clicks = pd.read_csv(root / 'clicks_pro.csv', encoding='utf-8', low_memory=False)
buys = pd.read_csv(root / 'buys_pro.csv', encoding='utf-8', low_memory=False)

In [5]:
clicks.nunique()
buys.nunique()

session_id     4431931
timestamp     24590089
item_id          48255
category           330
label                2
dtype: int64

session_id    377376
timestamp     966877
item_id        19210
price            707
quantity          27
dtype: int64

In [6]:
clicks.head()
buys.head()

Unnamed: 0,session_id,timestamp,item_id,category,label
0,1,2014-04-07T10:51:09.277Z,1909,0,False
1,1,2014-04-07T10:54:09.868Z,1908,0,False
2,1,2014-04-07T10:54:46.998Z,1910,0,False
3,1,2014-04-07T10:57:00.306Z,9038,0,False
4,2,2014-04-07T13:56:37.614Z,17503,0,False


Unnamed: 0,session_id,timestamp,item_id,price,quantity
0,420374,2014-04-06T18:44:58.314Z,2231,12462,1
1,420374,2014-04-06T18:44:58.325Z,2223,10471,1
2,281626,2014-04-06T09:40:13.032Z,1769,1883,1
3,420368,2014-04-04T06:13:28.848Z,907,6073,1
4,420368,2014-04-04T06:13:28.858Z,39413,2617,1


- 购买力字典

In [7]:
buy_item_dict = dict(buys.groupby('session_id')['item_id'].apply(list))

In [8]:
buy_item_dict[420374]  # session 420374 购买的 item id

[2231, 2223]

In [10]:
with open(root/'buy_item_dict.pkl',  'wb') as f:
    pickle.dump(buy_item_dict, f)

## 构造数据集

In [None]:
import torch
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

class YooChooseDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(YooChooseDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['../input/yoochoose_click_binary_100000_sess.dataset']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []

        # process by session_id
        grouped = df.groupby('session_id')
        for session_id, group in tqdm(grouped):
            le = LabelEncoder()
            sess_item_id = le.fit_transform(group.item_id)
            group = group.reset_index(drop=True)
            group['sess_item_id'] = sess_item_id
            node_features = group.loc[group.session_id==session_id,['sess_item_id','item_id','category']].sort_values('sess_item_id')[['item_id','category']].drop_duplicates().values

            node_features = torch.LongTensor(node_features).unsqueeze(1)
            target_nodes = group.sess_item_id.values[1:]
            source_nodes = group.sess_item_id.values[:-1]

            edge_index = torch.tensor([source_nodes,
                                   target_nodes], dtype=torch.long)
            x = node_features

            if session_id in buy_item_dict:
                positive_indices = le.transform(buy_item_dict[session_id])
                label = np.zeros(len(node_features))
                label[positive_indices] = 1
            else:
                label = [0] * len(node_features)


            y = torch.FloatTensor(label)

            data = Data(x=x, edge_index=edge_index, y=y)

            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])