In [None]:
import os
import random
import tensorflow.compat.v1 as tf
import numpy as np

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"]="1"

def set_seeds(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.set_random_seed(seed)
    np.random.seed(seed)
def set_global_determinism(seed):
    set_seeds(seed=seed)

    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)

# Call the above function with seed value
set_global_determinism(seed=42)

from hyperopt import hp
import pandas as pd
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
import scipy.sparse as sp
from scipy.sparse.csgraph import laplacian
from scipy.sparse.linalg import svds

import networkx as nx
from datasets import Dataset
import layers
import metrics
import time

print("tf.__version__: {}".format(tf.__version__) )

In [None]:
class Model(object):
    def __init__(self, session, dataset, **config):
        # tensorflow Session object
        self.session = session
        # datasets.Dataset object
        self.dataset = dataset
        # dict of hyperparameters, etc.
        self.config = config

        # inputs
        self.id = tf.placeholder(tf.int64, [None])
        self.user_id = tf.placeholder(tf.int64, [None])
        self.item_id= tf.placeholder(tf.int64, [None])
        # labels
        self.r_true = tf.placeholder(tf.float32, [None])
        # initialize graph
        self.user_mask = self._mask(self.dataset.side_info.get('user_graph', None))
        self.item_mask = self._mask(self.dataset.side_info.get('item_graph', None))
        rating_threshold = 3.0
        self.user_item_mask = self._user_item_mask(self.dataset.data, rating_threshold)
        
        # define model parameters
        self.weights, self.biases, self.user_factor_pp_dense, self.user_feature_dense, self.item_factor_pp_dense, self.item_feature_dense, self.user_factor_in_ui_dense, self.item_factor_in_ui_dense = self._params()
        # define rating computation and scale to dataset range
        self.r_pred = self.dataset.min + self.dataset.range*tf.sigmoid(self._r_pred())
        # define loss
        self.loss = self._loss()
        # define rmse metric for update monitoring
        self.rmse = tf.reduce_mean((self.r_pred - self.r_true)**2)**0.5
        
        # Adam optimization with default learning rate
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)
        
        # acc
        self.r_pred_round = tf.round(self.r_pred)
        self.correct_prediction = tf.equal(self.r_pred_round, self.r_true)
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))


    def _params(self):
        weights, biases = {}, {}
        
        user_factor_pp_dense = None
        user_feature_dense = None
        
        item_factor_pp_dense = None
        item_feature_dense = None
        
        user_factor_in_ui_dense = None
        item_factor_in_ui_dense = None
        
        return weights, biases, user_factor_pp_dense, user_feature_dense, item_factor_pp_dense, item_feature_dense, user_factor_in_ui_dense, item_factor_in_ui_dense

    def _r_pred(self):
        raise NotImplementedError()
    
    def _local_pred(self):
        raise NotImplementedError()
    
    def _global_local_contra_loss(self):
        raise NotImplementedError()
        
    def _global_local_mse_loss(self):
        raise NotImplementedError()

    def _loss(self):
        #recommend task related loss
        self.mse = tf.reduce_mean((self.r_pred - self.r_true)**2)
        self.infonce_1 = self.infonce_loss(self.user_factor, self.item_factor, self.r_true,)
        
        self.mse_1 = tf.reduce_mean((self.user_factor_pp_dense - self.user_feature_dense)**2)
        self.mse_2 = tf.reduce_mean((self.item_factor_pp_dense - self.item_feature_dense)**2)
        self.mse_3 = tf.reduce_mean((self.user_factor_in_ui_dense - self.user_feature_dense)**2)
        self.mse_4 = tf.reduce_mean((self.item_factor_in_ui_dense - self.item_feature_dense)**2)
        
        self.user_content_contra_loss, self.item_content_contra_loss, self.user_action_contra_loss, self.item_action_contra_loss = self._global_local_contra_loss()
        self.global_local_contra_loss = self.user_content_contra_loss+self.item_content_contra_loss+self.user_action_contra_loss+self.item_action_contra_loss
        
        # KL divergence loss
        self.global_pred = tf.sigmoid(self._r_pred()) 
        self.local_pred = tf.sigmoid(self._local_pred()) 
        self.global_pred_reshape = tf.reshape(self.global_pred, [-1, 1])
        self.local_pred_reshape = tf.reshape(self.local_pred, [-1, 1])
        self.global_pred_all = tf.concat([self.global_pred_reshape, 1 - self.global_pred_reshape], axis = 1)
        self.local_pred_all = tf.concat([self.local_pred_reshape, 1 - self.local_pred_reshape], axis = 1)
        self.KL = tf.reduce_mean(
                    tf.reduce_sum(self.global_pred_all * tf.math.log(self.global_pred_all / self.local_pred_all),
                                  keepdims=True, axis=1)
                  )

        self.reg = self._reg()
        self.alpha = float(self.config.get('alpha', 0))
        KL_loss_weight = 0.001
        infonce_loss_weight = 0.0001

        return (1 - self.alpha)*self.mse/self.dataset.range**2 + \
                        self.alpha*self.reg + self.mse_1 + self.mse_2 + self.mse_3 + self.mse_4 + \
                        infonce_loss_weight*self.global_local_contra_loss + self.infonce_1


    def _reg(self):
        user_graph = self.dataset.side_info.get('user_graph', None)
        item_graph = self.dataset.side_info.get('item_graph', None)
        
        self.reg_l2, self.reg_graph = 0, 0
        for w in self.weights.values():
            self.reg_l2 += tf.reduce_sum(w**2)
            if w.shape[0] == self.dataset.n_user:
                self.reg_graph += self._graph_reg(user_graph, w)
            elif w.shape[0] == self.dataset.n_item:
                self.reg_graph += self._graph_reg(item_graph, w)
            else:
                self.reg_graph += self._graph_reg(None, w)
        
        self.beta = float(self.config.get('beta', 0))
        return (1 - self.beta)*self.reg_l2 + self.beta*self.reg_graph

    def _graph_reg(self, g, w):
        if g is None:
            # if no graph is provided, use L2 reg
            return tf.reduce_sum(w**2)
        if len(w.shape) == 1:
            w = tf.reshape(w, (-1, 1))
        normed = self.config.get('normed', False)
        if self.config.get('sparse', True):
            s = laplacian(sp.coo_matrix(g), normed=normed).astype(np.float32)
            s = tf.sparse.reorder(tf.SparseTensor(np.array([s.row, s.col]).T, s.data, s.shape))
            return tf.linalg.trace(tf.matmul(w, tf.sparse.matmul(s, w), transpose_a=True))
        s = tf.constant(laplacian(g.A if sp.issparse(g) else g, normed=normed), dtype=tf.float32)
        return tf.linalg.trace(tf.matmul(w, tf.matmul(s, w), transpose_a=True))

    def run(self, ops, batch):
        return self.session.run(ops, {
            self.id: batch.index,
            self.user_id: batch.user_id,
            self.item_id: batch.item_id,
            self.r_true: batch.rating,
        })

    def train(self, max_updates=100000, n_check=100, patience=float('inf'), batch_size=None):
        self.session.run(tf.global_variables_initializer())
        best = {'updates': 0, 'loss': float('inf'), 'rmse_tune': float('inf'), 'acc_tune': float('inf'), 'rmse_test': float('inf'), 'acc_test': float('inf')}
        for i in range(max_updates):
            opt, loss, accuracy = self.run([self.opt, self.loss, self.accuracy], self.dataset.get_batch(mode='train', size=batch_size))
            if i % n_check == 0 or i == max_updates - 1:
                # monitoring
                rmse_tune, acc_tune = self.run([self.rmse,self.accuracy], self.dataset.get_batch(mode='tune', size=102400))
                if len(self.dataset.tune) == 0 or rmse_tune < best['rmse_tune']:
                    rmse_test, acc_test = self.run([self.rmse, self.accuracy], self.dataset.get_batch(mode='test', size=102400))
                    best = {'updates': i, 'loss': loss, 'rmse_tune': rmse_tune, 'acc_tune': acc_tune, 'rmse_test': rmse_test, "acc_test": acc_test}
                print(best)
                if (i - best['updates'])//n_check > patience:
                    # early stopping
                    break
        return best

    def test_infer(self):
        print("\n\nTESTING...")
        time_start=time.time()
        if self.dataset.name in ["MovieLens1M", "KuaiRandPure"]:
            print("batch infer")
            self.dataset.index['test'] = 0
            test_pred = []
            while True:
                tmp_test_dataset = self.dataset.get_batch(mode='test', size=102400)
                tmp_test_pred = self.run(self.r_pred, tmp_test_dataset)
                test_pred.append( tmp_test_pred )
                if self.dataset.index['test']==0:
                    break
                
            test_pred = np.hstack(test_pred)
        else:
            test_dataset = self.dataset.get_batch(mode='test', size=None)
            test_pred = self.run(self.r_pred, test_dataset)
        test_dataset = self.dataset.get_batch(mode='test', size=None)
        r_true, user_ids, item_ids = test_dataset[['rating', 'user_id', 'item_id']].values.T
        print('\ttime cost',time.time()-time_start,'s')
        return {'test_pred': test_pred, 'r_true': r_true, 'user_ids': user_ids, 'item_ids': item_ids}


    def get_metrics(self, r_pred):
        print("\n\nEVALUATING...")
        time_start=time.time()
        r_pred, r_true, user_ids, item_ids = r_pred['test_pred'], r_pred['r_true'], r_pred['user_ids'], r_pred['item_ids']

        _metric1 = metrics.metric_at_once(r_pred, r_true, user_ids, item_ids, k=10)
        _metric2 = metrics.metric_at_once(r_pred, r_true, user_ids, item_ids, k=20)
        dddall = {**_metric1, **_metric2}
        
        return_res = {
            "P@10":dddall['Precision@10'],"P@20":dddall['Precision@20'],
            "R@10":dddall['Recall@10'],"R@20":dddall['Recall@20'],
            "H@10":dddall['Hitrate@10'],"H@20":dddall['Hitrate@20'],
            "N@10":dddall['NDCG@10'],"N@20":dddall['NDCG@20']
                     }
        return_res = {k:round(v,6) for k,v in return_res.items()}
        
        print('\ttime cost',time.time()-time_start,'s')
        return return_res


    def _schema(self):
        return {
            'weights': self.weights,
            'biases': self.biases,
        }

    def _mask(self, g):
        # define the mask used for GAT
        if g is None:
            # no masking
            return 1.0
        if self.config.get('sparse', True):
            shape = g.shape
            g = sp.coo_matrix(g, shape=shape, dtype=np.float32)
            g = tf.sparse.reorder(tf.SparseTensor(np.array([g.row, g.col]).T, g.data, shape))
            return tf.sparse.add(tf.sparse.eye(*shape), g)
        return tf.eye(*g.shape) + tf.constant(g.A if sp.issparse(g) else g, dtype=tf.float32)
    
    def _user_item_mask(self, input_df, rating_threshold):
        ratings = input_df[["user_id","item_id","rating"]]
        # WARNING: if the number of user is greater than 10000, you should modify the following code:
        ratings_itemid = ratings["item_id"] + 10000
        ratings_userid = ratings["user_id"]
        ratings.loc[:, "item_id"] = ratings_itemid
        ratings.loc[:, "user_id"] = ratings_userid
        G = nx.Graph()

        print("\n\nGRAPH BUILDING")
        print("\tuserNum: {} itemNum: {}".format(len(set(ratings["user_id"])), len(set(ratings["item_id"]))))
        if self.dataset.diff_tag_n_user:
            print("\tbuild gragh: fix user num from {} to {}".format(len(set(ratings["user_id"])), self.dataset.n_user))
            G.add_nodes_from(list(range(self.dataset.n_user)), bipartite=0)
        else:
            G.add_nodes_from(ratings["user_id"], bipartite=0)
            
        if self.dataset.diff_tag_n_item:
            print("\tbuild gragh: fix item num from {} to {}".format(len(set(ratings["item_id"])), self.dataset.n_item))
            G.add_nodes_from(list(range(10000, self.dataset.n_item + 10000)), bipartite=0)
        else:
            G.add_nodes_from(ratings["item_id"], bipartite=1)
        
        # set threshold
        valid_df = ratings[ratings["rating"] > rating_threshold][["user_id", "item_id"]]
        edge_list = valid_df.apply(lambda x: tuple(x), axis=1).values.tolist()

        G.add_edges_from(edge_list)

        A = np.array(nx.adjacency_matrix(G, nodelist=list(G.nodes).sort()).todense())
        return A

    def _normalized_aggregation(self, g, w):
        if g is None:
            return 0.0
        if self.config.get('sparse', True):
            g = sp.coo_matrix(g, dtype=np.float32)
            g = tf.sparse.reorder(tf.SparseTensor(np.array([g.row, g.col]).T, g.data, g.shape))
            return tf.sparse.matmul(g, w)/(tf.sparse.reduce_sum(g, axis=1, keepdims=True)**0.5 + 1e-10)
        g = tf.constant(g.A if sp.issparse(g) else g, dtype=tf.float32)
        return tf.matmul(g, w)/(tf.reduce_sum(g, axis=1, keepdims=True)**0.5 + 1e-10)




