In [1]:
# -*- coding:utf-8 -*-
# recommending a list of 20 items out of 40000 items
import numpy as np
import tensorflow as tf
import sys
import time
import random
import copy
from agent import AgentModel
from environment import EnvModel
from discriminator import DisModel
from utils import FLAGS, load_data, build_vocab, gen_batched_data, PAD_ID, UNK_ID, GO_ID, EOS_ID, _START_VOCAB
import os
#**********************************************************************************

In [2]:
# Empty the output file
fout = open(FLAGS['agn_output_file'].value, "w")
fout.close()
fout = open(FLAGS['env_output_file'].value, "w")
fout.close()

# Creating training directory if it does not exist
if not os.path.exists(FLAGS['interact_data_dir'].value):
    os.makedirs(FLAGS['interact_data_dir'].value)
if not os.path.exists(FLAGS['agn_train_dir'].value):
    os.makedirs(FLAGS['agn_train_dir'].value)
if not os.path.exists(FLAGS['env_train_dir'].value):
    os.makedirs(FLAGS['env_train_dir'].value)
if not os.path.exists(FLAGS['dis_train_dir'].value):
    os.makedirs(FLAGS['dis_train_dir'].value)
    
print(FLAGS)

generate_session, gen_session, gen_rec_list, gen_aims_idx, gen_purchase, session_no, next_session = [], [], [], [], [], 0, True
ini_state = [[[[0.]*FLAGS['units'].value]]*2]*FLAGS['layers'].value
gen_state = ini_state

def select_action(click, state):
    # current_action = [aid2index[item] for item in list(np.random.permutation(vocab[len(_START_VOCAB):])[:FLAGS['action_num'].value])]
    with agn_graph.as_default():
        output = agn_sess.run(
            [agn_model.random_rec_index, agn_model.encoder_state_predict], 
            feed_dict={agn_model.sessions_input: np.reshape(click, [1,1]), 
                    agn_model.sessions_length: np.array([1]),
                    agn_model.lstm_state:state})
    return np.concatenate([np.reshape(output[0], [1, 1, FLAGS['action_num'].value]), np.reshape([EOS_ID], [1,1,1])], 2), output[1]

def rollout(state, click, rollout_list, rollout_rec_list, rollout_aim, rollout_purchase, length):
    rollout_list.append(click)

    with agn_graph.as_default():
        output = agn_sess.run([agn_model.encoder_state_predict, agn_model.random_rec_index], feed_dict={
            agn_model.sessions_input:np.reshape(click, [1,1]),
            agn_model.sessions_length:[1],
            agn_model.lstm_state:state})
        next_state = output[0]
        action = np.concatenate([np.reshape(output[1], [1, 1, FLAGS['action_num'].value]), np.reshape([EOS_ID], [1,1,1])], 2)
        rollout_rec_list.append(action[0,0,:])

    with env_graph.as_default():
        #[1, 1, 10]
        rec_list = np.reshape(rollout_rec_list[-1], [1,1,-1])
        output = env_sess.run([env_model.inf_random_index, env_model.inf_purchase_prob], feed_dict={
            env_model.sessions_input:np.reshape(rollout_list, [1, -1]), 
            env_model.rec_lists:rec_list, 
            env_model.rec_mask:np.ones_like(rec_list),
            env_model.sessions_length:[len(rollout_list)]})

        next_click = rec_list[0,0,output[0][0, -1, 0]]
        rollout_purchase.append(1 if output[1][0, 0, 1] > 0.5 else 0)
        rollout_aim.append(output[0][0,-1,0])

    if len(rollout_list) >= length or click == 3:
        return rollout_list, rollout_rec_list, rollout_aim, rollout_purchase
    return rollout(next_state, next_click, list(rollout_list), list(rollout_rec_list), list(rollout_aim), list(rollout_purchase), length)


