## Settings

In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append('..')

# Main

In [1]:
import json
import pickle
from collections import defaultdict

import torch
from torch import nn
from ignite.metrics import TopKCategoricalAccuracy, Loss

In [2]:
batch_size = 2 ** 13

loss_fn = nn.CrossEntropyLoss()
opt_ = torch.optim.Adam
lr = 0.00003
val_metrics = {
        'top-10 acc': TopKCategoricalAccuracy(10),
        'loss': Loss(loss_fn)
        }
device = 'cuda:1'
max_epochs = 1000

## Load Data

In [None]:
import pandas as pd

In [None]:
clicks = pd.read_csv('../data/raw/yoochoose-clicks.dat',
        names=['sess', 'ts', 'item', 'cat'],  dtype={'cat': str},
        usecols=['sess', 'ts', 'item'], header=None)
clicks.head()

## Preprocess

In [None]:
from datetime import datetime as dt, timedelta as td

In [None]:
clicks['ts'] = clicks['ts'].apply(lambda s: dt.strptime(s[:19], '%Y-%m-%dT%H:%M:%S'))
clicks.head()

In [None]:
valsplitdate = max(clicks['ts']) - td(20)
testsplitdate = max(clicks['ts']) - td(1)
item_count = clicks['item'].value_counts()

In [None]:
remain_sess = []
remain_item = set()
for _, group in clicks.groupby('sess', sort=False):
    print(group.iat[0, 0], end='\r')
    gi = group['item'].tolist()
    n = len(gi)
    stop = False
    if n > 1:
        for item in gi:
            if item_count[item] < 5:
                stop = True
                break
    else:
        stop = True
    if not stop:
        remain_sess.append((str(group.iat[0, 0]), group.iat[0, 1], gi))
        for item in gi:
            remain_item.add(item)
with open('../data/interim/n_items.json', 'w') as f:
    json.dump(len(remain_item), f)

In [None]:
from sklearn.preprocessing import LabelEncoder

remain_item = list(remain_item)
item_enc = LabelEncoder()
item_enc.fit(remain_item)

In [None]:
train_d = defaultdict(list)
val_d = defaultdict(list)
test_d = defaultdict(list)
for sess, ts, items in remain_sess:
    print(sess, end='\r')
    items = item_enc.transform(items).tolist()
    if ts < valsplitdate:
        for i in range(1, len(items)):
            train_d[sess].append((items[: i], items[i]))
    elif ts < testsplitdate:
        for i in range(1, len(items)):
            val_d[sess].append((items[: i], items[i]))
    else:
        for i in range(1, len(items)):
            test_d[sess].append((items[: i], items[i]))
with open('../data/interim/train.json', 'w') as f:
    json.dump(train_d, f)
with open('../data/interim/val.json', 'w') as f:
    json.dump(val_d, f)
with open('../data/interim/test.json', 'w') as f:
    json.dump(test_d, f)

## Prepare Input Data

In [3]:
from math import floor
from torch_geometric.data import Data, Dataset, DataLoader

In [4]:
class YooChooseDataset(Dataset):
    def __init__(self, d):
        super(YooChooseDataset, self).__init__()
        self.samples = self.add_from_dict(d)
        
    def add_from_dict(self, d):
        samples = []
        for dd in d.values():
            for x_ids, y in dd:
                x_ids_ = set(x_ids)
                x_ids_.remove(x_ids[- 1])
                x_ids_ = list(x_ids_) + [x_ids[- 1]]
                x_dict = {x_id: i for i, x_id in enumerate(x_ids_)}
                x = [[x_id] for x_id in x_ids_]
                edge_dict = defaultdict(lambda: defaultdict(int))
                for i in range(len(x_ids) - 1):
                    edge_dict[x_dict[x_ids[i]]][x_dict[x_ids[i + 1]]] += 1
                edge_index, edge_weights = [], []
                for o in edge_dict.keys():
                    s = sum(edge_dict[o].values())
                    for d in edge_dict[o].keys():
                        edge_index.append([o, d])
                        edge_weights.append(edge_dict[o][d] / s)
                x = torch.tensor(x, dtype=torch.long)
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_weights = torch.tensor(edge_weights)
                samples.append((Data(x, edge_index=edge_index, edge_weights=edge_weights), y))
        return samples
                
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return {
                'graph': self.samples[idx][0],
                'label': self.samples[idx][1]
                }

