In [1]:
from mars_gym.data import utils
import pandas as pd

In [2]:
utils.datasets()

['random',
 'yoochoose',
 'processed_yoochoose',
 'trivago_rio',
 'processed_trivago_rio']

In [3]:
df, df_meta = utils.load_dataset('processed_trivago_rio')

In [4]:
df.head()

Unnamed: 0,session_id,user_id,timestamp,action_type,item_id,impressions,list_reference_item,pos_item_id,clicked
0,05fe82b496fb9,M1Z13DD0P2KH,1541422443,clickout item,4304686,"['109351', '150138', '4345728', '105014', '478...","['', '', '', '', '']",7,1.0
1,05fe82b496fb9,M1Z13DD0P2KH,1541422474,clickout item,960255,"['1475717', '5196406', '104880', '109351', '68...","['4304686', '', '', '', '']",20,1.0
2,05fe82b496fb9,M1Z13DD0P2KH,1541423039,clickout item,2188598,"['104558', '326781', '104786', '1223390', '206...","['4304686', '960255', '', '', '']",9,1.0
3,05fe82b496fb9,M1Z13DD0P2KH,1541424631,clickout item,8459162,"['105014', '5659850', '478121', '109351', '956...","['4304686', '960255', '2188598', '', '']",23,1.0
4,05fe82b496fb9,M1Z13DD0P2KH,1541424685,interaction item info,8459162,,"['4304686', '960255', '2188598', '8459162', '']",-1,0.0


In [5]:
df_meta[['list_metadata']].head()

Unnamed: 0,list_metadata
0,"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, ..."
1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"[0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, ..."
3,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, ..."
4,"[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, ..."


In [6]:
from mars_gym.data.utils import DownloadDataset
import luigi
class PrepareInteractionData(luigi.Task):
    def requires(self):
        return DownloadDataset(dataset="processed_trivago_rio", output_path=OUTPUT_PATH)

    def output(self):
        return luigi.LocalTarget(os.path.join(DATASET_DIR, "dataset.csv",))

    def run(self):
        os.makedirs(DATASET_DIR, exist_ok=True)

        df = pd.read_csv(self.input()[0].path)

        # .... transform dataset

        df.to_csv(self.output().path)


class PrepareMetaData(luigi.Task):
    def requires(self):
        return DownloadDataset(dataset="processed_trivago_rio", output_path=OUTPUT_PATH)

    def output(self):
        return luigi.LocalTarget(os.path.join(DATASET_DIR, "metadata.csv",))

    def run(self):
        os.makedirs(DATASET_DIR, exist_ok=True)

        df = pd.read_csv(self.input()[1].path)

        # .... transform dataset

        df.to_csv(self.output().path, index="item_id")

In [7]:
from mars_gym.data.task import BasePrepareDataFrames

class PrepareTrivagoDataFrame(BasePrepareDataFrames):
    def requires(self):
        return (
            PrepareInteractionData(),
            PrepareMetaData(),
        )

    @property
    def timestamp_property(self) -> str:
        return "timestamp"

    @property
    def dataset_dir(self) -> str:
        return DATASET_DIR

    @property
    def read_data_frame_path(self):
        return self.input()[0].path

    @property
    def metadata_data_frame_path(self):
        return self.input()[1].path

In [8]:
from samples.trivago_simple.data import PrepareTrivagoDataFrame

In [9]:
import luigi

In [10]:
from mars_gym.data.dataset import InteractionsDataset
from mars_gym.meta_config import *
from samples.trivago_rio import data

trivago_rio = ProjectConfig(
    base_dir=data.BASE_DIR,
    prepare_data_frames_task=data.PrepareTrivagoDataFrame,
    dataset_class=InteractionsDataset,
    user_column=Column("user_id", IOType.INDEXABLE),
    item_column=Column("item_id", IOType.INDEXABLE),
    other_input_columns=[
        Column("pos_item_id", IOType.NUMBER),
        Column("list_reference_item", IOType.INDEXABLE_ARRAY, same_index_as="item_id"),
    ],
    metadata_columns=[Column("list_metadata", IOType.INT_ARRAY),],
    output_column=Column("clicked", IOType.NUMBER),
    available_arms_column_name="impressions"
)

In [11]:
import luigi
from typing import Dict, Any
import torch
import torch.nn as nn
from mars_gym.meta_config import ProjectConfig
from mars_gym.model.abstract import RecommenderModule


class SimpleLinearModel(RecommenderModule):
    def __init__(
        self,
        project_config: ProjectConfig,
        index_mapping: Dict[str, Dict[Any, int]],
    ):
      """
      build model architecture
      """
      super().__init__(project_config, index_mapping)
      #...

    def forward(
        self,
        user_ids: torch.Tensor,
        item_ids: torch.Tensor,
        pos_item_id: torch.Tensor,
        list_reference_item: torch.Tensor,
        list_metadata: torch.Tensor,
    ):
      """
      build forward
      """
      pass