def generate_next_click(current_click, flog, use_dis=FLAGS['use_dis'].value):
    global gen_session, gen_rec_list, gen_aims_idx, gen_state, gen_purchase, session_no, next_session

    if len(gen_session) >= max_interact_len or current_click == 3:
        gen_session = [np.random.choice(sort_start_click, p=sort_start_click_prob)]
        gen_rec_list, gen_aims_idx, gen_purchase = [], [], []
        gen_state = ini_state
        session_no += 1
        next_session = True
        current_click = gen_session[-1]
        print(flog, "------------next session:%d------------" % (session_no))
    else:
        gen_session.append(current_click)
        next_session = False
    session_click = np.reshape(np.array(gen_session), [1, len(gen_session)])
    action, state = select_action(session_click[0,-1], gen_state)
    print(flog, "current_click:", current_click)
    gen_state = state

    with env_graph.as_default():
        #[1, 1, 10]
        output = env_sess.run([env_model.inf_random_index, env_model.inf_purchase_prob], feed_dict={
            env_model.sessions_input:session_click, 
            env_model.rec_lists:action, 
            env_model.rec_mask:np.ones_like(action),
            env_model.sessions_length:[len(session_click[0])]})
        next_click = action[0, 0, output[0][0, -1, 0]]
        purchase_prob = output[1][0, 0, 1]
        print(flog, "next_click:", next_click, "purchase_prob:", purchase_prob, "reward:", 4 if purchase_prob > 0.5 else 1)
        gen_rec_list.append(list(action[0,0,:]))
        gen_aims_idx.append(output[0][0,-1,0])
        gen_purchase.append(1 if purchase_prob > 0.5 else 0)
    dis_reward = 1.

    if use_dis:
        with dis_graph.as_default():
            score = []
            rollout_num = 5 if (len(gen_session) < max_interact_len) and (next_click != 3) else 1
            for _ in range(rollout_num):
                tmp_total_click, tmp_total_rec_list, tmp_total_aims_idx, tmp_total_purchase = rollout(gen_state,next_click,list(gen_session), list(gen_rec_list), list(gen_aims_idx), list(gen_purchase), max_interact_len+1)
                prob = dis_sess.run(dis_model.prob, {
                    dis_model.sessions_input:np.reshape(tmp_total_click, [1, -1]),
                    dis_model.sessions_length:np.array([len(tmp_total_click)]),
                    dis_model.rec_lists:np.array([tmp_total_rec_list]),
                    dis_model.rec_mask:np.ones([1,len(tmp_total_click),len(tmp_total_rec_list[-1])]),
                    dis_model.aims_idx:np.reshape(tmp_total_aims_idx, [1, len(tmp_total_click)]),
                    dis_model.purchase:np.reshape(tmp_total_purchase, [1, len(tmp_total_purchase)])
                    })
                score.append(prob[0])
            dis_reward = np.mean(score)
        print(flog, "dis_reward:%.8f" % dis_reward)

    action = list(action[0,0,:])
    print(flog, "action:", action)
    return current_click, next_click, action, purchase_prob, dis_reward

def generate_data(size, flog, use_dis=FLAGS['use_dis'].value):
    global generate_session, current_click, session_no, next_session
    tmp_session_no = session_no
    current_click = np.random.choice(sort_start_click, p=sort_start_click_prob)
    while session_no < tmp_session_no + size:
        current_click, next_click, current_action, purchase_prob, dis_reward = generate_next_click(current_click, flog, use_dis=use_dis)
        if not next_session and len(generate_session) > 0:
            generate_session[-1].append({"session_no":session_no, "click":current_click, "rec_list": current_action, "purchase":(0 if purchase_prob<=0.5 else 1), "dis_reward": dis_reward})
        else:
            if len(generate_session) > 0:
                length = len(generate_session[-1])
                for i in range(1, length):
                    generate_session[-1][length-i]["rec_list"] = generate_session[-1][length-i-1]["rec_list"]
                    generate_session[-1][length-i]["purchase"] = generate_session[-1][length-i-1]["purchase"]
                generate_session[-1][0]["rec_list"] = [generate_session[-1][0]["click"]]
                generate_session[-1][0]["purchase"] = 0

            generate_session.append([{"session_no":session_no, "click":current_click, "rec_list": current_action, "purchase":(0 if purchase_prob<=0.5 else 1), "dis_reward": dis_reward}])
        current_click = next_click
    next_session = True
    if len(generate_session) > FLAGS['pool_size'].value:
        generate_session = generate_session[-FLAGS['pool_size'].value:]

