## CCA Model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm



class CCA_Model(nn.Module):
    def __init__(self, usernum, itemnum, args, embedding_size, cxt_size, use_res=False):
        super(CCA_Model, self).__init__()
        self.args = args
        hidden_units = args.hidden_units
        self.use_res = use_res

        self.item_emb = nn.Embedding(itemnum + 1, hidden_units, padding_idx=0)
        self.W_s = nn.Linear(embedding_size, embedding_size)
        self.W_m = nn.Linear(embedding_size, embedding_size)
        self.b_g = nn.Parameter(torch.zeros(embedding_size))
        self.feat_emb = nn.Linear(embedding_size + cxt_size, hidden_units * 8)
        self.embComp = nn.Linear(hidden_units + hidden_units * 8, hidden_units)
        self.pos_emb = nn.Embedding(args.maxlen, hidden_units)

        self.attention_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=hidden_units, nhead=args.num_heads,
                                       dim_feedforward=hidden_units * 6, dropout=args.dropout_rate,
                                       activation='relu')
            for _ in range(args.num_blocks)
        ])

        # New modules according to paper
        self.review_context_proj = nn.Linear(embedding_size + cxt_size, hidden_units)
        self.item_proj = nn.Linear(hidden_units + hidden_units, hidden_units)

        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_units, num_heads=args.num_heads, dropout=args.dropout_rate)
        self.cross_ffn = nn.Sequential(
            nn.Linear(hidden_units, hidden_units * 6),
            nn.ReLU(),
            nn.Linear(hidden_units * 6, hidden_units)
        )
        self.cross_norm1 = nn.LayerNorm(hidden_units)
        self.cross_norm2 = nn.LayerNorm(hidden_units)

        self.predict_dense = nn.Linear(hidden_units, 1)

    def forward(self, user_ids, seq_items, pos_items, neg_items,
            seq_cxt, pos_cxt, neg_cxt, seq_feat_single, seq_feat_mean,
            pos_feat, neg_feat, is_training=True):
    
        # === 输入embedding处理 ===
        mask = (seq_items != 0).unsqueeze(-1).float()  # (batch_size, seq_len, 1)
        seq_in = self.item_emb(seq_items)              # (batch_size, seq_len, hidden_units)
        
        g = torch.sigmoid(self.W_s(seq_feat_single) + self.W_m(seq_feat_mean) + self.b_g)
        seq_feat = g * seq_feat_single + (1 - g) * seq_feat_mean
        seq_feat_emb = self.feat_emb(torch.cat([seq_feat, seq_cxt], dim=-1))  # (batch_size, seq_len, hidden_units*5)
        
        seq_concat = torch.cat([seq_in, seq_feat_emb], dim=-1)  # (batch_size, seq_len, hidden_units + hidden_units*5)
        seq = self.embComp(seq_concat)                         # (batch_size, seq_len, hidden_units)
        
        if is_training:
            seq = F.dropout(seq, p=self.args.dropout_rate, training=True)
        seq = seq * mask
        
        for block in self.attention_blocks:
            seq = block(seq.transpose(0, 1)).transpose(0, 1)  # Transformer是(seq_len, batch_size, hidden_units)
            seq = seq * mask
        
        seq_low = seq.clone()
        
        # === 正负样本embedding处理 ===
        # 正样本: pos_items
        pos_emb = self.item_emb(pos_items)  # (batch_size, 1, hidden_units)
        pos_feat_emb = self.review_context_proj(torch.cat([pos_feat, pos_cxt], dim=-1))  # (batch_size, 1, hidden_units)
        pos_concat = torch.cat([pos_emb, pos_feat_emb], dim=-1)
        pos_emb_final = self.item_proj(pos_concat)  # (batch_size, 1, hidden_units)
        
        # 负样本: neg_items (negative samples, possibly more than one per positive item)
        neg_emb = self.item_emb(neg_items)  # (batch_size, n_neg, hidden_units)
        neg_feat_emb = self.review_context_proj(torch.cat([neg_feat, neg_cxt], dim=-1))  # (batch_size, n_neg, hidden_units)
        neg_concat = torch.cat([neg_emb, neg_feat_emb], dim=-1)
        neg_emb_final = self.item_proj(neg_concat)  # (batch_size, n_neg, hidden_units)
        
        # === 低阶interaction（点积评分）===
        positive_rating = (seq_low[:, -1, :] * pos_emb.squeeze(1)).sum(-1)  # (batch_size,)
        negative_rating = (seq_low[:, -1, None, :] * neg_emb).sum(-1)       # (batch_size, n_neg)
        
        # === 高阶interaction（Cross Attention + FFN）===
        batch_size, n_neg, hidden_units = neg_emb_final.shape
        k = v = seq.transpose(0, 1)  # (seq_len, batch_size, hidden_units)
        
        # 正样本 Cross Attention
        q_pos = pos_emb_final.squeeze(1).unsqueeze(0)  # (1, batch_size, hidden_units)
        cross_attn_output_pos, _ = self.cross_attention(q_pos, k, v)
        cross_attn_output_pos = cross_attn_output_pos.transpose(0, 1).squeeze(1)  # (batch_size, hidden_units)
        
        cross_attn_output_pos = self.cross_norm1(pos_emb_final.squeeze(1) + cross_attn_output_pos)
        ffn_output_pos = self.cross_ffn(cross_attn_output_pos)
        cross_attn_output_pos = self.cross_norm2(cross_attn_output_pos + ffn_output_pos)
        
        pos_logits_high = self.predict_dense(cross_attn_output_pos).squeeze(-1)  # (batch_size,)
        
        # 负样本 Cross Attention，用循环处理
        neg_logits_high = []
        for i in range(n_neg):
            neg_emb_single = neg_emb_final[:, i, :]  # (batch_size, hidden_units)
            q_neg = neg_emb_single.unsqueeze(0)      # (1, batch_size, hidden_units)
            
            cross_attn_output_neg, _ = self.cross_attention(q_neg, k, v)
            cross_attn_output_neg = cross_attn_output_neg.transpose(0, 1).squeeze(1)
        
            cross_attn_output_neg = self.cross_norm1(neg_emb_single + cross_attn_output_neg)
            ffn_output_neg = self.cross_ffn(cross_attn_output_neg)
            cross_attn_output_neg = self.cross_norm2(cross_attn_output_neg + ffn_output_neg)
        
            neg_logit = self.predict_dense(cross_attn_output_neg).squeeze(-1)  # (batch_size,)
            neg_logits_high.append(neg_logit)
        
        neg_logits_high = torch.stack(neg_logits_high, dim=1)  # (batch_size, n_neg)
        
        # === Loss计算 ===
        istarget_pos = (pos_items != 0).float().squeeze(-1)  # (batch_size,)
        istarget_neg = (neg_items != 0).float()              # (batch_size, n_neg)
        
        loss_high = -(F.logsigmoid(pos_logits_high) * istarget_pos).sum() \
                    -(F.logsigmoid(-neg_logits_high) * istarget_neg).sum()
        loss_high /= (istarget_pos.sum() + istarget_neg.sum())
        
        loss_low = -(F.logsigmoid(positive_rating) * istarget_pos).sum() \
                   -(F.logsigmoid(-negative_rating) * istarget_neg).sum()
        loss_low /= (istarget_pos.sum() + istarget_neg.sum())
        
        total_loss = loss_high + loss_low
        #total_loss = loss_low
        
        # === AUC计算 ===
        auc = (((pos_logits_high.unsqueeze(1) - neg_logits_high) > 0).float() * istarget_neg).sum() / (istarget_pos.sum() * istarget_neg.size(1))
        
        return total_loss, auc, loss_high, loss_low



    def predict(self, seq, seqcxt, seq_feat_single, seq_feat_mean, item_idx, item_cxt, item_feat):
        seq_in = self.item_emb(seq)
        g = torch.sigmoid(self.W_s(seq_feat_single) + self.W_m(seq_feat_mean) + self.b_g)
        seq_feat = g * seq_feat_single + (1 - g) * seq_feat_mean
        seq_feat_emb = self.feat_emb(torch.cat([seq_feat, seqcxt], dim=-1))
        seq_concat = torch.cat([seq_in, seq_feat_emb], dim=-1)
        seq = self.embComp(seq_concat)

        #positions = torch.arange(seq.size(1), device=seq.device).unsqueeze(0)
        #seq += self.pos_emb(positions)

        for block in self.attention_blocks:
            seq = block(seq.transpose(0, 1)).transpose(0, 1)

        seq_low = seq[:, -1, :]

        test_emb = self.item_emb(item_idx)
        test_feat_emb = self.review_context_proj(torch.cat([item_feat, item_cxt], dim=-1))
        test_concat = torch.cat([test_emb, test_feat_emb], dim=-1)
        test_emb_final = self.item_proj(test_concat)

        logits_low = (seq_low.unsqueeze(1) * test_emb).sum(-1)

        batch_size = test_emb_final.shape[0]

        k = v = seq.expand(batch_size, -1, -1)  # (item_num, seq_len, hidden_units)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        
        q = test_emb_final.unsqueeze(0)  # (1, item_num, hidden_units)
        
        cross_attn_output, _ = self.cross_attention(q, k, v)
        cross_attn_output = cross_attn_output.transpose(0, 1).squeeze(1)
        
        cross_attn_output = self.cross_norm1(test_emb_final + cross_attn_output)
        ffn_output = self.cross_ffn(cross_attn_output)
        cross_attn_output = self.cross_norm2(cross_attn_output + ffn_output)
        
        logits_high = self.predict_dense(cross_attn_output).squeeze(-1)

        #return logits_low
        return logits_low + logits_high



