In [1]:
from pathlib import Path
import time
import datetime
from tqdm import tqdm
from collections import defaultdict, Counter
import copy
import random
import re
import numpy as np
import os
from sklearn.metrics import roc_auc_score
import pickle

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.optim as optim
import torch.nn.functional as F

In [4]:
device = torch.device("cuda:0")

In [5]:
torch.cuda.set_device(device)

In [6]:
dataset = 'demo/'

data_path = Path("../blob/data/" + str(dataset) + "utils/")
model_path = Path("../blob/model/" + str(dataset))

In [7]:
npratio = 4
max_his_len = 50
min_word_cnt = 3
max_title_len = 30

In [8]:
batch_size = 32
epoch = 10
lr=0.0001
name = 'nrms_' + dataset[:-1]
retrain = False

# collect impressions

In [9]:
with open(data_path/'train_sam_uid.pkl', 'rb') as f:
    train_sam = pickle.load(f)
    
with open(data_path/'valid_sam_uid.pkl', 'rb') as f:
    valid_sam = pickle.load(f)
    
with open(data_path/'test_sam_uid.pkl', 'rb') as f:
    test_sam = pickle.load(f)

with open(data_path/'user_indices.pkl', 'rb') as f:
    user_indices = pickle.load(f)

# News Preprocesss

In [10]:
with open(data_path/'nid2index.pkl', 'rb') as f:
    nid2index = pickle.load(f)
    
with open(data_path/'vocab_dict.pkl', 'rb') as f:
    vocab_dict = pickle.load(f)

embedding_matrix = np.load(data_path/'embedding.npy')
news_index = np.load(data_path /'news_index.npy')

In [11]:
if os.path.exists(data_path/'test_nid2index.pkl'):
    with open(data_path/'test_nid2index.pkl', 'rb') as f:
        test_nid2index = pickle.load(f)

    test_news_index = np.load(data_path /'test_news_index.npy')
else: # TODO: for now use valid to do test (cb)
    test_nid2index = nid2index
    test_news_index = news_index
    test_sam = valid_sam

# Dataset & DataLoader

In [12]:
def newsample(nnn, ratio):
    if ratio > len(nnn):
        return nnn + ["<unk>"] * (ratio - len(nnn))
    else:
        return random.sample(nnn, ratio)

In [13]:
class TrainDataset(Dataset):
    def __init__(self, samples, nid2index, news_index):
        self.news_index = news_index
        self.nid2index = nid2index
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        # pos, neg, his, neg_his
        pos, neg, his, _ = self.samples[idx]
        neg = newsample(neg, npratio)
        candidate_news = [pos] + neg
        candidate_news = self.news_index[[self.nid2index[n] for n in candidate_news]]
        his = [self.nid2index[n] for n in his] + [0] * (max_his_len - len(his))
        his = self.news_index[his]
        
        label = np.array(0)
        return candidate_news, his, label

In [14]:
class NewsDataset(Dataset):
    def __init__(self, news_index):
        self.news_index = news_index
        
    def __len__(self):
        return len(self.news_index)
    
    def __getitem__(self, idx):
        return self.news_index[idx]

In [15]:
news_dataset = NewsDataset(news_index)

In [16]:
class UserDataset(Dataset):
    def __init__(self, 
                 samples,
                 news_vecs,
                 nid2index):
        self.samples = samples
        self.news_vecs = news_vecs
        self.nid2index = nid2index
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        poss, negs, his, _ = self.samples[idx]
        his = [self.nid2index[n] for n in his] + [0] * (max_his_len - len(his))
        his = self.news_vecs[his]
        return his

# Model

In [17]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, attn_mask=None):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        scores = torch.exp(scores)
        if attn_mask is not None:
            scores = scores * attn_mask
        attn = scores / (torch.sum(scores, dim=-1, keepdim=True)  + 1e-8)
        
        context = torch.matmul(attn, V)
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model # 300
        self.n_heads = n_heads # 20
        self.d_k = d_k # 20
        self.d_v = d_v # 20
        
        self.W_Q = nn.Linear(d_model, d_k * n_heads) # 300, 400
        self.W_K = nn.Linear(d_model, d_k * n_heads) # 300, 400
        self.W_V = nn.Linear(d_model, d_v * n_heads) # 300, 400
        
        self._initialize_weights()
                
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=1)
                
    def forward(self, Q, K, V, attn_mask=None):
        residual, batch_size = Q, Q.size(0)
        
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)
        
        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(1).expand(batch_size, max_len, max_len) 
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 
        
        context, attn = ScaledDotProductAttention(self.d_k)(q_s, k_s, v_s, attn_mask) 
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) 
        return context 

