In [None]:
import os
import sys
import json
import torch
import random
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True

seed_everything(2022)

In [38]:
args = {
    "data_dir": "../data/ml-1m",
    "epochs": 40,
    "device": "cuda:0",
    "lr": 1e-3,
}

args = argparse.Namespace(**args)


In [None]:
def load_ndjson(input_file):
    with open(input_file, "r") as f:
        lines = f.read()
        d = [json.loads(l) for l in lines.splitlines()]
    return d


def load_seq_txt(input_file):
    output = []
    with open(input_file, "r") as f:
        for line in f.readlines():
            line = line.strip("\n")
            line = line.split(" ")
            line = [int(i) for i in line]
            output.append(line)
    return output


def load_dataset(data_dir, mode=""):
    mode_list = ["", "train", "dev", "test"]
    if mode not in mode_list:
        raise ValueError("Incorrect mode. Must be `train`|`dev`|`test`.")

    if mode != "":
        data_dir = os.path.join(data_dir, mode)
    behavior = load_seq_txt(os.path.join(data_dir, "sequential_data.txt"))
    content = load_ndjson(os.path.join(data_dir, "content.json"))
    return behavior, content


behavior, content = load_dataset(args.data_dir, mode="")


from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

content_tokens = []

for i in tqdm(range(len(content))):
    if content[i] is None:
        content_tokens.append(None)
        continue
    i_content = content[i]
    i_content_tokens = {}
    for content_type in i_content.keys():
        if content_type == "id":
            continue
        try:
            i_content_tokens[content_type] = tokenizer.encode(i_content[content_type], return_tensors='pt')
        except:
            i_content_tokens[content_type] = torch.zeros(1)
    content_tokens.append(i_content_tokens)

In [40]:
class Seq_Dataset(Dataset):
    def __init__(self, purchase_history, mode="") -> None:
        self.max_len = 10
        self.purchase_history = purchase_history
        self.mode = mode
        assert mode in ["train", "test"]
        super().__init__()

    def __len__(self):
        return len(self.purchase_history)

    def __getitem__(self, index):
        purchase_history = self.purchase_history[index]
        if self.mode == "train":
            purchase_history = purchase_history[0:-1]
        
        seq_list = purchase_history[0:-1]
        tgt_list = purchase_history[1:]
        seq = self.truncate_and_pad(seq_list)
        tgt = self.truncate_and_pad(tgt_list)
        return torch.LongTensor(seq), torch.LongTensor(tgt)
    
    def truncate_and_pad(self, input_list):
        length = len(input_list)
        if length > self.max_len:
            return input_list[length - self.max_len : length]
        elif length < self.max_len:
            return [0] * (self.max_len - length) + input_list
        else:
            return input_list


max_item = 0
def process_dataset(behavior, content):
    purchase_history = []
    global max_item
    for i, user_info in enumerate(behavior):
        if not user_info or len(user_info) < 5:
            continue
        user_info = user_info[1:]
        max_item = max(max_item, max(user_info))
        purchase_history.append(user_info)
    return purchase_history


train_dataset = Seq_Dataset(process_dataset(behavior, content), mode="train")
test_dataset = Seq_Dataset(process_dataset(behavior, content), mode="test")

3952


In [41]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True, num_workers=0)

In [78]:
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

def gather_indexes(output, gather_index):
    gather_index = gather_index.view(-1, 1, 1).expand(output.shape[0], -1, output.shape[-1])
    output_tensor = output.gather(dim=1, index=gather_index)
    return output_tensor.squeeze(1)

class VanillaAttention(nn.Module):
    def __init__(self, hidden_dim, attn_dim):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, attn_dim), nn.ReLU(True), nn.Linear(attn_dim, 1)
        )

    def forward(self, input_tensor):
        energy = self.projection(input_tensor)
        weights = torch.softmax(energy.squeeze(-1), dim=-1)
        hidden_states = (input_tensor * weights.unsqueeze(-1)).sum(dim=-2)
        return hidden_states, weights