In [None]:
with open('../data/interim/n_items.json', 'r') as f:
    n_items = json.load(f)
with open('../data/interim/train.json', 'r') as f:
    train_d = json.load(f)
with open('../data/interim/val.json', 'r') as f:
    val_d = json.load(f)
with open('../data/interim/test.json', 'r') as f:
    test_d = json.load(f)
print('# items: {}\n# train sessions: {}\n# val sessions: {}\n# test sessions: {}'
        .format(n_items, len(train_d), len(val_d), len(test_d)))

In [None]:
train_dataset = YooChooseDataset(train_d)
val_dataset = YooChooseDataset(val_d)
test_dataset = YooChooseDataset(test_d)
print('# train samples: {}\n# val samples: {}\n# test samples: {}'
        .format(len(train_dataset), len(val_dataset), len(test_dataset)))

In [None]:
from multiprocessing import Pool
#from itertools import repeat

def save_dataset(dataset, tvt, i, n_div):
    lbd = floor((i / n_div) * len(dataset))
    rbd = floor(((i + 1) / n_div) * len(dataset))
    ds = YooChooseDataset({})
    ds.samples = dataset.samples[lbd: rbd]
    torch.save(ds, '../data/processed/{}_{}.pt'
            .format(tvt, str(i).zfill(2)))

n_div = 90
for i in range(n_div):
    save_dataset(train_dataset, 'train_dataset', i, n_div)
n_div = 10
for i in range(n_div):
    save_dataset(val_dataset, 'val_dataset', i, n_div)
n_div = 1
for i in range(n_div):
    save_dataset(test_dataset, 'test_dataset', i, n_div)
    
"""
with Pool(30) as p:
    dataset = train_dataset
    tvt = 'train_dataset'
    n_div = 90
    p.map(save_dataset, zip(repeat(dataset), repeat(tvt), range(n_div), repeat(n_div)))
    
with Pool(30) as p:
    dataset = val_dataset
    tvt = 'val_dataset'
    n_div = 10
    p.map(save_dataset, zip(repeat(dataset), repeat(tvt), range(n_div), repeat(n_div)))

with Pool(30) as p:
    dataset = test_dataset
    tvt = 'test_dataset'
    n_div = 1
    p.map(save_dataset, zip(repeat(dataset), repeat(tvt), range(n_div), repeat(n_div)))
"""

In [5]:
from torch.utils.data import ConcatDataset

with open('../data/interim/n_items.json', 'r') as f:
    n_items = json.load(f)
train_dataset_list = []
for i in range(9):
    d = torch.load(
            '../data/processed/train_dataset_' + str(i).zfill(2) + '.pt')
    train_dataset_list.append(d)
train_dataset = ConcatDataset(train_dataset_list)
val_dataset_list = []
for i in range(1):
    d = torch.load(
            '../data/processed/val_dataset_' + str(i).zfill(2) + '.pt')
    val_dataset_list.append(d)
val_dataset = ConcatDataset(val_dataset_list)
test_dataset_list = []
for i in range(1):
    d = torch.load(
            '../data/processed/test_dataset_' + str(i).zfill(2) + '.pt')
    test_dataset_list.append(d)
test_dataset = ConcatDataset(test_dataset_list)
print('# train samples: {}\n# val samples: {}\n# test samples: {}'
        .format(len(train_dataset), len(val_dataset), len(test_dataset)))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=24)

# train samples: 2112000
# val samples: 246752
# test samples: 55474


## Construct Model

In [6]:
from torch_geometric.nn import GatedGraphConv

class Attention(nn.Module):
    def __init__(self, embed_dim):
        super(Attention, self).__init__()
        self.wq = nn.Linear(embed_dim, embed_dim)
        self.wk = nn.Linear(embed_dim, embed_dim)
        self.sigmoid = nn.Sigmoid()
        self.q = nn.Linear(embed_dim, 1)
        self.w = nn.Linear(2 * embed_dim, embed_dim)
        
    def forward(self, x, batch):
        sections = list(torch.bincount(batch).to('cpu').numpy())
        x_split = torch.split(x, sections)
        q_split = [x_[- 1].view(1, - 1) for x_ in x_split]
        q = torch.cat([x_[- 1].view(1, - 1).repeat(x_.shape[0], 1) for x_ in x_split])
        q = self.wq(q)
        k = self.wk(x)
        a = self.q(q + k)
        ax = a * x
        ax_split = torch.split(ax, sections)
        sg_split = [torch.sum(ax_, 0).view(1, - 1) for ax_ in ax_split]
        sh_split = self.w(torch.cat((torch.cat(q_split), torch.cat(sg_split)), 1))
        return sh_split
    