#**********************************************************************************
#**********************************************************************************
#**********************************************************************************


absl.app:
  --[no]only_check_args: Set to true to validate args and exit.
    (default: 'false')
  --[no]pdb: Alias for --pdb_post_mortem.
    (default: 'false')
  --[no]pdb_post_mortem: Set to true to handle uncaught exceptions with PDB post
    mortem.
    (default: 'false')
  --profile_file: Dump profile information to a file (for python -m pstats).
    Implies --run_with_profiling.
  --[no]run_with_pdb: Set to true for PDB debug mode
    (default: 'false')
  --[no]run_with_profiling: Set to true for profiling the script. Execution will
    be slower, and the output format might change over time.
    (default: 'false')
  --[no]use_cprofile_for_profiling: Use cProfile instead of the profile module
    for profiling. This has no effect unless --run_with_profiling is set.
    (default: 'true')

absl.logging:
  --[no]alsologtostderr: also log to stderr?
    (default: 'false')
  --log_dir: directory to write logfiles into
    (default: '')
  --logger_levels: Specify log level of loggers

In [3]:
config = tf.ConfigProto(device_count={'GPU':5}) # configuration of cpu/gpu, set which GPUs are to be used using gpu numbers
# config.gpu_options.allow_growth = True # dynamically grow the memory used on the gpu
tf.reset_default_graph()
env_graph = tf.Graph()
agn_graph = tf.Graph()
dis_graph = tf.Graph()
env_sess = tf.Session(config=config, graph=env_graph)
agn_sess = tf.Session(config=config, graph=agn_graph)
dis_sess = tf.Session(config=config, graph=dis_graph)

data = load_data(FLAGS['data_dir'].value, FLAGS['data_name'].value) # load data from the specified directory and file mentioned in utils.py
data = np.random.permutation(data) # randomize the order of data points

max_interact_len = 2 * int(np.mean([len(s) for s in data]))
print("Average length of the dataset:", np.mean([len(s) for s in data]), "max_interact_len:", max_interact_len)
fold = len(data) // 40 # divide into 40 batches
data_train = data[:(fold * 38)] # 38 batches for training
data_dev = data[(fold * 38):(fold * 39)] # one batch for validation
data_test = data[(fold * 39):] # one batch for testing

vocab, embed = build_vocab(data_train) # get list of clicked items, rewards corresponding to each clicked item and 50 actions
aid2index = {} # clicked items as key and index as value
index2aid = {} # index as key and clicked items as value in data
for i,a in enumerate(vocab):
    aid2index[a] = i
    index2aid[i] = a

if FLAGS['use_simulated_data'].value: # if using simulated data
    sort_start_click = [data[0][0]["click"]]
    sort_start_click_prob = [1.]
else:
    start_click = {} # how many times has a starting item been clicked for each session
    # first clicked item in each session
    for d in data:
        if d[0]["click"] in aid2index: # if clicked item is in the click2index dict
            k = aid2index[d[0]["click"]]
        else:
            continue
        if k in start_click:
            start_click[k] += 1
        else:
            start_click[k] = 1
    # sort all the first clicked item of each session using values (how many times it has been clicked)
    sort_start_click = sorted(start_click, key=start_click.get, reverse=True)
    # using the frequency of first clicked item of each session calculate probability
    sort_start_click_prob = np.array([start_click[item] for item in sort_start_click]) / float(np.sum([start_click[item] for item in sort_start_click]))