# ----------------- WarpSamplerDataset + Trainer -----------------


import torch
import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
class WarpSamplerDataset(Dataset):
    def __init__(self, user_train, usernum, itemnum, cxtdict, cxtsize, ItemFeatures, itemFeat, user_negative, embedding_size, maxlen):
        super(WarpSamplerDataset, self).__init__()
        self.user_train = user_train
        self.usernum = usernum
        self.itemnum = itemnum
        self.cxtdict = cxtdict
        self.cxtsize = cxtsize
        self.ItemFeatures = ItemFeatures
        self.itemFeat = itemFeat
        self.embedding_size = embedding_size
        self.maxlen = maxlen
        self.users = list(user_train.keys())
        self.user_negative = user_negative

    def __len__(self):
        return len(self.users)

    def _sample(self, user):
        seq = np.zeros([self.maxlen], dtype=np.int64)
        seqcxt = np.zeros([self.maxlen, self.cxtsize], dtype=np.float32)
        seqFeat_single = np.zeros([self.maxlen, self.embedding_size], dtype=np.float32)
        seqFeat_mean = np.zeros([self.maxlen, self.embedding_size], dtype=np.float32)
    
        user_seq = self.user_train[user]
        if len(user_seq) < 2:
            return self._sample(random.choice(self.users))
    
        # Construct positive item list: Penultimate L-1 items and the final item
        pos_items = user_seq[-self.maxlen:]  # L-1 penultimate items + the final item (the most recent interaction)
        hist_seq = pos_items[:-1]  # The history of items
        
        idx = self.maxlen - 1
        for i in reversed(hist_seq):
            seq[idx] = i
            seqcxt[idx] = self.cxtdict.get((user, i), np.zeros(self.cxtsize, dtype=np.float32))
            seqFeat_single[idx] = self.ItemFeatures.get((user, i), np.zeros(self.embedding_size, dtype=np.float32))
            seqFeat_mean[idx] = np.array(self.itemFeat.get(i, np.zeros(self.embedding_size, dtype=np.float32)))
            idx -= 1
            if idx == -1:
                break
        
        # Generate negative item list by randomly selecting L items the user hasn't interacted with
        seen = set(user_seq)
        all_items = set(range(1, self.itemnum + 1))
        candidate_items = list(all_items - seen)
        neg_samples = np.random.choice(candidate_items, size=self.maxlen, replace=False)
        neg = neg_samples.astype(np.int64)
        
        # Context and features for positive and negative items
        pos = pos_items[-1]  # The most recent positive item
        poscxt = self.cxtdict.get((user, pos), np.zeros(self.cxtsize, dtype=np.float32))
        posFeat = np.array(self.itemFeat.get(pos, np.zeros(self.embedding_size, dtype=np.float32)))
        
        # Negative samples' context directly copied from positive item context
        negcxt = np.tile(np.expand_dims(poscxt, 0), (self.maxlen, 1)).astype(np.float32)
        
        # Negative features from itemFeat
        negFeat = np.array([self.itemFeat.get(n, np.zeros(self.embedding_size, dtype=np.float32)) for n in neg])
    
        user_tensor = np.full(self.maxlen, user, dtype=np.int64)
    
        return (user_tensor,
                seq,
                np.array([pos], dtype=np.int64),
                neg,
                seqcxt.astype(np.float32),
                np.expand_dims(poscxt, 0).astype(np.float32),
                negcxt,
                seqFeat_single.astype(np.float32),
                seqFeat_mean.astype(np.float32),
                np.expand_dims(posFeat, 0).astype(np.float32),
                negFeat.astype(np.float32))

    def __getitem__(self, idx):
        user = self.users[idx]
        return self._sample(user)


