In [1]:

import functools
from gym import spaces
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
# RecSim imports
from recsim import agent
from recsim import document
from recsim import user
from recsim.choice_model import MultinomialLogitChoiceModel
from recsim.simulator import environment
from recsim.simulator import recsim_gym
from recsim.simulator import runner_lib
from recsim.environments import interest_exploration
from recsim.agent import AbstractEpisodicRecommenderAgent
from recsim import utils

In [2]:
from recsim.environments import interest_exploration
from MusicEnv.Documents import *
from MusicEnv.Listener import *

from Agents.StaticAgent import StaticAgent
from Agents.GreedyClusterAgent import GreedyClusterAgent
from Agents.QLeaningAgent import QLearningAgent

In [3]:
def createEnvironment(env_config, num_genres=3):
    
    IEDocument.NUM_CLUSTERS = num_genres
    
    seed = env_config['seed']
    rng = np.random.default_rng(seed)
    user_model = MusicListenerModel(
        env_config['slate_size']
        ,choice_model_ctor = choice_model.MultinomialProportionalChoiceModel
        ,user_state_ctor = MusicListenerState
        ,response_model_ctor = MusicResponse
        ,seed=env_config['seed']
    )

    document_sampler = MusicDocumentSampler(
      topic_distribution = rng.dirichlet(np.ones(num_genres))
      ,topic_quality_mean = rng.uniform(0, 1, num_genres)
      ,topic_quality_stddev = np.ones(num_genres) * 0.1
    )

    ievenv = environment.Environment(
        user_model,
        document_sampler,
        env_config['num_candidates'],
        env_config['slate_size'],
        resample_documents=env_config['resample_documents'])
    
    return recsim_gym.RecSimGymEnv(
        ievenv
        ,clicked_watchtime_reward
        ,utils.aggregate_video_cluster_metrics
        ,utils.write_video_cluster_metrics)



In [4]:
def create_agent_greedy(sess, environment, eval_mode, summary_writer=None):
  return GreedyClusterAgent(environment.observation_space, environment.action_space)

def create_agent_static(sess, environment, eval_mode, summary_writer=None):
  return StaticAgent(environment.observation_space, environment.action_space)

def create_agent_q(sess, environment, eval_mode, summary_writer=None):
  return QLearningAgent(environment.observation_space, environment.action_space)

In [5]:
env_config = {'slate_size': 3,
              'seed': 7,
              'num_candidates': 17,
              'resample_documents': True}

agents = {
        # 'Q_learning' : create_agent_q,
        # 'static' : create_agent_static,
        'greedy' : create_agent_greedy,
}

experiment_name = "TEST_OWANIE"

for agent_name, create_agent in agents.items():
        tmp_base_dir = f'tmp\\{experiment_name}\\{agent_name}'
        ie_environment = createEnvironment(env_config, 6)

        runner = runner_lib.TrainRunner(
                checkpoint_frequency=200,
                base_dir=tmp_base_dir,
                create_agent_fn = create_agent,
                env=ie_environment,
                max_training_steps= 100,
                max_steps_per_episode = 100,
                num_iterations=100000
                )

        runner.run_experiment()

INFO:tensorflow:max_training_steps = 100, number_iterations = 100000,checkpoint frequency = 200 iterations.
INFO:tensorflow:max_steps_per_episode = 100
INFO:tensorflow:Beginning training...
INFO:tensorflow:Starting iteration 0
INFO:tensorflow:Starting iteration 1
INFO:tensorflow:Starting iteration 2
INFO:tensorflow:Starting iteration 3
INFO:tensorflow:Starting iteration 4
INFO:tensorflow:Starting iteration 5
INFO:tensorflow:Starting iteration 6
INFO:tensorflow:Starting iteration 7
INFO:tensorflow:Starting iteration 8
INFO:tensorflow:Starting iteration 9
INFO:tensorflow:Starting iteration 10
INFO:tensorflow:Starting iteration 11
INFO:tensorflow:Starting iteration 12
INFO:tensorflow:Starting iteration 13
INFO:tensorflow:Starting iteration 14
INFO:tensorflow:Starting iteration 15
INFO:tensorflow:Starting iteration 16
INFO:tensorflow:Starting iteration 17
INFO:tensorflow:Starting iteration 18
INFO:tensorflow:Starting iteration 19
INFO:tensorflow:Starting iteration 20
INFO:tensorflow:Starti

KeyboardInterrupt: 