In [31]:
import os
import sys
import json
import torch
import random
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam

In [32]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True

seed_everything(2022)

In [75]:
args = {
    "data_dir": "../data/ml-1m",
    "epochs": 20,
    "device": "cuda:0",
    "lr": 5e-4,
    "mask_prob": 0.4
}

args = argparse.Namespace(**args)


In [76]:
def load_ndjson(input_file):
    with open(input_file, "r") as f:
        lines = f.read()
        d = [json.loads(l) for l in lines.splitlines()]
    return d


def load_seq_txt(input_file):
    output = []
    with open(input_file, "r") as f:
        for line in f.readlines():
            line = line.strip("\n")
            line = line.split(" ")
            line = [int(i) for i in line]
            output.append(line)
    return output


def load_dataset(data_dir, mode=""):
    mode_list = ["", "train", "dev", "test"]
    if mode not in mode_list:
        raise ValueError("Incorrect mode. Must be `train`|`dev`|`test`.")

    if mode != "":
        data_dir = os.path.join(data_dir, mode)
    behavior = load_seq_txt(os.path.join(data_dir, "sequential_data.txt"))
    content = load_ndjson(os.path.join(data_dir, "content.json"))
    return behavior, content


behavior, content = load_dataset(args.data_dir, mode="")

In [None]:
class Seq_Dataset(Dataset):
    def __init__(self, purchase_history, max_item, mask_prob, mode="") -> None:
        self.max_len = 10
        self.purchase_history = purchase_history
        self.mode = mode
        self.mask_id = max_item + 1
        self.mask_prob = mask_prob
        assert mode in ["train", "test"]
        super().__init__()

    def __len__(self):
        return len(self.purchase_history)

    def __getitem__(self, index):
        purchase_history = self.purchase_history[index]
        if self.mode == "train":
            purchase_history = purchase_history[0:-1]
        
        seq_list = purchase_history[0:-1]
        tgt_list = purchase_history[1:]
        seq = self.truncate_and_pad(seq_list)
        tgt = self.truncate_and_pad(tgt_list)

        if self.mode == "train":
            tokens, labels = [], []
            for s in seq:
                if s != 0:
                    prob = random.random()
                    if prob < self.mask_prob:
                        tokens.append(self.mask_id)
                        labels.append(s)
                    else:
                        tokens.append(s)
                        labels.append(0)
                else:
                    tokens.append(s)
                    labels.append(0)
            return torch.LongTensor(tokens), torch.LongTensor(labels)
        else:
            seq.append(self.mask_id)
            seq = seq[1:]
            return torch.LongTensor(seq), torch.LongTensor(tgt)
    
    def truncate_and_pad(self, input_list):
        length = len(input_list)
        if length > self.max_len:
            return input_list[length - self.max_len : length]
        elif length < self.max_len:
            return [0] * (self.max_len - length) + input_list
        else:
            return input_list


max_item = 0
def process_dataset(behavior, content):
    purchase_history = []
    global max_item
    for i, user_info in enumerate(behavior):
        if not user_info or len(user_info) < 5:
            continue
        user_info = user_info[1:]
        max_item = max(max_item, max(user_info))
        purchase_history.append(user_info)
    print(len(purchase_history))
    return purchase_history


train_dataset = Seq_Dataset(process_dataset(behavior, content), max_item, args.mask_prob, mode="train")
test_dataset = Seq_Dataset(process_dataset(behavior, content), max_item, args.mask_prob, mode="test")

In [78]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=False, num_workers=0)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True, num_workers=0)

In [79]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_, xavier_uniform_, uniform_, constant_
import math

class PositionalEmbedding(nn.Module):

    def __init__(self, max_len, d_model):
        super(PositionalEmbedding, self).__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)


class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn


class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)


class SublayerConnection(nn.Module):
    def __init__(self, size, enable_res_parameter, dropout=0.1):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        self.enable = enable_res_parameter
        if enable_res_parameter:
            self.a = nn.Parameter(torch.tensor(1e-8))

    def forward(self, x, sublayer):
        if not self.enable:
            return self.norm(x + self.dropout(sublayer(x)))
        else:
            return self.norm(x + self.dropout(self.a * sublayer(x)))


class PointWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ffn, dropout=0.1):
        super(PointWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.linear2(self.activation(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, attn_heads, d_ffn, enable_res_parameter, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attn = MultiHeadAttention(attn_heads, d_model, dropout)
        self.ffn = PointWiseFeedForward(d_model, d_ffn, dropout)
        self.skipconnect1 = SublayerConnection(d_model, enable_res_parameter, dropout)
        self.skipconnect2 = SublayerConnection(d_model, enable_res_parameter, dropout)

    def forward(self, x, mask):
        x = self.skipconnect1(x, lambda _x: self.attn.forward(_x, _x, _x, mask=mask))
        x = self.skipconnect2(x, self.ffn)
        return x


class BERT(nn.Module):
    def __init__(self, args):
        super(BERT, self).__init__()
        self.num_item = args.num_item + 2
        d_model = args.d_model
        attn_heads = args.attn_heads
        d_ffn = args.d_ffn
        layers = args.bert_layers
        dropout = args.dropout
        enable_res_parameter = 1
        self.device = args.device
        self.embedding_sharing = args.embedding_sharing
        self.max_len = args.max_len
        self.attention_mask = torch.tril(torch.ones((self.max_len, self.max_len), dtype=torch.bool)).to(self.device)

        self.token = nn.Embedding(self.num_item, d_model)
        self.position = PositionalEmbedding(self.max_len, d_model)

        self.TRMs = nn.ModuleList(
            [TransformerBlock(d_model, attn_heads, d_ffn, enable_res_parameter, dropout) for i in range(layers)])

        self.output = nn.Linear(d_model, self.num_item - 1)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            stdv = np.sqrt(1. / self.num_item)
            uniform_(module.weight.data, -stdv, stdv)
        elif isinstance(module, nn.Linear):
            xavier_normal_(module.weight.data)
            if module.bias is not None:
                constant_(module.bias.data, 0.1)

    def forward(self, x, label=None):
        assert not (label is not None)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        x = self.token(x) + self.position(x)

        for TRM in self.TRMs:
            x = TRM(x, mask)

        if not self.embedding_sharing:
            out = self.output(x[:, -1, :])
            return self.output(x), out
        else:
            out = F.linear(x[:, -1, :], self.token.weight[:-1])
            return F.linear(x, self.token.weight[:-1]), out

In [80]:
def train(model, train_dataloader, test_dataloader, opt, loss_func, max_item):
    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        train_iter = tqdm(train_dataloader, ncols=100)

        for idx, (seq, tgt) in enumerate(train_iter):
            seq = seq.to(args.device)
            tgt = tgt.to(args.device)
            
            out, _ = model(seq)
            loss = loss_func(out.view(-1, out.shape[-1]), tgt.view(-1))

            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.cpu().item()

            train_iter.set_postfix({"loss": total_loss / (idx + 1)})

In [81]:
def _dcg_score(y_true, order, k=10):
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)

def _ndcg_score(y_true, order, k=10):
    actual = _dcg_score(y_true, order, k)
    return actual / 1.


class MetricScores(object):
    def __init__(self) -> None:
        self.ndcg5s = []
        self.ndcg10s = []
        self.ndcg20s = []
        self.recall1_true = 0
        self.recall5_true = 0
        self.recall5_total = 0
        self.recall10_true = 0
        self.recall10_total = 0
        self.recall20_true = 0
        self.recall20_total = 0
        self.k = [1, 5, 10, 20]

    def __call__(self, label_ids: torch.Tensor, pred_ids: torch.Tensor):
        assert len(label_ids) == len(pred_ids)
    
        for k in self.k:
            pred = pred_ids[:, 0:k].detach().cpu().numpy()
            label = label_ids.detach().cpu().numpy().reshape(-1, 1)

            binary_labels = np.where(pred==label, 1, 0)

            if k == 10:
                self.recall10_true += np.sum(binary_labels)
                self.recall10_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(10)]
                    ndcg5 = _ndcg_score(y_true, order, 5)
                    ndcg10 = _ndcg_score(y_true, order, 10)
                    
                    self.ndcg5s.append(ndcg5)
                    self.ndcg10s.append(ndcg10)
            elif k == 5:
                self.recall5_true += np.sum(binary_labels)
                self.recall5_total += binary_labels.shape[0]
            elif k == 1:
                self.recall1_true += np.sum(binary_labels)
            elif k == 20:
                self.recall20_true += np.sum(binary_labels)
                self.recall20_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(20)]
                    ndcg20 = _ndcg_score(y_true, order, 20)
                    self.ndcg20s.append(ndcg20)
        return


    def output(self):
        ndcg5, ndcg10, ndcg20 = np.mean(self.ndcg5s), np.mean(self.ndcg10s), np.mean(self.ndcg20s)
        recall1 = self.recall1_true / self.recall5_total
        recall5 = self.recall5_true / self.recall5_total
        recall10 = self.recall10_true / self.recall10_total
        recall20 = self.recall20_true / self.recall20_total
        print(
            "Recall@1: {:.4f}\nRecall@5: {:.4f}\nRecall@10: {:.4f}\nRecall@20: {:.4f}\nnDCG@5: {:.4f}\nnDCG@10: {:.4f}\nnDCG@20: {:.4f}\n".format(
                recall1, recall5, recall10, recall20, ndcg5, ndcg10, ndcg20
            )
        )
        res = {}
        res["scores"] = {
            "Recall@1": recall1,
            "Recall@5": recall5,
            "Recall@10": recall10,
            "Recall@20": recall20,
            "nDCG@5": ndcg5,
            "nDCG@10": ndcg10,
            "nDCG@20": ndcg20,
        }
        return res

In [82]:
def evaluate(model, test_dataloader, max_item, device):
    res = MetricScores()
    model.eval()
    with torch.no_grad():
        for idx, (seq, tgt) in enumerate(tqdm(test_dataloader)):
            seq = seq.to(device)
            tgt = tgt.to(device)
            
            _, out = model(seq)
            pred = torch.argsort(out, dim=-1, descending=True)
            tgt = tgt[:, -1]

            res(tgt, pred)

    return res.output()

In [None]:
bert_args = {
    "device": "cuda:0",
    "num_item": max_item,
    "max_len": 10,
    "bert_layers": 8,
    "d_model": 512,
    "d_ffn": 512,
    "dropout": 0.1,
    "attn_heads": 8,
    "embedding_sharing": False,
}

model = BERT(argparse.Namespace(**bert_args))

model = model.to(args.device)
CE = torch.nn.CrossEntropyLoss(ignore_index=0)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
train(model, train_dataloader, test_dataloader, opt, CE, max_item+1)

print(args)
print(bert_args)
res = evaluate(model, test_dataloader, max_item, args.device)