class Trainer:
    def __init__(self, model, train_loader, device, itemnum,lr=0.001, step_size=3, gamma=0.5):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.device = device
        self.itemnum = itemnum
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = StepLR(self.optimizer, step_size=step_size, gamma=gamma)#动态调整学习率

    def train(self, num_epochs):
        self.model.train()
        for epoch in range(1, num_epochs + 1):
            total_loss, total_auc, total_loss_high, total_loss_low  = 0.0, 0.0, 0.0, 0.0
            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
                batch = [torch.tensor(x).to(self.device) for x in batch]
                user, seq, pos, neg, seqcxt, poscxt, negcxt, seqFeat_single, seqFeat_mean, posFeat, negFeat = batch

                loss, auc, loss_high, loss_low = self.model(user, seq, pos, neg, seqcxt, poscxt, negcxt, seqFeat_single, seqFeat_mean, posFeat, negFeat)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                total_loss_high += loss_high.item()
                total_loss_low += loss_low.item()
                total_auc += auc.item()

            avg_loss = total_loss / len(self.train_loader)
            avg_loss_low = total_loss_low / len(self.train_loader)
            avg_loss_high = total_loss_high / len(self.train_loader)
            avg_auc = total_auc / len(self.train_loader)
            print(f"Epoch {epoch} - Loss: {avg_loss:.4f}, AUC: {avg_auc:.4f}, Loss_high: {avg_loss_high:.4f}, Loss_low: {avg_loss_low:.4f}")
            self.scheduler.step()

    def evaluate(self, model, dataset, users, itemFeat, ItemFeatures, cxtdict, maxlen, device, ground_truth, top_k=10, user_negative=None):
        model.eval()
        NDCG, HT, valid_user = 0.0, 0.0, 0.0
        for u in users:
            if len(dataset.user_train[u]) < 1 or u not in ground_truth:
                continue
            seq = np.zeros([maxlen], dtype=np.int64)
            seqcxt = np.zeros([maxlen, dataset.cxtsize], dtype=np.float32)
            seqfeat_single = np.zeros([maxlen, dataset.embedding_size], dtype=np.float32)
            seqfeat_mean = np.zeros([maxlen, dataset.embedding_size], dtype=np.float32)

            idx = maxlen - 1
            for i in reversed(dataset.user_train[u]):
                seq[idx] = i
                seqcxt[idx] = cxtdict.get((u, i), np.zeros(dataset.cxtsize))
                seqfeat_mean[idx] = np.array(itemFeat.get(i, np.zeros(dataset.embedding_size)))
                seqfeat_single[idx] = np.array(ItemFeatures.get((u, i), np.zeros(dataset.embedding_size)))
                idx -= 1
                if idx == -1:
                    break

            positive_item = ground_truth[u]

            #生成负样本的两种方式
            negatives = user_negative.get(u, [])
            

            
            item_idx = [positive_item] + negatives
            positive_context = cxtdict.get((u, positive_item), np.zeros(dataset.cxtsize))

            item_cxt = []
            for i in item_idx:
                if i == positive_item:
                    context = positive_context
                else:
                    # 负样本直接复制正样本的context
                    context = positive_context
                item_cxt.append(context)
            item_cxt = np.array(item_cxt)
            #item_cxt = np.array([cxtdict.get((u, i), np.zeros(dataset.cxtsize)) for i in item_idx])
            item_feat = np.array([itemFeat.get(i, np.zeros(dataset.embedding_size)) for i in item_idx])

            seq = torch.tensor(seq).unsqueeze(0).to(device)
            seqcxt = torch.tensor(seqcxt).unsqueeze(0).to(device)
            seqfeat_single = torch.tensor(seqfeat_single).unsqueeze(0).to(device)
            seqfeat_mean = torch.tensor(seqfeat_mean).unsqueeze(0).to(device)

            item_idx = torch.tensor(item_idx).to(device)
            item_cxt = torch.tensor(item_cxt).float().to(device)
            item_feat = torch.tensor(item_feat).float().to(device)

            with torch.no_grad():
                scores = model.predict(seq, seqcxt, seqfeat_single, seqfeat_mean, item_idx, item_cxt, item_feat)
                scores = scores.squeeze(0).detach().cpu().numpy()

                top_k_items = np.argsort(scores)[::-1][:top_k]
                if 0 in top_k_items:
                    rank = top_k_items.tolist().index(0)
                    NDCG += 1 / np.log2(rank + 2)
                    HT += 1
                valid_user += 1

        return NDCG / valid_user, HT / valid_user



