In [None]:
import os
import sys
import json
import torch
import random
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True

seed_everything(2022)


In [None]:
args = {
    # "data_dir": "../data/ml-1m",
    # "data_dir": "../data/Goodreads/poetry",
    "data_dir": "../data/MINDsmall",
    "epochs": 40,
    "device": "cuda:2",
    # "lr": 5e-3,
}

print(args)
args = argparse.Namespace(**args)


In [None]:
def load_ndjson(input_file):
    with open(input_file, "r") as f:
        lines = f.read()
        d = [json.loads(l) for l in lines.splitlines()]
    return d


def load_seq_txt(input_file):
    output = []
    with open(input_file, "r") as f:
        for line in f.readlines():
            line = line.strip("\n")
            line = line.split(" ")
            line = [int(i) for i in line]
            output.append(line)
    return output


def load_dataset(data_dir, mode=""):
    mode_list = ["", "train", "dev", "test"]
    if mode not in mode_list:
        raise ValueError("Incorrect mode. Must be `train`|`dev`|`test`.")

    if mode != "":
        data_dir = os.path.join(data_dir, mode)
    behavior = load_seq_txt(os.path.join(data_dir, "sequential_data.txt"))
    content = load_ndjson(os.path.join(data_dir, "content.json"))
    return behavior, content


behavior, content = load_dataset(args.data_dir, mode="")


In [None]:
class KnowledgeDataset(Dataset):
    r"""
    A Dataset Class for training the knowledge prompt
    """

    def __init__(
        self, content, tokenizer, encoder_max_length=512, data_type="", task="",
    ) -> None:
        super().__init__()
        assert data_type in ["ml", "mind", "vg", "mt", "et", "children", "comics_graphic", "poetry"]
        self.tokenizer = tokenizer
        self.max_length = encoder_max_length
        self.data_type = data_type
        if self.data_type == "ml":
            self.metadata_name = "desc"
        elif self.data_type == "mind":
            self.metadata_name = "abstract"
        elif self.data_type == ["vg", "mt", "et", "children", "comics_graphic", "poetry"]:
            self.metadata_name = "description"
        self.task = task
        if task == "random" or task == "simple":
            self.content = content
        else:
            self.content = self.clean_content(content)

    def clean_content(self, content: list[dict]) -> list:
        r"""
        For some content without long element data, delete them from dataset.
        """
        content_new = []
        for i in content:
            if i is not None and i.get(self.metadata_name, False):
                intro = i.get(self.metadata_name, float("nan"))
                if type(intro) == float and np.isnan(intro):
                    pass
                else:
                    content_new.append(i)
        return content_new

    def __len__(self):
        return len(self.content)

    def __getitem__(self, index):
        input_ids, loss_begin, loss_end = self.template(self.content[index])
        attn_masks = [1] * len(input_ids)
        input_ids, attn_masks = self.padding(input_ids, attn_masks)
        loss_ids = [0] * len(attn_masks)
        for i in range(loss_begin, loss_end):
            loss_ids[i] = 1
        return {
            "input_ids": torch.LongTensor(input_ids),
            "attn_masks": torch.FloatTensor(attn_masks),
            "loss_ids": torch.LongTensor(loss_ids),
        }

    def template(self, content: dict) -> list[int]:
        if self.data_type == "ml":
            main_element = "title"
            long_element = "desc"
            other_element = ["genres"]
        elif self.data_type == "mind":
            main_element = "title"
            long_element = "abstract"
            other_element = ["category", "subcategory"]
        elif self.data_type in ["vg", "mt", "et"]:
            main_element = "title"
            long_element = "description"
            other_element = ["category", "brand", "price", "feature"]
        elif self.data_type in ["children", "comics_graphic", "poetry"]:
            main_element = "title"
            long_element = "description"
            other_element = ["genre", "authors", "average_rating", "publication_year"]
        else:
            raise Exception("Illegal data type.")
        
        if self.task == "simple":
            text1 = content[main_element]
            input_ids = self.tokenizer(text1)["input_ids"]
            if len(input_ids) > self.max_length:
                input_ids = input_ids[0 : self.max_length]
            loss_begin, loss_end = 0, len(input_ids) - 1
            return input_ids, loss_begin, loss_end

    def padding(self, input_ids: list, attn_masks: list):
        r"""
        Padding the inputs for GPT model.

        For training, we pad the right side,
        """
        assert len(input_ids) <= self.max_length
        input_ids = input_ids + [self.tokenizer.pad_token_id] * (
            self.max_length - len(input_ids)
        )
        attn_masks = attn_masks + [0] * (self.max_length - len(attn_masks))
        return input_ids, attn_masks
    

def collate_fn_k(batch):
    input_ids = [x["input_ids"] for x in batch]
    attention_mask = [x["attn_masks"] for x in batch]
    loss_ids = [x["loss_ids"] for x in batch]
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_mask),
        "loss_ids": torch.stack(loss_ids)
    }


