In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
device = 'cuda'

In [2]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from pytorchtools import EarlyStopping
assert(torch.cuda.is_available())

## Data Preprocessing

In [3]:
def load_news(path):
    news_dict = {} # index -> news
    news_list = [] # index -> news
    newsid_dict = {} # newsid -> index
    word_dict = {'<PAD>': 0, '<OOV>': 1}
    cate_dict = {'<PAD>': 0, '<OOV>': 1}
    with open(path, 'r') as f:
        for line in f.readlines():
            news_id, category, subcategory, title, abstract, \
                url, title_entities, abstract_entities = line.strip().split('\t')
            title = title.lower().replace('.', '').replace(',', '').replace(';', '').replace(':', '').replace('\'', '').replace('"', '').replace('?', '').replace('!', '').replace('(', '').replace(')', '').split(' ')
            abstract = abstract.lower().replace('.', '').replace(',', '').replace(';', '').replace(':', '').replace('\'', '').replace('"', '').replace('?', '').replace('!', '').replace('(', '').replace(')', '').split(' ')
            for word in title + abstract:
                if word not in word_dict:
                    word_dict[word] = len(word_dict)
            if category not in cate_dict:
                cate_dict[category] = len(cate_dict)
            if subcategory not in cate_dict:
                cate_dict[subcategory] = len(cate_dict)
            if news_id not in newsid_dict:
                newsid_dict[news_id] = len(newsid_dict)
                news_list.append([category, subcategory, title, abstract])
    print(len(news_list))
    return news_list, newsid_dict, word_dict, cate_dict

In [4]:
max_title = 30
max_body = 100
def map_news_input(news_list, word_dict, cate_dict):
    n_news = len(news_list)
    titles = np.zeros((n_news, max_title), dtype = 'int32')
    bodys = np.zeros((n_news, max_body), dtype = 'int32')
    cates = np.zeros((n_news,1), dtype = 'int32')
    subcates = np.zeros((n_news,1), dtype = 'int32')
    for i in range(n_news):
        category, subcategory, title, abstract = news_list[i]
        titles[i, :len(title)] = [word_dict[word] for word in title[:max_title]]
        bodys[i, :len(abstract)] = [word_dict[word] for word in abstract[:max_body]]
        cates[i] = cate_dict[category]
        subcates[i] = cate_dict[subcategory]
    news_info = np.concatenate((titles, bodys, cates, subcates), axis = 1)
    print(news_info.shape)
    return news_info # index -> news_info

In [5]:
'''
news_list: original news
news_info: mapped news(word ids)
'''
news_list, newsid_dict, word_dict, cate_dict = load_news('/data/Recommend/MIND/small_news.tsv')
news_info = map_news_input(news_list, word_dict, cate_dict)

65238
(65238, 132)


In [6]:
def load_glove(word_to_ix, dim = 100):
    if dim == 100:
        path = '/data/pretrained/Glove/glove.6B.100d.txt'
    elif dim == 300:
        path = '/data/pretrained/Glove/glove.840B.300d.txt'
    word_emb = []
    word_emb = np.zeros((len(word_to_ix), dim), dtype = float)
    with open(path, 'r') as f:
        for line in f:
            data = line.strip().split(' ') # [word emb1 emb2 ... emb n]
            word = data[0]
            if word in word_to_ix:
                word_emb[word_to_ix[word]] = [float(i) for i in data[1:]]
    print(word_emb.shape)
    return torch.tensor(word_emb, dtype = torch.float)

In [7]:
word_emb = load_glove(word_dict, 300)
cate_emb = load_glove(cate_dict, 100)

(80416, 300)
(282, 100)


In [8]:
def load_train_impression(path, newsid_dict): # train&dev
    logs = []
    with open(path, 'r') as f:
        for line in f:
            imp_id, user_id, time, history, impression = line.strip().split('\t')
            if history:
                history = [newsid_dict[news_id] for news_id in history.split(' ')]
            else:
                history = []
            positive = []
            negative = []
            for item in impression.split(' '):
                news_id, num = item.split('-')
                if num == '1':
                    positive.append(newsid_dict[news_id])
                else:
                    negative.append(newsid_dict[news_id])
            logs.append([history, positive, negative]) # indexs
    return logs

In [9]:
max_history = 50
def map_user(logs): # index -> history, 用 index 代表 user_id, train&dev
    n_user = len(logs)
    user_hist = np.zeros((n_user, max_history), dtype = 'int32') # index -> history
    for i in range(n_user):
        history, positive, negative = logs[i]
        n_hist = len(history)
        if n_hist == 0:
            continue
        user_hist[i, -n_hist:] = history[-max_history:]
    return user_hist         