## Load Data

In [2]:
import json
import pickle


cat = 'Video_Games'
root = 'datasets/' + cat + '/'
with open(root+'user_train.json', 'r') as f:
    user_train = json.load(f)       #训练集
with open(root+'user_test.json', 'r') as f:
    user_test = json.load(f)        #测试集
with open(root+'user_valid.json', 'r') as f:
    user_valid = json.load(f)       #验证集
with open(root+'itemFeat.json', 'r') as f:
    itemFeat = json.load(f)         #商品平均embedding
with open(root+'user_negative.json', 'r') as f:
    user_negative = json.load(f)    #负样本

In [3]:
with open(root+'cxtdict.pkl', 'rb') as f:
    cxtdict = pickle.load(f)        #上下文 时间embedding


In [4]:
import sys
import numpy.core as core
import numpy.core.multiarray

sys.modules['numpy._core'] = core
sys.modules['numpy._core.multiarray'] = numpy.core.multiarray

from joblib import load
ItemFeatures = load(root+'ItemFeatures.pkl')



In [5]:
usernum = len(user_train)
itemnum = len(itemFeat)

In [6]:
itemnum

16427

In [7]:
usernum

37460

## Model Training

In [8]:
class Args:
    maxlen = 10
    hidden_units = 100
    num_blocks = 2
    num_heads = 5
    dropout_rate = 0.5
    l2_emb = 0.00001
    lr = 0.001