def filter(d):
    new_d = []
    for i, s in enumerate(d): # index, session in passed data
        tmps = []
        for c in s: # for each interaction
            # store the clicked item's index if it's in click2index list otherwise store UNK_ID or 1
            c["click"] = aid2index[c["click"]] if c["click"] in aid2index else UNK_ID
            # for each item in recommended list, if it's been clicked store it's index from click2index
            c["rec_list"] = list(set([aid2index[rl] if rl in aid2index else UNK_ID for rl in c["rec_list"]])) + [EOS_ID]
            tmps.append(c)
        new_d.append(tmps)
    d = copy.deepcopy(new_d)
    return d

Reading data from: ./data/train.csv
Number of sessions after filtering: 68722


  return array(a, dtype, copy=False, order=order)


Average length of the dataset: 2.8058700270655685 max_interact_len: 4
Creating vocabulary...


In [4]:
data_train = filter(data_train)
print("filtered train")
data_dev = filter(data_dev)
print("filtered dev")
data_test = filter(data_test)
print("filtered test")
# interaction will look like this
# {'click': 12, 'rec_list': [1, 12, 13, 14, 15, 16, 3], 'purchase': 0, 'dis_reward': 1.0}

# here average length means average interactions among sessions
print("Get training data: number is %d, average length of session is %.4f" % (len(data_train), np.mean([len(s) for s in data_train])))
print("Get validation data: number is %d, average length of session is %.4f" % (len(data_dev), np.mean([len(s) for s in data_dev])))
print("Get testing data: number is %d, average length of session is %.4f" % (len(data_test), np.mean([len(s) for s in data_test])))

filtered train
filtered dev
filtered test
Get training data: number is 65284, average length of session is 2.8043
Get validation data: number is 1718, average length of session is 2.8597
Get testing data: number is 1720, average length of session is 2.8122


In [5]:
with agn_graph.as_default(): # define agent graph session
    agn_model = AgentModel(
            num_items=len(embed),
            num_embed_units=FLAGS['embed_units'].value,
            num_units=FLAGS['units'].value,
            num_layers=FLAGS['layers'].value,
            embed=embed,
            action_num=FLAGS['action_num'].value)
    agn_model.print_parameters()
    if tf.train.get_checkpoint_state(FLAGS['agn_train_dir'].value):
        print("Reading agent model parameters from %s" % FLAGS['agn_train_dir'].value)
        agn_model.saver.restore(agn_sess, tf.train.latest_checkpoint(FLAGS['agn_train_dir'].value))
    else:
        print("Creating agent model with fresh parameters.")
        agn_sess.run(tf.global_variables_initializer())

if FLAGS['interact'].value: # True to use online model-based training
    with env_graph.as_default():
        env_model = EnvModel(
                num_items=len(embed),
                num_embed_units=FLAGS['embed_units'].value,
                num_units=FLAGS['units'].value,
                num_layers=FLAGS['layers'].value,
                vocab=vocab,
                embed=embed)
        env_model.print_parameters()
        if tf.train.get_checkpoint_state(FLAGS['env_train_dir'].value):
            print("Reading environment model parameters from %s" % FLAGS['env_train_dir'].value)
            env_model.saver.restore(env_sess, tf.train.latest_checkpoint(FLAGS['env_train_dir'].value))
        else:
            print("Creating environment model with fresh parameters.")
            env_sess.run(tf.global_variables_initializer())

    if FLAGS['use_dis'].value:
        with dis_graph.as_default():
            dis_model = DisModel(
                    num_items=len(embed),
                    num_embed_units=FLAGS['embed_units'].value,
                    num_units=FLAGS['units'].value,
                    num_layers=FLAGS['layers'].value,
                    vocab=vocab,
                    embed=embed)
            dis_model.print_parameters()
            if tf.train.get_checkpoint_state(FLAGS['dis_train_dir'].value):
                print("Reading discriminator model parameters from %s" % FLAGS['dis_train_dir'].value)
                dis_model.saver.restore(dis_sess, tf.train.latest_checkpoint(FLAGS['dis_train_dir'].value))
            else:
                print("Creating discriminator model with fresh parameters.")
                dis_sess.run(tf.global_variables_initializer())



Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.

Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Please use `layer.add_weight` method instead.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please fi

