In [None]:
import numpy as np
import tensorflow as tf
import config
import logging
import generator_dif
import discriminator_dif
import copy
from tqdm import tqdm

## DiffusionGAN

In [None]:
class DiffusionGAN(object):
    def __init__(self):
        self.emb_dim = 50
        self.train_data = config.small_train_cascades
        self.test_data = config.small_test_cascades
        self.epoch = 20
        self._u2idx = {}
        self._idx2u = []
        self._buildIndex()  
        self._train_cascades = self._readFromFile(self.train_data)  
        self._test_cascades = self._readFromFile(self.test_data)
        self.train_size = len(self._train_cascades)  # 3419
        self.test_size = len(self._test_cascades)
        
        logging.info(
            "training set size:%d    testing set size:%d" % (self.train_size, self.test_size))
        self.emb_user = np.random.rand(self.user_size, self.emb_dim)
        
        self.matrix = []
        
        self.build_generator()
        self.build_discriminator()

        self.latest_checkpoint = tf.train.latest_checkpoint(config.model_log)
        self.saver = tf.train.Saver()
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        self.sess = tf.Session(config=self.config)
        self.sess.run(self.init_op)

    def _buildIndex(self):
        train_user_set = set()
        test_user_set = set()

        for line in open(self.train_data):
            if len(line.strip()) == 0:
                continue
            chunks = line.strip().split()
            for chunk in chunks:
                user, timestamp = chunk.split(',')
                train_user_set.add(user)  
        for line in open(self.test_data):
            if len(line.strip()) == 0:
                continue
            chunks = line.strip().split()
            for chunk in chunks:
                user, timestamp = chunk.split(',')
                test_user_set.add(user)

        pos = 0
        for user in train_user_set:
            self._u2idx[user] = pos
            pos += 1
            self._idx2u.append(user)
        self.user_size = len(train_user_set)  
        logging.info("user size : %d" % self.user_size)

    def _readFromFile(self, filename):
        t_cascades = []
        for line in open(filename):
            if len(line.strip()) == 0:
                continue
            userlist = []
            chunks = line.strip().split()
            for chunk in chunks:
                user, timestamp = chunk.split(',')
                if user in self._u2idx:
                    userlist.append(self._u2idx[user])

            if len(userlist) > 1:
                t_cascades.append(userlist)

        return t_cascades

    def build_generator(self):
        self.generator = generator_dif.Generator(node_emd_init=self.emb_user)

    def build_discriminator(self):
        self.discriminator = discriminator_dif.Discriminator(node_emd_init=self.emb_user)

    def write_embeddings_to_file(self):
        modes = [self.generator, self.discriminator]
        for i in range(2):
            embedding_matrix = self.sess.run(modes[i].embedding_matrix)
            index = np.array(range(self.user_size)).reshape(-1, 1)
            embedding_matrix = np.hstack([index, embedding_matrix])
            embedding_list = embedding_matrix.tolist()
            embedding_str = [str(int(emb[0])) + "\t" + "\t".join([str(x) for x in emb[1:]]) + "\n"
                             for emb in embedding_list]
            with open(config.emb_filenames[i], "w+") as f:
                lines = [str(self.user_size) + "\t" + str(config.n_emb) + "\n"] + embedding_str
                f.writelines(lines)
                
    def write_embeddings_to_file2(self):
        modes = [self.generator, self.discriminator]
        for i in range(2):
            embedding_matrix = self.sess.run(modes[i].embedding_matrix)
            self.matrix.append(embedding_matrix)
            index = np.array(range(self.user_size)).reshape(-1, 1)
            embedding_matrix = np.hstack([index, embedding_matrix])
            embedding_list = embedding_matrix.tolist()
            embedding_str = [str(int(emb[0])) + "\t" + "\t".join([str(x) for x in emb[1:]]) + "\n"
                             for emb in embedding_list]
            with open(config.emb_filenames[i], "w+") as f:
                lines = [str(self.user_size) + "\t" + str(config.n_emb) + "\n"] + embedding_str
                f.writelines(lines)

    def prepare_data_for_d(self, all_score, U):
        copy_score = copy.deepcopy(all_score)  
        U_score = copy_score[U]
        max_score_id = dict()
        for score in U_score:
            for i in U:
                score[i] = 0
            maxid = np.argmax(score)
            max_score_id[maxid] = np.max(score)
        neg = max(max_score_id, key=max_score_id.get)
        return neg

    def prepare_data_for_g(self, all_score, U):
        copy_score = copy.deepcopy(all_score) 
        U_score = copy_score[U]
        max_score_id = dict()
        for score in U_score:
            for i in U:
                score[i] = 0
            maxid = np.argmax(score)
            max_score_id[maxid] = np.max(score)
        neg = max(max_score_id, key=max_score_id.get)
        score = self.computePv_dis2(neg, U)
        reward = self.sess.run(self.discriminator.reward,
                               feed_dict={self.discriminator.u_t: [neg],
                                          self.discriminator.U: U,
                                          self.discriminator.score: [score]})
        return neg, reward

    def computePv_dis(self, u_t, U):
        pv = 1.0
        feed_dict = {self.discriminator.u_t: u_t, self.discriminator.U: U}
        u_t_embedding, U_embedding = self.sess.run([self.discriminator.u_t_embedding, self.discriminator.U_embedding],
                                                   feed_dict=feed_dict)
        for u in U_embedding:
            p_uv = 1. / np.sqrt(np.sum(np.square(u_t_embedding - u)))
            pv = pv * (1 - p_uv)
        p_v = 1 - pv
        return p_v

    def computePv_dis2(self, u_t, U):
        pv = 1.0
        feed_dict = {self.discriminator.u_t:[u_t], self.discriminator.U: U}
        u_t_embedding, U_embedding = self.sess.run([self.discriminator.u_t_embedding, self.discriminator.U_embedding],
                                                   feed_dict=feed_dict)
        for u in U_embedding:
            p_uv = 1. / np.sqrt(np.sum(np.square(u_t_embedding - u)))
            pv = pv * (1 - p_uv)
        p_v = 1 - pv
        return p_v

    def computePv_gen(self, u_t, U):
        pv = 1.0
        feed_dict = {self.generator.u_t: [u_t], self.generator.U: U}
        u_t_embedding, U_embedding = self.sess.run([self.generator.u_t_embedding, self.generator.U_embedding],
                                                   feed_dict=feed_dict)
        for u in U_embedding:
            p_uv = 1. / np.sqrt(np.sum(np.square(u_t_embedding - u)))
            pv = pv * (1 - p_uv)
        p_v = 1 - pv
        return p_v

    def all_score(self):
        embedding_matrix = self.sess.run(self.generator.embedding_matrix)
        all_score = np.matmul(embedding_matrix,embedding_matrix.transpose())
        for i in range(self.user_size):
            all_score[i][i] = 0
        return all_score

    def train(self):
        checkpoint = tf.train.get_checkpoint_state(config.model_log)
        if checkpoint and checkpoint.model_checkpoint_path and config.load_model:
            print("loading the checkpoint: %s" % checkpoint.model_checkpoint_path)
            self.saver.restore(self.sess, checkpoint.model_checkpoint_path)

        self.write_embeddings_to_file()
        print("start training...")
        for epoch in range(config.n_epochs):
            print("epoch %d" % epoch)
            if epoch > 0 and epoch % config.save_steps == 0:
                self.saver.save(self.sess, config.model_log + "model.checkpoint")
            for d_epoch in range(config.n_epochs_dis):
                for cascade in self._train_cascades:
                    all_score = self.all_score()
                    for i in range(1, len(cascade)):
                        U = cascade[0:i]
                        neg = self.prepare_data_for_d(all_score, U)
                        u_true = cascade[i]
                        labels = [0, 1]
                        u = [neg, u_true]
                        score = self.computePv_dis(u, U)
                        self.sess.run(self.discriminator.d_updates, feed_dict={self.discriminator.u_t: u,
                                                                               self.discriminator.U: U,
                                                                               self.discriminator.label: labels,
                                                                               self.discriminator.score: [score]})
            for g_epoch in range(config.n_epochs_gen):
                for cascade in self._train_cascades:
                    all_score = self.all_score()
                    for i in range(1, len(cascade)):
                        U = cascade[0:i]
                        neg, reward = self.prepare_data_for_g(all_score, U)
                        p_v = self.computePv_gen(neg, U)
                        self.sess.run(self.generator.g_updates, feed_dict={self.generator.u_t: [neg],
                                                                           self.generator.U: U,
                                                                           self.generator.reward: reward,
                                                                           self.generator.p_v: [p_v]})

        self.write_embeddings_to_file2()
        print("training completes")



In [None]:
diffusion_gan = DiffusionGAN()

In [None]:
idx2u = diffusion_gan._idx2u

## Training

In [None]:
diffusion_gan.train()

In [None]:
matrix = diffusion_gan.matrix
emb_gen = matrix[0]
emb_dis = matrix[1]
user_emb = emb_dis
user_size = diffusion_gan.user_size
test_size = diffusion_gan.test_size
print(emb_gen)
print(emb_dis)