In [10]:
neg_ratio = 4
def neg_sample(negative):
    if len(negative) < neg_ratio:
        return random.sample(negative * (neg_ratio // len(negative) + 1), neg_ratio)
    else:
        return random.sample(negative, neg_ratio)

def get_train_input(logs): # 和 map_user 使用同一个 log
    all_pos = [] # 每个 sample 的 pos
    all_neg = []
    user_id = [] # 每个 sample 的 user，用 index 表示，和 map_user 的结果对应
    for i in range(len(logs)):
        history, positive, negative = logs[i]
        for pos in positive:
            all_pos.append(pos)
            all_neg.append(neg_sample(negative))
            user_id.append(i)
    n_imps = len(all_pos)
    imps = np.zeros((n_imps, 1 + neg_ratio), dtype = 'int32')
    for i in range(len(all_pos)):
        imps[i, 0] = all_pos[i]
        imps[i, 1:] = all_neg[i]
    user_id = np.array(user_id, dtype = 'int32')
    labels = np.zeros((n_imps, 1 + neg_ratio), dtype = 'int32')
    labels[:, 0] = 1
    print(n_imps)
    return imps, user_id, labels

def get_dev_input(logs): # 和 map_user 使用同一个 log
    imps = []
    labels = []
    user_id = np.zeros((len(logs)), dtype = 'int32') # 每个 sample 的 user index，和 map_user 的结果对应
    for i in range(len(logs)):
        history, positive, negative = logs[i]
        imps.append(np.array(positive + negative, dtype = 'int32'))
        labels.append([1] * len(positive) + [0] * len(negative))
        user_id[i] = i
    print(len(logs))
    return imps, user_id, labels

In [11]:
class TrainDataset(Dataset):
    def __init__(self, imp_datas, imp_users, imp_labels, news_info, user_clicks, batch_size, news_ents = None):
        self.imp_datas = imp_datas # (n_imps, 1 + k)
        self.imp_users = imp_users
        self.imp_labels = imp_labels
        self.news = news_info
        self.user_clicks = user_clicks
        self.batch_size = batch_size
        self.news_ents = news_ents
        
        self.n_data = imp_datas.shape[0]
        
    def __len__(self):
        return int(np.ceil(self.n_data / self.batch_size))

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min((idx + 1) * self.batch_size, self.n_data)
        
        data_id = self.imp_datas[start: end] # (n_batch, 1 + k)
        data_news = self.news[data_id] # (n_batch, 1 + k, news_len)
        user_id = self.imp_users[start: end] # (n_batch)
        user_news_id = self.user_clicks[user_id] # (n_batch, n_hist)
        user_news = self.news[user_news_id] # (n_batch, n_hist, news_len)
        labels = self.imp_labels[start: end] # (n_batch, 1 + k)
        
        if self.news_ents is not None:
            samp_ents = self.news_ents[data_id]
            user_ents = self.news_ents[user_news_id]
            return data_news, user_news, labels, samp_ents, user_ents
        
        return data_news, user_news, labels
    
class DevDataset(Dataset): # data 和 label 是 list，每条数据不同长度
    def __init__(self, imp_datas, imp_users, imp_labels, news_info, user_clicks, batch_size):
        self.imp_datas = imp_datas # [imp1, imp2, ..., impn]
        self.imp_users = imp_users # (n_imps)
        self.imp_labels = imp_labels
        self.news = news_info
        self.user_clicks = user_clicks
        self.batch_size = batch_size
        
        self.n_data = len(imp_datas)
        
    def __len__(self):
        return int(np.ceil(self.n_data / self.batch_size))

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min((idx + 1) * self.batch_size, self.n_data)
        
        data_ids = []
        data_news = [] # [(n_imp, news_len)]
        labels = [] # [(n_imp)]
        for i in range(start, end):
            data_id = self.imp_datas[i] # (n_imp)
            data_ids.append(data_id)
            # data_news.append(self.news[data_id]) # (n_imp, news_len)
            labels.append(self.imp_labels[i]) # (n_imp)
        user_id = self.imp_users[start: end] # (n_batch)
        user_news_id = self.user_clicks[user_id] # (n_batch, n_hist)
        # user_news = self.news[user_news_id] # (n_batch, n_hist, news_len)
        
        #return data_news, user_news, labels
        return data_ids, user_news_id, labels
    
class PaddedDevDataset(Dataset): # data 和 label 是 list，每条数据不同长度
    def __init__(self, imp_datas, imp_users, imp_labels, news_info, user_clicks, batch_size):
        self.imp_datas = imp_datas # [imp1, imp2, ..., impn]
        self.imp_users = imp_users # (n_imps)
        self.imp_labels = imp_labels
        self.news = news_info
        self.user_clicks = user_clicks
        self.batch_size = batch_size
        
        self.n_data = len(imp_datas)
        
    def __len__(self):
        return int(np.ceil(self.n_data / self.batch_size))

    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min((idx + 1) * self.batch_size, self.n_data)
        
        len_imp = [len(i) for i in self.imp_datas[start: end]]
        data_ids = np.zeros((sum(len_imp), ), dtype = np.int32)
        idx = 0
        labels = [] # [(n_imp)]
        index = [] # 每个 imp 属于 batch 内的第几个 user
        for i in range(start, end):
            data_id = self.imp_datas[i] # (n_imp)
            data_ids[idx: idx + len(data_id)] = data_id
            idx += len(data_id)
            index += [i - start] * len(data_id) # 第 i 个 user 的 imp 数量
            labels.append(self.imp_labels[i]) # (n_imp)
        user_id = self.imp_users[start: end] # (n_batch)
        user_news_id = self.user_clicks[user_id] # (n_batch, n_hist)
        
        return data_ids, user_news_id, labels, index

In [12]:
n_batch = 8
train_logs = load_train_impression('/data/Recommend/MIND/MINDsmall_train/behaviors.tsv', newsid_dict)
train_user_hist = map_user(train_logs)
train_datas, train_users, train_labels = get_train_input(train_logs)
train_dataset = TrainDataset(train_datas, train_users, train_labels, news_info, train_user_hist, n_batch)

dev_logs = load_train_impression('/data/Recommend/MIND/MINDsmall_dev/behaviors.tsv', newsid_dict)
dev_user_hist = map_user(dev_logs)
dev_datas, dev_users, dev_labels = get_dev_input(dev_logs)
dev_dataset = DevDataset(dev_datas, dev_users, dev_labels, news_info, dev_user_hist, 64)
# pad_dev_dataset = PaddedDevDataset(dev_datas, dev_users, dev_labels, news_info, dev_user_hist, 64)
# caum_dev_dataset = DevDataset(dev_datas, dev_users, dev_labels, news_info, dev_user_hist, 64)

valid_datas, valid_users, valid_labels = get_train_input(dev_logs) # 用 train 的方法构造 dev_set
valid_dataset = TrainDataset(valid_datas, valid_users, valid_labels, news_info, dev_user_hist, n_batch)

236344
73152
111383


In [13]:
def encode_all_news(news_info, news_encoder):
    n_news = len(news_info)
    news_rep = []
    n_batch = 32
    for i in range((len(news_info) + n_batch - 1) // n_batch):
        batch_news = torch.tensor(news_info[i * n_batch: (i + 1) * n_batch], dtype = torch.long, device = 'cuda')
        batch_rep = news_encoder(batch_news).detach().cpu().numpy()
        news_rep.append(batch_rep)
    news_rep = np.concatenate(news_rep, axis = 0)
    return news_rep # (n_news, n_title, n_emb)

def encode_all_user(user_ids, user_hist, user_encoder, news_rep):
    user_rep = []
    with torch.no_grad():
        for _, batch in enumerate(dev_dataset):
            if len(batch[0]) == 0:
                break
            user_hist_rep = torch.tensor(news_rep[batch[1]], device = 'cuda') # (n_batch, n_hist)
            user = model.user_encoder(user_hist_rep).detach().cpu().numpy() # (n_batch, emb_dim)
            user_rep.append(user)
    # user_rep = np.concatenate(user_rep, axis = 0)
    return user_rep # [user_rep, ...]

In [14]:
def dcg_score(y_true, y_score, k=10):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2 ** y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)

def ndcg_score(y_true, y_score, k=10):
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_score, k)
    return actual / best

def mrr_score(y_true, y_score):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)

