In [102]:
import os 
import ujson

root_dir="/home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/sequential_train_test/hlen_4_bm25/"

train_sim_path = os.path.join(root_dir, "sim_rec_sequential.train.json")
train_compl_path = os.path.join(root_dir, "compl_rec_sequential.train.json")
test_sim_path = os.path.join(root_dir, "sim_rec_sequential.test.json")
test_compl_path = os.path.join(root_dir, "compl_rec_sequential.test.json")


In [67]:
ls "/home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/sequential_train_test/hlen_4_bm25/"

compl_rec_sequential.test.json   search_sequential.train.json
compl_rec_sequential.train.json  sim_rec_sequential.test.json
[0m[01;34mrec_search[0m/                      sim_rec_sequential.train.json
search_sequential.test.json


In [110]:
import torch
import torch.nn as nn
import argparse
import numpy as np
import faiss
from tqdm import tqdm
import sys
sys.path.append("/home/jupyter/unity_jointly_rec_and_search/kgc-dr/")

NUM_USER=893619
NUM_ITEM=2260878

class AverageMeter(object):
    """Computes and stores the average and current value.
    Examples::
        >>> # Initialize a meter to record loss
        >>> losses = AverageMeter()
        >>> # Update meter after every minibatch update
        >>> losses.update(loss_value, batch_size)
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class SASProductDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.pids = []
        with open(path) as fin:
            for line in fin:
                pid, text = line.strip().split("\t")
                self.pids.append(int(pid))
    def __len__(self):
        return len(self.pids)
    
    def __getitem__(self, idx): 
        return self.pids[idx]
    
    def collate_fn(batch):
        return torch.LongTensor(batch)

class SASDataset(torch.utils.data.Dataset):
    def __init__(self, path, is_train, maxlen, padding_idx=NUM_ITEM):
        self.data = []
        self.is_train = is_train
        self.maxlen = maxlen
        self.padding_idx = padding_idx
        with open(path) as fin:
            for line in fin:
                self.data.append(ujson.loads(line))
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        seq_len = len(example["context_value_ids"])
        assert seq_len <= self.maxlen
        if self.is_train:
            assert len(example["context_value_ids"]) == len(example["target_value_ids"]) == len(example["neg_value_ids"])
            hist_pids = [self.padding_idx] * (self.maxlen - seq_len) + example["context_value_ids"]
            target_pids = [self.padding_idx] * (self.maxlen - seq_len) + example["target_value_ids"]
            neg_pids = [self.padding_idx] * (self.maxlen - seq_len) + example["neg_value_ids"]
            assert len(hist_pids) == len(target_pids )== len(neg_pids) == self.maxlen
        else:
            hist_pids = [self.padding_idx] * (self.maxlen - seq_len) + example["context_value_ids"]
            assert len(hist_pids) == self.maxlen, (len(hist_pids), self.maxlen)
        if self.is_train:
            return (example["uid"], hist_pids, target_pids, neg_pids)
        else:
            return (example["uid"], hist_pids)
    
    def collate_fn(self, batch):
        if self.is_train:
            uids, hist_pids, target_pids, neg_pids = zip(*batch)
            uids = torch.LongTensor(uids)
            hist_pids = torch.LongTensor(hist_pids)
            target_pids = torch.LongTensor(target_pids)
            neg_pids = torch.LongTensor(neg_pids)
            targets = torch.cat((torch.ones_like(target_pids, dtype=torch.float32),
                                torch.zeros_like(neg_pids, dtype=torch.float32)), dim=-1)
            return [
                torch.LongTensor(uids),
                torch.LongTensor(hist_pids),
                torch.LongTensor(target_pids),
                torch.LongTensor(neg_pids),
                targets]
        else:
            uids, hist_pids = zip(*batch)
            return [torch.LongTensor(uids),
                    torch.LongTensor(hist_pids)]
    

class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs
    
class SASRec(torch.nn.Module):
    def __init__(self, args):
        super(SASRec, self).__init__()

        self.user_num = args.user_num
        self.item_num = args.item_num
        self.dev = args.device

        # TODO: loss += args.l2_emb for regularizing embedding vectors during training
        # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch
        self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units)
        self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) # TO IMPROVE
        self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention
        self.attention_layers = torch.nn.ModuleList()
        self.forward_layernorms = torch.nn.ModuleList()
        self.forward_layers = torch.nn.ModuleList()

        self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)

        for _ in range(args.num_blocks):
            new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer =  torch.nn.MultiheadAttention(args.hidden_units,
                                                            args.num_heads,
                                                            args.dropout_rate)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
            self.forward_layers.append(new_fwd_layer)

            # self.pos_sigmoid = torch.nn.Sigmoid()
            # self.neg_sigmoid = torch.nn.Sigmoid()

    def log2feats(self, log_seqs):
        seqs = self.item_emb(log_seqs)
        seqs *= self.item_emb.embedding_dim ** 0.5
        positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1])
        seqs += self.pos_emb(torch.LongTensor(positions).to(self.dev))
        seqs = self.emb_dropout(seqs)

        timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev)
        seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim

        tl = seqs.shape[1] # time dim len for enforce causality
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

        for i in range(len(self.attention_layers)):
            seqs = torch.transpose(seqs, 0, 1)
            Q = self.attention_layernorms[i](seqs)
            mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 
                                            attn_mask=attention_mask)
                                            # key_padding_mask=timeline_mask
                                            # need_weights=False) this arg do not work?
            seqs = Q + mha_outputs
            seqs = torch.transpose(seqs, 0, 1)

            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *=  ~timeline_mask.unsqueeze(-1)

        log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)

        return log_feats

    def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs): # for training        
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        pos_embs = self.item_emb(pos_seqs)
        neg_embs = self.item_emb(neg_seqs)

        pos_logits = (log_feats * pos_embs).sum(dim=-1)
        neg_logits = (log_feats * neg_embs).sum(dim=-1)

        # pos_pred = self.pos_sigmoid(pos_logits)
        # neg_pred = self.neg_sigmoid(neg_logits)
        
        assert pos_logits.dim() == neg_logits.dim() == 2

        return pos_logits, neg_logits # pos_pred, neg_pred

    def predict(self, user_ids, log_seqs, item_indices): # for inference
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste

        item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) # (U, I, C)

        logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1)

        # preds = self.pos_sigmoid(logits) # rank same item list for different users

        return logits # preds # (U, I)
    
    def query_embs(self, log_seqs):
        return self.log2feats(log_seqs)[:, -1, :]
    
    def passage_embs(self, item_indices):
        return self.item_emb(item_indices)
    
def train(model, train_dataloader, num_epochs=200):
    bce_criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss()
    adam_optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
    
    loss_avg_meter = AverageMeter()
    global_step = 1
    print("batch: {}, step: {}, loss: {}") 
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            uids, hist_pids, pos_pids, neg_pids, targets = batch
            pos_logits, neg_logits = model(uids, hist_pids, pos_pids, neg_pids)
            logits = torch.cat((pos_logits, neg_logits), dim=-1)
            loss = bce_criterion(logits, targets)
            
            
            loss.backward()
            adam_optimizer.step()
            adam_optimizer.zero_grad()
            
            loss_avg_meter.update(loss.item())
        
            if (global_step+1) % 4_000 == 0:
                print(f"{epoch} {global_step} {loss_avg_meter.avg:.3f}")
                loss_avg_meter.reset()
            
            global_step += 1

In [None]:
class Args:
    maxlen=5
    hidden_units=50
    num_blocks=2
    num_heads=1
    dropout_rate=0.5
    device="cpu"
    user_num=NUM_USER
    item_num=NUM_ITEM
    
args = Args
print(args.maxlen)

train_dataset = SASDataset(train_compl_path, True, args.maxlen)
train_dataloader = torch.utils.data.DataLoader(train_dataset,num_workers=4, shuffle=True, batch_size=128, 
                                              collate_fn=train_dataset.collate_fn)
model = SASRec(args)
train(model, train_dataloader)

5
batch: {}, step: {}, loss: {}


100%|██████████| 99/99 [00:28<00:00,  3.50it/s]
100%|██████████| 99/99 [00:40<00:00,  2.46it/s]
100%|██████████| 99/99 [00:40<00:00,  2.42it/s]
100%|██████████| 99/99 [00:40<00:00,  2.45it/s]
100%|██████████| 99/99 [00:41<00:00,  2.41it/s]
100%|██████████| 99/99 [00:37<00:00,  2.67it/s]
 62%|██████▏   | 61/99 [00:22<00:14,  2.70it/s]

In [116]:
from timeit import default_timer as timer

def get_index(model, test_dataloader, hidden_units):
    index = faiss.IndexFlatIP(hidden_units)
    index = faiss.IndexIDMap(index)
    
    model.eval()
    with torch.inference_mode():
        for batch_pids in tqdm(test_dataloader, total=len(test_dataloader)):
            pid_embs = model.passage_embs(batch_pids)
            
            index.add_with_ids(pid_embs.numpy(), batch_pids.numpy())
    return index 

def index_retrieve(index, query_embeddings, topk, batch=None):
    print("Query Num", len(query_embeddings))
    start = timer()
    if batch is None:
        nn_scores, nearest_neighbors = index.search(query_embeddings, topk)
    else:
        query_offset_base = 0
        pbar = tqdm(total=len(query_embeddings))
        nearest_neighbors = []
        nn_scores = []
        while query_offset_base < len(query_embeddings):
            batch_query_embeddings = query_embeddings[query_offset_base:query_offset_base+ batch]
            batch_nn_scores, batch_nn = index.search(batch_query_embeddings, topk)
            nearest_neighbors.extend(batch_nn.tolist())
            nn_scores.extend(batch_nn_scores.tolist())
            query_offset_base += len(batch_query_embeddings)
            pbar.update(len(batch_query_embeddings))
        pbar.close()

    elapsed_time = timer() - start
    elapsed_time_per_query = 1000 * elapsed_time / len(query_embeddings)
    print(f"Elapsed Time: {elapsed_time:.1f}s, Elapsed Time per query: {elapsed_time_per_query:.1f}ms")
    return nn_scores, nearest_neighbors

def get_query_side_embeddings(model, query_dataloader):
    embeddings = []
    embeddings_ids = []
    model.eval()
    for _, batch in tqdm(enumerate(query_dataloader), disable=False, 
                                desc=f"encode # {len(query_dataloader)} seqs"):
        with torch.no_grad():
            uids, hist_pids = batch
            reps = model.query_embs(hist_pids)
        
        embeddings.append(reps.numpy())
        embeddings_ids.extend(uids.numpy())
    
    embeddings = np.concatenate(embeddings) 
    assert len(embeddings) == len(embeddings_ids), (len(embeddings), len(embeddings_ids))
    
    return embeddings, embeddings_ids



data_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/unified_user"
product_path = os.path.join(data_dir, "collection_title_catalog.tsv")

product_dataset = SASProductDataset(product_path)
product_dataloader = torch.utils.data.DataLoader(product_dataset, num_workers=4, shuffle=False, batch_size=128)

#test_seq_dataset = SASDataset(test_sim_path, False, args.maxlen)
test_seq_dataset = SASDataset(test_compl_path, False, args.maxlen)
test_seq_dataloader = torch.utils.data.DataLoader(test_seq_dataset, num_workers=4, shuffle=False, batch_size=128,
                                                 collate_fn=test_seq_dataset.collate_fn)

index = get_index(model, product_dataloader, args.hidden_units)
seq_embeddings, seq_ids = get_query_side_embeddings(model, test_seq_dataloader)
nn_scores, nn_doc_ids = index_retrieve(index, seq_embeddings, 200, batch=128)

qid_to_ranks = {}
for qid, docids, scores in zip(seq_ids, nn_doc_ids, nn_scores):
    for docid, s in zip(docids, scores):
        if qid not in qid_to_ranks:
            qid_to_ranks[qid] = [(docid, s)]
        else:
            qid_to_ranks[qid] += [(docid, s)]
print(f"# unique query = {len(qid_to_ranks)}")

100%|██████████| 17664/17664 [00:13<00:00, 1345.01it/s]
encode # 79 seqs: 79it [00:00, 124.17it/s]


Query Num 10000


100%|██████████| 10000/10000 [00:57<00:00, 175.02it/s]


Elapsed Time: 57.1s, Elapsed Time per query: 5.7ms
# unique query = 10000


In [117]:
from evaluation import retrieval_evaluator

with open("./tmp_rank.run", "w") as f:
        for qid in qid_to_ranks:
            ranks = qid_to_ranks[qid]
            for i, (docid, s) in enumerate(ranks):
                f.write(f"{qid}\t{docid}\t{i+1}\t{s}\n")
                
#qrels_path = os.path.join(data_dir, "sequential_train_test/urels.sim.test.tsv")
qrels_path = os.path.join(data_dir, "sequential_train_test/urels.compl.test.tsv")

evaluator = retrieval_evaluator.RankingEvaluator(qrels_path)
evaluator.compute_metrics("./tmp_rank.run")

{'MRR@10': 0.24597293650793653,
 'QueriesWithRelevant@10': 3367,
 'MRR@1000': 0.2500012874929283,
 'QueriesWithRelevant@1000': 4677,
 'Recall@50': 0.27768157869907867,
 'Recall@1000': 0.3223696739371739,
 'nDCG@10': 0.20138034375289013,
 'nDCG@100': 0.2166772946713396,
 'MAP@1000': 0.1722181902673803,
 'QueriesRanked': 10000}

KeyboardInterrupt: 