In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import sys
sys.path.append("..")

import jax
import jax.numpy as jnp
import flax.linen as nn
from tqdm import tqdm

import optuna

import polars as pl
import numpy as np
import matplotlib.pyplot as plt

from herec.utils import *
from herec.loader import *
from herec.reader import *
from herec.trainer import *
from herec.model import *

In [None]:
import mlflow
from dotenv import load_dotenv
load_dotenv("../.env")

# パラメータの可視化

In [None]:
"""
    HMFの可視化
"""

run_id = "3c8ecbf575e44d6da67139819ff95c87"
temperature = restoreHyperParams(run_id)["model"]["temperature"]
params = restoreModelParams( run_id )

# Generate User Embeddings
if "rootMatrix" in params["userEmbedder"].keys():
    userrootEmbed = params["userEmbedder"]["rootMatrix"]
    userRootConnection = jnp.linalg.multi_dot([nn.softmax(val / temperature) for key, val in params["userEmbedder"].items() if key != "rootMatrix"] + [jnp.eye(userrootEmbed.shape[0])])
    userObjEmbed = userRootConnection @ userrootEmbed
else:
    userObjEmbed = params["userEmbedder"]["embedding"]

# Generate Item Embeddings
if "rootMatrix" in params["itemEmbedder"].keys():
    itemrootEmbed = params["itemEmbedder"]["rootMatrix"]
    itemRootConnection = jnp.linalg.multi_dot([nn.softmax(val / temperature) for key, val in params["itemEmbedder"].items() if key != "rootMatrix"] + [jnp.eye(itemrootEmbed.shape[0])])
    itemObjEmbed = itemRootConnection @ itemrootEmbed
else:
    itemObjEmbed = params["itemEmbedder"]["embedding"]

# 可視化
fig, ax = plt.subplots( 2, 2, figsize=(8, 8) )

if "rootMatrix" in params["userEmbedder"].keys():
    ax[0, 0].scatter(userrootEmbed[:, 0], userrootEmbed[:, 1], c="gray")
ax[0, 0].scatter(userObjEmbed[:, 0], userObjEmbed[:, 1], c="blue")
if "rootMatrix" in params["itemEmbedder"].keys():
    ax[0, 1].scatter(itemrootEmbed[:, 0], itemrootEmbed[:, 1], c="gray")
ax[0, 1].scatter(itemObjEmbed[:, 0], itemObjEmbed[:, 1], c="blue")

if "rootMatrix" in params["userEmbedder"].keys():
    ax[1, 0].scatter(userrootEmbed[:, -2], userrootEmbed[:, -1], c="gray")
ax[1, 0].scatter(userObjEmbed[:, -2], userObjEmbed[:, -1], c="blue")
if "rootMatrix" in params["itemEmbedder"].keys():
    ax[1, 1].scatter(itemrootEmbed[:, -2], itemrootEmbed[:, -1], c="gray")
ax[1, 1].scatter(itemObjEmbed[:, -2], itemObjEmbed[:, -1], c="blue")

plt.show()

In [None]:
print(jax.lax.top_k(userRootConnection, 5)[1])
print(jax.lax.top_k(itemRootConnection, 5)[1])
print(jax.lax.top_k(userObjEmbed[:10] @ itemObjEmbed.T, 5)[1])
# print( jax.lax.top_k(userrootEmbed @ itemrootEmbed.T, 5)[1] )

In [None]:
"""
    MFの可視化
"""

run_id = "e69f7c68b3884a64b40af918ab1c832c"
params = restoreModelParams( run_id )

userObjEmbed = params["userEmbedder"]["embedding"]
itemObjEmbed = params["itemEmbedder"]["embedding"]

# 可視化
fig, ax = plt.subplots( 2, 2, figsize=(8, 8) )

if "rootMatrix" in params["userEmbedder"].keys():
    ax[0, 0].scatter(userrootEmbed[:, 0], userrootEmbed[:, 1], c="gray")
ax[0, 0].scatter(userObjEmbed[:, 0], userObjEmbed[:, 1], c="blue")
if "rootMatrix" in params["itemEmbedder"].keys():
    ax[0, 1].scatter(itemrootEmbed[:, 0], itemrootEmbed[:, 1], c="gray")
ax[0, 1].scatter(itemObjEmbed[:, 0], itemObjEmbed[:, 1], c="blue")

if "rootMatrix" in params["userEmbedder"].keys():
    ax[1, 0].scatter(userrootEmbed[:, -2], userrootEmbed[:, -1], c="gray")
ax[1, 0].scatter(userObjEmbed[:, -2], userObjEmbed[:, -1], c="blue")
if "rootMatrix" in params["itemEmbedder"].keys():
    ax[1, 1].scatter(itemrootEmbed[:, -2], itemrootEmbed[:, -1], c="gray")
ax[1, 1].scatter(itemObjEmbed[:, -2], itemObjEmbed[:, -1], c="blue")

plt.show()

In [None]:
print(jax.lax.top_k(userObjEmbed[:10] @ itemObjEmbed.T, 5)[1])