In [15]:
# train with valid
def train(model, train_dataset, valid_dataset = None, epochs = 4):
    optimizer = optim.Adam(model.parameters(), lr = 1e-4)
    entrophy = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_losses = []
        valid_losses = []
        model.train()
        for _, batch in tqdm(enumerate(train_dataset)):
            if batch[0].shape[0] == 0:
                break
            # torch.Size([16, 5, 30]) torch.Size([16, 50, 30]) torch.Size([16])
            sample = torch.tensor(batch[0], dtype = torch.long, device = device)
            history = torch.tensor(batch[1], dtype = torch.long, device = device)
            correct = torch.argmax(torch.tensor(batch[2], dtype = torch.long, device = device), dim = 1)
            optimizer.zero_grad()
            output = model(history, sample)
            loss = entrophy(output, correct)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        if valid_dataset is not None:
            model.eval()
            with torch.no_grad():
                for _, batch in enumerate(valid_dataset):
                    if batch[0].shape[0] == 0:
                        break
                    sample = torch.tensor(batch[0], dtype = torch.long, device = device)
                    history = torch.tensor(batch[1], dtype = torch.long, device = device)
                    correct = torch.argmax(torch.tensor(batch[2], dtype = torch.long, device = device), dim = 1)
                    output = model(sample, history)
                    loss = entrophy(output, correct)
                    valid_losses.append(loss.item())
                print('[epoch {:d}] train_loss: {:.4f} valid_loss: {:.4f}'.format(epoch + 1, np.average(train_losses), np.average(valid_losses)))
        else:
            print('[epoch {:d}] train_loss: {:.4f}'.format(epoch + 1, np.average(train_losses)))