In [None]:
class GRU_Dataset(Dataset):
    def __init__(self, purchase_history, mode="") -> None:
        self.max_len = 10
        self.purchase_history = purchase_history
        self.mode = mode
        assert mode in ["train", "test"]
        super().__init__()

    def __len__(self):
        return len(self.purchase_history)

    def __getitem__(self, index):
        purchase_history = self.purchase_history[index]
        if self.mode == "train":
            purchase_history = purchase_history[0:-1]
        
        seq_list = purchase_history[0:-1]
        tgt_list = purchase_history[1:]
        seq = self.truncate_and_pad(seq_list)
        tgt = self.truncate_and_pad(tgt_list)
        return torch.LongTensor(seq), torch.LongTensor(tgt)
    
    def truncate_and_pad(self, input_list):
        length = len(input_list)
        if length > self.max_len:
            return input_list[length - self.max_len : length]
        elif length < self.max_len:
            return [0] * (self.max_len - length) + input_list
        else:
            return input_list


max_item = 0
item_info_dict = {}
def process_dataset(behavior, content):
    purchase_history, impression_items, impression_labels = [], [], []
    global max_item
    for i, user_info in enumerate(behavior):
        if not user_info or len(user_info) < 5:
            continue
        user_info = user_info[1:]
        max_item = max(max_item, max(user_info))
        purchase_history.append(user_info)

        item_info_dict.update(dict.fromkeys(user_info))
    return purchase_history


train_dataset = GRU_Dataset(process_dataset(behavior, content), mode="train")
test_dataset = GRU_Dataset(process_dataset(behavior, content), mode="test")

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True, num_workers=0)

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

k_dataset = KnowledgeDataset(content, tokenizer, data_type="ml", task="simple", encoder_max_length=24)

input_ids = torch.zeros((max_item+1, 24), dtype=torch.long)
attn_masks = torch.zeros((max_item+1, 24), dtype=torch.long)

for item_id in tqdm(item_info_dict.keys()):
    item_info = k_dataset[item_id]
    input_ids[item_id] = item_info["input_ids"]
    attn_masks[item_id] = item_info["attn_masks"]

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_, xavier_uniform_

class GRU4rec(nn.Module):
    def __init__(self, args):
        super(GRU4rec, self).__init__()
        self.num_item = args.num_item + 1
        self.embedding_size = args.d_model
        self.layers = args.gru_layers
        self.dropout = args.dropout
        self.embedding_sharing = args.embedding_sharing
        self.sample = args.loss_sample

        self.item_embedding = nn.Embedding(self.num_item, self.embedding_size, padding_idx=0)
        self.emb_dropout = nn.Dropout(self.dropout)
        self.gru = nn.GRU(
            input_size=self.embedding_size,
            hidden_size=2 * self.embedding_size,
            num_layers=self.layers,
            bias=False,
            batch_first=True
        )
        self.gru2 = nn.GRU(
            input_size=self.embedding_size,
            hidden_size=2 * self.embedding_size,
            num_layers=self.layers,
            bias=False,
            batch_first=True
        )
        self.projection = nn.Linear(2 * self.embedding_size, self.embedding_size)
        self.output = nn.Linear(self.embedding_size, self.num_item)
        self.emb = nn.Linear(768,self.embedding_size)

        self.apply(self._init_weights)
        self.model_k = BertModel.from_pretrained("bert-base-uncased")

    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            xavier_normal_(module.weight)
        elif isinstance(module, nn.GRU):
            xavier_uniform_(module.weight_hh_l0)
            xavier_uniform_(module.weight_ih_l0)

    def forward(self, item_seq, label=None):
        item_seq_emb = self.emb_dropout(self.item_embedding(item_seq))
        # print(item_seq)
        input_ids_ = input_ids[item_seq].to(args.device)
        attn_masks_ = attn_masks[item_seq].to(args.device)

        item_embedding = []
        for i, j in zip(input_ids_, attn_masks_):
            outputs = self.model_k(
                input_ids=i,
                attention_mask=j,
            )
            item_embedding.append(outputs.pooler_output)
        item_embedding = torch.stack(item_embedding)

        # item_seq_emb = self.emb_dropout(item_embedding)
        item_content_emb = self.emb(item_embedding)

        gru_out1, _ = self.gru(item_content_emb)
        gru_out2, _ = self.gru2(item_seq_emb)

        gru_out = self.projection(gru_out1 + gru_out2)
        if not self.embedding_sharing:
            # seq_output = self.output(gru_out[:, -1, :])  # [batch_size, seq_len, num_item]
            seq_output = self.output(gru_out)
        else:
            seq_output = F.linear(gru_out, self.item_embedding.weight)
        # return gru_out[:, -1, :], seq_output
        return seq_output[:, -1, :], seq_output


gru_args = {
    "num_item": max_item,
    "d_model": 64,
    "gru_layers": 2,
    "dropout": 0.1,
    "embedding_sharing": False,
    "loss_sample": None,
    "output_seq": False,
}
model = GRU4rec(argparse.Namespace(**gru_args))