In [None]:
class GMGCL(Model):

    def _r_pred(self):
        return (tf.reduce_sum(self.user_factor*self.item_factor, 1)
            + self.user_bias + self.item_bias + self.bias)
    
    def infonce_loss(self, user_emb, item_emb, r_true, rou = 1, SCORE_THRESHOLD=3.0):
        user_emb = tf.math.l2_normalize(user_emb, axis=-1)  # [n, dim]
        item_emb = tf.math.l2_normalize(item_emb, axis=-1)  # [n, dim]
        
        pos_inner_product = tf.math.reduce_sum(user_emb * item_emb, keepdims=True, axis=1)
        numerator = tf.math.exp(pos_inner_product / rou)
        # negative sample
        negative_score_mask = tf.where(r_true>SCORE_THRESHOLD, tf.ones_like(numerator), tf.zeros_like(numerator))
        numerator = numerator * negative_score_mask

        all_inner_product = tf.matmul(user_emb, item_emb, transpose_b=True)

        eye_matrix = (1-tf.eye(tf.shape(user_emb)[0]))
        positive_score_mask  = tf.where(r_true<=SCORE_THRESHOLD, tf.ones_like(numerator), tf.zeros_like(numerator))
        positive_score_mask = tf.math.logical_or(tf.cast(eye_matrix, tf.bool), tf.cast(positive_score_mask, tf.bool))
        positive_score_mask = tf.cast(positive_score_mask, tf.float32)

        denominator_tmp_local = positive_score_mask * tf.math.exp(all_inner_product / rou)
        denominator = tf.math.reduce_sum(denominator_tmp_local, keepdims=True, axis=1)
        
        return tf.reduce_mean(-tf.log(numerator / denominator+0.0001))


    def _params(self):
        # basic params
        self.rank = int(self.config.get('rank', 1))
        weights = {
            'user_factor': tf.get_variable('user_factor', [self.dataset.n_user, self.rank]),
            'item_factor': tf.get_variable('item_factor', [self.dataset.n_item, self.rank]),
            'user_bias': tf.Variable(tf.zeros([self.dataset.n_user])),
            'item_bias': tf.Variable(tf.zeros([self.dataset.n_item])),
        }
        biases = {'bias': tf.Variable(0.0)}
        self.user_factor = tf.nn.embedding_lookup(weights['user_factor'], self.user_id)
        self.item_factor = tf.nn.embedding_lookup(weights['item_factor'], self.item_id)
        self.user_bias = tf.nn.embedding_lookup(weights['user_bias'], self.user_id)
        self.item_bias = tf.nn.embedding_lookup(weights['item_bias'], self.item_id)
        self.bias = biases['bias']
        
    
        data, shape = self.dataset.data, (self.dataset.n_user, self.dataset.n_item)
        implicit = sp.coo_matrix((data.is_train, (data.user_id, data.item_id)), shape=shape)

        self.k = int(self.config.get('k', 1))
        u, s, vt = svds(implicit.astype(float), k=self.k)
        
        self.user_features = tf.constant(u*s**0.5, dtype=tf.float32)
        self.item_features = tf.constant(vt.T*s**0.5, dtype=tf.float32)

        user_features = self.dataset.side_info.get('user_features', None)
        item_features = self.dataset.side_info.get('item_features', None)
         
        
        if user_features is not None:
            self.user_features = tf.concat(
                [self.user_features, tf.constant(user_features.values, dtype=tf.float32)], 1)
        if item_features is not None:
            self.item_features = tf.concat(
                [self.item_features, tf.constant(item_features.values, dtype=tf.float32)], 1)
            
        user_feature_shape_dim = self.user_features.shape[-1]
        item_feature_shape_dim = self.item_features.shape[-1]

        self.n_head = int(self.config.get('n_head', 1))
        self.activation_in = tf.keras.activations.get(self.config.get('activation_in', 'softsign'))
        self.activation_out = tf.keras.activations.get(self.config.get('activation_out', 'hard_sigmoid'))
        self.residual = bool(self.config.get('residual', True))
        self.user_in = layers.GAT(
            self.user_features.shape[1], self.n_head*self.rank, 1, concat=False, residual=self.residual, name='user_in')
        self.user_out = layers.Dense(self.user_in.dim_out, self.rank, name='user_out')

        self.user_pp_dense = layers.Dense(self.rank, self.rank, name='user_pp_dense')
        self.user_feat_dense = layers.Dense(user_feature_shape_dim, self.rank, name='user_feat_dense')
     
        self.item_pp_dense = layers.Dense(self.rank, self.rank, name='item_pp_dense')
        self.item_feat_dense = layers.Dense(item_feature_shape_dim, self.rank, name='item_feat_dense')

        
        self.item_in = layers.GAT(
            self.item_features.shape[1], self.n_head*self.rank, 1, concat=False, residual=self.residual, name='item_in')
        self.item_out = layers.Dense(self.item_in.dim_out, self.rank, name='item_out')
        
        self.user_feat_in = layers.Dense(user_feature_shape_dim, self.rank, name='user_feat_in')
        self.item_feat_in = layers.Dense(item_feature_shape_dim, self.rank, name='item_feat_in')
        self.user_item_in = layers.GAT(
            self.rank, self.n_head*self.rank, 1, concat=False, residual=self.residual, name='user_item_in')
        self.user_item_out = layers.Dense(self.user_item_in.dim_out, self.rank, name='user_item_out')
        self.user_factor_in_ui_graph_dense = layers.Dense(self.rank, self.rank, name='user_factor_in_ui_graph_dense')
        self.item_factor_in_ui_graph_dense = layers.Dense(self.rank, self.rank, name='item_factor_in_ui_graph_dense')

        self.user_gate = layers.Dense(3 * self.rank, self.rank, name='user_gate')
        self.item_gate = layers.Dense(3 * self.rank, self.rank, name='item_gate')
        self.uu_view_dense = layers.Dense(2 * self.rank, self.rank, name='uu_view_dense')
        self.ui_view_dense = layers.Dense(2 * self.rank, self.rank, name='ui_view_dense')
        self.ii_view_dense = layers.Dense(2 * self.rank, self.rank, name='ii_view_dense')
        self.iu_view_dense = layers.Dense(2 * self.rank, self.rank, name='iu_view_dense')
       
        self.user_global_view_dense = layers.Dense(self.rank, self.rank, name='user_global_view_dense')
        self.item_global_view_dense = layers.Dense(self.rank, self.rank, name='item_global_view_dense')
        
        self.user_content_local_dense = layers.Dense(self.rank, self.rank, name='user_content_local_dense')
        self.user_action_local_dense = layers.Dense(self.rank, self.rank, name='user_action_local_dense')
        self.user_global_dense = layers.Dense(self.rank, self.rank, name='user_global_dense')
        
        self.item_content_local_dense = layers.Dense(self.rank, self.rank, name='item_content_local_dense')
        self.item_action_local_dense = layers.Dense(self.rank, self.rank, name='item_action_local_dense')
        self.item_global_dense = layers.Dense(self.rank, self.rank, name='item_global_dense')
        
        for layer in [self.user_in, self.user_out, self.item_in, self.item_out, self.user_pp_dense, self.user_feat_dense, self.item_pp_dense, self.item_feat_dense, self.user_gate, self.item_gate,\
                      self.user_feat_in, self.item_feat_in, self.user_item_in, self.user_item_out, self.user_factor_in_ui_graph_dense, self.item_factor_in_ui_graph_dense,\
                      self.uu_view_dense, self.ui_view_dense, self.ii_view_dense, self.iu_view_dense,\
                      self.user_global_view_dense, self.item_global_view_dense,\
                      self.user_content_local_dense, self.user_action_local_dense, self.user_global_dense, self.item_content_local_dense, self.item_action_local_dense, self.item_global_dense]:
            weights.update(layer.get_weights())
            biases.update(layer.get_biases())

        sparse = bool(self.config.get('sparse', True))
        
        self.user_factor_pp = self.user_out(self.user_in(
            self.user_features, self.activation_in, self.user_mask, sparse), self.activation_out)
        
        user_factor_pp_dense = self.user_pp_dense(self.user_factor_pp, self.activation_out)
        print('self.user_factor_pp_dense.shape:', user_factor_pp_dense.shape)
        user_feature_dense = self.user_feat_dense(self.user_features, self.activation_out)
        
        self.item_factor_pp = self.item_out(self.item_in(
            self.item_features, self.activation_in, self.item_mask, sparse), self.activation_out)
        
        item_factor_pp_dense = self.item_pp_dense(self.item_factor_pp, self.activation_out)
        print('self.item_factor_pp_dense.shape:', item_factor_pp_dense.shape)
        item_feature_dense = self.item_feat_dense(self.item_features, self.activation_out)
        
        user_feat_in_dense = self.user_feat_in(self.user_features)
        item_feat_in_dense = self.item_feat_in(self.item_features)
        self.user_item_in_dense = tf.concat([user_feat_in_dense, item_feat_in_dense], axis = 0)  # [m+n, dim]
        self.user_item_factor_pp = self.user_item_out(self.user_item_in(
            self.user_item_in_dense, self.activation_in, self.user_item_mask, sparse), self.activation_out)
        self.user_factor_in_ui_graph = tf.slice(self.user_item_factor_pp, [0, 0], [tf.shape(self.user_features)[0], -1])
        self.item_factor_in_ui_graph = tf.slice(self.user_item_factor_pp, [tf.shape(self.user_features)[0], 0], [-1, -1])
        user_factor_in_ui_dense = self.user_factor_in_ui_graph_dense(self.user_factor_in_ui_graph, self.activation_out)
        item_factor_in_ui_dense = self.item_factor_in_ui_graph_dense(self.item_factor_in_ui_graph, self.activation_out)
        
        weight1_fix = 1.0
        weight2_fix = 1.0
        weight3_fix = 1.0
        use_fix_weight = True
        if use_fix_weight:
            self.user_global_view = weight1_fix*weights['user_factor'] + weight2_fix*self.user_factor_pp + weight3_fix*self.user_factor_in_ui_graph
            self.item_global_view = weight1_fix*weights['item_factor'] + weight2_fix*self.item_factor_pp + weight3_fix*self.item_factor_in_ui_graph

        self.user_factor = tf.nn.embedding_lookup(self.user_global_view, self.user_id)
        self.item_factor = tf.nn.embedding_lookup(self.item_global_view, self.item_id)
   

        self.user_factor_local = tf.nn.embedding_lookup(self.user_factor_pp, self.user_id)
        self.item_factor_local = tf.nn.embedding_lookup(self.item_factor_pp, self.item_id)
        

        cosine_threshold = 0.5
        self.user_factor_norm = tf.math.l2_normalize(self.user_global_view, axis=-1)
        self.user_factor_cos = tf.matmul(self.user_factor_norm, self.user_factor_norm, transpose_b = True)
        self.user_mask = tf.where(self.user_factor_cos > cosine_threshold, self.user_mask, tf.zeros_like(self.user_mask))

        self.item_factor_norm = tf.math.l2_normalize(self.item_global_view, axis=-1)
        self.item_factor_cos = tf.matmul(self.item_factor_norm, self.item_factor_norm, transpose_b = True)
        self.item_mask = tf.where(self.item_factor_cos > cosine_threshold, self.item_mask, tf.zeros_like(self.item_mask))
        
        user_item_factor_norm_concat = tf.concat([self.user_factor_norm, self.item_factor_norm], axis=0)
        self.user_item_factor_cos = tf.matmul(user_item_factor_norm_concat, user_item_factor_norm_concat, transpose_b = True)
        self.user_item_mask = tf.where(self.user_item_factor_cos > cosine_threshold, self.user_item_mask, tf.zeros_like(self.user_item_mask))
            
        return weights, biases, user_factor_pp_dense, user_feature_dense, item_factor_pp_dense, item_feature_dense, user_factor_in_ui_dense, item_factor_in_ui_dense

    def _schema(self):
        return {
            'weights': self.weights,
            'biases': self.biases,
            'outputs': {
                'user_alpha': self.user_in.heads[0].alpha,
                'item_alpha': self.item_in.heads[0].alpha,
                'user_factor_pp': self.user_factor_pp,
                'item_factor_pp': self.item_factor_pp,
                'user_factor': self.user_weight1 * self.weights['user_factor'] + self.user_weight2 * self.user_factor_pp,
                'item_factor': self.item_weight1 * self.weights['item_factor'] + self.item_weight2 * self.item_factor_pp,
            },
        }
    
    def _local_pred(self):
        return (tf.reduce_sum(self.user_factor_local*self.item_factor_local, 1)
            + self.user_bias + self.item_bias + self.bias)
    
    def _cal_contra_loss_v1(self, local_view, global_view, rou = 0.07):
        pos_inner_product = tf.math.reduce_sum(local_view * global_view, keepdims=True, axis=1)
        numerator = tf.math.exp(pos_inner_product / rou)
        
        all_inner_product = tf.matmul(local_view, global_view, transpose_b=True)
        denominator_tmp = tf.math.exp(all_inner_product / rou)
        denominator = tf.math.reduce_sum(denominator_tmp, keepdims=True, axis=1)
        
        return tf.reduce_mean(-tf.log(numerator / denominator))
    
    def _global_local_contra_loss(self):
        rou = 0.07
        
        user_content_local_norm = tf.math.l2_normalize(self.weights['user_factor']+self.user_factor_pp, axis=-1)  # [m, dim]
        user_action_local_norm = tf.math.l2_normalize(self.weights['user_factor']+self.user_factor_in_ui_graph, axis=-1)  # [m, dim]
        user_global_view_norm = self.user_factor_norm  # [m, dim]
        
        item_content_local_norm = tf.math.l2_normalize(self.weights['item_factor']+self.item_factor_pp, axis=-1)  # [n, dim]
        item_action_local_norm = tf.math.l2_normalize(self.weights['item_factor']+self.item_factor_in_ui_graph, axis=-1)  # [n, dim]
        item_global_view_norm = self.item_factor_norm  # [n, dim]

        user_content_contra_loss = self._cal_contra_loss_v1(user_content_local_norm, user_global_view_norm, rou)
        item_content_contra_loss = self._cal_contra_loss_v1(item_content_local_norm, item_global_view_norm, rou)

        user_action_contra_loss = self._cal_contra_loss_v1(user_action_local_norm, user_global_view_norm, rou)
        item_action_contra_loss = self._cal_contra_loss_v1(item_action_local_norm, item_global_view_norm, rou)
        
        return user_content_contra_loss, item_content_contra_loss, user_action_contra_loss, item_action_contra_loss
    
    def _global_local_mse_loss(self):
        user_content_local = self.user_content_local_dense(self.weights['user_factor']+self.user_factor_pp, self.activation_out)
        user_action_local = self.user_action_local_dense(self.weights['user_factor']+self.user_factor_in_ui_graph, self.activation_out)
        user_global = self.user_global_dense(self.weights['user_factor']+self.user_factor_pp+self.user_factor_in_ui_graph, self.activation_out)
        
        item_content_local = self.item_content_local_dense(self.weights['item_factor']+self.item_factor_pp, self.activation_out)
        item_action_local = self.item_action_local_dense(self.weights['item_factor']+self.item_factor_in_ui_graph, self.activation_out)
        item_global = self.item_global_dense(self.weights['item_factor']+self.item_factor_pp+self.item_factor_in_ui_graph, self.activation_out)

        user_content_mse_loss = tf.reduce_mean((user_content_local - user_global)**2)
        item_content_mse_loss = tf.reduce_mean((item_content_local - item_global)**2)

        user_action_mse_loss = tf.reduce_mean((user_action_local - user_global)**2)
        item_action_mse_loss = tf.reduce_mean((item_action_local - item_global)**2)
        
        return user_content_mse_loss + item_content_mse_loss + user_action_mse_loss + item_action_mse_loss
        
        