In [6]:
best_env_train_acc, best_env_train_acc_1 = 0., 0.
def env_train(size, pg=False):
    print("In env_train..")
    global best_env_train_acc, best_env_train_acc_1
    pre_losses = [1e18] * 3
    for _ in range(size):
        print("Epoch = "+ str(_))
        with env_graph.as_default():
            start_time = time.time()
            if pg:
                loss = env_model.pg_train(env_sess, generate_session)
                pr_loss, pu_loss = 0, 0
            else:
                loss, pr_loss, pu_loss, _, _ = env_model.train(env_sess, data_train)
            if loss > max(pre_losses):  # Learning rate decay
                env_sess.run(env_model.learning_rate_decay_op)
            pre_losses = pre_losses[1:] + [loss]
            print("Env epoch %d lr %.4f time %.4f ppl [%.8f] pr_loss [%.8f] pu_loss [%.8f]" \
                  % (env_model.epoch.eval(session=env_sess), env_model.learning_rate.eval(session=env_sess), time.time() - start_time, loss, pr_loss, pu_loss))
            loss, pr_loss, pu_loss, acc, acc_1 = env_model.train(env_sess, data_dev, is_train=False)
            print("        dev_set, ppl [%.8f] pr_loss [%.8f] pu_loss [%.8f] best_p@%d [%.4f]" % (loss, pr_loss, pu_loss, FLAGS['metric'].value, best_env_train_acc))
            if acc > best_env_train_acc or acc_1 > best_env_train_acc_1:
                if acc > best_env_train_acc: best_env_train_acc = acc
                if acc_1 > best_env_train_acc_1: best_env_train_acc_1 = acc_1
                loss, pr_loss, pu_loss, _, _ = env_model.train(env_sess, data_test, is_train=False)
                print("        test_set, ppl [%.8f] pr_loss [%.8f] pu_loss [%.8f] best_p@%d [%.4f]" % (loss, pr_loss, pu_loss, FLAGS['metric'].value, best_env_train_acc))
                env_model.saver.save(env_sess, '%s/checkpoint' % FLAGS['env_train_dir'].value, global_step=env_model.global_step)
                print("Saving env model params in %s" % FLAGS['env_train_dir'].value)
            print("------env %strain finish-------"%("pg " if pg else ""))

best_agn_train_acc, best_agn_train_acc_1 = 0., 0.
def agn_train(size):
    print("In agn_train..")
    global best_agn_train_acc, best_agn_train_acc_1
    for _ in range(size):
        print("Epoch = "+ str(_))
        with agn_graph.as_default():
            start_time = time.time()
            loss, acc, acc_1 = agn_model.train(agn_sess, data_train, generate_session)
            print("Agn epoch %d learning rate %.4f epoch-time %.4f loss [%.8f] p@%d %.4f%% p@1 %.4f%%" \
                    % (agn_model.epoch.eval(session=agn_sess), agn_model.learning_rate.eval(session=agn_sess), time.time() - start_time, loss, FLAGS['metric'].value, acc*100, acc_1*100))
            loss, acc, acc_1 = agn_model.train(agn_sess, data_dev, is_train=False)
            print("        dev_set, loss [%.8f] p@%d %.4f%% p@1 %.4f%%" % (loss, FLAGS['metric'].value, acc*100, acc_1*100))
            if acc > best_agn_train_acc or acc_1 > best_agn_train_acc_1:
                if acc > best_agn_train_acc: best_agn_train_acc = acc
                if acc_1 > best_agn_train_acc_1: best_agn_train_acc_1 = acc_1
                loss, acc, acc_1 = agn_model.train(agn_sess, data_test, is_train=False)
                print("        test_set, loss [%.8f] p@%d %.4f%% p@1 %.4f%%" % (loss, FLAGS['metric'].value, acc*100, acc_1*100))
                agn_model.saver.save(agn_sess, '%s/checkpoint' % FLAGS['agn_train_dir'].value, global_step=agn_model.global_step)
                print("Saving agn model params in %s" % FLAGS['agn_train_dir'].value)
            print("------agn train finish-------")