class PredProb(nn.Module):
    def __init__(self):
        super(PredProb, self).__init__()
    
    def forward(self, sh, embedding):
        return torch.mm(sh, embedding.weight.transpose(1, 0))

class SRGNN(nn.Module):
    def __init__(self, n_items, embed_dim):
        super(SRGNN, self).__init__()
        self.embedding = nn.Embedding(n_items, embed_dim)
        self.gatedgconv = GatedGraphConv(embed_dim, 1)
        self.relu = nn.ReLU()
        self.attention = Attention(embed_dim)
        self.predprob = PredProb()
        
    def _initialize_weights(self, ):
        pass
    
    def forward(self, data):
        x, edge_index, edge_weights, batch =\
                data.x, data.edge_index, data.edge_weights, data.batch
        x = self.embedding(x).squeeze()
        x = self.gatedgconv(x, edge_index, edge_weights)
        x = self.relu(x)
        x = self.attention(x, batch)
        x = self.predprob(x, self.embedding)
        return x

In [9]:
model = SRGNN(n_items, 128)
for b in model.named_children():
    print(b)

('embedding', Embedding(37821, 128))
('gatedgconv', GatedGraphConv(128, num_layers=1))
('relu', ReLU())
('attention', Attention(
  (wq): Linear(in_features=128, out_features=128, bias=True)
  (wk): Linear(in_features=128, out_features=128, bias=True)
  (sigmoid): Sigmoid()
  (q): Linear(in_features=128, out_features=1, bias=True)
  (w): Linear(in_features=256, out_features=128, bias=True)
))
('predprob', PredProb())


## Train

In [19]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import Checkpoint, DiskSaver

def train_net(net, opt, loss_fn, val_metrics, train_loader, val_loader, device):
    net.to(device)
    def prepare_batch(batch, device, non_blocking=False):
        x, y = batch.values()
        return x.to(device), y.to(device)
    def output_transform(x, y, y_pred, loss):
        return (y_pred.max(1)[1], y)
    trainer = create_supervised_trainer(net, opt, loss_fn, device,
            prepare_batch=prepare_batch, output_transform=output_transform)
    evaluator = create_supervised_evaluator(net, val_metrics, device,
            prepare_batch=prepare_batch)
    s = '{}: {:.2f} '
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_loader)
        print('Epoch {}'.format(trainer.state.epoch))
        message = 'Train - '
        for m in val_metrics.keys():
            message += s.format(m, evaluator.state.metrics[m])
        print(message)
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        evaluator.run(val_loader)
        message = 'Val   - '
        for m in val_metrics.keys():
            message += s.format(m, evaluator.state.metrics[m])
        print(message)
    
    #=====================================
    def score_function(engine):
        return engine.state.metrics['top-10 acc']
    to_save = {'model': model}
    handler = Checkpoint(to_save, DiskSaver('../models/tmp', create_dir=True), n_saved=2,
            filename_prefix='best', score_function=score_function, score_name="val_acc")
    evaluator.add_event_handler(Events.COMPLETED, handler)
    #===============================
    
    return trainer

In [None]:
opt = opt_(model.parameters(), lr)

trainer = train_net(model, opt, loss_fn, val_metrics,
        train_loader, val_loader, device)
trainer.run(train_loader, max_epochs=max_epochs)

Epoch 1
Train - top-10 acc: 0.06 loss: 10.30 
Val   - top-10 acc: 0.00 loss: 11.25 
Epoch 2
Train - top-10 acc: 0.12 loss: 9.79 
Val   - top-10 acc: 0.01 loss: 11.08 
Epoch 3
Train - top-10 acc: 0.17 loss: 9.41 
Val   - top-10 acc: 0.01 loss: 10.97 
Epoch 4
Train - top-10 acc: 0.21 loss: 9.10 
Val   - top-10 acc: 0.02 loss: 10.89 
Epoch 5
Train - top-10 acc: 0.24 loss: 8.85 
Val   - top-10 acc: 0.03 loss: 10.84 
Epoch 6
Train - top-10 acc: 0.26 loss: 8.63 
Val   - top-10 acc: 0.04 loss: 10.81 
Epoch 7
Train - top-10 acc: 0.28 loss: 8.44 
Val   - top-10 acc: 0.05 loss: 10.79 
Epoch 8
Train - top-10 acc: 0.29 loss: 8.27 
Val   - top-10 acc: 0.06 loss: 10.79 
Epoch 9
Train - top-10 acc: 0.31 loss: 8.11 
Val   - top-10 acc: 0.06 loss: 10.79 
Epoch 10
Train - top-10 acc: 0.32 loss: 7.96 
Val   - top-10 acc: 0.07 loss: 10.80 
Epoch 11
Train - top-10 acc: 0.33 loss: 7.82 
Val   - top-10 acc: 0.07 loss: 10.82 
Epoch 12
Train - top-10 acc: 0.34 loss: 7.68 
Val   - top-10 acc: 0.07 loss: 10.84 


