In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import jax
import jax.numpy as jnp
import flax.linen as nn
from tqdm import tqdm
import polars as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('tableau-colorblind10')

import mlflow
from dotenv import load_dotenv
load_dotenv("../.env")

import sys
sys.path.append("..")
from herec.utils import *
from herec.loader import *
from herec.reader import *
from herec.trainer import *
from herec.model import *

In [None]:
def restoreDataAndPrediction(datasetName, modelName, seed):

    if datasetName == "Twitch100K":
        reader = Twitch100K()
    if datasetName == "ML100K_IMPLICIT":
        reader = ML100K_IMPLICIT()
    if datasetName == "ML1M_IMPLICIT":
        reader = ML1M_IMPLICIT()

    print( f"{datasetName}-{modelName}-TEST", seed )
    
    # DATA READ
    DATA = reader.get(seed, "test")
    
    # Run IDを取得
    run_id = pl.from_pandas(mlflow.search_runs( experiment_names=[f"{datasetName}-{modelName}-TEST"] )).filter( pl.col("params.seed").cast(int) == seed ).get_column("run_id")[0]
    
    # モデルパラメータ/ハイパーパラメータを取得
    params = restoreModelParams( run_id, -1 )
    hyparams = restoreHyperParams( run_id )
    
    # モデルを取得
    model = getModel( modelName, hyparams, DATA )
    
    pred_scores = model.apply({"params": params}, DATA["df_EVALUATION"]["user_ids"], method=model.get_all_scores_by_user_ids)
    topk_indices = jax.lax.top_k( pred_scores, 1000 )[1]

    return DATA, topk_indices

In [None]:
fig, ax = plt.subplots()

datasetName = "ML100K_IMPLICIT"
for modelName in ["MF_BPR", "ProtoMF_BPR", "HE_MF_USER_BPR", "HE_MF_ITEM_BPR", "HE_MF_BPR"]:

    values = []
    for seed in range(3):
        DATA, topk_indices = restoreDataAndPrediction(datasetName, modelName, seed)
        true_item_ids = DATA["df_EVALUATION"]["true_item_ids"]
        true_item_len = DATA["df_EVALUATION"]["true_item_len"]
        hit_flags = jax.vmap(lambda a, b: jnp.isin(a, b), in_axes=(0, 0), out_axes=(0))(topk_indices, DATA["df_EVALUATION"]["true_item_ids"]).astype(int)
        hit_flags_cumsum = hit_flags.cumsum(axis=1)
        recall_at_k = (hit_flags_cumsum.T / true_item_len).T.mean(axis=0)
        values.append( recall_at_k )

    recall_at_k = sum(values) / 3
    ax.plot( range(1, 1001), recall_at_k, label=modelName )

plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()

datasetName = "ML1M_IMPLICIT"
for modelName in ["MF_BPR", "HE_MF_BPR"]:

    values = []
    for seed in range(3):
        DATA, topk_indices = restoreDataAndPrediction(datasetName, modelName, seed)
        true_item_ids = DATA["df_EVALUATION"]["true_item_ids"]
        true_item_len = DATA["df_EVALUATION"]["true_item_len"]
        hit_flags = jax.vmap(lambda a, b: jnp.isin(a, b), in_axes=(0, 0), out_axes=(0))(topk_indices, DATA["df_EVALUATION"]["true_item_ids"]).astype(int)
        hit_flags_cumsum = hit_flags.cumsum(axis=1)
        recall_at_k = (hit_flags_cumsum.T / true_item_len).T.mean(axis=0)
        values.append( recall_at_k )

    recall_at_k = sum(values) / 3
    ax.plot( range(1, 1001), recall_at_k, label=modelName )

plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots()

datasetName = "Twitch100K"
for modelName in ["MF_BPR", "HE_MF_BPR"]:

    values = []
    for seed in range(3):
        DATA, topk_indices = restoreDataAndPrediction(datasetName, modelName, seed)
        true_item_ids = DATA["df_EVALUATION"]["true_item_ids"]
        true_item_len = DATA["df_EVALUATION"]["true_item_len"]
        hit_flags = jax.vmap(lambda a, b: jnp.isin(a, b), in_axes=(0, 0), out_axes=(0))(topk_indices, DATA["df_EVALUATION"]["true_item_ids"]).astype(int)
        hit_flags_cumsum = hit_flags.cumsum(axis=1)
        recall_at_k = (hit_flags_cumsum.T / true_item_len).T.mean(axis=0)
        values.append( recall_at_k )

    recall_at_k = sum(values) / 3
    ax.plot( range(1, 1001), recall_at_k, label=modelName )

plt.legend()
plt.show()