In [1]:
import torch
from torch import nn,Tensor,LongTensor
import torch_geometric as pyg
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import numpy as np
import torch_geometric.utils as pyg_utils
from torch_geometric.utils import degree
from torch.utils.data import Dataset,DataLoader
from torch_geometric.nn.conv import MessagePassing
from torch_sparse import SparseTensor,matmul
from scipy.sparse import csr_matrix
import torch.nn as nn

ImportError: cannot import name 'Tensor' from 'torch' (unknown location)

# Parameter settings(参数设置)

In [None]:
def cprint(words: str):
    print(f"\033[0;30;43m{words}\033[0m")

def bprint(words:str):
    print(f"\033[0;30;45m{words}\033[0m")

GPU = torch.cuda.is_available()
device = torch.device('cuda' if GPU else "cpu")
embedding_size = 64
batch_size = 4096
test_batch_size = 1024
lr = 1e-3
reg_weight = 1e-4
layer = 3

# Data Processing(加载数据)

In [None]:
class Loader(Dataset):
    """
    Loading data from datasets
    already supportted:['gowalla','amazon-book','yelp2018','lastfm']
    """
    def __init__(self,path='./data/'):
        dir_path = path + 'gowalla'
        cprint(f'loading from {dir_path}')
        train_file = dir_path + '/train.txt'
        test_file = dir_path + '/test.txt'
        train_users,train_items = [],[]
        test_users,test_items = [],[]
        train_edge_index,test_edge_index = [],[]
        with open(train_file) as f:
            for l in f.readlines():
                if len(l) > 0:
                    all = l.strip('\n').split(' ')
                    uid = int(all[0])
                    val = int(len(all) * 0.8)
                    items = [int(i) for i in all[1:val]]
                    for item in items:
                        train_edge_index.append([uid,item])
                    train_users.extend([uid] * len(items))
                    train_items.extend(items)

        with open(test_file) as f:
            for l in f.readlines():
                if len(l) > 0:
                    all = l.strip('\n').split(' ')
                    uid = int(all[0])
                    try:
                        items = [int(i) for i in all[1:]]
                    except Exception:
                        continue
                    for item in items:
                        test_edge_index.append([uid,item])
                    test_users.extend([uid] * len(items))
                    test_items.extend(items)
        


        train_edge_index = torch.LongTensor(np.array(train_edge_index).T)
        test_edge_index = torch.LongTensor(np.array(test_edge_index).T)
        edge_index = torch.cat((train_edge_index,test_edge_index),1)
        self.edge_index = edge_index
        num_users = len(torch.unique(edge_index[0]))
        num_items = len(torch.unique(edge_index[1]))
        mask = torch.zeros(num_users,num_items)
        self.num_users = num_users
        self.num_items = num_items
        self.train_edge_index = train_edge_index
        self.test_edge_index = test_edge_index
        self.UserItemNet = csr_matrix((np.ones(len(self.edge_index[0])), (self.edge_index[0].numpy(), self.edge_index[1].numpy())),
                                      shape=(self.num_users, self.num_items))
        # self.test_edge_index = test_edge_index
        self.bipartite_graph = self.getSparseBipartite()
        self.adj_mat = self.getSparseGraph()

        self.train_loader = DataLoader(
            range(self.train_edge_index.size(1)),
            shuffle=True,
            batch_size=4096
        )
        self.test_loader = DataLoader(
            list(range(num_users)),batch_size=1024,shuffle=False,num_workers=5
        )
        test_ground_truth_list = [[] for _ in range(num_users)]
        for i in range(len(test_items)):
            test_ground_truth_list[test_users[i]].append(test_items[i])
        for i in range(len(train_items)):
            mask[train_users[i]][train_items[i]] = -np.inf
        self.test_ground_truth_list = test_ground_truth_list
        self.mask = mask
    '''
    A = |0   R|
        |R^T 0|
    R : user-item bipartite graph
    A : unnormalized Adjacency Matrix
    '''
    def getSparseGraph(self):
        cprint("generate Adjacency Matrix A")
        user_index = self.train_edge_index[0]
        item_index = self.train_edge_index[1]
        row_index = torch.cat([user_index,item_index+self.num_users])
        col_index = torch.cat([item_index+self.num_users,user_index])
        return SparseTensor(row=row_index,col=col_index,sparse_sizes=(self.num_items+self.num_users,self.num_items+self.num_users))

    def getSparseBipartite(self):
        user_index = self.train_edge_index[0]
        item_index = self.train_edge_index[1]
        return SparseTensor(row=user_index,col=item_index,sparse_sizes=(self.num_users,self.num_items))
    
    def get_user_all_interacted(self,users):
        users = users.detach().cpu().numpy()
        posItems = []
        for user in users:
            posItems.append(self.UserItemNet[user].nonzero()[1])
        return posItems