In [16]:
def evaluate(model, dev_dataset, news_info, dev_users, dev_user_hist):
    news_rep = encode_all_news(news_info, model.news_encoder) # (65238, 400)
    user_rep = encode_all_user(dev_users, dev_user_hist, model.user_encoder, news_rep)
    
    model.eval()
    with torch.no_grad():
        auc_scores = []
        mrr_scores = []
        ndcg5_scores = []
        ndcg10_scores = []
        for i, batch in tqdm(enumerate(dev_dataset)):
            if len(batch[0]) == 0:
                break
            user = user_rep[i]
            for j in range(len(batch[0])):
                sample = news_rep[batch[0][j]] # (n_imp, emb_dim)
                positive = batch[2][j] # (1, n_imp)

                score = np.matmul(sample, user[j]) # (1, n_imp)
                predict = np.exp(score) / np.sum(np.exp(score))

                auc_scores.append(roc_auc_score(positive, predict))
                mrr_scores.append(mrr_score(positive, predict))
                ndcg5_scores.append(ndcg_score(positive, predict, k = 5))
                ndcg10_scores.append(ndcg_score(positive, predict, k = 10))
    print('[Test] AUC: {:4f}, MRR: {:4f}, nDCG5:{:4f}, nDCG10: {:4f}'.format(
        np.mean(auc_scores), np.mean(mrr_scores), np.mean(ndcg5_scores), np.mean(ndcg10_scores)
    ))

## FIM

In [17]:
class HDCNewsEncoder(nn.Module): # Hierarchical dilated convolution
    def __init__(self, args, word_emb, news_dim, dilations = [1, 2, 3]):
        super().__init__()
        self.args = args
        self.use_cate = args['use_cate']
        self.word_embedding = nn.Embedding.from_pretrained(word_emb)
        
        self.cnns = nn.ModuleList([nn.Conv1d(word_emb.shape[1], news_dim, 3, dilation = i, padding = i) for i in dilations])
        self.layernorm = nn.LayerNorm(news_dim)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, news):
        title, body, cate = news[:, :max_title], news[:, max_title: -2], news[:, -2:]
        if self.use_cate:
            title = torch.cat((title, cate), dim = 1)
        
        t_rep = self.word_embedding(title) # (n_batch, n_seq, emb_dim)
        t_rep = self.dropout(t_rep)
        r = []
        for cnn in self.cnns:
            l = F.relu(cnn(t_rep.transpose(-1, -2)).transpose(-1, -2))
            l = self.layernorm(l) # (n_batch, n_seq, n_filter)
            r.append(l)
        r = torch.stack(r, dim = 1) # (n_batch, n_layer, n_seq, n_filter)
        
        return r
        