def dis_train(size):
    random_generate_session = np.random.permutation(generate_session).tolist()
    for _ in range(size):
        print("In dis_train..")
        print("Epoch = "+ str(_))
        with dis_graph.as_default():
            start_time = time.time()
            loss, acc = dis_model.train(data_train, random_generate_session, sess=dis_sess)
            dis_model.saver.save(dis_sess, '%s/checkpoint' % FLAGS['dis_train_dir'].value, global_step=dis_model.global_step)
            print("Dis epoch %d learning rate %.4f epoch-time %.4f perplexity [%.8f] acc %.4f%%" \
                    % (dis_model.epoch.eval(session=dis_sess), dis_model.learning_rate.eval(session=dis_sess), time.time() - start_time, loss, acc*100))
            print("------dis train finish-------")

def interact(size, use_dis=FLAGS['use_dis'].value):
    start_time = time.time()
    with open("%s/train_log_%d.txt"%(FLAGS['interact_data_dir'].value, session_no), "w") as flog:
        generate_data(size, flog, use_dis=use_dis)
    print("%d interactions finished after %.4fs " % (size, time.time()-start_time))

In [None]:
# Pretraining (replace number of epoch to 50 by 2)
agn_train(1) # train agent for 2 epochs
if FLAGS['interact'].value:
    env_train(1) # train env for 2 epochs
    interact(1000, use_dis=False)
    if FLAGS['use_dis'].value:
        dis_train(1) # train dis for 3 epochs

# Adversarial training
generate_session = []
pre_losses = [1e18] * 3
while True:
    for _ in range(1):
        for _ in range(1):
            if FLAGS['interact'].value:
                interact(200)
            agn_train(1)

        if FLAGS['use_dis'].value:
            env_train(1, pg=True)
            env_train(1)
    if FLAGS['use_dis'].value:
        dis_train()
    print("*"*25)

In agn_train..
Epoch = 0
Get training data:len(dataset) is 65284 
before while in train..
ed = 0
ed = 32
ed = 64
ed = 96
ed = 128
ed = 160
ed = 192
ed = 224
ed = 256
ed = 288
ed = 320
ed = 352
ed = 384
ed = 416
ed = 448
ed = 480
ed = 512
ed = 544
ed = 576
ed = 608
ed = 640
ed = 672
ed = 704
ed = 736
ed = 768
ed = 800
ed = 832
ed = 864
ed = 896
ed = 928
ed = 960
ed = 992
ed = 1024
ed = 1056
ed = 1088
ed = 1120
ed = 1152
ed = 1184
ed = 1216
ed = 1248
ed = 1280
ed = 1312
ed = 1344
ed = 1376
ed = 1408
ed = 1440
ed = 1472
ed = 1504
ed = 1536
ed = 1568
ed = 1600
ed = 1632
ed = 1664
ed = 1696
ed = 1728
ed = 1760
ed = 1792
ed = 1824
ed = 1856
ed = 1888
ed = 1920
ed = 1952
ed = 1984
ed = 2016
ed = 2048
ed = 2080
ed = 2112
ed = 2144
ed = 2176
ed = 2208
ed = 2240
ed = 2272
ed = 2304
ed = 2336
ed = 2368
ed = 2400
ed = 2432
ed = 2464
ed = 2496
ed = 2528
ed = 2560
ed = 2592
ed = 2624
ed = 2656
ed = 2688
ed = 2720
ed = 2752
ed = 2784
ed = 2816
ed = 2848
ed = 2880
ed = 2912
ed = 2944
ed = 2976
ed = 30