# Define Graph Mode(定义图模型)

In [None]:
class RecModel(MessagePassing):
    def __init__(self,
                 num_users:int,
                 num_items:int,
                 edge_index:LongTensor):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_nodes = num_users + num_items
        self.f = nn.Sigmoid()


    def get_sparse_graph(self,
                         edge_index,
                         use_value=False,
                         value=None):
        num_users = self.num_users
        num_nodes = self.num_nodes
        r,c = edge_index
        row = torch.cat([r , c + num_users])
        col = torch.cat([c + num_users , r])
        if use_value:
            value = torch.cat([value,value])
            return SparseTensor(row=row,col=col,value=value,sparse_sizes=(num_nodes,num_nodes))
        else:
            return SparseTensor(row=row,col=col,sparse_sizes=(num_nodes,num_nodes))
    
    def get_embedding(self):
        raise NotImplementedError
    
    def forward(self,
                edge_label_index:Tensor):
        out = self.get_embedding()
        out_u,out_i = torch.split(out,[self.num_users,self.num_items])
        out_src = out_u[edge_label_index[0]]
        out_dst = out_i[edge_label_index[1]]
        out_dst_neg = out_i[edge_label_index[2]]
        return (out_src * out_dst).sum(dim=-1),(out_src * out_dst_neg).sum(dim=-1)
    
    def link_prediction(self,
                        src_index:Tensor=None,
                        dst_index:Tensor=None):
        out = self.get_embedding()
        out_u,out_i = torch.split(out,[self.num_users,self.num_items])
        if src_index is None:
            src_index = torch.arange(self.num_users).long()
        if dst_index is None:
            dst_index = torch.arange(self.num_items).long()
        out_src = out_u[src_index]
        out_dst = out_i[dst_index]
        pred = out_src @ out_dst.t()
        return pred.sigmoid()
    
    def recommendation_loss(self,
                            pos_rank,
                            neg_rank,
                            edge_label_index):
        rec_loss = torch.nn.functional.softplus(neg_rank - pos_rank).mean()
        user_emb = self.user_emb.weight
        item_emb = self.item_emb.weight
        embedding = torch.cat([user_emb[edge_label_index[0]],
                               item_emb[edge_label_index[1]],
                               item_emb[edge_label_index[2]]])
        regularization = reg_weight * (1/2) * embedding.norm(p=2).pow(2)
        regularization = regularization / pos_rank.size(0)
        return rec_loss , regularization
    
    def message(self, x_j: Tensor) -> Tensor:
        return x_j
    
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t,x)

In [None]:
class LightGCN(RecModel):
    def __init__(self,
                 num_users:int,
                 num_items:int,
                 edge_index:LongTensor):
        super().__init__(
            num_users=num_users,
            num_items=num_items,
            edge_index=edge_index
        )
        self.user_emb = nn.Embedding(num_embeddings=num_users,
                                     embedding_dim=embedding_size)
        self.item_emb = nn.Embedding(num_embeddings=num_items,
                                     embedding_dim=embedding_size)
        nn.init.normal_(self.user_emb.weight,std=0.1)
        nn.init.normal_(self.item_emb.weight,std=0.1)
        self.K = 3
        edge_index = self.get_sparse_graph(edge_index=edge_index,use_value=False,value=None)
        self.edge_index = gcn_norm(edge_index)
        self.alpha= 1./ (1 + self.K)
        if isinstance(self.alpha, Tensor):
            assert self.alpha.size(0) == self.K + 1
        else:
            self.alpha = torch.tensor([self.alpha] * (self.K + 1))
        print('Go LightGCN')
        print(f"params settings: \n emb_size:{embedding_size}\n L2 reg:{reg_weight}\n layer:{self.K}")

    def get_embedding(self):
        x_u=self.user_emb.weight
        x_i=self.item_emb.weight
        x=torch.cat([x_u,x_i])
        out = x * self.alpha[0]
        for i in range(self.K):
            x = self.propagate(edge_index=self.edge_index,x=x)
            out = out + x * self.alpha[i + 1]
        return out

    def instance_loss(self,edge_label_index):
        out = self.get_embedding()
        users,items = torch.split(out,[self.num_users,self.num_items])
        user_emb = users[edge_label_index[0]]
        item_pos = items[edge_label_index[1]]
        item_neg = items[edge_label_index[2]]
        return ((user_emb * item_pos).sum(dim=-1) - (user_emb * item_neg).sum(dim=-1)).sigmoid()

# Evaluation Metrics(模型评估指标)