In [18]:
class FIM(nn.Module):
    def __init__(self, word_emb, args):
        super().__init__()
        dilations = [1, 2, 3]
        self.news_encoder = HDCNewsEncoder(args, word_emb, 150, dilations)
        self.cnn1 = nn.Conv3d(len(dilations), 32, 3, padding = 1)
        self.cnn2 = nn.Conv3d(len(dilations), 16, 3, padding = 1)
        self.pool = nn.MaxPool3d((3, 3, 3), (3, 3, 3))
        self.fc = nn.Linear(102400, 1)
    
    def forward(self, hist, samp):
        n_batch, n_news, n_seq = hist.shape
        n_samp = samp.shape[1] # k + 1
        
        hist = hist.reshape(n_batch * n_news, n_seq)
        h = self.news_encoder(hist) # (n_batch*n_news, n_layer, n_seq, n_filter)
        samp = samp.reshape(n_batch * n_samp, n_seq)
        c = self.news_encoder(samp) # (n_batch*(k+1), n_layer, n_seq, n_filter)
        
        _, n_layer, n_seq, news_dim = h.shape
        h = h.reshape(n_batch, n_news, n_layer, n_seq, news_dim).transpose(1, 0) # (n_news, n_batch, n_layer, n_seq, news_dim)
        c = c.reshape(n_batch, n_samp, n_layer, n_seq, news_dim).transpose(1, 0) # (k + 1, n_batch, n_layer, n_seq, news_dim)
        
        m = torch.zeros((n_samp, n_news, n_batch, n_layer, n_seq, n_seq), device = 'cuda')
        for i in range(c.shape[0]):
            m[i] = torch.matmul(h, c[i].transpose(-1, -2)) / np.sqrt(news_dim + 1e-8) # (n_layer, n_seq, n_seq)
        m = m.permute(2, 0, 3, 1, 4, 5).reshape(n_batch * n_samp, n_layer, n_news, n_seq, n_seq) # (n_batch, n_samp, n_layer, n_news, n_seq, n_seq)
        
        q1 = self.cnn1(m) # (n_batch, n_samp, n_filter, n_news, n_seq, n_seq)
        q2 = self.cnn1(m) # (n_batch, n_samp, n_filter, n_news, n_seq, n_seq)
        q1 = self.pool(q1).squeeze() # （n_batch * n_samp, n_filter, 16, 10, 10）
        q2 = self.pool(q2).squeeze()
        s = torch.cat((q1, q2), dim = -1).reshape(n_batch, n_samp, -1) # (n_batch, n_samp, q1 + q2)
        y = self.fc(s).squeeze() # (n_batch, n_samp)
        return y
    
    def predict(self, h, c): # (n_news, n_layer, n_seq, n_filter)
        n_news, n_layer, n_seq, news_dim = h.shape
        n_samp = c.shape[1] # k + 1
        
        m = torch.zeros((n_samp, n_news, n_layer, n_seq, n_seq))
        for i in range(c.shape[0]):
            m[i] = torch.matmul(h, c[i]) / torch.sqrt(news_dim + 1e-8) # (n_layer, n_seq, n_seq)
        m = m.transpose(2, 1) # (n_samp, n_layer, n_news, n_seq, n_seq)
        
        q1 = self.cnn1(m) # (n_samp, n_filter, n_news, n_seq, n_seq)
        q2 = self.cnn1(m) # (n_samp, n_filter, n_news, n_seq, n_seq)
        q1 = self.pool(q1).squeeze()
        q2 = self.pool(q2).squeeze()
        s = torch.cat(q1, q2, dim = -1) # (n_samp, q1 + q2)
        
        y = self.fc(s).squeeze() # (n_samp)
        y = F.softmax(y, dim = -1)
        return y
    
def predict_fim(model, h, c):
    n_news, n_layer, n_seq, news_dim = h.shape
    n_samp = c.shape[0] # k + 1
    
    m = torch.zeros((n_samp, n_news, n_layer, n_seq, n_seq), device = 'cuda')
    # print(m.shape, h.shape, c.shape)
    for i in range(c.shape[0]):
        # torch.Size([50, 3, 32, 150]) torch.Size([3, 32, 150])
        m[i] = torch.matmul(h, c[i].transpose(-1, -2)) / np.sqrt(news_dim + 1e-8) # (n_layer, n_seq, n_seq)
    m = m.transpose(2, 1).reshape(n_samp, n_layer, n_news, n_seq, n_seq) # (n_batch, n_samp, n_layer, n_news, n_seq, n_seq)

    q1 = model.cnn1(m) # (n_samp, n_filter, n_news, n_seq, n_seq)
    q2 = model.cnn1(m) # (n_samp, n_filter, n_news, n_seq, n_seq)
    q1 = model.pool(q1).squeeze() # （n_samp, n_filter, 16, 10, 10）
    q2 = model.pool(q2).squeeze()
    s = torch.cat((q1, q2), dim = -1).reshape(n_samp, -1) # (n_batch, n_samp, q1 + q2)
    y = model.fc(s).squeeze() # (n_samp)
    return y