args = Args()
embedding_size = 768
cxt_size = 6
batch_size = 128

dataset = WarpSamplerDataset(user_train, usernum, itemnum, cxtdict, cxt_size, ItemFeatures, itemFeat,user_negative, embedding_size, args.maxlen)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

model = CCA_Model(usernum, itemnum, args, embedding_size, cxt_size, use_res=True)
trainer = Trainer(model, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu',itemnum = itemnum, lr=args.lr)

round =10

for i in range(1,round+1):
    print("-------------Training round:"+str(i)+"--------------------")
    trainer.train(num_epochs=1)
    ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_test,user_negative = user_negative)
    print(f"Validation NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_test,user_negative = user_negative)
print(f"Test NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

-------------Training round:1--------------------


  batch = [torch.tensor(x).to(self.device) for x in batch]
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1: 100%|██████████| 293/293 [02:04<00:00,  2.36it/s]


Epoch 1 - Loss: 3.0501, AUC: 0.5866, Loss_high: 0.3014, Loss_low: 2.7487
Validation NDCG@10: 0.0650, Hit@10: 0.1440
-------------Training round:2--------------------


Epoch 1: 100%|██████████| 293/293 [02:04<00:00,  2.36it/s]


Epoch 1 - Loss: 1.4701, AUC: 0.6671, Loss_high: 0.2802, Loss_low: 1.1899
Validation NDCG@10: 0.1364, Hit@10: 0.2615
-------------Training round:3--------------------


Epoch 1: 100%|██████████| 293/293 [02:04<00:00,  2.36it/s]


Epoch 1 - Loss: 0.9391, AUC: 0.7059, Loss_high: 0.2707, Loss_low: 0.6684
Validation NDCG@10: 0.2023, Hit@10: 0.3288
-------------Training round:4--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.42it/s]


Epoch 1 - Loss: 0.8087, AUC: 0.7371, Loss_high: 0.2629, Loss_low: 0.5458
Validation NDCG@10: 0.2119, Hit@10: 0.3541
-------------Training round:5--------------------


Epoch 1: 100%|██████████| 293/293 [02:03<00:00,  2.38it/s]


Epoch 1 - Loss: 0.7402, AUC: 0.7519, Loss_high: 0.2596, Loss_low: 0.4806
Validation NDCG@10: 0.2169, Hit@10: 0.3622
-------------Training round:6--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.42it/s]