class FDSA(nn.Module):
    def __init__(self, args, content):
        super(FDSA, self).__init__()

        self.n_layers = args.n_layers
        self.n_heads = args.n_heads
        self.hidden_size = args.hidden_size
        self.inner_size = args.inner_size
        self.hidden_dropout_prob = args.dropout
        self.attn_dropout_prob = args.dropout

        self.pooling_mode = args.pooling_mode
        self.device = args.device

        self.content_tokens = content
        self.selected_features = args.selected_features
        self.n_items = args.num_item
        self.max_seq_length = 10

        self.item_embedding = nn.Embedding(
            self.n_items, self.hidden_size, padding_idx=0
        )
        self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)

        self.layer_norm_eps = 1e-5
        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)
        self.concat_layer = nn.Linear(self.hidden_size * 2, self.hidden_size)

        trm_layer = TransformerEncoderLayer(
            self.hidden_size,
            self.n_heads,
            self.inner_size,
            self.hidden_dropout_prob,
            batch_first=True,
        )
        self.item_trm_encoder = TransformerEncoder(
            trm_layer,
            self.n_layers,
            norm=self.LayerNorm,
        )

        self.feature_att_layer = VanillaAttention(self.hidden_size, self.hidden_size)

        self.feature_trm_encoder = TransformerEncoder(
            trm_layer,
            self.n_layers,
            norm=self.LayerNorm,
        )

        self.loss_fct = nn.CrossEntropyLoss()

        test_item_seq = torch.arange(0, self.n_items)
        self.test_item_emb = self.item_embedding(test_item_seq)
        self.apply(self._init_weights)
        self.other_parameter_name = ["feature_embed_layer"]

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=1.0)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def get_attention_mask(self, item_seq, bidirectional=False):
        attention_mask = item_seq != 0
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        if not bidirectional:
            extended_attention_mask = torch.tril(
                extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1))
            )
        extended_attention_mask = torch.where(extended_attention_mask, 0.0, -10000.0)
        extended_attention_mask = extended_attention_mask.squeeze(dim=1)
        mask_list = [extended_attention_mask for _ in range(self.n_heads)]
        extended_attention_mask = torch.cat(mask_list, dim=0)
        return extended_attention_mask

    def embed_features(self, item_seq, selected_features):
        feature_embedding_layer = nn.Embedding(len(tokenizer), self.hidden_size, padding_idx=0)
        feature_embeddings = []
        for feature in selected_features:
            feature_tokens_list = []
            padded_tokens_list = []
            max_len = 0
            for batch in item_seq:
                for i in batch:
                    try:
                        tokens = self.content_tokens[i][feature]
                    except:
                        print(i)
                        continue
                    feature_tokens_list.append(tokens)
                    if tokens.shape[-1] > max_len:
                        max_len = tokens.shape[-1]
            for tokens in feature_tokens_list:
                tokens = tokens.view(1, -1)
                for _ in range(max_len - tokens.shape[-1]):
                    tokens = torch.cat((tokens, torch.tensor([[0]])), dim=-1)
                padded_tokens_list.append(tokens)
            feature_tokens = torch.stack(padded_tokens_list, dim=0)
            feature_tokens = feature_tokens.view(item_seq.shape[0], item_seq.shape[1], -1).long()
            feature_embeds = feature_embedding_layer(feature_tokens)
            if self.pooling_mode == "mean":
                result = torch.mean(feature_embeds, dim=-2)
            elif self.pooling_mode == "max":
                result = torch.max(feature_embeds, dim=-2)
            elif self.pooling_mode == "min":
                result = torch.min(feature_embeds, dim=-2)
            result = result.unsqueeze(dim=-2)
            feature_embeddings.append(result)
        return torch.cat(feature_embeddings, dim=-2).to(self.device)

    def forward(self, item_seq, item_seq_len):
        item_emb = self.item_embedding(item_seq)

        position_ids = torch.arange(
            item_seq.size(1), dtype=torch.long, device=item_seq.device
        )
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)

        item_emb = item_emb + position_embedding
        item_emb = self.LayerNorm(item_emb)
        item_trm_input = self.dropout(item_emb)

        feature_table = self.embed_features(item_seq, self.selected_features)

        feature_emb, attn_weight = self.feature_att_layer(feature_table)
        feature_emb = feature_emb + position_embedding
        feature_emb = self.LayerNorm(feature_emb)
        feature_trm_input = self.dropout(feature_emb)

        extended_attention_mask = self.get_attention_mask(item_seq)

        item_output = self.item_trm_encoder(
            item_trm_input, extended_attention_mask
        )

        feature_output = self.feature_trm_encoder(
            feature_trm_input, extended_attention_mask
        )

        output_concat = torch.cat((item_output, feature_output), -1)
        output = self.concat_layer(output_concat)
        output = self.LayerNorm(output)
        seq_output = self.dropout(output)
        return seq_output

    def predict(self, item_seq):
        seq_output = self.forward(item_seq, len(item_seq[0]))
        scores = torch.mul(seq_output, self.test_item_emb).sum(dim=1)
        return scores

    def full_sort_predict(self, item_seq):
        seq_output = self.forward(item_seq, len(item_seq[0]))
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(
            seq_output, test_items_emb.transpose(0, 1)
        )
        seq_output = gather_indexes(seq_output, torch.tensor([len(item_seq[0]) - 1]).to(self.device))
        last_scores = torch.matmul(
            seq_output, test_items_emb.transpose(0, 1)
        )
        return scores, last_scores