In [19]:
args = {'model': 'FIM', 
        'use_cate': True}
print(args)
model = FIM(word_emb, args).to('cuda')
train(model, train_dataset)
# train_and_eval_caum(model, train_dataset, dev_dataset, news_info, news_ents)

{'model': 'FIM', 'use_cate': True}


0it [00:00, ?it/s]

[epoch 1] train_loss: 1.4527


0it [00:00, ?it/s]

[epoch 2] train_loss: 1.3694


0it [00:00, ?it/s]

[epoch 3] train_loss: 1.3338


0it [00:00, ?it/s]

[epoch 4] train_loss: 1.3094


In [19]:
def evaluate_fim(model, dev_dataset, news_info, dev_users, dev_user_hist):
    news_rep = encode_all_news(news_info, model.news_encoder) # (65238, 400)
    
    model.eval()
    with torch.no_grad():
        auc_scores = []
        mrr_scores = []
        ndcg5_scores = []
        ndcg10_scores = []
        for i, batch in tqdm(enumerate(dev_dataset)):
            if len(batch[0]) == 0:
                break
            for j in range(len(batch[0])):
                samp = torch.tensor(news_rep[batch[0][j]], device = 'cuda') # (n_imp, emb_dim)
                hist = torch.tensor(news_rep[batch[1][j]], device = 'cuda') # (n_hist, emb_dim)
                positive = batch[2][j] # [n_imp]

                predict = predict_fim(model, hist, samp).detach().cpu().numpy() # (n_imp)

                auc_scores.append(roc_auc_score(positive, predict))
                mrr_scores.append(mrr_score(positive, predict))
                ndcg5_scores.append(ndcg_score(positive, predict, k = 5))
                ndcg10_scores.append(ndcg_score(positive, predict, k = 10))
    print('[Test] AUC: {:4f}, MRR: {:4f}, nDCG5:{:4f}, nDCG10: {:4f}'.format(
        np.mean(auc_scores), np.mean(mrr_scores), np.mean(ndcg5_scores), np.mean(ndcg10_scores)
    ))

In [20]:
def train_and_eval_fim(model, train_dataset, dev_dataset, news_info, dev_users, dev_user_hist, epochs = 4):
    optimizer = optim.Adam(model.parameters(), lr = 1e-4)
    entrophy = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_losses = []
        valid_losses = []
        model.train()
        for _, batch in tqdm(enumerate(train_dataset)):
            if batch[0].shape[0] == 0:
                break
            # torch.Size([16, 5, 30]) torch.Size([16, 50, 30]) torch.Size([16])
            sample = torch.tensor(batch[0], dtype = torch.long, device = device)
            history = torch.tensor(batch[1], dtype = torch.long, device = device)
            correct = torch.argmax(torch.tensor(batch[2], dtype = torch.long, device = device), dim = 1)
            optimizer.zero_grad()
            output = model(history, sample)
            loss = entrophy(output, correct)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
        print('[epoch {:d}] train_loss: {:.4f}'.format(epoch + 1, np.average(train_losses)))
        evaluate_fim(model, dev_dataset, news_info, dev_users, dev_user_hist)

In [23]:
args = {'model': 'FIM', 
        'use_cate': True}
print(args)
model = FIM(word_emb, args).to('cuda')
train_and_eval_fim(model, train_dataset, dev_dataset, news_info, dev_users, dev_user_hist)

{'model': 'FIM', 'use_cate': True}


0it [00:00, ?it/s]

[epoch 1] train_loss: 1.4472


0it [00:00, ?it/s]

[Test] AUC: 0.625115, MRR: 0.293408, nDCG5:0.316881, nDCG10: 0.382791


0it [00:00, ?it/s]

[epoch 2] train_loss: 1.3649


0it [00:00, ?it/s]

[Test] AUC: 0.642601, MRR: 0.300540, nDCG5:0.327832, nDCG10: 0.391846


0it [00:00, ?it/s]

[epoch 3] train_loss: 1.3312


0it [00:00, ?it/s]

[Test] AUC: 0.639133, MRR: 0.303845, nDCG5:0.330262, nDCG10: 0.393041


0it [00:00, ?it/s]

[epoch 4] train_loss: 1.3074


0it [00:00, ?it/s]

[Test] AUC: 0.642919, MRR: 0.304551, nDCG5:0.332696, nDCG10: 0.396139