Epoch 1 - Loss: 0.6744, AUC: 0.7625, Loss_high: 0.2571, Loss_low: 0.4173
Validation NDCG@10: 0.2322, Hit@10: 0.3819
-------------Training round:7--------------------


Epoch 1: 100%|██████████| 293/293 [02:02<00:00,  2.39it/s]


Epoch 1 - Loss: 0.6236, AUC: 0.7806, Loss_high: 0.2524, Loss_low: 0.3712
Validation NDCG@10: 0.2384, Hit@10: 0.3942
-------------Training round:8--------------------


Epoch 1: 100%|██████████| 293/293 [02:03<00:00,  2.37it/s]


Epoch 1 - Loss: 0.5942, AUC: 0.7861, Loss_high: 0.2506, Loss_low: 0.3435
Validation NDCG@10: 0.2411, Hit@10: 0.4006
-------------Training round:9--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5704, AUC: 0.7908, Loss_high: 0.2497, Loss_low: 0.3207
Validation NDCG@10: 0.2464, Hit@10: 0.4048
-------------Training round:10--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.41it/s]


Epoch 1 - Loss: 0.5517, AUC: 0.7995, Loss_high: 0.2470, Loss_low: 0.3047
Validation NDCG@10: 0.2500, Hit@10: 0.4125
Test NDCG@10: 0.2500, Hit@10: 0.4125


In [9]:
round =40

for i in range(1,round+1):
    print("-------------Training round:"+str(i)+"--------------------")
    trainer.train(num_epochs=1)
    ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_test,user_negative = user_negative)
    print(f"Validation NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_valid,user_negative = user_negative)
print(f"Test NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

-------------Training round:1--------------------


  batch = [torch.tensor(x).to(self.device) for x in batch]
Epoch 1: 100%|██████████| 293/293 [01:58<00:00,  2.47it/s]


Epoch 1 - Loss: 0.5429, AUC: 0.8021, Loss_high: 0.2463, Loss_low: 0.2967
Validation NDCG@10: 0.2525, Hit@10: 0.4164
-------------Training round:2--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5346, AUC: 0.8050, Loss_high: 0.2457, Loss_low: 0.2889
Validation NDCG@10: 0.2539, Hit@10: 0.4196
-------------Training round:3--------------------


Epoch 1: 100%|██████████| 293/293 [01:58<00:00,  2.47it/s]


Epoch 1 - Loss: 0.5268, AUC: 0.8101, Loss_high: 0.2439, Loss_low: 0.2829
Validation NDCG@10: 0.2561, Hit@10: 0.4206
-------------Training round:4--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5232, AUC: 0.8112, Loss_high: 0.2435, Loss_low: 0.2797
Validation NDCG@10: 0.2574, Hit@10: 0.4234
-------------Training round:5--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.46it/s]


Epoch 1 - Loss: 0.5201, AUC: 0.8113, Loss_high: 0.2433, Loss_low: 0.2768
Validation NDCG@10: 0.2584, Hit@10: 0.4238
-------------Training round:6--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5150, AUC: 0.8146, Loss_high: 0.2421, Loss_low: 0.2729
Validation NDCG@10: 0.2564, Hit@10: 0.4241
-------------Training round:7--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.42it/s]