In [None]:
if __name__ == '__main__':
    MODEL = GMGCL
    
    # choose dataset
    # DATASET_PATH = '../data/datasets/MovieLens100K' # check data/datasets for options
    # DATASET_PATH = '../data/datasets/Flixster' # check data/datasets for options
    DATASET_PATH = '../data/datasets/MovieLens1M' # check data/datasets for options
    # DATASET_PATH = '../data/datasets/KuaiRandPure' # check data/datasets for options
    
    BATCH_SIZE=512
    DATASET = Dataset.load(DATASET_PATH)
    NAME = '{}_{}'.format(MODEL.__name__, DATASET.name)
    
    config = {'activation_in': 'softsign', 'activation_out': 'sigmoid', 'alpha': 1e-05, 
              'k': 16.0, 'n_head': 5.0, 'rank': 12.0, 'residual': True, 'max_updates': 1000, 'sparse': False}

    #NOTE(yangfei05):
    config['sparse'] = False
    print(f"parameters: {config}")

    print("\n\nTRAIN AND EVAL")
    data = DATASET.data[['user_id', 'item_id', 'rating', 'is_test']]
            
    dataset = Dataset(data, **DATASET.side_info)
    dataset_np = np.array(dataset)

    
    physical_devices = tf.config.list_physical_devices('GPU') 
    for gpu_instance in physical_devices: 
        tf.config.experimental.set_memory_growth(gpu_instance, True)
    with tf.Graph().as_default():
        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            session.run(tf.local_variables_initializer())
            model = MODEL(session, dataset, **config)
            print('model:', model)
            model.train(max_updates=config['max_updates'], batch_size=BATCH_SIZE)
            test_pred_df = model.test_infer()
            score = model.get_metrics(test_pred_df)
            print(score)
            display( pd.DataFrame({k:[v] for k,v in score.items()}) )