In [12]:
class SimpleLinearModel(RecommenderModule):
    def __init__(
        self,
        project_config: ProjectConfig,
        index_mapping: Dict[str, Dict[Any, int]],
        n_factors: int,
        metadata_size: int,
        window_hist_size: int,
    ):
        super().__init__(project_config, index_mapping)

        self.user_embeddings = nn.Embedding(self._n_users, n_factors)
        self.item_embeddings = nn.Embedding(self._n_items, n_factors)

        # user + item + flatten hist + position + metadata
        num_dense = 2 * n_factors + window_hist_size * n_factors + 1 + metadata_size

        self.dense = nn.Sequential(
            nn.Linear(num_dense, 500), nn.SELU(), nn.Linear(500, 1),
        )

    def flatten(self, input: torch.Tensor):
        return input.view(input.size(0), -1)

    def forward(
        self,
        user_ids: torch.Tensor,
        item_ids: torch.Tensor,
        pos_item_id: torch.Tensor,
        list_reference_item: torch.Tensor,
        list_metadata: torch.Tensor,
    ):
        user_emb = self.user_embeddings(user_ids)
        item_emb = self.item_embeddings(item_ids)
        history_items_emb = self.item_embeddings(list_reference_item)

        x = torch.cat(
            (
                user_emb,
                item_emb,
                self.flatten(history_items_emb),
                pos_item_id.float().unsqueeze(1),
                list_metadata.float(),
            ),
            dim=1,
        )

        x = self.dense(x)
        return torch.sigmoid(x)

In [13]:
from mars_gym.model.bandit import BanditPolicy
from typing import Dict, Any, List, Tuple, Union

class BasePolicy(BanditPolicy):
    def __init__(self, reward_model: nn.Module, seed: int = 42):
        """
        Initialize bandit information and params
        """
        super().__init__(reward_model)

    def _select_idx(
        self,
        arm_indices: List[int],
        arm_contexts: Tuple[np.ndarray, ...] = None,
        arm_scores: List[float] = None,
        pos: int = 0,
    ) -> Union[int, Tuple[int, float]]:
        """
        Choose the index of arm selected in turn
        """

        return action

In [14]:
class EGreedyPolicy(BanditPolicy):
    def __init__(self, reward_model: nn.Module, seed: int = 42):
        super().__init__(reward_model)
        self._rng = RandomState(seed)

    def _select_idx(
        self,
        arm_indices: List[int],
        arm_contexts: Tuple[np.ndarray, ...] = None,
        arm_scores: List[float] = None,
        pos: int = 0,
    ) -> Union[int, Tuple[int, float]]:

        n_arms = len(arm_indices)
        arm_probas = np.ones(n_arms) / n_arms

        if self._rng.choice([True, False], p=[self._epsilon, 1.0 - self._epsilon]):
            action = self._rng.choice(len(arm_indices), p=arm_probas)
        else:
            action = int(np.argmax(arm_scores))

        return action

In [3]:
print(2)

2


In [3]:
from mars_gym.simulation.interaction import InteractionTraining

  "Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"


In [14]:
job_train = InteractionTraining(
    project="samples.trivago_simple.config.trivago_rio",
    recommender_module_class="samples.trivago_simple.simulation.SimpleLinearModel",
    recommender_extra_params={
        "n_factors": 10,
        "metadata_size": 148,
        "window_hist_size": 5,
    },
    bandit_policy_class="samples.trivago_simple.simulation.EGreedyPolicy",
    bandit_policy_params={
        "epsilon": 0.8,
        "seed": 42
    },
    test_size=0.1,
    obs_batch_size=100,
    num_episodes=1,
)

In [15]:
job_train.run()

DataFrame: env_data_frame,  (8186, 5)
DataFrame: interactions_data_frame,  (10898, 9)


KeyError: 'list_metadata'

In [20]:
!mars-gym run interaction --project samples.trivago_simple.config.trivago_rio --recommender-module-class samples.trivago_simple.simulation.SimpleLinearModel --recommender-extra-params '{"n_factors": 10, "metadata_size": 148, "window_hist_size": 5}' --bandit-policy-class samples.trivago_simple.simulation.EGreedyPolicy --bandit-policy-params '{"epsilon": 0.1}' --obs-batch-size 100

'PYTHONPATH' is not recognized as an internal or external command,
operable program or batch file.


In [21]:
import sys

In [22]:
sys.executable

'C:\\Users\\Rodrigues\\anaconda3\\envs\\mars-gym\\python.exe'