Epoch 1 - Loss: 0.5134, AUC: 0.8161, Loss_high: 0.2417, Loss_low: 0.2716
Validation NDCG@10: 0.2581, Hit@10: 0.4263
-------------Training round:8--------------------


Epoch 1: 100%|██████████| 293/293 [02:03<00:00,  2.38it/s]


Epoch 1 - Loss: 0.5129, AUC: 0.8159, Loss_high: 0.2417, Loss_low: 0.2711
Validation NDCG@10: 0.2576, Hit@10: 0.4259
-------------Training round:9--------------------


Epoch 1: 100%|██████████| 293/293 [02:02<00:00,  2.39it/s]


Epoch 1 - Loss: 0.5112, AUC: 0.8173, Loss_high: 0.2416, Loss_low: 0.2696
Validation NDCG@10: 0.2575, Hit@10: 0.4249
-------------Training round:10--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5101, AUC: 0.8179, Loss_high: 0.2413, Loss_low: 0.2687
Validation NDCG@10: 0.2582, Hit@10: 0.4277
-------------Training round:11--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5090, AUC: 0.8176, Loss_high: 0.2412, Loss_low: 0.2679
Validation NDCG@10: 0.2583, Hit@10: 0.4273
-------------Training round:12--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5078, AUC: 0.8189, Loss_high: 0.2408, Loss_low: 0.2670
Validation NDCG@10: 0.2579, Hit@10: 0.4270
-------------Training round:13--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.41it/s]


Epoch 1 - Loss: 0.5082, AUC: 0.8181, Loss_high: 0.2411, Loss_low: 0.2671
Validation NDCG@10: 0.2587, Hit@10: 0.4280
-------------Training round:14--------------------


Epoch 1: 100%|██████████| 293/293 [02:02<00:00,  2.38it/s]


Epoch 1 - Loss: 0.5073, AUC: 0.8188, Loss_high: 0.2408, Loss_low: 0.2665
Validation NDCG@10: 0.2589, Hit@10: 0.4273
-------------Training round:15--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.41it/s]


Epoch 1 - Loss: 0.5072, AUC: 0.8185, Loss_high: 0.2410, Loss_low: 0.2662
Validation NDCG@10: 0.2595, Hit@10: 0.4280
-------------Training round:16--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.42it/s]


Epoch 1 - Loss: 0.5065, AUC: 0.8190, Loss_high: 0.2409, Loss_low: 0.2656
Validation NDCG@10: 0.2585, Hit@10: 0.4270
-------------Training round:17--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5056, AUC: 0.8194, Loss_high: 0.2404, Loss_low: 0.2652
Validation NDCG@10: 0.2584, Hit@10: 0.4284
-------------Training round:18--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5064, AUC: 0.8193, Loss_high: 0.2407, Loss_low: 0.2657
Validation NDCG@10: 0.2584, Hit@10: 0.4277
-------------Training round:19--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5058, AUC: 0.8200, Loss_high: 0.2403, Loss_low: 0.2655
Validation NDCG@10: 0.2583, Hit@10: 0.4277
-------------Training round:20--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.42it/s]


Epoch 1 - Loss: 0.5068, AUC: 0.8197, Loss_high: 0.2408, Loss_low: 0.2660
Validation NDCG@10: 0.2588, Hit@10: 0.4284
-------------Training round:21--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.46it/s]


Epoch 1 - Loss: 0.5062, AUC: 0.8195, Loss_high: 0.2407, Loss_low: 0.2655
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:22--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5061, AUC: 0.8193, Loss_high: 0.2407, Loss_low: 0.2654
Validation NDCG@10: 0.2588, Hit@10: 0.4287
-------------Training round:23--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5059, AUC: 0.8196, Loss_high: 0.2406, Loss_low: 0.2653
Validation NDCG@10: 0.2589, Hit@10: 0.4287
-------------Training round:24--------------------


Epoch 1: 100%|██████████| 293/293 [01:58<00:00,  2.47it/s]


Epoch 1 - Loss: 0.5060, AUC: 0.8202, Loss_high: 0.2406, Loss_low: 0.2655
Validation NDCG@10: 0.2589, Hit@10: 0.4291
-------------Training round:25--------------------