In [None]:
def _dcg_score(y_true, order, k=10):
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)

def _ndcg_score(y_true, order, k=10):
    actual = _dcg_score(y_true, order, k)
    return actual / 1.


class MetricScores(object):
    def __init__(self) -> None:
        self.ndcg5s = []
        self.ndcg10s = []
        self.ndcg20s = []
        self.recall1_true = 0
        self.recall5_true = 0
        self.recall5_total = 0
        self.recall10_true = 0
        self.recall10_total = 0
        self.recall20_true = 0
        self.recall20_total = 0
        self.k = [1, 5, 10, 20]

    def __call__(self, label_ids: torch.Tensor, pred_ids: torch.Tensor):
        assert len(label_ids) == len(pred_ids)
    
        # max_k = max(self.k)
        for k in self.k:
            pred = pred_ids[:, 0:k].detach().cpu().numpy()
            label = label_ids.detach().cpu().numpy().reshape(-1, 1)

            binary_labels = np.where(pred==label, 1, 0)

            if k == 10:
                self.recall10_true += np.sum(binary_labels)
                self.recall10_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(10)]
                    ndcg5 = _ndcg_score(y_true, order, 5)
                    ndcg10 = _ndcg_score(y_true, order, 10)
                    
                    self.ndcg5s.append(ndcg5)
                    self.ndcg10s.append(ndcg10)
            elif k == 5:
                self.recall5_true += np.sum(binary_labels)
                self.recall5_total += binary_labels.shape[0]
            elif k == 1:
                self.recall1_true += np.sum(binary_labels)
            elif k == 20:
                self.recall20_true += np.sum(binary_labels)
                self.recall20_total += binary_labels.shape[0]
                for y_true in binary_labels:
                    order = [i for i in range(20)]
                    ndcg20 = _ndcg_score(y_true, order, 20)
                    self.ndcg20s.append(ndcg20)
        return


    def output(self):
        ndcg5, ndcg10, ndcg20 = np.mean(self.ndcg5s), np.mean(self.ndcg10s), np.mean(self.ndcg20s)
        recall1 = self.recall1_true / self.recall5_total
        recall5 = self.recall5_true / self.recall5_total
        recall10 = self.recall10_true / self.recall10_total
        recall20 = self.recall20_true / self.recall20_total
        print(
            "Recall@1: {:.4f}\nRecall@5: {:.4f}\nRecall@10: {:.4f}\nRecall@20: {:.4f}\nnDCG@5: {:.4f}\nnDCG@10: {:.4f}\nnDCG@20: {:.4f}\n".format(
                recall1, recall5, recall10, recall20, ndcg5, ndcg10, ndcg20
            )
        )
        res = {}
        res["scores"] = {
            "Recall@1": recall1,
            "Recall@5": recall5,
            "Recall@10": recall10,
            "Recall@20": recall20,
            "nDCG@5": ndcg5,
            "nDCG@10": ndcg10,
            "nDCG@20": ndcg20,
        }
        return res


In [None]:
def evaluate(model, test_dataloader, max_item, device):
    avg_ndcg, avg_hit, cnt = 0.0, 0.0, 0
    res = MetricScores()
    res2 = MetricScores()
    model.eval()
    with torch.no_grad():
        for idx, (seq, tgt) in enumerate(tqdm(test_dataloader, ncols=120)):
            y_trues = []
            seq = seq.to(device)
            tgt = tgt.to(device)
            
            out, _ = model(seq)
            pred = torch.argsort(out, dim=-1, descending=True)
            tgt = tgt[:, -1]

            res(tgt, pred)
                

    res.output()
    return res


In [None]:
def train(model, train_dataloader, test_dataloader, opt, loss_func, max_item):
    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        train_iter = tqdm(train_dataloader, ncols=100)

        for idx, (seq, tgt) in enumerate(train_iter):
            seq = seq.to(args.device)
            tgt = tgt.to(args.device)
            
            _, out = model(seq)

            # print(_.shape)
            loss = loss_func(out.view(-1, max_item), tgt.view(-1))
            (loss/4).backward()

            if idx % 4 == 3:
                opt.step()
                opt.zero_grad()
                total_loss += 4*(loss.cpu().item())

            train_iter.set_postfix({"loss": total_loss / (idx + 1)})
        
        res = evaluate(model, test_dataloader, max_item, args.device)


model = model.to(args.device)
CE = torch.nn.CrossEntropyLoss(ignore_index=0)
plm = ["model_k"]
for n, p in model.named_parameters():
    if not any(nd in n for nd in plm):
        print(n)

optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in plm)
        ],
        "lr": 1e-3,
    },
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if any(nd in n for nd in plm)
        ],
        "lr": 1e-4,
    },
]
opt = torch.optim.Adam(optimizer_grouped_parameters)
train(model, train_dataloader, test_dataloader, opt, CE, max_item+1)