In [8]:
import pandas as pd

In [17]:
pip install tensorboard

Collecting tensorboard
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 2.1 MB/s eta 0:00:01
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
Collecting tensorboard-plugin-wit>=1.6.0
  Using cached tensorboard_plugin_wit-1.8.1-py3-none-any.whl (781 kB)
Collecting markdown>=2.6.8
  Downloading Markdown-3.3.7-py3-none-any.whl (97 kB)
[K     |████████████████████████████████| 97 kB 5.5 MB/s  eta 0:00:01
[?25hCollecting grpcio>=1.24.3
  Downloading grpcio-1.48.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[K     |████████████████████████████████| 4.6 MB 25.7 MB/s eta 0:00:01
[?25hCollecting google-auth<3,>=1.6.3
  Using cached google_auth-2.12.0-py2.py3-none-any.whl (169 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Using cached google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)
Collecting absl-py>=0.4
  Usin

In [19]:
%load_ext tensorboard

In [25]:
%tensorboard --logdir output/tensorboard_logs/

Reusing TensorBoard on port 6009 (pid 32656), started 0:01:32 ago. (Use '!kill 32656' to kill it.)

In [13]:
pd.read_csv('output/interaction/InteractionTraining/results/InteractionTraining____samples_trivago____epsilon___0_1__aa554203af/test_set_predictions.csv')

Unnamed: 0,session_id,user_id,timestamp,action_type,item_id,impressions,list_reference_item,pos_item_id,clicked,sorted_actions,prob_actions,action_scores,trained,item_indexed
0,300bb06f19ed2,FQPF98HITP9Q,1541517266,interaction item info,5722428,"['5722428', '5723646', '1770325', '4041114', '...","['', '', '', '', '']",-1,0.0,"['623276', '5958544', '2878872', '4041114', '1...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9996925592422485, 0.9707615971565247, 0.893...",0.0,True
1,4624d9af6f3bc,0B0UYI22Q9FE,1541517318,clickout item,5049220,"['95685', '4409390', '104558', '106405', '4779...","['5822734', '104558', '106405', '95685', '95685']",6,1.0,"['104859', '104558', '97505', '106405', '10480...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[3.811391202646064e-10, 2.417341049554267e-10,...",0.0,True
2,9fdb168f4f0a0,78N03IRUO8XM,1541517330,clickout item,3497814,"['3497814', '4342016', '150138', '4415314', '4...","['', '', '', '', '']",0,1.0,"['4345728', '97505', '109339', '5708658', '104...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.8815032839775085, 0.7482463717460632, 0.280...",0.0,True
3,300bb06f19ed2,FQPF98HITP9Q,1541517334,interaction item info,1477057,"['1477057', '2626948', '1477057', '4549880', '...","['5722428', '', '', '', '']",-1,0.0,"['95794', '6861472', '4626720', '3249804', '65...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9978467226028442, 0.992649257183075, 0.9925...",0.0,True
4,8a39981bbe25d,158TP1W9SFHS,1541517362,clickout item,2703198,"['1671277', '4132378', '4043068', '8459162', '...","['8459162', '', '', '', '']",6,1.0,"['5723646', '2858820', '918999', '2788294', '3...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[6.691816906823078e-06, 1.223645085701719e-06,...",0.0,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1251,985602f0674a9,DL9645R5R3U9,1541548569,interaction item info,3148916,"['3148916', '3148916', '4549880', '4556616', '...","['104793', '5493194', '6452448', '104790', '17...",-1,0.0,"['2924213', '2924213', '118472', '3249804', '9...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...","[0.9917187690734863, 0.9917187690734863, 0.926...",0.0,True
1252,7c4a61b59d85a,HJ06520S4CEM,1541548654,interaction item info,109351,"['109351', '104558', '150138', '1475717', '562...","['104558', '', '', '', '']",-1,0.0,"['3761170', '3881442', '109339', '130772', '16...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9971530437469482, 0.9901509284973145, 0.973...",0.0,True
1253,7c4a61b59d85a,HJ06520S4CEM,1541548674,interaction item info,109351,"['109351', '153550', '1471427', '153550', '314...","['104558', '109351', '', '', '']",-1,0.0,"['477986', '4626720', '104680', '104780', '109...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.09320629388093948, 0.06786947697401047, 0.0...",0.0,True
1254,985602f0674a9,DL9645R5R3U9,1541548762,interaction item info,829671,"['829671', '4454760', '6794024', '6549868', '2...","['104793', '5493194', '6452448', '104790', '17...",-1,0.0,"['5155978', '5958544', '104558', '2733700', '9...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9693230390548706, 0.9036438465118408, 0.588...",0.0,True