Epoch 1: 100%|██████████| 293/293 [01:58<00:00,  2.47it/s]


Epoch 1 - Loss: 0.5064, AUC: 0.8191, Loss_high: 0.2408, Loss_low: 0.2656
Validation NDCG@10: 0.2589, Hit@10: 0.4291
-------------Training round:26--------------------


Epoch 1: 100%|██████████| 293/293 [01:58<00:00,  2.47it/s]


Epoch 1 - Loss: 0.5056, AUC: 0.8196, Loss_high: 0.2405, Loss_low: 0.2651
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:27--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5068, AUC: 0.8188, Loss_high: 0.2409, Loss_low: 0.2658
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:28--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5059, AUC: 0.8193, Loss_high: 0.2406, Loss_low: 0.2653
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:29--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5064, AUC: 0.8195, Loss_high: 0.2409, Loss_low: 0.2656
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:30--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.45it/s]


Epoch 1 - Loss: 0.5065, AUC: 0.8195, Loss_high: 0.2409, Loss_low: 0.2656
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:31--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5060, AUC: 0.8194, Loss_high: 0.2406, Loss_low: 0.2654
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:32--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5057, AUC: 0.8197, Loss_high: 0.2405, Loss_low: 0.2652
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:33--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.46it/s]


Epoch 1 - Loss: 0.5062, AUC: 0.8194, Loss_high: 0.2407, Loss_low: 0.2655
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:34--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.43it/s]


Epoch 1 - Loss: 0.5058, AUC: 0.8204, Loss_high: 0.2406, Loss_low: 0.2653
Validation NDCG@10: 0.2587, Hit@10: 0.4287
-------------Training round:35--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5058, AUC: 0.8197, Loss_high: 0.2405, Loss_low: 0.2653
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:36--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.42it/s]


Epoch 1 - Loss: 0.5056, AUC: 0.8191, Loss_high: 0.2404, Loss_low: 0.2652
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:37--------------------


Epoch 1: 100%|██████████| 293/293 [01:59<00:00,  2.46it/s]


Epoch 1 - Loss: 0.5054, AUC: 0.8194, Loss_high: 0.2406, Loss_low: 0.2648
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:38--------------------


Epoch 1: 100%|██████████| 293/293 [02:00<00:00,  2.44it/s]


Epoch 1 - Loss: 0.5054, AUC: 0.8198, Loss_high: 0.2404, Loss_low: 0.2650
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:39--------------------


Epoch 1: 100%|██████████| 293/293 [02:02<00:00,  2.40it/s]


Epoch 1 - Loss: 0.5055, AUC: 0.8201, Loss_high: 0.2404, Loss_low: 0.2651
Validation NDCG@10: 0.2588, Hit@10: 0.4291
-------------Training round:40--------------------


Epoch 1: 100%|██████████| 293/293 [02:01<00:00,  2.41it/s]


Epoch 1 - Loss: 0.5055, AUC: 0.8194, Loss_high: 0.2404, Loss_low: 0.2651
Validation NDCG@10: 0.2588, Hit@10: 0.4291
Test NDCG@10: 0.2579, Hit@10: 0.4200


In [None]:
round =100

for i in range(1,round+1):
    print("-------------Training round:"+str(i)+"--------------------")
    trainer.train(num_epochs=1)
    ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_test,user_negative = user_negative)
    print(f"Validation NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_valid,user_negative = user_negative)
print(f"Test NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

In [None]:
trainer = Trainer(model, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu', lr=0.0001)
round =100

for i in range(1,round+1):
    print("-------------Training round:"+str(i)+"--------------------")
    trainer.train(num_epochs=1)
    ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_test,user_negative = user_negative)
    print(f"Validation NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

ndcg, hit = trainer.evaluate(model, dataset, list(user_train.keys()), itemFeat, ItemFeatures, cxtdict, args.maxlen, device='cuda' if torch.cuda.is_available() else 'cpu',ground_truth=user_valid,user_negative = user_negative)
print(f"Test NDCG@10: {ndcg:.4f}, Hit@10: {hit:.4f}")

In [None]:
Validation NDCG@10: 0.0808, Hit@10: 0.1562