## Test

In [None]:
a = 1

In [None]:
torch.save(train_dataset, '../data/processed/train_dataset.pt')
torch.save(val_dataset, '../data/processed/val_dataset.pt')
torch.save(test_dataset, '../data/processed/test_dataset.pt')

In [None]:
with open('../data/processed/train_dataset_00.pickle', 'rb') as f:
    asdfasdf = pickle.load(f)
asdfasdf

In [None]:
import pickle

#with open('../data/processed/train_dataset.pkl', 'wb') as f:
#    pickle.dump(train_dataset, f)
with open('../data/processed/val_dataset.pkl', 'wb') as f:
    pickle.dump(val_dataset, f)
#with open('../data/processed/test_dataset.pkl', 'wb') as f:
#    pickle.dump(test_dataset, f)

In [None]:
from sklearn.model_selection import train_test_split

train_d, val_d = train_test_split(train_d.items(), test_size=0.1, shuffle=True)
train_d, val_d = dict(train_d), dict(val_d)
train_dataset = YooChooseDataset(train_d)
val_dataset = YooChooseDataset(val_d)
test_dataset = YooChooseDataset(test_d)
print('# train samples: {}\n# val samples: {}\n# test samples: {}'
        .format(len(train_dataset), len(val_dataset), len(test_dataset)))

In [None]:
# 여기서부터 subset을 배제하고 가야됨
from torch.utils.data import random_split

train_dataset = YooChooseDataset(train_d)
train_dataset, val_dataset = random_split(train_dataset,
        [floor(0.9 * len(train_dataset)),
                len(train_dataset) - floor(0.9 * len(train_dataset))])
test_dataset = YooChooseDataset(test_d)
print('# train samples: {}\n# val samples: {}\n# test samples: {}'
        .format(len(train_dataset), len(val_dataset), len(test_dataset)))

In [None]:
class YooChooseDataset(Dataset):
    def __init__(self, d):
        super(YooChooseDataset, self).__init__()
        self.samples = self.add_from_dict(d)
        
    def add_from_dict(self, d):
        samples = []
        for dd in d.values():
            for data in dd:
                samples.append(data)
        return samples
                
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        x_ids, y = self.samples[idx]
        x_ids_ = set(x_ids)
        x_ids_.remove(x_ids[- 1])
        x_ids_ = list(x_ids_) + [x_ids[- 1]]
        x_dict = {x_id: i for i, x_id in enumerate(x_ids_)}
        x = [[x_id] for x_id in x_ids_]
        edge_dict = defaultdict(lambda: defaultdict(int)) # 얘네를 다 미리 저장해야 할 듯? f-b propagation step에 비해 오래 걸리는지 확인해보고 일단
        for i in range(len(x_ids) - 1):
            edge_dict[x_dict[x_ids[i]]][x_dict[x_ids[i + 1]]] += 1
        edge_index, edge_weights = [], []
        for o in edge_dict.keys():
            s = sum(edge_dict[o].values())
            for d in edge_dict[o].keys():
                edge_index.append([o, d])
                edge_weights.append(edge_dict[o][d] / s)
        x = torch.tensor(x, dtype=torch.long)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_weights = torch.tensor(edge_weights)
        return {
                'graph': Data(x, edge_index=edge_index, edge_weights=edge_weights),
                'label': y
                }

In [None]:
"""
from sklearn.preprocessing import OneHotEncoder

remain_item = item_enc.transform(remain_item).tolist()
item_enc = OneHotEncoder(sparse=False)
item_enc.fit([[item] for item in remain_item])
with open('../data/interim/onehotencoder.pkl', 'wb') as f:
    pickle.dump(item_enc, f)
"""