In [18]:

class AdditiveAttention(nn.Module):
    ''' AttentionPooling used to weighted aggregate news vectors
    Arg: 
        d_h: the last dimension of input
    '''
    def __init__(self, d_h, hidden_size=200):
        super(AdditiveAttention, self).__init__()
        self.att_fc1 = nn.Linear(d_h, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, candidate_vector_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, candidate_vector_dim
        """
        bz = x.shape[0]
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)

        alpha = torch.exp(alpha)
        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)
        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)

        x = torch.bmm(x.permute(0, 2, 1), alpha)
        x = torch.reshape(x, (bz, -1))  # (bz, 400)
        return x

In [19]:
class TextEncoder(nn.Module):
    def __init__(self, 
                 word_embedding_dim=300, 
                 num_attention_heads=20,
                 query_vector_dim = 200,
                 dropout_rate=0.2,
                 enable_gpu=True):
        super(TextEncoder, self).__init__()
        self.dropout_rate = 0.2
        pretrained_news_word_embedding = torch.from_numpy(embedding_matrix).float()
        
        self.word_embedding = nn.Embedding.from_pretrained(
            pretrained_news_word_embedding, freeze=False)
        
        self.multihead_attention = MultiHeadAttention(word_embedding_dim,
                                              num_attention_heads, 20, 20)
        self.additive_attention = AdditiveAttention(num_attention_heads*20,
                                                    query_vector_dim)
    def forward(self, text):
        # REVIEW: remove training=self.training to enable dropout during testing 
        text_vector = F.dropout(self.word_embedding(text.long()),
                                p=self.dropout_rate,
                                # training=self.training
                                )
        multihead_text_vector = self.multihead_attention(
            text_vector, text_vector, text_vector)
        multihead_text_vector = F.dropout(multihead_text_vector,
                                          p=self.dropout_rate,
                                        #   training=self.training
                                          )
        # batch_size, word_embedding_dim
        text_vector = self.additive_attention(multihead_text_vector)
        return text_vector

In [20]:
class UserEncoder(nn.Module):
    def __init__(self,
                 news_embedding_dim=400,
                 num_attention_heads=20,
                 query_vector_dim=200
                ):
        super(UserEncoder, self).__init__()
        self.multihead_attention = MultiHeadAttention(news_embedding_dim,
                                              num_attention_heads, 20, 20)
        self.additive_attention = AdditiveAttention(num_attention_heads*20,
                                                    query_vector_dim)
        
        self.neg_multihead_attention = MultiHeadAttention(news_embedding_dim,
                                                         num_attention_heads, 20, 20)
        
    def forward(self, clicked_news_vecs):
        multi_clicked_vectors = self.multihead_attention(
            clicked_news_vecs, clicked_news_vecs, clicked_news_vecs
        )
        pos_user_vector = self.additive_attention(multi_clicked_vectors)
        
        user_vector = pos_user_vector
        return user_vector

In [21]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.text_encoder = TextEncoder()
        self.user_encoder = UserEncoder()
        
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, candidate_news, clicked_news, targets, compute_loss=True):
        batch_size, npratio, word_num = candidate_news.shape
        candidate_news = candidate_news.view(-1, word_num)
        candidate_vector = self.text_encoder(candidate_news).view(batch_size, npratio, -1)
        
        batch_size, clicked_news_num, word_num = clicked_news.shape
        clicked_news = clicked_news.view(-1, word_num)
        clicked_news_vecs = self.text_encoder(clicked_news).view(batch_size, clicked_news_num, -1)
        
        user_vector = self.user_encoder(clicked_news_vecs)
        
        score = torch.bmm(candidate_vector, user_vector.unsqueeze(-1)).squeeze(dim=-1)
        
        if compute_loss:
            loss = self.criterion(score, targets)
            return loss, score
        else:
            return score

# Train

In [22]:
def dcg_score(y_true, y_score, k=10):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)


def ndcg_score(y_true, y_score, k=10):
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_score, k)
    return actual / best


def mrr_score(y_true, y_score):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)

In [23]:
def compute_amn(y_true, y_score):
    auc = roc_auc_score(y_true,y_score)
    mrr = mrr_score(y_true,y_score)
    ndcg5 = ndcg_score(y_true,y_score,5)
    ndcg10 = ndcg_score(y_true,y_score,10)
    return auc, mrr, ndcg5, ndcg10

def evaluation_split(news_vecs, user_vecs, samples, nid2index):
    all_rslt = []
    for i in tqdm(range(len(samples))):
        poss, negs, _, _ = samples[i]
        user_vec = user_vecs[i]
        y_true = [1] * len(poss) + [0] * len(negs)
        news_ids = [nid2index[i] for i in poss + negs]
        news_vec = news_vecs[news_ids]
        y_score = np.multiply(news_vec, user_vec)
        y_score = np.sum(y_score, axis=1)
        try:
            all_rslt.append(compute_amn(y_true, y_score))
        except Exception as e:
            print(e)
    return np.array(all_rslt)

In [24]:
train_ds = TrainDataset(train_sam, nid2index, news_index)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)

In [25]:

if retrain:
    for time in range(1):
        model = Model().to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        best_auc = 0
        for ep in range(5):
            loss = 0
            accuary = 0.0
            model.train()
            train_loader = tqdm(train_dl)
            for cnt, batch_sample in enumerate(train_loader):
                candidate_news_index, his_index, label = batch_sample
                sample_num = candidate_news_index.shape[0]
                candidate_news_index = candidate_news_index.to(device)
                his_index = his_index.to(device)
                label = label.to(device)
                bz_loss, y_hat = model(candidate_news_index, his_index, label)

                loss += bz_loss.detach().cpu().numpy()
                optimizer.zero_grad()
                bz_loss.backward()

                optimizer.step()

                if cnt % 10 == 0:
                    train_loader.set_description(f"[{cnt}]steps loss: {loss / (cnt+1):.4f} ")
                    train_loader.refresh() 


            model.eval()
            news_dl = DataLoader(news_dataset, batch_size=1024, shuffle=False, num_workers=0)
            news_vecs = []
            for news in tqdm(news_dl):
                news = news.to(device)
                news_vec = model.text_encoder(news).detach().cpu().numpy()
                news_vecs.append(news_vec)
            news_vecs = np.concatenate(news_vecs)

            user_dataset = UserDataset(valid_sam, news_vecs, nid2index)
            user_vecs = []
            user_dl = DataLoader(user_dataset, batch_size=1024, shuffle=False, num_workers=0)
            for his in tqdm(user_dl):
                his = his.to(device)
                user_vec = model.user_encoder(his).detach().cpu().numpy()
                user_vecs.append(user_vec)
            user_vecs = np.concatenate(user_vecs)

            val_scores = evaluation_split(news_vecs, user_vecs, valid_sam, nid2index)
            val_auc, val_mrr, val_ndcg, val_ndcg10 = [np.mean(i) for i in list(zip(*val_scores))]
            print(f"[{ep}] epoch auc: {val_auc:.4f}, mrr: {val_mrr:.4f}, ndcg5: {val_ndcg:.4f}, ndcg10: {val_ndcg10:.4f}")

            with open(model_path/f'{name}.txt', 'a') as f:
                f.write(f"[{ep}] epoch auc: {val_auc:.4f}, mrr: {val_mrr:.4f}, ndcg5: {val_ndcg:.4f}, ndcg10: {val_ndcg10:.4f}\n")
                    
            if val_auc > best_auc:
                best_auc = val_auc
                torch.save(model.state_dict(), model_path/f'{name}.pkl')
                with open(model_path/f'{name}.txt', 'a') as f:
                    f.write(f"[{ep}] epoch save model\n")
            

In [26]:
model = Model().to(device)
model.load_state_dict(torch.load(model_path/f'{name}.pkl'))
model.eval()
for m in model.modules():
    if m.__class__.__name__.startswith('dropout'):
        print(m)
        m.train()

y_scores = defaultdict(list)
y_trues = {}

for i in range(100):
    print('eva repeat #', str(i))
    test_news_dataset = NewsDataset(test_news_index)
    news_dl = DataLoader(test_news_dataset, batch_size=1024, shuffle=False, num_workers=0)
    news_vecs = []
    for news in tqdm(news_dl):
        news = news.to(device)
        news_vec = model.text_encoder(news).detach().cpu().numpy()
        news_vecs.append(news_vec)
    news_vecs = np.concatenate(news_vecs)

    user_dataset = UserDataset(test_sam, news_vecs, test_nid2index)
    user_vecs = []
    user_dl = DataLoader(user_dataset, batch_size=1024, shuffle=False, num_workers=0)
    for his in tqdm(user_dl):
        his = his.to(device)
        user_vec = model.user_encoder(his).detach().cpu().numpy()
        user_vecs.append(user_vec)
    user_vecs = np.concatenate(user_vecs)

    for i in tqdm(range(len(valid_sam))):
        poss, negs, _, _ = valid_sam[i]
        user_vec = user_vecs[i]
        y_true = [1] * len(poss) + [0] * len(negs)
        news_ids = [nid2index[i] for i in poss + negs]
        news_vec = news_vecs[news_ids]
        y_score = np.multiply(news_vec, user_vec)
        y_score = np.sum(y_score, axis=1)
        
        y_scores[i].append(y_score)
        y_trues[i] = y_true

    # test_auc, test_mrr, test_ndcg, test_ndcg10 = [np.mean(i) for i in list(zip(*test_scores))]
    # print(f"[{i}] time test auc: {test_auc:.4f}, mrr: {test_mrr:.4f}, ndcg5: {test_ndcg:.4f}, ndcg10: {test_ndcg10:.4f}")

# with open(model_path/ f'{name}.txt', 'a') as f:
#         f.write(f"[{time}] time test auc: {test_auc:.4f}, mrr: {test_mrr:.4f}, ndcg5: {test_ndcg:.4f}, ndcg10: {test_ndcg10:.4f}\n")

eva repeat # 0


100%|██████████| 28/28 [00:00<00:00, 46.93it/s]
100%|██████████| 8/8 [00:00<00:00,  8.14it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24110.53it/s]


eva repeat # 1


100%|██████████| 28/28 [00:00<00:00, 41.48it/s]
100%|██████████| 8/8 [00:00<00:00, 11.80it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24285.24it/s]


eva repeat # 2


100%|██████████| 28/28 [00:00<00:00, 52.45it/s]
100%|██████████| 8/8 [00:00<00:00, 11.99it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26145.53it/s]


eva repeat # 3


100%|██████████| 28/28 [00:00<00:00, 53.52it/s]
100%|██████████| 8/8 [00:00<00:00, 10.98it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25109.73it/s]


eva repeat # 4


100%|██████████| 28/28 [00:00<00:00, 56.51it/s]
100%|██████████| 8/8 [00:00<00:00, 12.09it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26128.16it/s]


eva repeat # 5


100%|██████████| 28/28 [00:00<00:00, 55.30it/s]
100%|██████████| 8/8 [00:00<00:00, 11.01it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24155.27it/s]


eva repeat # 6


100%|██████████| 28/28 [00:00<00:00, 62.15it/s]
100%|██████████| 8/8 [00:00<00:00, 12.02it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26436.86it/s]


eva repeat # 7


100%|██████████| 28/28 [00:00<00:00, 55.81it/s]
100%|██████████| 8/8 [00:00<00:00, 11.69it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25872.59it/s]


eva repeat # 8


100%|██████████| 28/28 [00:00<00:00, 38.16it/s]
100%|██████████| 8/8 [00:00<00:00, 10.82it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25051.08it/s]


eva repeat # 9


100%|██████████| 28/28 [00:00<00:00, 60.89it/s]
100%|██████████| 8/8 [00:00<00:00, 10.67it/s]
100%|██████████| 7538/7538 [00:00<00:00, 23311.09it/s]


eva repeat # 10


100%|██████████| 28/28 [00:00<00:00, 68.98it/s]
100%|██████████| 8/8 [00:00<00:00, 11.06it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25419.96it/s]


eva repeat # 11


100%|██████████| 28/28 [00:00<00:00, 64.32it/s]
100%|██████████| 8/8 [00:00<00:00, 10.70it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24581.91it/s]


eva repeat # 12


100%|██████████| 28/28 [00:00<00:00, 71.71it/s]
100%|██████████| 8/8 [00:00<00:00, 10.74it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25112.60it/s]


eva repeat # 13


100%|██████████| 28/28 [00:00<00:00, 62.08it/s]
100%|██████████| 8/8 [00:00<00:00, 11.20it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25202.08it/s]


eva repeat # 14


100%|██████████| 28/28 [00:00<00:00, 67.18it/s]
100%|██████████| 8/8 [00:00<00:00, 10.87it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24655.79it/s]


eva repeat # 15


100%|██████████| 28/28 [00:00<00:00, 55.86it/s]
100%|██████████| 8/8 [00:00<00:00, 11.02it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24698.05it/s]


eva repeat # 16


100%|██████████| 28/28 [00:00<00:00, 75.51it/s]
100%|██████████| 8/8 [00:00<00:00, 11.13it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25768.23it/s]


eva repeat # 17


100%|██████████| 28/28 [00:00<00:00, 73.02it/s]
100%|██████████| 8/8 [00:00<00:00, 10.75it/s]
100%|██████████| 7538/7538 [00:00<00:00, 23481.74it/s]


eva repeat # 18


100%|██████████| 28/28 [00:00<00:00, 81.20it/s]
100%|██████████| 8/8 [00:00<00:00, 10.85it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25918.44it/s]


eva repeat # 19


100%|██████████| 28/28 [00:00<00:00, 73.37it/s]
100%|██████████| 8/8 [00:00<00:00, 10.57it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25458.20it/s]


eva repeat # 20


100%|██████████| 28/28 [00:00<00:00, 70.15it/s]
100%|██████████| 8/8 [00:00<00:00, 10.56it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25581.63it/s]


eva repeat # 21


100%|██████████| 28/28 [00:00<00:00, 68.53it/s]
100%|██████████| 8/8 [00:00<00:00, 10.39it/s]
100%|██████████| 7538/7538 [00:00<00:00, 23734.65it/s]


eva repeat # 22


100%|██████████| 28/28 [00:00<00:00, 52.72it/s]
100%|██████████| 8/8 [00:00<00:00, 11.46it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25697.71it/s]


eva repeat # 23


100%|██████████| 28/28 [00:00<00:00, 80.04it/s]
100%|██████████| 8/8 [00:00<00:00, 11.26it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25298.27it/s]


eva repeat # 24


100%|██████████| 28/28 [00:00<00:00, 81.22it/s]
100%|██████████| 8/8 [00:00<00:00, 11.30it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25809.69it/s]


eva repeat # 25


100%|██████████| 28/28 [00:00<00:00, 70.30it/s]
100%|██████████| 8/8 [00:00<00:00, 11.18it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24409.49it/s]


eva repeat # 26


100%|██████████| 28/28 [00:00<00:00, 75.89it/s]
100%|██████████| 8/8 [00:00<00:00, 10.36it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25233.30it/s]


eva repeat # 27


100%|██████████| 28/28 [00:00<00:00, 70.68it/s]
100%|██████████| 8/8 [00:00<00:00, 11.20it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25421.68it/s]


eva repeat # 28


100%|██████████| 28/28 [00:00<00:00, 79.11it/s]
100%|██████████| 8/8 [00:00<00:00, 11.09it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25729.98it/s]


eva repeat # 29


100%|██████████| 28/28 [00:00<00:00, 51.03it/s]
100%|██████████| 8/8 [00:00<00:00, 10.38it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26099.36it/s]


eva repeat # 30


100%|██████████| 28/28 [00:00<00:00, 74.02it/s]
100%|██████████| 8/8 [00:00<00:00, 11.73it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25814.16it/s]


eva repeat # 31


100%|██████████| 28/28 [00:00<00:00, 73.28it/s]
100%|██████████| 8/8 [00:00<00:00, 10.17it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25337.09it/s]


eva repeat # 32


100%|██████████| 28/28 [00:00<00:00, 82.96it/s]
100%|██████████| 8/8 [00:00<00:00, 10.57it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25440.31it/s]


eva repeat # 33


100%|██████████| 28/28 [00:00<00:00, 81.26it/s]
100%|██████████| 8/8 [00:00<00:00, 10.45it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25410.63it/s]


eva repeat # 34


100%|██████████| 28/28 [00:00<00:00, 80.98it/s]
100%|██████████| 8/8 [00:00<00:00, 11.04it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25360.73it/s]


eva repeat # 35


100%|██████████| 28/28 [00:00<00:00, 76.89it/s]
100%|██████████| 8/8 [00:00<00:00, 11.05it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25480.54it/s]


eva repeat # 36


100%|██████████| 28/28 [00:00<00:00, 46.93it/s]
100%|██████████| 8/8 [00:00<00:00, 11.36it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25966.04it/s]


eva repeat # 37


100%|██████████| 28/28 [00:00<00:00, 80.60it/s]
100%|██████████| 8/8 [00:00<00:00, 11.44it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26162.75it/s]


eva repeat # 38


100%|██████████| 28/28 [00:00<00:00, 74.29it/s]
100%|██████████| 8/8 [00:00<00:00, 11.29it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25215.28it/s]


eva repeat # 39


100%|██████████| 28/28 [00:00<00:00, 83.14it/s]
100%|██████████| 8/8 [00:00<00:00, 10.82it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25582.54it/s]


eva repeat # 40


100%|██████████| 28/28 [00:00<00:00, 84.74it/s]
100%|██████████| 8/8 [00:00<00:00, 10.76it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26426.94it/s]


eva repeat # 41


100%|██████████| 28/28 [00:00<00:00, 76.56it/s]
100%|██████████| 8/8 [00:00<00:00, 10.81it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26221.60it/s]


eva repeat # 42


100%|██████████| 28/28 [00:00<00:00, 76.78it/s]
100%|██████████| 8/8 [00:00<00:00, 11.12it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26096.97it/s]


eva repeat # 43


100%|██████████| 28/28 [00:00<00:00, 51.21it/s]
100%|██████████| 8/8 [00:00<00:00, 11.28it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25801.81it/s]


eva repeat # 44


100%|██████████| 28/28 [00:00<00:00, 85.31it/s]
100%|██████████| 8/8 [00:00<00:00, 10.86it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26408.05it/s]


eva repeat # 45


100%|██████████| 28/28 [00:00<00:00, 80.47it/s]
100%|██████████| 8/8 [00:00<00:00, 11.32it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25894.71it/s]


eva repeat # 46


100%|██████████| 28/28 [00:00<00:00, 83.51it/s]
100%|██████████| 8/8 [00:00<00:00, 11.27it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25529.14it/s]


eva repeat # 47


100%|██████████| 28/28 [00:00<00:00, 79.86it/s]
100%|██████████| 8/8 [00:00<00:00, 11.15it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25554.97it/s]


eva repeat # 48


100%|██████████| 28/28 [00:00<00:00, 83.02it/s]
100%|██████████| 8/8 [00:00<00:00, 11.42it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26366.99it/s]


eva repeat # 49


100%|██████████| 28/28 [00:00<00:00, 70.75it/s]
100%|██████████| 8/8 [00:00<00:00, 11.13it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25931.05it/s]


eva repeat # 50


100%|██████████| 28/28 [00:00<00:00, 53.94it/s]
100%|██████████| 8/8 [00:00<00:00, 10.92it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25691.19it/s]


eva repeat # 51


100%|██████████| 28/28 [00:00<00:00, 69.86it/s]
100%|██████████| 8/8 [00:00<00:00, 10.90it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25297.48it/s]


eva repeat # 52


100%|██████████| 28/28 [00:00<00:00, 74.37it/s]
100%|██████████| 8/8 [00:00<00:00, 11.09it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25963.14it/s]


eva repeat # 53


100%|██████████| 28/28 [00:00<00:00, 71.46it/s]
100%|██████████| 8/8 [00:00<00:00, 11.29it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25891.04it/s]


eva repeat # 54


100%|██████████| 28/28 [00:00<00:00, 66.10it/s]
100%|██████████| 8/8 [00:00<00:00, 11.19it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26048.10it/s]


eva repeat # 55


100%|██████████| 28/28 [00:00<00:00, 58.88it/s]
100%|██████████| 8/8 [00:00<00:00, 11.16it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25879.43it/s]


eva repeat # 56


100%|██████████| 28/28 [00:00<00:00, 68.70it/s]
100%|██████████| 8/8 [00:00<00:00, 11.48it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25488.59it/s]


eva repeat # 57


100%|██████████| 28/28 [00:00<00:00, 46.85it/s]
100%|██████████| 8/8 [00:00<00:00, 10.85it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25481.92it/s]


eva repeat # 58


100%|██████████| 28/28 [00:00<00:00, 70.85it/s]
100%|██████████| 8/8 [00:00<00:00, 11.04it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26698.93it/s]


eva repeat # 59


100%|██████████| 28/28 [00:00<00:00, 70.67it/s]
100%|██████████| 8/8 [00:00<00:00, 11.21it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26222.42it/s]


eva repeat # 60


100%|██████████| 28/28 [00:00<00:00, 66.39it/s]
100%|██████████| 8/8 [00:00<00:00, 11.88it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26528.56it/s]


eva repeat # 61


100%|██████████| 28/28 [00:00<00:00, 63.74it/s]
100%|██████████| 8/8 [00:00<00:00, 11.52it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26169.75it/s]


eva repeat # 62


100%|██████████| 28/28 [00:00<00:00, 61.20it/s]
100%|██████████| 8/8 [00:00<00:00, 11.39it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26263.68it/s]


eva repeat # 63


100%|██████████| 28/28 [00:00<00:00, 65.86it/s]
100%|██████████| 8/8 [00:00<00:00, 11.45it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26439.71it/s]


eva repeat # 64


100%|██████████| 28/28 [00:00<00:00, 47.21it/s]
100%|██████████| 8/8 [00:00<00:00, 11.22it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26192.38it/s]


eva repeat # 65


100%|██████████| 28/28 [00:00<00:00, 69.60it/s]
100%|██████████| 8/8 [00:00<00:00, 11.45it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25985.91it/s]


eva repeat # 66


100%|██████████| 28/28 [00:00<00:00, 71.92it/s]
100%|██████████| 8/8 [00:00<00:00, 12.10it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26370.87it/s]


eva repeat # 67


100%|██████████| 28/28 [00:00<00:00, 71.10it/s]
100%|██████████| 8/8 [00:00<00:00, 11.31it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24935.00it/s]


eva repeat # 68


100%|██████████| 28/28 [00:00<00:00, 63.27it/s]
100%|██████████| 8/8 [00:00<00:00, 11.10it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26533.06it/s]


eva repeat # 69


100%|██████████| 28/28 [00:00<00:00, 66.44it/s]
100%|██████████| 8/8 [00:00<00:00, 11.53it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25376.36it/s]


eva repeat # 70


100%|██████████| 28/28 [00:00<00:00, 66.00it/s]
100%|██████████| 8/8 [00:00<00:00, 11.26it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26677.96it/s]


eva repeat # 71


100%|██████████| 28/28 [00:00<00:00, 49.07it/s]
100%|██████████| 8/8 [00:00<00:00, 10.83it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25460.64it/s]


eva repeat # 72


100%|██████████| 28/28 [00:00<00:00, 65.59it/s]
100%|██████████| 8/8 [00:00<00:00, 11.75it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26352.03it/s]


eva repeat # 73


100%|██████████| 28/28 [00:00<00:00, 64.57it/s]
100%|██████████| 8/8 [00:00<00:00, 11.51it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25770.33it/s]


eva repeat # 74


100%|██████████| 28/28 [00:00<00:00, 66.07it/s]
100%|██████████| 8/8 [00:00<00:00, 11.05it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26304.02it/s]


eva repeat # 75


100%|██████████| 28/28 [00:00<00:00, 68.86it/s]
100%|██████████| 8/8 [00:00<00:00, 11.21it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25588.54it/s]


eva repeat # 76


100%|██████████| 28/28 [00:00<00:00, 65.46it/s]
100%|██████████| 8/8 [00:00<00:00, 11.31it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26086.98it/s]


eva repeat # 77


100%|██████████| 28/28 [00:00<00:00, 63.94it/s]
100%|██████████| 8/8 [00:00<00:00, 11.15it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26923.95it/s]


eva repeat # 78


100%|██████████| 28/28 [00:00<00:00, 48.87it/s]
100%|██████████| 8/8 [00:00<00:00, 11.98it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26381.03it/s]


eva repeat # 79


100%|██████████| 28/28 [00:00<00:00, 71.83it/s]
100%|██████████| 8/8 [00:00<00:00, 11.61it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26656.55it/s]


eva repeat # 80


100%|██████████| 28/28 [00:00<00:00, 68.30it/s]
100%|██████████| 8/8 [00:00<00:00, 11.78it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26376.94it/s]


eva repeat # 81


100%|██████████| 28/28 [00:00<00:00, 62.40it/s]
100%|██████████| 8/8 [00:00<00:00, 10.94it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25350.03it/s]


eva repeat # 82


100%|██████████| 28/28 [00:00<00:00, 64.93it/s]
100%|██████████| 8/8 [00:00<00:00, 11.53it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26146.18it/s]


eva repeat # 83


100%|██████████| 28/28 [00:00<00:00, 63.39it/s]
100%|██████████| 8/8 [00:00<00:00, 11.96it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26406.81it/s]


eva repeat # 84


100%|██████████| 28/28 [00:00<00:00, 65.98it/s]
100%|██████████| 8/8 [00:00<00:00, 12.35it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26287.98it/s]


eva repeat # 85


100%|██████████| 28/28 [00:00<00:00, 46.95it/s]
100%|██████████| 8/8 [00:00<00:00, 11.96it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26116.07it/s]


eva repeat # 86


100%|██████████| 28/28 [00:00<00:00, 62.99it/s]
100%|██████████| 8/8 [00:00<00:00, 11.80it/s]
100%|██████████| 7538/7538 [00:00<00:00, 26655.51it/s]


eva repeat # 87


100%|██████████| 28/28 [00:00<00:00, 64.46it/s]
100%|██████████| 8/8 [00:00<00:00, 12.44it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25843.23it/s]


eva repeat # 88


100%|██████████| 28/28 [00:00<00:00, 65.68it/s]
100%|██████████| 8/8 [00:00<00:00, 11.15it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24442.04it/s]


eva repeat # 89


100%|██████████| 28/28 [00:00<00:00, 63.02it/s]
100%|██████████| 8/8 [00:00<00:00, 11.29it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25259.73it/s]


eva repeat # 90


100%|██████████| 28/28 [00:00<00:00, 64.00it/s]
100%|██████████| 8/8 [00:00<00:00, 11.17it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24673.72it/s]


eva repeat # 91


100%|██████████| 28/28 [00:00<00:00, 65.03it/s]
100%|██████████| 8/8 [00:00<00:00, 11.01it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25735.89it/s]


eva repeat # 92


100%|██████████| 28/28 [00:00<00:00, 45.14it/s]
100%|██████████| 8/8 [00:00<00:00, 11.10it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24831.66it/s]


eva repeat # 93


100%|██████████| 28/28 [00:00<00:00, 63.80it/s]
100%|██████████| 8/8 [00:00<00:00, 11.24it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25761.51it/s]


eva repeat # 94


100%|██████████| 28/28 [00:00<00:00, 61.01it/s]
100%|██████████| 8/8 [00:00<00:00, 10.76it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24841.75it/s]


eva repeat # 95


100%|██████████| 28/28 [00:00<00:00, 65.57it/s]
100%|██████████| 8/8 [00:00<00:00, 11.25it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25551.94it/s]


eva repeat # 96


100%|██████████| 28/28 [00:00<00:00, 71.89it/s]
100%|██████████| 8/8 [00:00<00:00, 10.97it/s]
100%|██████████| 7538/7538 [00:00<00:00, 23918.91it/s]


eva repeat # 97


100%|██████████| 28/28 [00:00<00:00, 72.47it/s]
100%|██████████| 8/8 [00:00<00:00, 11.15it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24864.04it/s]


eva repeat # 98


100%|██████████| 28/28 [00:00<00:00, 65.43it/s]
100%|██████████| 8/8 [00:00<00:00, 10.39it/s]
100%|██████████| 7538/7538 [00:00<00:00, 24521.93it/s]


eva repeat # 99


100%|██████████| 28/28 [00:00<00:00, 50.51it/s]
100%|██████████| 8/8 [00:00<00:00, 12.06it/s]
100%|██████████| 7538/7538 [00:00<00:00, 25898.76it/s]


In [31]:
all_rslt_mean = []
all_rslt_ucb1 = []
all_rslt_ucb05 = []
all_rslt_ucb2 = []

for key, value in y_scores.items():
    mean = np.asarray(value).mean(axis = 0)
    std = np.asarray(value).std(axis = 0)
    try:
        all_rslt_mean.append(compute_amn(y_trues[key], mean))
        all_rslt_ucb1.append(compute_amn(y_trues[key], mean + std ))
        all_rslt_ucb05.append(compute_amn(y_trues[key], mean + 0.5 * std ))
        all_rslt_ucb2.append(compute_amn(y_trues[key], mean + 2 * std ))
    except Exception as e:
        print(e)

In [28]:
val_auc, val_mrr, val_ndcg, val_ndcg10 = [np.mean(i) for i in list(zip(*np.array(all_rslt_mean)))]
val_auc

0.6310987965346135

In [29]:
val_auc, val_mrr, val_ndcg, val_ndcg10 = [np.mean(i) for i in list(zip(*np.array(all_rslt_ucb1)))]
val_auc

0.6307318702678361

In [32]:
val_auc, val_mrr, val_ndcg, val_ndcg10 = [np.mean(i) for i in list(zip(*np.array(all_rslt_ucb05)))]
val_auc

0.631134064312768

In [33]:
val_auc, val_mrr, val_ndcg, val_ndcg10 = [np.mean(i) for i in list(zip(*np.array(all_rslt_ucb2)))]
val_auc

0.6297142446544629