In [1]:
import os
import shutil
import mlflow
from glob import glob
from tqdm import tqdm
import numpy as np
import pandas as pd
import polars as pl
import re
import matplotlib
import matplotlib.pyplot as plt
plt.style.use('tableau-colorblind10')
from IPython.display import display
from scipy import stats

from dotenv import load_dotenv
load_dotenv("../.env")

import jax
import jax.numpy as jnp

import sys
sys.path.append("..")
from herec.utils import *
from herec.loader import *
from herec.reader import *
from herec.model import *



# ML1M_IMPLICT （アイテムの深さ2）

In [2]:
## Restore

run_id = "4e3539a9ee7f42f2a7d18ae64ddb4f80"
params = jax.device_put(restoreModelParams(run_id))
hyparams = restoreHyperParams(run_id)
t = hyparams["model"]["temperature"]

userConnectionMatrix = jax.nn.softmax(params["userEmbedder"]["connectionMatrix_1"] / t)
itemConnectionMatrix = jax.nn.softmax(params["itemEmbedder"]["connectionMatrix_1"] / t) @ jax.nn.softmax(params["itemEmbedder"]["connectionMatrix_2"] / t)
userRootMatrix = params["userEmbedder"]["rootMatrix"]
itemRootMatrix = params["itemEmbedder"]["rootMatrix"]

userEmbedMatrix = userConnectionMatrix @ userRootMatrix
itemEmbedMatrix = itemConnectionMatrix @ itemRootMatrix

clusterToCluster = userRootMatrix @ itemRootMatrix.T

[19]


  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


In [3]:
## READ

DATA = ML1M_IMPLICIT().get(0, "test")
userIdMap, itemIdMap = DATA["user_id_map"], DATA["item_id_map"]

with open(f"{getRepositoryPath()}/dataset/ML1M/ml-1m/users.dat") as f:
    lines = f.readlines()
userAttriibute = pl.DataFrame([line.replace("\n", "").split("::") for line in lines], schema=["userId", "gender", "age", "occupation", "zipCode"], orient="row")
userAttriibute = userAttriibute.with_columns( pl.col("userId", "age", "occupation").cast(int) )
userAttriibute = userAttriibute.filter( pl.col("userId").is_in(userIdMap.keys()) ).with_columns( pl.col("userId").replace(userIdMap, default=None) ) # FoldのUserIDに揃える
userAttriibute = userAttriibute.with_columns( pl.col("userId").replace( dict(zip(range(userConnectionMatrix.shape[0]), userConnectionMatrix.argmax(axis=1).tolist())), default=None ).alias("userClusterId") ) # ClusterIdに変換
userAttriibute = userAttriibute.with_columns(
    pl.col("gender").replace({"M": "Man", "F": "Female",}),
    pl.col("age").replace({
        1: "Under 18",
        18: "18-24",
        25: "25-34",
        35: "35-44",
        45: "45-49",
        50: "50-55",
        56: "56+",
    }),
    pl.col("occupation").replace({
        0: "other or not specified",
        1: "academic/educator",
        2: "artist",
        3: "clerical/admin",
        4: "college/grad student",
        5: "customer service",
        6: "doctor/health care",
        7: "executive/managerial",
        8: "farmer",
        9: "homemaker",
        10: "K-12 student",
        11: "lawyer",
        12: "programmer",
        13: "retired",
        14: "sales/marketing",
        15: "scientist",
        16: "self-employed",
        17: "technician/engineer",
        18: "tradesman/craftsman",
        19: "unemployed",
        20: "writer",
    }),
)

with open(f"{getRepositoryPath()}/dataset/ML1M/ml-1m/movies.dat", encoding='latin1') as f:
    lines = f.readlines()
itemAttriibute = pl.DataFrame([line.replace("\n", "").split("::") for line in lines], schema=["itemId", "title", "genres"], orient="row")
itemAttriibute = itemAttriibute.with_columns( pl.col("itemId").cast(int), pl.col("genres").str.split("|") )
itemAttriibute = itemAttriibute.filter( pl.col("itemId").is_in(itemIdMap.keys()) ).with_columns( pl.col("itemId").replace(itemIdMap, default=None) ) # FoldのItemIDに揃える
itemAttriibute = itemAttriibute.with_columns( pl.col("itemId").replace( dict(zip(range(itemConnectionMatrix.shape[0]), itemConnectionMatrix.argmax(axis=1).tolist())), default=None ).alias("itemClusterId") ) # ClusterIdに変換