In [80]:
def _dcg_score(y_true, order, k=10):
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)

def _ndcg_score(y_true, order, k=10):
    actual = _dcg_score(y_true, order, k)
    return actual / 1.


class MetricScores(object):
    def __init__(self) -> None:
        self.ndcg5s = []
        self.ndcg10s = []
        self.ndcg20s = []
        self.recall1_true = 0
        self.recall5_true = 0
        self.recall5_total = 0
        self.recall10_true = 0
        self.recall10_total = 0
        self.recall20_true = 0
        self.recall20_total = 0
        self.k = [1, 5, 10, 20]

    def __call__(self, label_ids: torch.Tensor, pred_ids: torch.Tensor):
        assert len(label_ids) == len(pred_ids)
    
        for k in self.k:
            pred = pred_ids[:, 0:k].detach().cpu().numpy()
            label = label_ids.detach().cpu().numpy().reshape(-1, 1)

            binary_labels = np.where(pred==label, 1, 0)

            if k == 10:
                self.recall10_true += np.sum(binary_labels)
                self.recall10_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(10)]
                    ndcg5 = _ndcg_score(y_true, order, 5)
                    ndcg10 = _ndcg_score(y_true, order, 10)
                    
                    self.ndcg5s.append(ndcg5)
                    self.ndcg10s.append(ndcg10)
            elif k == 5:
                self.recall5_true += np.sum(binary_labels)
                self.recall5_total += binary_labels.shape[0]
            elif k == 1:
                self.recall1_true += np.sum(binary_labels)
            elif k == 20:
                self.recall20_true += np.sum(binary_labels)
                self.recall20_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(20)]
                    ndcg20 = _ndcg_score(y_true, order, 20)
                    self.ndcg20s.append(ndcg20)
        return


    def output(self):
        ndcg5, ndcg10, ndcg20 = np.mean(self.ndcg5s), np.mean(self.ndcg10s), np.mean(self.ndcg20s)
        recall1 = self.recall1_true / self.recall5_total
        recall5 = self.recall5_true / self.recall5_total
        recall10 = self.recall10_true / self.recall10_total
        recall20 = self.recall20_true / self.recall20_total
        print(
            "Recall@1: {:.4f}\nRecall@5: {:.4f}\nRecall@10: {:.4f}\nRecall@20: {:.4f}\nnDCG@5: {:.4f}\nnDCG@10: {:.4f}\nnDCG@20: {:.4f}\n".format(
                recall1, recall5, recall10, recall20, ndcg5, ndcg10, ndcg20
            )
        )
        res = {}
        res["scores"] = {
            "Recall@1": recall1,
            "Recall@5": recall5,
            "Recall@10": recall10,
            "Recall@20": recall20,
            "nDCG@5": ndcg5,
            "nDCG@10": ndcg10,
            "nDCG@20": ndcg20,
        }
        return res

In [81]:
def evaluate(model, test_dataloader, max_item, device):
    res = MetricScores()
    model.eval()
    with torch.no_grad():
        for idx, (seq, tgt) in enumerate(tqdm(test_dataloader)):
            seq = seq.to(device)
            tgt = tgt.to(device)
            
            _, scores = model.full_sort_predict(seq)
            pred = torch.argsort(scores, dim=-1, descending=True)
            tgt = tgt[:, -1]

            res(tgt, pred)

    return res.output()

In [None]:
def train(model, train_dataloader, test_dataloader, opt, loss_func, max_item):
    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        train_iter = tqdm(train_dataloader, ncols=100)

        for idx, (seq, tgt) in enumerate(train_iter):
            seq = seq.to(args.device)
            tgt = tgt.to(args.device)
            
            scores, _ = model.full_sort_predict(seq)
            loss = loss_func(scores.view(-1, max_item), tgt.view(-1))

            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.cpu().item()

            train_iter.set_postfix({"loss": total_loss / (idx + 1)})
        
        evaluate(model, test_dataloader, max_item, args.device)

In [None]:
fdsa_args = {
    "num_item": max_item+1,
    "n_layers": 2,
    "n_heads": 4,
    "hidden_size": 512,
    "inner_size": 2048,
    "dropout": 0.1,
    "pooling_mode": "mean",
    "device": args.device,
    "selected_features": ["genres", "title", "desc"]
}

model = FDSA(argparse.Namespace(**fdsa_args), content_tokens)

model = model.to(args.device)
CE = torch.nn.CrossEntropyLoss(ignore_index=0)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
train(model, train_dataloader, test_dataloader, opt, CE, max_item+1)

print(args)
print(fdsa_args)
res = evaluate(model, test_dataloader, max_item, args.device)