In [None]:
@torch.no_grad()
def test(k_values:list,
         model,
         train_edge_index,
         test_edge_index,
         num_users,
         ):
    model.eval()
    recall = {k: 0 for k in k_values}
    ndcg = {k: 0 for k in k_values}
    total_examples = 0
    for start in range(0, num_users, 1024):
        end = start + 1024
        if end > num_users:
            end = num_users
        src_index=torch.arange(start,end).long().to(device)
        logits = model.link_prediction(src_index=src_index,dst_index=None)

        # Exclude training edges:
        mask = ((train_edge_index[0] >= start) &
                (train_edge_index[0] < end))
        masked_interactions = train_edge_index[:,mask]
        logits[masked_interactions[0] - start,masked_interactions[1]] = float('-inf')
        # Generate ground truth matrix
        ground_truth = torch.zeros_like(logits, dtype=torch.bool)
        mask = ((test_edge_index[0] >= start) &
                (test_edge_index[0] < end))
        masked_interactions = test_edge_index[:,mask]
        ground_truth[masked_interactions[0] - start,masked_interactions[1]] = True
        node_count = degree(test_edge_index[0, mask] - start,
                            num_nodes=logits.size(0))
        topk_indices = logits.topk(max(k_values),dim=-1).indices
        for k in k_values:
            topk_index = topk_indices[:,:k]
            isin_mat = ground_truth.gather(1, topk_index)
            # Calculate recall
            recall[k] += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())
            # Calculate NDCG
            log_positions = torch.log2(torch.arange(2, k + 2, device=logits.device).float())
            dcg = (isin_mat / log_positions).sum(dim=-1)
            ideal_dcg = torch.zeros_like(dcg)
            for i in range(len(dcg)):
                ideal_dcg[i] = (1.0 / log_positions[:node_count[i].clamp(0, k).int()]).sum()
            ndcg[k] += float((dcg / ideal_dcg.clamp(min=1e-6)).sum())

        total_examples += int((node_count > 0).sum())

    recall = {k: recall[k] / total_examples for k in k_values}
    ndcg = {k: ndcg[k] / total_examples for k in k_values}

    return recall,ndcg


def Fast_Sampling(dataset):
    """
    A more efficient sampler with simplified negative sampling
    easy to overfit on raw GNN model
    """
    train_edge_index = dataset.train_edge_index.to(device)
    num_items = dataset.num_items
    mini_batch = []
    train_loader = DataLoader(
            range(train_edge_index.size(1)),
            shuffle=True,
            batch_size=batch_size)
    for index in train_loader:
        pos_edge_label_index = train_edge_index[:,index]
        neg_edge_label_index = torch.randint(0, num_items,(index.numel(), ), device=device)
        edge_label_index = torch.stack([
            pos_edge_label_index[0],
            pos_edge_label_index[1],
            neg_edge_label_index,
        ])
        mini_batch.append(edge_label_index)
    return mini_batch    

# Sampling and Training(采样&训练)

In [None]:
def train_bpr(dataset,
                  model:LightGCN,
                  opt):
    model = model
    model.train()
    S = Fast_Sampling(dataset=dataset)
    aver_loss = 0.
    total_batch = len(S)
    for edge_label_index in S:
        pos_rank,neg_rank = model(edge_label_index)
        bpr_loss,L2_reg = model.recommendation_loss(pos_rank,neg_rank,edge_label_index)
        loss = bpr_loss + L2_reg
        opt.zero_grad()
        loss.backward()
        opt.step()    
        aver_loss += (bpr_loss + L2_reg)
    aver_loss /= total_batch
    return f"average loss {aver_loss:5f}"

In [None]:
dataset = Loader()
train_edge_index = dataset.train_edge_index.to(device)
test_edge_index = dataset.test_edge_index.to(device)
num_users = dataset.num_users
num_items = dataset.num_items
model = LightGCN(num_users=num_users,
                 num_items=num_items,
                 edge_index=train_edge_index).to(device)
opt = torch.optim.Adam(params=model.parameters(),lr=lr)
best = 0.
patience = 0.
max_score = 0.
for epoch in range(1, 1001):
    loss = train_bpr(dataset=dataset,model=model,opt=opt)
    recall,ndcg = test([20,50],model,train_edge_index,test_edge_index,num_users)
    print(f'Epoch: {epoch:03d}, {loss}, R@20: '
          f'{recall[20]:.4f}, R@50: {recall[50]:.4f} '
          f', N@20: {ndcg[20]:.4f}, N@50: {ndcg[50]:.4f}')
    if epoch % 5 == 0:
        print(model.link_prediction(torch.arange(5).to(device),torch.arange(5).to(device)))