In [4]:
def count(values):

    if values.dtype == pl.List:
        values = values.explode().value_counts().with_columns( pl.col("count") / values.shape[0] ).sort("count", descending=True).head(3)
    else:
        values = values.value_counts().with_columns( pl.col("count") / values.shape[0] ).sort("count", descending=True).head(3)
    
    for key, val in values.rows():
        print( key, rf"({round(val*100, 1)}\%)", end=", " )
    print()

## 男/女はどのようなアイテムクラスタを視聴しているか？ 

In [5]:
# is_MAN?
print( userAttriibute.select( "userClusterId", (pl.col("gender") == "Man").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬────────┐
│ userClusterId ┆ is_man │
│ ---           ┆ ---    │
│ i64           ┆ f64    │
╞═══════════════╪════════╡
│ 557           ┆ 0.0    │
│ 301           ┆ 0.0    │
│ 480           ┆ 0.0    │
│ 93            ┆ 0.0    │
│ …             ┆ …      │
│ 110           ┆ 1.0    │
│ 238           ┆ 1.0    │
│ 524           ┆ 1.0    │
│ 387           ┆ 1.0    │
└───────────────┴────────┘


In [6]:
"""
    Man
"""

userId = 524

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (2, 6)
┌────────┬────────┬───────┬─────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation      ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---             ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str             ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪═════════════════╪═════════╪═══════════════╡
│ 4956   ┆ Man    ┆ 25-34 ┆ writer          ┆ 98105   ┆ 524           │
│ 3088   ┆ Man    ┆ 18-24 ┆ sales/marketing ┆ 48185   ┆ 524           │
└────────┴────────┴───────┴─────────────────┴─────────┴───────────────┘
Man (100.0\%), 
18-24 (50.0\%), 25-34 (50.0\%), 
sales/marketing (50.0\%), writer (50.0\%), 

******************************

[203 296 271]
Drama (75.7\%), Comedy (28.4\%), Romance (20.3\%), 
66.6
Drama (62.4\%), Comedy (37.8\%), Romance (20.8\%), 
117.8

52.5

******************************

[ 66 160   3]
Adventure (66.7\%), Children's (33.3\%), Fantasy (33.3\%), 
12.5
Action (50.0\%), Horror (25.0\%), Thr

In [7]:
"""
    Female
"""

userId = 557

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬───────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation        ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---               ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str               ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪═══════════════════╪═════════╪═══════════════╡
│ 541    ┆ Female ┆ 50-55 ┆ academic/educator ┆ 62901   ┆ 557           │
└────────┴────────┴───────┴───────────────────┴─────────┴───────────────┘
Female (100.0\%), 
50-55 (100.0\%), 
academic/educator (100.0\%), 

******************************

[203 271 296]
Drama (75.7\%), Comedy (28.4\%), Romance (20.3\%), 
66.6

52.5
Drama (62.4\%), Comedy (37.8\%), Romance (20.8\%), 
117.8

******************************

[ 66 150 160]
Adventure (66.7\%), Action (33.3\%), Fantasy (33.3\%), 
12.5
Comedy (35.9\%), Action (32.8\%), Children's (17.2\%), 
26.4
Action (50.0\%), Horror (25.0\%), Sci-Fi (25.0\%), 
20.4


## 特定の職種はどのようなアイテムクラスタを視聴しているか？ 

In [8]:
# is programmer?
print( userAttriibute.select( "userClusterId", (pl.col("occupation") == "programmer").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬────────┐
│ userClusterId ┆ is_man │
│ ---           ┆ ---    │
│ i64           ┆ f64    │
╞═══════════════╪════════╡
│ 527           ┆ 0.0    │
│ 6             ┆ 0.0    │
│ 530           ┆ 0.0    │
│ 274           ┆ 0.0    │
│ …             ┆ …      │
│ 348           ┆ 0.5    │
│ 27            ┆ 1.0    │
│ 246           ┆ 1.0    │
│ 74            ┆ 1.0    │
└───────────────┴────────┘


In [9]:
userId = 74

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬────────────┬────────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation ┆ zipCode    ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---        ┆ ---        ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str        ┆ str        ┆ i64           │
╞════════╪════════╪═══════╪════════════╪════════════╪═══════════════╡
│ 3853   ┆ Man    ┆ 25-34 ┆ programmer ┆ 27713-9225 ┆ 74            │
└────────┴────────┴───────┴────────────┴────────────┴───────────────┘
Man (100.0\%), 
25-34 (100.0\%), 
programmer (100.0\%), 

******************************

[178 138 251]
Action (100.0\%), Thriller (54.5\%), Adventure (36.4\%), 
37.2
Action (57.4\%), Thriller (39.3\%), Comedy (24.6\%), 
40.600002
Action (87.8\%), Thriller (53.7\%), Adventure (31.7\%), 
50.4

******************************

[314 311 203]

7.0
Drama (80.0\%), Comedy (20.0\%), Mystery (20.0\%), 
13.2
Drama (75.7\%), Comedy (28.4\%), Romance (20.3\%), 
66.6


In [10]:
# is customer service?
print( userAttriibute.select( "userClusterId", (pl.col("occupation") == "customer service").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬──────────┐
│ userClusterId ┆ is_man   │
│ ---           ┆ ---      │
│ i64           ┆ f64      │
╞═══════════════╪══════════╡
│ 6             ┆ 0.0      │
│ 530           ┆ 0.0      │
│ 9             ┆ 0.0      │
│ 143           ┆ 0.0      │
│ …             ┆ …        │
│ 169           ┆ 0.25     │
│ 573           ┆ 0.285714 │
│ 316           ┆ 0.5      │
│ 238           ┆ 1.0      │
└───────────────┴──────────┘


In [11]:
userId = 238

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬──────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation       ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---              ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str              ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪══════════════════╪═════════╪═══════════════╡
│ 692    ┆ Man    ┆ 18-24 ┆ customer service ┆ 55414   ┆ 238           │
└────────┴────────┴───────┴──────────────────┴─────────┴───────────────┘
Man (100.0\%), 
18-24 (100.0\%), 
customer service (100.0\%), 

******************************

[240   3 178]
Sci-Fi (100.0\%), Action (50.0\%), Drama (25.0\%), 
49.4
Action (33.6\%), Comedy (24.0\%), Sci-Fi (23.6\%), 
129.3
Action (100.0\%), Thriller (54.5\%), Adventure (36.4\%), 
37.2

******************************

[314 311 203]

7.0
Drama (80.0\%), Comedy (20.0\%), Mystery (20.0\%), 
13.2
Drama (75.7\%), Comedy (28.4\%), Romance (20.3\%), 
66.6


# ML1M_IMPLICT （アイテムの深さ1）

In [12]:
## Restore

run_id = "4e3539a9ee7f42f2a7d18ae64ddb4f80"
params = jax.device_put(restoreModelParams(run_id))
hyparams = restoreHyperParams(run_id)
t = hyparams["model"]["temperature"]

userConnectionMatrix = jax.nn.softmax(params["userEmbedder"]["connectionMatrix_1"] / t)
itemConnectionMatrix = jax.nn.softmax(params["itemEmbedder"]["connectionMatrix_1"] / t)
userRootMatrix = params["userEmbedder"]["rootMatrix"]
itemRootMatrix = jax.nn.softmax(params["itemEmbedder"]["connectionMatrix_2"] / t) @ params["itemEmbedder"]["rootMatrix"]

userEmbedMatrix = userConnectionMatrix @ userRootMatrix
itemEmbedMatrix = itemConnectionMatrix @ itemRootMatrix

clusterToCluster = userRootMatrix @ itemRootMatrix.T

[19]


Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00,  5.52it/s]


In [13]:
## READ

DATA = ML1M_IMPLICIT().get(0, "test")
userIdMap, itemIdMap = DATA["user_id_map"], DATA["item_id_map"]

with open(f"{getRepositoryPath()}/dataset/ML1M/ml-1m/users.dat") as f:
    lines = f.readlines()
userAttriibute = pl.DataFrame([line.replace("\n", "").split("::") for line in lines], schema=["userId", "gender", "age", "occupation", "zipCode"], orient="row")
userAttriibute = userAttriibute.with_columns( pl.col("userId", "age", "occupation").cast(int) )
userAttriibute = userAttriibute.filter( pl.col("userId").is_in(userIdMap.keys()) ).with_columns( pl.col("userId").replace(userIdMap, default=None) ) # FoldのUserIDに揃える
userAttriibute = userAttriibute.with_columns( pl.col("userId").replace( dict(zip(range(userConnectionMatrix.shape[0]), userConnectionMatrix.argmax(axis=1).tolist())), default=None ).alias("userClusterId") ) # ClusterIdに変換
userAttriibute = userAttriibute.with_columns(
    pl.col("gender").replace({"M": "Man", "F": "Female",}),
    pl.col("age").replace({
        1: "Under 18",
        18: "18-24",
        25: "25-34",
        35: "35-44",
        45: "45-49",
        50: "50-55",
        56: "56+",
    }),
    pl.col("occupation").replace({
        0: "other or not specified",
        1: "academic/educator",
        2: "artist",
        3: "clerical/admin",
        4: "college/grad student",
        5: "customer service",
        6: "doctor/health care",
        7: "executive/managerial",
        8: "farmer",
        9: "homemaker",
        10: "K-12 student",
        11: "lawyer",
        12: "programmer",
        13: "retired",
        14: "sales/marketing",
        15: "scientist",
        16: "self-employed",
        17: "technician/engineer",
        18: "tradesman/craftsman",
        19: "unemployed",
        20: "writer",
    }),
)

with open(f"{getRepositoryPath()}/dataset/ML1M/ml-1m/movies.dat", encoding='latin1') as f:
    lines = f.readlines()
itemAttriibute = pl.DataFrame([line.replace("\n", "").split("::") for line in lines], schema=["itemId", "title", "genres"], orient="row")
itemAttriibute = itemAttriibute.with_columns( pl.col("itemId").cast(int), pl.col("genres").str.split("|") )
itemAttriibute = itemAttriibute.filter( pl.col("itemId").is_in(itemIdMap.keys()) ).with_columns( pl.col("itemId").replace(itemIdMap, default=None) ) # FoldのItemIDに揃える
itemAttriibute = itemAttriibute.with_columns( pl.col("itemId").replace( dict(zip(range(itemConnectionMatrix.shape[0]), itemConnectionMatrix.argmax(axis=1).tolist())), default=None ).alias("itemClusterId") ) # ClusterIdに変換

In [14]:
def count(values):

    if values.dtype == pl.List:
        values = values.explode().value_counts().with_columns( pl.col("count") / values.shape[0] ).sort("count", descending=True).head(3)
    else:
        values = values.value_counts().with_columns( pl.col("count") / values.shape[0] ).sort("count", descending=True).head(3)
    
    for key, val in values.rows():
        print( key, rf"({round(val*100, 1)}\%)", end=", " )
    print()

## 男/女はどのようなアイテムクラスタを視聴しているか？ 

In [15]:
# is_MAN?
print( userAttriibute.select( "userClusterId", (pl.col("gender") == "Man").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬────────┐
│ userClusterId ┆ is_man │
│ ---           ┆ ---    │
│ i64           ┆ f64    │
╞═══════════════╪════════╡
│ 301           ┆ 0.0    │
│ 557           ┆ 0.0    │
│ 93            ┆ 0.0    │
│ 96            ┆ 0.0    │
│ …             ┆ …      │
│ 110           ┆ 1.0    │
│ 244           ┆ 1.0    │
│ 524           ┆ 1.0    │
│ 387           ┆ 1.0    │
└───────────────┴────────┘


In [16]:
"""
    Man
"""

userId = 524

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (2, 6)
┌────────┬────────┬───────┬─────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation      ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---             ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str             ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪═════════════════╪═════════╪═══════════════╡
│ 4956   ┆ Man    ┆ 25-34 ┆ writer          ┆ 98105   ┆ 524           │
│ 3088   ┆ Man    ┆ 18-24 ┆ sales/marketing ┆ 48185   ┆ 524           │
└────────┴────────┴───────┴─────────────────┴─────────┴───────────────┘
Man (100.0\%), 
18-24 (50.0\%), 25-34 (50.0\%), 
sales/marketing (50.0\%), writer (50.0\%), 

******************************

[415 550 379]
Drama (100.0\%), 
4.6
Comedy (100.0\%), 
4.7000003
Drama (100.0\%), 
4.6

******************************

[398 223 574]

3.7

3.7
Action (55.6\%), War (44.4\%), Sci-Fi (44.4\%), 
3.7


In [17]:
"""
    Female
"""

userId = 557

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬───────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation        ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---               ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str               ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪═══════════════════╪═════════╪═══════════════╡
│ 541    ┆ Female ┆ 50-55 ┆ academic/educator ┆ 62901   ┆ 557           │
└────────┴────────┴───────┴───────────────────┴─────────┴───────────────┘
Female (100.0\%), 
50-55 (100.0\%), 
academic/educator (100.0\%), 

******************************

[415 550 379]
Drama (100.0\%), 
4.6
Comedy (100.0\%), 
4.7000003
Drama (100.0\%), 
4.6

******************************

[307 470 128]
Action (100.0\%), Sci-Fi (33.3\%), Drama (33.3\%), 
4.8
Action (100.0\%), Adventure (100.0\%), 
4.6
Horror (66.7\%), War (33.3\%), Sci-Fi (33.3\%), 
4.6


## 特定の職種はどのようなアイテムクラスタを視聴しているか？ 

In [18]:
# is programmer?
print( userAttriibute.select( "userClusterId", (pl.col("occupation") == "programmer").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬────────┐
│ userClusterId ┆ is_man │
│ ---           ┆ ---    │
│ i64           ┆ f64    │
╞═══════════════╪════════╡
│ 143           ┆ 0.0    │
│ 268           ┆ 0.0    │
│ 271           ┆ 0.0    │
│ 6             ┆ 0.0    │
│ …             ┆ …      │
│ 348           ┆ 0.5    │
│ 27            ┆ 1.0    │
│ 246           ┆ 1.0    │
│ 74            ┆ 1.0    │
└───────────────┴────────┘


In [19]:
userId = 74

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬────────────┬────────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation ┆ zipCode    ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---        ┆ ---        ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str        ┆ str        ┆ i64           │
╞════════╪════════╪═══════╪════════════╪════════════╪═══════════════╡
│ 3853   ┆ Man    ┆ 25-34 ┆ programmer ┆ 27713-9225 ┆ 74            │
└────────┴────────┴───────┴────────────┴────────────┴───────────────┘
Man (100.0\%), 
25-34 (100.0\%), 
programmer (100.0\%), 

******************************

[489 405 623]
Sci-Fi (100.0\%), Action (66.7\%), Comedy (33.3\%), 
3.5
Comedy (100.0\%), Western (100.0\%), Action (100.0\%), 
3.4
Action (100.0\%), Crime (100.0\%), Comedy (100.0\%), 
3.5

******************************

[379 550 415]
Drama (100.0\%), 
4.6
Comedy (100.0\%), 
4.7000003
Drama (100.0\%), 
4.6


In [20]:
# is customer service?
print( userAttriibute.select( "userClusterId", (pl.col("occupation") == "customer service").cast(int).alias("is_man") ).group_by("userClusterId").mean().sort("is_man") )

shape: (396, 2)
┌───────────────┬──────────┐
│ userClusterId ┆ is_man   │
│ ---           ┆ ---      │
│ i64           ┆ f64      │
╞═══════════════╪══════════╡
│ 143           ┆ 0.0      │
│ 268           ┆ 0.0      │
│ 402           ┆ 0.0      │
│ 399           ┆ 0.0      │
│ …             ┆ …        │
│ 169           ┆ 0.25     │
│ 573           ┆ 0.285714 │
│ 316           ┆ 0.5      │
│ 238           ┆ 1.0      │
└───────────────┴──────────┘


In [21]:
userId = 238

# 対象ユーザクラスタの中身
print( d := userAttriibute.filter( pl.col("userClusterId") == userId ) )
count(d["gender"])
count(d["age"])
count(d["occupation"])

print("\n******************************\n")

# 嗜好 Top-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[-3:][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

print("\n******************************\n")

# 嗜好 Worst-5 アイテムクラスタ
print(ids := clusterToCluster[userId].argsort()[:3][::-1])
for id in ids:
    values = itemAttriibute.filter( pl.col("itemClusterId") == id ).get_column("genres")
    count(values)
    print(itemConnectionMatrix[:, id].sum().round(1))

shape: (1, 6)
┌────────┬────────┬───────┬──────────────────┬─────────┬───────────────┐
│ userId ┆ gender ┆ age   ┆ occupation       ┆ zipCode ┆ userClusterId │
│ ---    ┆ ---    ┆ ---   ┆ ---              ┆ ---     ┆ ---           │
│ i64    ┆ str    ┆ str   ┆ str              ┆ str     ┆ i64           │
╞════════╪════════╪═══════╪══════════════════╪═════════╪═══════════════╡
│ 692    ┆ Man    ┆ 18-24 ┆ customer service ┆ 55414   ┆ 238           │
└────────┴────────┴───────┴──────────────────┴─────────┴───────────────┘
Man (100.0\%), 
18-24 (100.0\%), 
customer service (100.0\%), 

******************************

[359  29   9]
Sci-Fi (60.0\%), Action (40.0\%), Film-Noir (20.0\%), 
3.4

3.4

3.4

******************************

[379 550 415]
Drama (100.0\%), 
4.6
Comedy (100.0\%), 
4.7000003
Drama (100.0\%), 
4.6
