# Utils

In [None]:
import json
import os, json, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import gzip, pickle
from typing import Optional

In [None]:
# Dataset
ITEM_ATTR_CSV = os.path.join("data/books/books_llmrec_format", "item_attribute.csv")
DATA_BASE_DIR = "A-LLMRec/data/books"
U_ITEM_PATH   = f"data/books/5core/item_meta_2017_kcore5_user_item.json"
META_GZ_PATH  = f"{DATA_BASE_DIR}/books_text_name_dict.json.gz"
PARTS = 5

test_path = os.path.join(DATA_BASE_DIR, "books_label.txt")
train_path = os.path.join(DATA_BASE_DIR, "books_train_raw.txt")


# ######################### A-LLMRec #########################
BASE_DIR = f"A-LLMRec/books_results"
PRED_P1 = f"{BASE_DIR}/predict_label_part1.json"
PRED_P5 = f"{BASE_DIR}/predict_label_part5.json"


######################## LLMRec #########################
BASE_DIR = "data/books/books_llmrec_format"
PRED_P1    = f"{BASE_DIR}/predict_label_part1.json"
PRED_P5    = f"{BASE_DIR}/predict_label_part5.json"


# ###################### Augmentation #######################
# BASE_DIR = f"Augmentation/data/books"
# PRED_P1    = f"{BASE_DIR}/predict_label_part1.json"
# PRED_P5    = f"{BASE_DIR}/predict_label_part5.json"


# ###################### Traditional CF ######################
# model = "LightGCN"  # "MF-BPR" or "LightGCN"
# BASE_DIR = f"data/books/traditionalCF"
# PRED_P1    = f"{BASE_DIR}/predict_label_part1.json"
# PRED_P5    = f"{BASE_DIR}/predict_label_part5.json"


with open(PRED_P5) as f:
    pred_dict = json.load(f)
print(pred_dict["0"][:30])  # user 0Ïùò ÏÉÅÏúÑ 10Í∞ú ÏòàÏ∏° ÏïÑÏù¥ÌÖú ID


actual = sum(len(v) for v in pred_dict.values())
print("actual predict_label interactions:", actual)

In [None]:

train_data = {}
with open(train_path, "r") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        u_str, i_str = line.split()
        u = str(u_str)         # ÌÇ§Îäî Î¨∏ÏûêÏó¥Î°ú
        i = int(i_str)         # ÏïÑÏù¥ÌÖúÏùÄ Ï†ïÏàòÎ°ú (ÌòπÏùÄ Î¨∏ÏûêÏó¥Î°ú Îëò ÏàòÎèÑ ÏûàÏùå)

        if u not in train_data:
            train_data[u] = []
        train_data[u].append(i)

ground_truth = {}
with open(test_path, "r") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        u_str, i_str, _ = line.split()
        u = str(u_str)
        i = int(i_str)
        # ground truthÍ∞Ä Ïó¨Îü¨ Í∞úÎ©¥ Î¶¨Ïä§Ìä∏Ïóê append, Ìïú Í∞úÎßå ÏûàÏúºÎ©¥ ÎçÆÏñ¥Ïì∞Í±∞ÎÇò Ï≤´ Í∞úÎßå Ïì∞Î©¥ Îê®
        if u not in ground_truth:
            ground_truth[u] = []
        ground_truth[u].append(i)

expected = sum(len(v) for v in ground_truth.values())
print("expected ground_truth interactions:", expected)

# ground truthÏóêÏÑú Í∞ÄÏû• Í∏¥ Í∏∞Î°ùÏùÑ Í∞ÄÏßÑ ÏÇ¨Ïö©Ïûê Í≤ÄÏÉâ, Ìï¥Îãπ ÏÇ¨Ïö©ÏûêÏùò Í∏∞Î°ù Í∏∏Ïù¥ ÌôïÏù∏
max_len_user = max(ground_truth.items(), key=lambda x: len(x[1]))
print(f"User with the longest ground truth interaction: {max_len_user[0]}, Length: {len(max_len_user[1])}")

# Í≥µÌÜµ ÏÇ¨Ïö©Ïûê
common_users = set(train_data.keys()) & set(ground_truth.keys())
print(len(train_data.keys()))
print(len(ground_truth.keys()))
print(len(common_users))


### Multi-hot-vector

In [None]:
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from collections import Counter

# -----------------------------
# 1) train_data -> multi-hot (sparse CSR)
# -----------------------------
# train_data: { "user_id(str)": [item_id(int), ...], ... }

users = list(train_data.keys())

# item universe(Ï†ÑÏ≤¥ ÏïÑÏù¥ÌÖú id) ÏàòÏßë
all_items = set()
for u in users:
    all_items.update(train_data[u])#[-10:])

# item_idÍ∞Ä 0..N-1Î°ú Ïó∞ÏÜçÏù¥ ÏïÑÎãê Ïàò ÏûàÏúºÎãà re-index
item_list = sorted(all_items)
item2col = {item: idx for idx, item in enumerate(item_list)}

n_users = len(users)
n_items = len(item_list)

# CSR ÎßåÎì§Í∏∞: (row=user, col=item) ÏúÑÏπòÏóê 1
rows = []
cols = []
data = []
for r, u in enumerate(users):
    # Ï§ëÎ≥µ interactionÏùÄ 1Î°ú Ï≤òÎ¶¨(Î©ÄÌã∞Ìï´)
    seen_list = train_data[u]#[-10:]
    seen = set(seen_list)
    for item in seen:
        rows.append(r)
        cols.append(item2col[item])
        data.append(1)

X = csr_matrix((data, (rows, cols)), shape=(n_users, n_items), dtype=np.float32)
print("X shape:", X.shape, "nnz:", X.nnz)

# -----------------------------
# 2) Í≥†Ï∞®Ïõê Ìù¨ÏÜåÎ≤°ÌÑ∞ -> Ï†ÄÏ∞®Ïõê ÏûÑÎ≤†Îî© (SVD)
# -----------------------------
# KMeansÎäî Í≥†Ï∞®Ïõê Ìù¨ÏÜåÏóêÏÑú Î∞îÎ°ú ÎèåÎ†§ÎèÑ ÎêòÏßÄÎßå,
# Î≥¥ÌÜµ SVDÎ°ú 50~200 Ï∞®Ïõê Ï†ïÎèÑÎ°ú Ï§ÑÏù¥Î©¥ Ìõ®Ïî¨ ÏïàÏ†ïÏ†ÅÏûÑ.
svd_dim = min(100, n_items - 1) if n_items > 1 else 1
svd = TruncatedSVD(n_components=svd_dim, random_state=42)
X_svd = svd.fit_transform(X)
print("X_svd shape:", X_svd.shape, "explained_var_ratio_sum:", svd.explained_variance_ratio_.sum())

# -----------------------------
# 3) K-means
# -----------------------------
k = 2  # ÏõêÌïòÎäî ÌÅ¥Îü¨Ïä§ÌÑ∞ ÏàòÎ°ú Î∞îÍøî
kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)
labels = kmeans.fit_predict(X_svd)

# -----------------------------
# clusterÎ≥Ñ ÏÇ¨Ïö©Ïûê Ïàò Ï∂úÎ†•
# -----------------------------
cluster_cnt = Counter(labels)

print("\nUsers per cluster (%)")
for c in sorted(cluster_cnt.keys()):
    ratio = cluster_cnt[c] / len(labels) * 100
    print(f"Cluster {c}: {cluster_cnt[c]} users ({ratio:.2f}%)")

# -----------------------------
# 4) 2D ÏãúÍ∞ÅÌôî Ï¢åÌëú ÎßåÎì§Í∏∞ (t-SNE)
# -----------------------------
# ÏÇ¨Ïö©Ïûê ÏàòÍ∞Ä ÏïÑÏ£º ÎßéÏúºÎ©¥(ÏàòÎßå Ïù¥ÏÉÅ) t-SNEÎäî ÎäêÎ¶¥ Ïàò ÏûàÏùå -> ÏïÑÎûò subsample ÏòµÏÖò ÏÇ¨Ïö© Í∂åÏû•
use_subsample = False
max_points = 5000

if use_subsample and n_users > max_points:
    rng = np.random.RandomState(42)
    idx = rng.choice(n_users, size=max_points, replace=False)
    X_vis_in = X_svd[idx]
    labels_vis = labels[idx]
else:
    X_vis_in = X_svd
    labels_vis = labels

tsne = TSNE(
    n_components=2,
    perplexity=30,
    learning_rate="auto",
    init="pca",
    random_state=42,
)
Z = tsne.fit_transform(X_vis_in)

# -----------------------------
# 5) Plot
# -----------------------------
plt.figure(figsize=(8, 6))
plt.scatter(Z[:, 0], Z[:, 1], c=labels_vis, s=10)
plt.title(f"User Multi-hot -> SVD({svd_dim}) -> KMeans(k={k}) -> t-SNE(2D)")
plt.xlabel("t-SNE dim1")
plt.ylabel("t-SNE dim2")
plt.tight_layout()
plt.show()


# RQ1
Ï∂îÏ≤ú ÌååÏù¥ÌîÑÎùºÏù∏ ÏÜçÏóêÏÑú LLMÏù¥ ÏÉùÏÑ±Ìïú Îç∞Ïù¥ÌÑ∞Ïùò Ìé∏Ìñ•/ÌôòÍ∞Å ÌòÑÏÉÅ Î∂ÑÏÑù

Ìè¨Ìï®: LLMRec, Augmentation

## LLMRec

### Bias

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import ast
import os
from collections import Counter

# ‚úÖ Í≤ΩÎ°ú ÏÑ§Ï†ï
base_dir = "data/books/books_llmrec_format/"
save_dir = "data/books/books_llmrec_format/poster/"

# ‚úÖ 1. Îç∞Ïù¥ÌÑ∞ Î°úÎìú (ÏöîÏ≤≠ÌïòÏã† Ïª¨ÎüºÎ™Ö Ï†ÅÏö©)
# header=0ÏùÑ Ïì∏ÏßÄ NoneÏùÑ Ïì∏ÏßÄÎäî ÌååÏùºÏóê Ìó§ÎçîÍ∞Ä Ìè¨Ìï®ÎêòÏñ¥ ÏûàÎäîÏßÄÏóê Îî∞Îùº Îã§Î¶ÖÎãàÎã§.
# Î≥¥ÌÜµ namesÎ•º ÏßÄÏ†ïÌïòÎ©¥ header=0 (Ï≤´Ï§Ñ Î¨¥Ïãú) ÌòπÏùÄ header=None (Ï≤´Ï§ÑÎ∂ÄÌÑ∞ Îç∞Ïù¥ÌÑ∞)Î•º ÏÉÅÌô©Ïóê ÎßûÏ∂∞ Ïç®Ïïº Ìï©ÎãàÎã§.
# Ïó¨Í∏∞ÏÑúÎäî Í∏∞Ï°¥ ÌååÏùºÏóê Ìó§ÎçîÍ∞Ä ÏûàÎã§Î©¥ header=0ÏùÑ Ï∂îÍ∞ÄÌï¥Ïïº Ìï©ÎãàÎã§.
if os.path.exists(ITEM_ATTR_CSV):
    item_attr_df = pd.read_csv(ITEM_ATTR_CSV, names=["id", "brand", "title", "category"], header=0) 
    print(f"‚úÖ Îç∞Ïù¥ÌÑ∞ Î°úÎìú ÏôÑÎ£å. Ï¥ù {len(item_attr_df)}Í∞ú ÏïÑÏù¥ÌÖú")
else:
    print(f"‚ùå ÌååÏùºÏùÑ Ï∞æÏùÑ Ïàò ÏóÜÏäµÎãàÎã§: {ITEM_ATTR_CSV}")
    exit()

# ‚úÖ 2. Ïπ¥ÌÖåÍ≥†Î¶¨ ÌååÏã± Ìï®Ïàò Ï†ïÏùò
def parse_categories(cat_raw):
    """
    Î¨∏ÏûêÏó¥ ÌòïÌÉúÏùò Î¶¨Ïä§Ìä∏("['A', 'B']")ÎÇò Îã®Ïàú Î¨∏ÏûêÏó¥("A")ÏùÑ Ïã§Ï†ú Î¶¨Ïä§Ìä∏Î°ú Î≥ÄÌôò
    """
    if pd.isna(cat_raw):
        return []
    
    cat_str = str(cat_raw).strip()
    
    try:
        # 1) "['Fiction', 'History']" ÌòïÌÉú ÌååÏã±
        if cat_str.startswith("[") and cat_str.endswith("]"):
            parsed = ast.literal_eval(cat_str)
            return [str(c).strip() for c in parsed if c]
        
        # 2) "Fiction, History" ÏâºÌëú Íµ¨Î∂Ñ ÌòïÌÉú
        elif "," in cat_str:
            return [c.strip() for c in cat_str.split(",") if c.strip()]
            
        # 3) Îã®Ïùº Î¨∏ÏûêÏó¥
        else:
            return [cat_str]
    except:
        return [cat_str]

# ‚úÖ 3. Ïπ¥ÌÖåÍ≥†Î¶¨ ÏßëÌï©(Set) Î∞è Î¶¨Ïä§Ìä∏(List) ÏÉùÏÑ±
all_categories = set()      # Ï§ëÎ≥µ ÏóÜÎäî Ï†ÑÏ≤¥ Ïπ¥ÌÖåÍ≥†Î¶¨ Î™©Î°ù (ÏöîÏ≤≠ÌïòÏã† Î∂ÄÎ∂Ñ)
all_categories_list = []    # ÎπàÎèÑÏàò Í≥ÑÏÇ∞ÏùÑ ÏúÑÌïú Ï†ÑÏ≤¥ Î¶¨Ïä§Ìä∏ (ÏãúÍ∞ÅÌôîÏö©)

print("üîÑ Ïπ¥ÌÖåÍ≥†Î¶¨ ÌååÏã± ÏßÑÌñâ Ï§ë...")
for item_id in item_attr_df.index:
    cat_raw = item_attr_df.loc[item_id, "category"]
    cats = parse_categories(cat_raw)
    
    all_categories.update(cats)       # Set ÏóÖÎç∞Ïù¥Ìä∏ (Unique)
    all_categories_list.extend(cats)  # List ÌôïÏû• (CountÏö©)

print(f"üîç Í≥†Ïú† Ïπ¥ÌÖåÍ≥†Î¶¨ Í∞úÏàò (Set): {len(all_categories)}")
print(f"üìä ÎàÑÏ†Å Ïπ¥ÌÖåÍ≥†Î¶¨ ÌÉúÍ∑∏ Ïàò (List): {len(all_categories_list)}")

# ‚úÖ 4. Ïã§Ï†ú Î∂ÑÌè¨ ÏãúÍ∞ÅÌôî (Top K)
def plot_real_distribution(data_list, title, top_k=20):
    if not data_list:
        print("Îç∞Ïù¥ÌÑ∞Í∞Ä ÏóÜÏñ¥ ÏãúÍ∞ÅÌôîÌï† Ïàò ÏóÜÏäµÎãàÎã§.")
        return

    # ÎπàÎèÑ Í≥ÑÏÇ∞
    counter = Counter(data_list)
    
    # Îç∞Ïù¥ÌÑ∞ÌîÑÎ†àÏûÑ Î≥ÄÌôò
    stat_df = pd.DataFrame.from_dict(counter, orient='index', columns=['Count']).reset_index()
    stat_df.rename(columns={'index': 'Category'}, inplace=True)
    
    # ÏÉÅÏúÑ Top K Ï†ïÎ†¨
    df_topk = stat_df.sort_values(by='Count', ascending=False).head(top_k).reset_index(drop=True)

    # Í∑∏ÎûòÌîÑ Í∑∏Î¶¨Í∏∞ (Í∞ÄÎ°ú ÎßâÎåÄ)
    plt.figure(figsize=(12, 8))
    bars = plt.barh(df_topk['Category'][::-1], df_topk['Count'][::-1], color='salmon', edgecolor='black', alpha=0.8)

    # ÏàòÏπò ÌëúÏãú
    for bar in bars:
        width = bar.get_width()
        plt.text(width + (width * 0.01), bar.get_y() + bar.get_height()/2, 
                 f'{int(width)}', ha='left', va='center', fontsize=10, fontweight='bold')

    plt.title(f"Real Distribution of {title} (Top {top_k})", fontsize=15, fontweight='bold')
    plt.xlabel("Count", fontsize=12)
    plt.grid(axis='x', linestyle='--', alpha=0.5)
    plt.tight_layout()

    # Ï†ÄÏû•
    os.makedirs(save_dir, exist_ok=True)
    filename = "Real_Category_Distribution_Parsed.png"
    plt.savefig(os.path.join(save_dir, filename))
    print(f"üìä Í∑∏ÎûòÌîÑ Ï†ÄÏû• ÏôÑÎ£å: {os.path.join(save_dir, filename)}")
    plt.show()

# Ïã§Ìñâ
plot_real_distribution(all_categories_list, "Book Categories", top_k=30)

In [None]:
import matplotlib.pyplot as plt
import pickle
import json
import ast
import re
from collections import Counter
import pandas as pd
import os
# ‚úÖ ÌååÏùº Í≤ΩÎ°ú ÏÑ§Ï†ï
file_path = "data/books/books_llmrec_format/augmented_user_profiling_dict_part5_step0_try0"
save_dir = "data/books/books_llmrec_format/poster/"

# ‚úÖ ÏãúÍ∞ÅÌôîÌï† ÎåÄÏÉÅ ÌÇ§(Key) Ï†ïÏùò
target_keys = [
    'age', 'gender', 
    'liked category', 'disliked category', 
    'liked author', 'country', 'language'
]

# Îç∞Ïù¥ÌÑ∞Î•º Ï†ÄÏû•Ìï† ÎîïÏÖîÎÑàÎ¶¨ Ï¥àÍ∏∞Ìôî
data_storage = {key: [] for key in target_keys}

# ‚úÖ Îç∞Ïù¥ÌÑ∞ Î°úÎìú
if os.path.exists(file_path):
    with open(file_path, "rb") as f:
        augmented_dict = pickle.load(f)
    print(f"‚úÖ ÌååÏùº Î°úÎìú ÏÑ±Í≥µ: {file_path} (Ï¥ù {len(augmented_dict)}Î™Ö)")
else:
    print(f"‚ùå ÌååÏùºÏùÑ Ï∞æÏùÑ Ïàò ÏóÜÏäµÎãàÎã§: {file_path}")
    augmented_dict = {}



import json
import ast
import re
def normalize_gender(val: str):
    """
    gender Í∞íÏùÑ ÏµúÎåÄÌïú Ïú†ÌïòÍ≤å Ï†ïÍ∑úÌôîÌï¥ÏÑú
    'm', 'f', 'nb', 'other' Ï§ë ÌïòÎÇòÎ°ú Îß§Ìïë (ÎòêÎäî None)
    """
    if val is None:
        return None

    s = str(val).strip().lower()
    s = re.sub(r"\s+", " ", s)

    # Î¨¥ÏùòÎØ∏ Í∞í Ï†úÍ±∞
    if s in {"", "unknown", "none", "n/a", "na", "null", "nil", "unspecified", "?"}:
        return None

    # ÌùîÌïú Ìå®ÌÑ¥Îì§(Ïö∞ÏÑ†ÏàúÏúÑ ÏûàÍ≤å)
    # male
    if re.fullmatch(r"(m|male|man|boy|masc|masculine)", s):
        return "m"
    # female
    if re.fullmatch(r"(f|female|woman|girl|fem|feminine)", s):
        return "f"

    # Î¨∏Ïû•Ìòï/Ïö∞Ìöå ÌëúÌòÑÎì§
    if any(x in s for x in ["prefer not", "rather not", "no answer", "not say", "private"]):
        return None

    # non-binary / genderqueer Îì±
    if any(x in s for x in ["nonbinary", "non-binary", "nb", "genderqueer", "gender fluid", "genderfluid", "agender"]):
        return "nb"

    # ÌòºÌï© ÌëúÍ∏∞: "m/f", "male/female" Îì±
    if re.search(r"\b(m|male)\b", s) and re.search(r"\b(f|female)\b", s):
        return "other"

    # ÌïúÍ∏ÄÎèÑ Í∞ôÏù¥ Ï≤òÎ¶¨(ÌòπÏãú Îì§Ïñ¥Ïò§Î©¥)
    if any(x in s for x in ["ÎÇ®", "ÎÇ®Ïûê", "ÎÇ®ÏÑ±"]):
        return "m"
    if any(x in s for x in ["Ïó¨", "Ïó¨Ïûê", "Ïó¨ÏÑ±"]):
        return "f"

    # "m." "f." Í∞ôÏù¥ ÎÅùÏóê Ï†êÏù¥ Î∂ôÏùÄ Í≤ΩÏö∞
    if re.fullmatch(r"m\.", s):
        return "m"
    if re.fullmatch(r"f\.", s):
        return "f"

    # Ïó¨Í∏∞ÍπåÏßÄ Î™ª Ïû°ÏúºÎ©¥ Í∏∞ÌÉÄÎ°ú ÎëêÍ±∞ÎÇò Î≤ÑÎ¶º(ÏõêÌïòÎ©¥ other ÎåÄÏã† NoneÏúºÎ°ú)
    return "other"

def clean_and_parse(profile_text):
    if profile_text is None:
        return None

    s = str(profile_text).strip()

    # 1) ÏΩîÎìúÌéúÏä§ÏóêÏÑú {...}Îßå ÎΩëÍ∏∞
    m = re.search(r"```(?:json)?\s*(\{.*\})\s*```", s, flags=re.DOTALL | re.IGNORECASE)
    if m:
        s = m.group(1)
    else:
        s = s.replace("```json", "").replace("```", "").strip()

    s = s.replace("\r", "").strip()

    # 2) Python dict literalÏù¥Î©¥ ast Î®ºÏ†Ä (ÏûëÏùÄÎî∞Ïò¥Ìëú ÌÇ§/Í∞íÏù¥Î©¥ Í±∞Ïùò 100% Ïù¥Ï™Ω)
    #    Í∞ÑÎã® Ìú¥Î¶¨Ïä§Ìã±: "{'age':" Ìå®ÌÑ¥
    if re.search(r"\{\s*'", s):
        try:
            return ast.literal_eval(s)
        except Exception:
            pass

    # 3) JSONÏö© Î∂àÎ≤ï escapeÎßå ÏµúÏÜå ÏàòÏ†ï
    s_json = s.replace("\\'", "'")  # JSONÏóêÏÑú Î∂àÎ≤ïÏù∏ \' Ï†úÍ±∞

    # 4) JSON ÏãúÎèÑ
    try:
        return json.loads(s_json)
    except Exception:
        pass

    # 5) ÎßàÏßÄÎßâ Î≥¥Ìóò: ast Ïû¨ÏãúÎèÑ (ÌòπÏãú Ìú¥Î¶¨Ïä§Ìã±Ïù¥ ÎπóÎÇòÍ∞Ñ Í≤ΩÏö∞)
    try:
        return ast.literal_eval(s)
    except Exception:
        return None

def split_authors(val):
    results = []

    items = val if isinstance(val, list) else [val]

    for item in items:
        if not item:
            continue
        s = str(item).strip().lower()

        # Íµ¨Î∂ÑÏûê ÌÜµÏùº (ÏâºÌëú Ï†úÏô∏)
        s = re.sub(r"\s*(?:&|;|\||/)\s*", " , ", s)
        s = re.sub(r"\s+(?:and|with)\s+", " , ", s)

        # "last, first" ÌòïÌÉúÎäî ÏâºÌëúÎ•º Ïù¥Î¶Ñ ÏùºÎ∂ÄÎ°ú Î≥º Í∞ÄÎä•ÏÑ±Ïù¥ Ïª§ÏÑú Î≥¥Ìò∏
        # Ïòà: "rowling, j. k." Í∞ôÏùÄ Í≤ΩÏö∞
        looks_like_last_first = bool(re.search(r"^[a-z\.\-]+,\s*[a-z]", s))

        if looks_like_last_first:
            parts = [s.strip()]  # ÌÜµÏß∏Î°ú ÌïòÎÇòÎ°ú Ï∑®Í∏â
        else:
            # ÏâºÌëú Í∏∞Ï§Ä Î∂ÑÎ¶¨
            parts = [p.strip() for p in s.split(",") if p.strip()]

        results.extend(parts)

    return results



# ‚úÖ Î©îÏù∏ Î£®ÌîÑ: Îç∞Ïù¥ÌÑ∞ Ï∂îÏ∂ú
for user_id, profile_text in augmented_dict.items():
    profile = clean_and_parse(profile_text)
    
    if not profile:
        print(f"‚ö†Ô∏è ÌååÏã± Ïã§Ìå®: User {user_id}")
        continue

    for key in target_keys:
        val = profile.get(key, None)
        if not val:
            continue

        if key == "liked author":
            authors = split_authors(val)
            data_storage[key].extend(authors)
            continue

        # ---- Í∏∞Ï°¥ Î°úÏßÅ ----
        if isinstance(val, list):
            cleaned_items = [str(item).strip().lower() for item in val if item]
            data_storage[key].extend(cleaned_items)

        elif isinstance(val, str):
            v = val.strip()
            if v.lower() in ['unknown', 'none', 'n/a', '']:
                continue

            if key == 'gender':
                g = normalize_gender(v)
                if g is not None:
                    data_storage[key].append(g)
                continue

            if ',' in v:
                data_storage[key].extend([x.strip().lower() for x in v.split(',')])
            else:
                data_storage[key].append(v.lower())



# ‚úÖ ÏãúÍ∞ÅÌôî Ìï®Ïàò
def plot_distribution(category_name, data_list, top_k=10):
    if not data_list:
        print(f"‚ö†Ô∏è {category_name} - Îç∞Ïù¥ÌÑ∞Í∞Ä ÏóÜÏäµÎãàÎã§.")
        return

    # ÎπàÎèÑ Í≥ÑÏÇ∞
    counter = Counter(data_list)
    
    # Îç∞Ïù¥ÌÑ∞ÌîÑÎ†àÏûÑ Î≥ÄÌôò
    df = pd.DataFrame.from_dict(counter, orient='index', columns=['Count']).reset_index()
    df.rename(columns={'index': 'Category'}, inplace=True)
    
    # ÏÉÅÏúÑ kÍ∞ú Ï∂îÏ∂ú
    df_topk = df.sort_values(by='Count', ascending=False).head(top_k).reset_index(drop=True)

    # Í∑∏ÎûòÌîÑ Í∑∏Î¶¨Í∏∞
    plt.figure(figsize=(12, 6))
    bars = plt.bar(df_topk['Category'], df_topk['Count'], color='skyblue', edgecolor='black', alpha=0.8)

    # ÎßâÎåÄ ÏúÑÏóê Ïà´Ïûê ÌëúÏãú
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + (yval * 0.01), int(yval), 
                 ha='center', va='bottom', fontsize=10, fontweight='bold')

    plt.title(f"Distribution of Predicted {category_name.title()} (Top {top_k})", fontsize=15, fontweight='bold')
    plt.xlabel(category_name.title(), fontsize=12)
    plt.ylabel("Count", fontsize=12)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    plt.tight_layout()
    
    # Ïù¥ÎØ∏ÏßÄ Ï†ÄÏû•
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Predicted_{category_name.replace(' ', '_')}.png"
    save_path = os.path.join(save_dir, filename)
    plt.savefig(save_path)
    print(f"üìä Ï†ÄÏû• ÏôÑÎ£å: {filename}")
    #plt.show()

# ‚úÖ Ïã§Ìñâ: Í∞Å ÌÇ§Î≥ÑÎ°ú ÏãúÍ∞ÅÌôî ÏàòÌñâ
print("\n========== ÏãúÍ∞ÅÌôî ÏãúÏûë ==========")

# Ìï≠Î™© ÌäπÏÑ±Ïóê Îî∞Îùº Top-K Ï°∞Ï†à
plot_distribution('age', data_storage['age'], top_k=15)
plot_distribution('gender', data_storage['gender'], top_k=5)
plot_distribution('country', data_storage['country'], top_k=15)
plot_distribution('language', data_storage['language'], top_k=10)

# Ïπ¥ÌÖåÍ≥†Î¶¨ÏôÄ Ï†ÄÏûêÎäî Ï¢ÖÎ•òÍ∞Ä ÎßéÏúºÎØÄÎ°ú Top 20ÍπåÏßÄ ÌôïÏù∏
plot_distribution('liked category', data_storage['liked category'], top_k=20)
plot_distribution('disliked category', data_storage['disliked category'], top_k=20)
plot_distribution('liked author', data_storage['liked author'], top_k=20)

In [None]:
OUT_TXT = os.path.join(
    "data/books/books_llmrec_format",
    "liked_author_by_user.txt"
)

with open(OUT_TXT, "w", encoding="utf-8") as f:
    for user_id, profile_text in augmented_dict.items():
        profile = clean_and_parse(profile_text)

        if not profile:
            continue

        val = profile.get("liked author", None)
        if not val:
            continue

        # author Î∂ÑÎ¶¨
        authors = split_authors(val)

        if not authors:
            continue

        # Ï§ëÎ≥µ Ï†úÍ±∞ + ÏàúÏÑú Ïú†ÏßÄ
        seen = set()
        uniq_authors = []
        for a in authors:
            if a not in seen:
                seen.add(a)
                uniq_authors.append(a)

        # Ìïú Ï§ÑÎ°ú Ï†ÄÏû•
        line = f"{user_id}\t" + " | ".join(uniq_authors) + "\n"
        f.write(line)

print(f"‚úÖ liked author txt Ï†ÄÏû• ÏôÑÎ£å: {OUT_TXT}")


In [None]:
import os
import re
import json
import ast
import pickle
from collections import Counter
from typing import Dict, Any, Optional, List, Tuple

import pandas as pd
import matplotlib.pyplot as plt

# -----------------------------
# Configuration
# -----------------------------
file_path = "data/books/books_llmrec_format/augmented_user_profiling_dict_part1_step0"
base_save_dir = "data/books/books_llmrec_format/poster/Gender"

target_keys = [
    "age", "gender",
    "liked category", "disliked category",
    "liked author", "country", "language",
]


# -----------------------------
# Parsing / Normalization
# -----------------------------
def clean_and_parse(profile_text: str, model: str = "gpt-4o") -> Optional[Dict[str, Any]]:
    if not isinstance(profile_text, str):
        return None

    if model == "gpt-4o":
        cleaned = (
            profile_text.strip()
            .replace("'''", "")
            .replace("```json", "")
            .replace("```", "")
            .replace("\n", "")
            .replace("\r", "")
            .replace("children's", "childrens")
            .replace("Children\\'s", "childrens")
        )
    else:
        cleaned = profile_text.strip()

    try:
        cleaned_json = re.sub(r"(?<!\\)'", '"', cleaned)
        return json.loads(cleaned_json)
    except json.JSONDecodeError:
        try:
            return ast.literal_eval(cleaned)
        except Exception:
            return None

def normalize_gender(raw: Any) -> Optional[str]:
    if raw is None:
        return None

    if isinstance(raw, list) and len(raw) > 0:
        raw = raw[0]

    if not isinstance(raw, str):
        return None

    s = raw.strip().lower()
    if s in {"m", "male", "man", "boy", "masculine"}:
        return "m"
    if s in {"f", "female", "woman", "girl", "feminine"}:
        return "f"

    if len(s) >= 1:
        if s[0] == "m":
            return "m"
        if s[0] == "f":
            return "f"

    return None


GENRE_NORMALIZATION = {
    "thrillers & suspense": "thriller & suspense",
    "thriller and suspense": "thriller & suspense",
    "thrillers and suspense": "thriller & suspense",
}

def normalize_genre(s: str) -> str:
    s = s.strip().lower()
    return GENRE_NORMALIZATION.get(s, s)

def maybe_normalize_by_key(key: str, s: str) -> str:
    # Ïû•Î•¥/Ïπ¥ÌÖåÍ≥†Î¶¨ Í≥ÑÏó¥ÏóêÎßå Ï†ÅÏö© (ÏõêÌïòÎ©¥ ÌÇ§ Ï∂îÍ∞Ä Í∞ÄÎä•)
    if key in {"liked category", "disliked category"}:
        return normalize_genre(s)
    return s

def normalize_item_value(key: str, val: Any) -> List[str]:
    out: List[str] = []

    if val is None:
        return out

    def push(x: str):
        x = x.strip().lower()
        if not x or x in {"unknown", "none", "n/a"}:
            return
        x = maybe_normalize_by_key(key, x)
        out.append(x)

    if isinstance(val, list):
        for item in val:
            if item is None:
                continue
            push(str(item))
        return out

    if isinstance(val, str):
        s = val.strip()
        if not s:
            return out
        if s.lower() in {"unknown", "none", "n/a"}:
            return out

        if "," in s:
            for p in s.split(","):
                push(p)
            return out

        if key == "gender":
            push(s[0])   # m/fÎßå ÎÇ®Í∏∞Í≤å Îê®
        else:
            push(s)
        return out

    push(str(val))
    return out



# -----------------------------
# Plotting
# -----------------------------
def plot_distribution(
    key: str,
    data_list: List[str],
    save_dir: str,
    gender: str,
    n_users: int,
    top_k: int = 10,
):

    if not data_list:
        print(f"{key}: no data")
        return

    counter = Counter(data_list)
    df = (
        pd.DataFrame.from_dict(counter, orient="index", columns=["Count"])
        .reset_index()
        .rename(columns={"index": "Category"})
    )

    df_topk = df.sort_values(by="Count", ascending=False).head(top_k).reset_index(drop=True)

    os.makedirs(save_dir, exist_ok=True)

    plt.figure(figsize=(12, 6))
    bars = plt.bar(df_topk["Category"], df_topk["Count"], alpha=0.85, edgecolor="black")

    for bar in bars:
        yval = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            yval + (yval * 0.01),
            int(yval),
            ha="center",
            va="bottom",
            fontsize=10,
            fontweight="bold",
        )

    title = f"{gender} (N={n_users}) ‚Äî {key.replace('_', ' ').title()} (Top {top_k})"
    plt.title(title, fontsize=15, fontweight="bold")

    plt.xlabel(key.replace('_', ' ').title(), fontsize=12)
    plt.ylabel("Count", fontsize=12)
    plt.xticks(rotation=45, ha="right", fontsize=10)
    plt.grid(axis="y", linestyle="--", alpha=0.5)
    plt.tight_layout()

    filename = f"Predicted_{key.replace(' ', '_')}.png"
    save_path = os.path.join(save_dir, filename)
    plt.savefig(save_path, dpi=200)
    plt.close()
    print(f"saved: {save_path}")



def get_topk_for_key(key: str) -> int:
    if key == "age":
        return 15
    if key == "gender":
        return 5
    if key == "country":
        return 15
    if key == "language":
        return 10
    if key in {"liked category", "disliked category", "liked author"}:
        return 20
    return 10


# -----------------------------
# Main
# -----------------------------
def main() -> None:
    if not os.path.exists(file_path):
        print(f"file not found: {file_path}")
        return

    with open(file_path, "rb") as f:
        augmented_dict = pickle.load(f)

    print(f"loaded: {file_path} (users: {len(augmented_dict)})")

    profiles_by_gender: Dict[str, List[Dict[str, Any]]] = {"m": [], "f": []}
    failed = 0
    unknown_gender = 0

    for user_id, profile_text in augmented_dict.items():
        profile = clean_and_parse(profile_text)
        if not profile:
            failed += 1
            continue

        g = normalize_gender(profile.get("gender"))
        if g not in {"m", "f"}:
            unknown_gender += 1
            continue

        profile["gender"] = g
        profiles_by_gender[g].append(profile)

    print(f"parsed_ok_male: {len(profiles_by_gender['m'])}")
    print(f"parsed_ok_female: {len(profiles_by_gender['f'])}")
    print(f"parse_failed: {failed}")
    print(f"unknown_gender_excluded: {unknown_gender}")

    save_dir_map = {
        "m": os.path.join(base_save_dir, "Male"),
        "f": os.path.join(base_save_dir, "Female"),
    }

    for g in ["m", "f"]:
        group_profiles = profiles_by_gender[g]
        save_dir = save_dir_map[g]

        gender_name = "Male" if g == "m" else "Female"
        n_users = len(group_profiles)

        data_storage = {k: [] for k in target_keys}        

        for profile in group_profiles:
            for key in target_keys:
                vals = normalize_item_value(key, profile.get(key))
                if vals:
                    data_storage[key].extend(vals)

        for key in target_keys:
            k = get_topk_for_key(key)
            plot_distribution(
                key=key,
                data_list=data_storage[key],
                save_dir=save_dir,
                gender=gender_name,
                n_users=n_users,
                top_k=k,
            )



if __name__ == "__main__":
    main()


### RQ3

In [None]:
import matplotlib.pyplot as plt
import pickle
import json
import ast
import re
from collections import Counter
import pandas as pd
import os

# ‚úÖ Í∏∞Î≥∏ ÏÑ§Ï†ï
base_dir = "data/books/books_llmrec_format/"
save_dir = "data/books/books_llmrec_format/poster/step_comparison/"
os.makedirs(save_dir, exist_ok=True)

# ‚úÖ Î∂ÑÏÑùÌï† ÌååÏùº Î™©Î°ù (part5_step0 ~ step4)
file_names = [f"augmented_user_profiling_dict_part5_step{i}" for i in range(5)]

# ‚úÖ ÏãúÍ∞ÅÌôîÌï† ÎåÄÏÉÅ ÌÇ§(Key) Ï†ïÏùò
target_keys = [
    'age', 'gender', 
    'liked category', 'disliked category', 
    'liked author', 'country', 'language'
]

# ‚úÖ Îç∞Ïù¥ÌÑ∞ ÌååÏã± Î∞è Ï∂îÏ∂ú Ìï®Ïàò (Í∏∞Ï°¥ Î°úÏßÅ Ïú†ÏßÄ)
def clean_and_parse(profile_text, model="gpt-4o"):
    if model == "gpt-4o":
        cleaned = profile_text.strip().replace("'''", "").replace('```json', '').replace('```', '').replace('\n', '').replace('\r', '').replace("children's", "childrens").replace("Children\'s", "childrens")
    else:
        cleaned = profile_text.strip()
    try:
        cleaned = re.sub(r"(?<!\\)'", '"', cleaned)
        return json.loads(cleaned)
    except json.JSONDecodeError:
        try:
            return ast.literal_eval(cleaned)
        except:
            return None

# ‚úÖ Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ Î°úÎìú Ìï®Ïàò
def load_all_steps_data(base_dir, file_names, target_keys):
    # Íµ¨Ï°∞: { 'step0': {'age': [], 'gender': []...}, 'step1': ... }
    all_steps_data = {} 

    for fname in file_names:
        step_name = fname.split('_dict_')[-1] # Ïòà: part5_step0
        file_path = os.path.join(base_dir, fname)
        
        step_data = {key: [] for key in target_keys}
        
        if os.path.exists(file_path):
            print(f"üìÇ ÌååÏùº Î°úÎìú Ï§ë: {fname} ...", end=" ")
            with open(file_path, "rb") as f:
                augmented_dict = pickle.load(f)
            print(f"ÏôÑÎ£å (Ï¥ù {len(augmented_dict)}Î™Ö)")
            
            for user_id, profile_text in augmented_dict.items():
                profile = clean_and_parse(profile_text)
                if not profile: continue

                for key in target_keys:
                    val = profile.get(key, None)
                    if val:
                        # Î¶¨Ïä§Ìä∏ Ï≤òÎ¶¨
                        if isinstance(val, list):
                            cleaned_items = [str(item).strip().lower() for item in val if item]
                            step_data[key].extend(cleaned_items)
                        # Î¨∏ÏûêÏó¥ Ï≤òÎ¶¨
                        elif isinstance(val, str):
                            val = val.strip()
                            if val.lower() in ['unknown', 'none', 'n/a', '']: continue
                            
                            if ',' in val:
                                cleaned_items = [item.strip().lower() for item in val.split(',')]
                                step_data[key].extend(cleaned_items)
                            else:
                                if key == 'gender':
                                    step_data[key].append(val.lower()[0])
                                else:
                                    step_data[key].append(val.lower())
        else:
            print(f"‚ùå ÌååÏùºÏùÑ Ï∞æÏùÑ Ïàò ÏóÜÏäµÎãàÎã§: {file_path}")
        
        all_steps_data[step_name] = step_data
        
    return all_steps_data

# ‚úÖ ÌÜµÌï© ÏãúÍ∞ÅÌôî Ìï®Ïàò (2Ìñâ 3Ïó¥ Í∑∏Î¶¨Îìú)
def plot_grid_distribution(target_key, all_steps_data, top_k=10):
    steps = sorted(all_steps_data.keys()) # step0 ~ step4 ÏàúÏÑú Î≥¥Ïû•
    
    # 2Ìñâ 3Ïó¥ subplot ÏÉùÏÑ± (Ï¥ù 6Ïπ∏ Ï§ë 5Ïπ∏ ÏÇ¨Ïö©)
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten() # 1Ï∞®Ïõê Î∞∞Ïó¥Î°ú Î≥ÄÌôòÌïòÏó¨ Ïù∏Îç±Ïã± Ìé∏ÏùòÏÑ± ÌôïÎ≥¥
    
    print(f"\nüé® Í∑∏Î¶¨Îäî Ï§ë: {target_key}")

    for i, step_name in enumerate(steps):
        if i >= 5: break # 5Í∞úÍπåÏßÄÎßå Ï≤òÎ¶¨
        
        data_list = all_steps_data[step_name][target_key]
        ax = axes[i]
        
        if not data_list:
            ax.text(0.5, 0.5, 'No Data', ha='center', va='center')
            ax.set_title(f"{step_name}", fontsize=12, fontweight='bold')
            continue

        # ÎπàÎèÑ Í≥ÑÏÇ∞ Î∞è ÏÉÅÏúÑ KÍ∞ú Ï∂îÏ∂ú
        counter = Counter(data_list)
        df = pd.DataFrame.from_dict(counter, orient='index', columns=['Count']).reset_index()
        df.rename(columns={'index': 'Category'}, inplace=True)
        df_topk = df.sort_values(by='Count', ascending=False).head(top_k).reset_index(drop=True)

        # ÎßâÎåÄ Í∑∏ÎûòÌîÑ
        bars = ax.bar(df_topk['Category'], df_topk['Count'], color='skyblue', edgecolor='black', alpha=0.8)

        # ÏàòÏπò ÌëúÏãú
        for bar in bars:
            yval = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, yval + (yval * 0.01), int(yval), 
                     ha='center', va='bottom', fontsize=9)

        # ÏÑúÎ∏åÌîåÎ°Ø Ïä§ÌÉÄÏùºÎßÅ
        title_step = step_name.replace("part5_", "").replace("_", " ").upper()
        ax.set_title(f"{title_step}", fontsize=12, fontweight='bold')
        ax.tick_params(axis='x', rotation=45, labelsize=9)
        ax.grid(axis='y', linestyle='--', alpha=0.5)

    # 6Î≤àÏß∏ Ïπ∏(ÎßàÏßÄÎßâ Ïπ∏)ÏùÄ ÎπÑÏõåÎëêÍ∏∞ (Îç∞Ïù¥ÌÑ∞Í∞Ä 5Í∞úÏù¥ÎØÄÎ°ú)
    if len(steps) < 6:
        for j in range(len(steps), 6):
            axes[j].axis('off')

    # Ï†ÑÏ≤¥ ÌÉÄÏù¥ÌãÄ Î∞è Î†àÏù¥ÏïÑÏõÉ Ï°∞Ï†ï
    plt.suptitle(f"Distribution Changes: {target_key.title()} (Step 0 - 4)", fontsize=20, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Main Title Í≥µÍ∞Ñ ÌôïÎ≥¥

    # Ï†ÄÏû•
    filename = f"Comparison_{target_key.replace(' ', '_')}.png"
    save_path = os.path.join(save_dir, filename)
    plt.savefig(save_path)
    print(f"‚úÖ Ï†ÄÏû• ÏôÑÎ£å: {filename}")
    # plt.show() # ÌïÑÏöîÏãú Ï£ºÏÑù Ìï¥Ï†ú
    plt.close() # Î©îÎ™®Î¶¨ Ìï¥Ï†úÎ•º ÏúÑÌï¥ Îã´Í∏∞

# ==========================================
# üöÄ Î©îÏù∏ Ïã§ÌñâÎ∂Ä
# ==========================================

# 1. Îç∞Ïù¥ÌÑ∞ Î°úÎìú
print("========== Îç∞Ïù¥ÌÑ∞ Î°úÎî© ÏãúÏûë ==========")
db = load_all_steps_data(base_dir, file_names, target_keys)

# 2. Í∑∏ÎûòÌîÑ Í∑∏Î¶¨Í∏∞
print("\n========== Í∑∏ÎûòÌîÑ ÏÉùÏÑ± ÏãúÏûë ==========")

# Í∞Å ÌÇ§Î≥ÑÎ°ú Top K ÏÑ§Ï†ïÌïòÏó¨ Í∑∏Î¶¨Í∏∞
plot_grid_distribution('age', db, top_k=15)
plot_grid_distribution('gender', db, top_k=5)
plot_grid_distribution('country', db, top_k=10)
plot_grid_distribution('language', db, top_k=10)

plot_grid_distribution('liked category', db, top_k=15)
plot_grid_distribution('disliked category', db, top_k=15)
plot_grid_distribution('liked author', db, top_k=15)

print("\nüéâ Î™®Îì† Í≥ºÏ†ïÏù¥ ÏôÑÎ£åÎêòÏóàÏäµÎãàÎã§.")

### Hallucination

In [None]:
import os
import re
import json
import ast
import pickle
from collections import Counter
from typing import Any, Dict, Optional

import pandas as pd

# =========================
# Paths (Books)
# =========================
PART1_PATH = "data/books/books_llmrec_format/augmented_user_profiling_dict_part5_step4_try1"
PART5_PATH = "data/books/books_llmrec_format/augmented_user_profiling_dict_part5_step4_try0"

OUT_DIR = "data/books/books_llmrec_format/poster/"
os.makedirs(OUT_DIR, exist_ok=True)

OUT_SUMMARY_CSV = os.path.join(OUT_DIR, "profile_rawstring_consistency_part1_vs_part5_summary.csv")
OUT_COUNTS_CSV  = os.path.join(OUT_DIR, "profile_rawstring_consistency_part1_vs_part5_counts.csv")
OUT_USERS_TXT   = os.path.join(OUT_DIR, "profile_rawstring_consistency_part1_vs_part5_mismatch_users.txt")

# =========================
# Target keys (Books)
# =========================
TARGET_KEYS = [
    "age", "gender",
    "liked category", "disliked category",
    "liked author", "country", "language",
]

# =========================
# Load pickles
# =========================
def load_pickle(path: str) -> Dict[str, Any]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    with open(path, "rb") as f:
        obj = pickle.load(f)
    return {str(k): v for k, v in obj.items()}

part1 = load_pickle(PART1_PATH)
part5 = load_pickle(PART5_PATH)

common_users = sorted(set(part1.keys()) & set(part5.keys()), key=lambda x: int(x) if x.isdigit() else x)
print(f"‚úÖ part1 users: {len(part1)}")
print(f"‚úÖ part5 users: {len(part5)}")
print(f"‚úÖ common users: {len(common_users)}")

# =========================
# Parse + key normalization only
# (Í∞í Ï†ïÍ∑úÌôîÎäî ÌïòÏßÄ ÏïäÏùå)
# =========================
def clean_and_parse(profile_text: Any) -> Optional[Dict[str, Any]]:
    if profile_text is None:
        return None

    cleaned = str(profile_text).strip()
    cleaned = cleaned.replace("'''", "").replace("```json", "").replace("```", "")
    cleaned = cleaned.replace("\n", " ").replace("\r", " ")
    cleaned = re.sub(r"\s+", " ", cleaned).strip()

    # json
    try:
        obj = json.loads(cleaned)
        return obj if isinstance(obj, dict) else None
    except Exception:
        pass

    # python literal
    try:
        obj = ast.literal_eval(cleaned)
        return obj if isinstance(obj, dict) else None
    except Exception:
        pass

    # single quote -> double quote ÌõÑ json
    try:
        cleaned_forced = re.sub(r"(?<!\\)'", '"', cleaned)
        obj = json.loads(cleaned_forced)
        return obj if isinstance(obj, dict) else None
    except Exception:
        return None

def normalize_key_name(k: str) -> str:
    return str(k).lower().strip().replace(" ", "_")

def normalize_keys(profile: Dict[str, Any]) -> Dict[str, Any]:
    new_profile = {}
    for k, v in profile.items():
        nk = normalize_key_name(k)

        # BooksÏóêÏÑú Ïò¨ Ïàò ÏûàÎäî ÌùîÌïú Î≥ÄÌòïÏùÑ ÏµúÎåÄÌïú Ìù°Ïàò (Í∞íÏùÄ Í∑∏ÎåÄÎ°ú)
        # liked category / liked_category / liked categories / liked_categories Îì±
        if nk in {"liked_categories", "liked_category"}:
            nk = "liked_category"
        if nk in {"disliked_categories", "disliked_category"}:
            nk = "disliked_category"
        if nk in {"liked_authors", "liked_author"}:
            nk = "liked_author"

        new_profile[nk] = v
    return new_profile

# TARGET_KEYSÎèÑ normalize_key_name Í∏∞Ï§ÄÏúºÎ°ú Îß§Ïπ≠ÎêòÍ≤å Î≥ÄÌôò
TARGET_KEYS_NORM = [normalize_key_name(k) for k in TARGET_KEYS]

# =========================
# Raw-string comparison
# =========================
def val_to_raw(v: Any, sort_list: bool = False) -> Any:
    """
    Í∞í ÏûêÏ≤¥Î•º "ÏõêÎ≥∏ Î¨∏ÏûêÏó¥Ïóê Í∞ÄÍπåÏö¥ ÌòïÌÉú"Î°ú ÎπÑÍµêÌïòÍ∏∞ ÏúÑÌïú Î≥ÄÌôò.
    - str: stripÎßå
    - list: ÏõêÏÜåÎ•º strÎ°ú Î≥ÄÌôòÌïú list (sort_list=TrueÎ©¥ Ï†ïÎ†¨Ìï¥ÏÑú ÏàúÏÑú ÏòÅÌñ• Ï†úÍ±∞)
    - Í∏∞ÌÉÄ: str(v).strip()
    """
    if v is None:
        return None
    if isinstance(v, str):
        return v.strip()
    if isinstance(v, list):
        lst = [str(x) for x in v]
        return sorted(lst) if sort_list else lst
    return str(v).strip()

def raw_equal(a: Any, b: Any) -> bool:
    return a == b

# Î¶¨Ïä§Ìä∏ ÏàúÏÑúÍπåÏßÄ Í∞ôÏïÑÏïº "Í∞ôÏùå"ÏúºÎ°ú Î≥ºÏßÄ Í≤∞Ï†ï
# - True: ÏàúÏÑú Î¨¥Ïãú (Ï†ïÎ†¨ ÌõÑ ÎπÑÍµê)
# - False: ÏàúÏÑúÍπåÏßÄ Ìè¨Ìï® (Í∑∏ÎåÄÎ°ú ÎπÑÍµê)
SORT_LIST_BEFORE_COMPARE = False

rows = []
mismatch_users = []

for uid in common_users:
    p1_raw = clean_and_parse(part1[uid])
    p5_raw = clean_and_parse(part5[uid])

    if not p1_raw or not p5_raw:
        rows.append({
            "user_id": uid,
            "parse_ok_part1": bool(p1_raw),
            "parse_ok_part5": bool(p5_raw),
            "mismatch_count": None,
            "mismatch_keys": "PARSE_FAIL",
        })
        continue

    p1 = normalize_keys(p1_raw)
    p5 = normalize_keys(p5_raw)

    mism_keys = []
    detail = {}

    for key_norm in TARGET_KEYS_NORM:
        v1 = val_to_raw(p1.get(key_norm, None), sort_list=SORT_LIST_BEFORE_COMPARE)
        v2 = val_to_raw(p5.get(key_norm, None), sort_list=SORT_LIST_BEFORE_COMPARE)

        if not raw_equal(v1, v2):
            mism_keys.append(key_norm)

        detail[f"{key_norm}_part1"] = v1
        detail[f"{key_norm}_part5"] = v2

    mismatch_count = len(mism_keys)
    if mismatch_count > 0:
        mismatch_users.append(uid)

    row = {
        "user_id": uid,
        "parse_ok_part1": True,
        "parse_ok_part5": True,
        "mismatch_count": mismatch_count,
        "mismatch_keys": ",".join(mism_keys) if mism_keys else "",
    }
    row.update(detail)
    rows.append(row)

df = pd.DataFrame(rows)

# =========================
# Save outputs
# =========================
df.to_csv(OUT_SUMMARY_CSV, index=False)
print(f"‚úÖ Saved summary CSV: {OUT_SUMMARY_CSV}")

# fieldÎ≥Ñ mismatch count (parse okÎßå)
df_ok = df[(df["parse_ok_part1"] == True) & (df["parse_ok_part5"] == True)].copy()

field_counts = Counter()
for keys in df_ok["mismatch_keys"].fillna("").tolist():
    if not keys:
        continue
    for k in keys.split(","):
        k = k.strip()
        if k:
            field_counts[k] += 1

counts_df = pd.DataFrame(
    sorted(field_counts.items(), key=lambda x: x[1], reverse=True),
    columns=["field", "mismatch_user_count"]
)
counts_df.to_csv(OUT_COUNTS_CSV, index=False)
print(f"‚úÖ Saved mismatch counts CSV: {OUT_COUNTS_CSV}")

with open(OUT_USERS_TXT, "w") as f:
    for uid in mismatch_users:
        f.write(f"{uid}\n")
print(f"‚úÖ Saved mismatch users list: {OUT_USERS_TXT}")

print("\n===== QUICK STATS =====")
print("parsed OK users:", len(df_ok))
print("users with >=1 mismatch:", int((df_ok["mismatch_count"].fillna(0) > 0).sum()))
if (df_ok["mismatch_count"].fillna(0) > 0).any():
    print("avg mismatch_count among mismatched:",
          float(df_ok.loc[df_ok["mismatch_count"].fillna(0) > 0, "mismatch_count"].mean()))
print("\nTop mismatch fields:")
print(counts_df.head(10))


## Augmentation

### Bias

In [None]:
import os
import json
import pickle
from collections import Counter
from pathlib import Path
from typing import Any

import pandas as pd
import matplotlib.pyplot as plt


# =========================
# Paths (Books)
# =========================
META_JSONL_PATH = "data/books/item_meta_2017_kcore10_user_item_split_filtered.json"
PKL_PATH = "Augmentation/data/books/aug_triplets_part1_step0_AB.pkl"

OUT_DIR = "Augmentation/data/books/results"
os.makedirs(OUT_DIR, exist_ok=True)


def load_pickle(path: str):
    with open(path, "rb") as f:
        return pickle.load(f)


def load_books_item_meta_jsonl(meta_path: str) -> pd.DataFrame:
    """
    JSONL Ìïú Ï§Ñ ÏòàÏãú:
    {"item_id": 0, "asin": "...", "title": "...", "category": ["Books", "Literature & Fiction", "Genre Fiction"], ...}

    Î∞òÌôò df Ïª¨Îüº:
    - item_id (int)
    - title (str, optional)
    - category (list[str], optional)
    """
    rows = []
    with open(meta_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            if "item_id" not in obj:
                continue
            rows.append({
                "item_id": int(obj["item_id"]),
                "title": obj.get("title", None),
                "category": obj.get("category", None),
            })
    return pd.DataFrame(rows)


def build_category_long_table(
    meta_df: pd.DataFrame,
    mode: str = "leaf",
) -> pd.DataFrame:
    """
    (item_id, category_token) long-form ÏÉùÏÑ±.

    mode:
    - "leaf": category pathÏùò ÎßàÏßÄÎßâ ÌÜ†ÌÅ∞Îßå ÏÇ¨Ïö©
      Ïòà: ["Books","Literature & Fiction","Genre Fiction"] -> "Genre Fiction"
    - "all": category pathÏùò Î™®Îì† ÌÜ†ÌÅ∞ÏùÑ Í∞ÅÍ∞Å Ïπ¥Ïö¥Ìä∏
      Ïòà: -> "Books", "Literature & Fiction", "Genre Fiction"
    - "path": category path Ï†ÑÏ≤¥Î•º " > "Î°ú joinÌï¥ÏÑú ÌïòÎÇòÏùò ÌÜ†ÌÅ∞ÏúºÎ°ú Ïπ¥Ïö¥Ìä∏
      Ïòà: -> "Books > Literature & Fiction > Genre Fiction"
    """
    rows = []

    for _, r in meta_df.iterrows():
        item_id = int(r["item_id"])
        cat = r.get("category", None)

        if cat is None or (isinstance(cat, float) and pd.isna(cat)):
            continue

        if isinstance(cat, str):
            # Ïñ¥Îñ§ Ï†ÑÏ≤òÎ¶¨ÏóêÏÑú Î¨∏ÏûêÏó¥Î°ú Îì§Ïñ¥ÏôîÏùÑ Í∞ÄÎä•ÏÑ± Î∞©Ïñ¥
            # "['Books', '...']" ÌòïÌÉúÎ©¥ evalÏùÄ ÏúÑÌóòÌïòÎãà ÏµúÏÜå Ï≤òÎ¶¨Îßå
            cat_list = [cat.strip()] if cat.strip() else []
        elif isinstance(cat, list):
            cat_list = [str(x).strip() for x in cat if x is not None and str(x).strip()]
        else:
            cat_list = [str(cat).strip()] if str(cat).strip() else []

        if not cat_list:
            continue

        if mode == "leaf":
            rows.append((item_id, cat_list[-1]))
        elif mode == "all":
            for tok in cat_list:
                rows.append((item_id, tok))
        elif mode == "path":
            rows.append((item_id, " > ".join(cat_list)))
        else:
            raise ValueError(f"Unknown mode: {mode}")

    return pd.DataFrame(rows, columns=["item_id", "category_token"])


def analyze_pos_category_distribution(
    triplets_pkl_path: str,
    cat_long_df: pd.DataFrame,
    topk: int = 30,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Î∞òÌôò:
    - counts: pos ÏÑ†ÌÉù ÎπàÎèÑ Í∏∞Î∞ò category Î∂ÑÌè¨ (count, prob)
    - cat_item_count_df: Ïú†ÎãàÌÅ¨ pos ÏïÑÏù¥ÌÖú Í∏∞Ï§Ä category Î∂ÑÌè¨ (item_count)
    """
    triplets = load_pickle(triplets_pkl_path)

    pos_items = []
    for t in triplets:
        if not isinstance(t, (list, tuple)) or len(t) < 3:
            continue
        _, pos, _ = t[0], t[1], t[2]
        try:
            pos_items.append(int(pos))
        except Exception:
            continue

    if len(pos_items) == 0:
        empty1 = pd.DataFrame(columns=["category_token", "count", "prob"])
        empty2 = pd.DataFrame(columns=["category_token", "item_count"])
        return empty1, empty2

    # === pos_items ÎÇ¥ Ïú†ÎãàÌÅ¨ ÏïÑÏù¥ÌÖúÎ≥Ñ ÏÑ†ÌÉù ÌöüÏàò ===
    pos_item_counter = Counter(pos_items)
    pos_item_count_df = (
        pd.DataFrame(pos_item_counter.items(), columns=["item_id", "pos_count"])
        .sort_values("pos_count", ascending=False)
        .reset_index(drop=True)
    )

    print(f"[POS ITEM COUNTS] unique items: {len(pos_item_count_df)}")
    print(pos_item_count_df.head(20))
    print("count stats:\n", pos_item_count_df["pos_count"].describe())

    # === Ïú†ÎãàÌÅ¨ ÏïÑÏù¥ÌÖú Í∏∞Ï§Ä category Ïπ¥Ïö¥Ìä∏ ===
    unique_pos_items = pos_item_count_df["item_id"].unique()

    unique_item_cats = (
        pd.DataFrame({"item_id": unique_pos_items})
        .merge(cat_long_df, on="item_id", how="left")
        .dropna(subset=["category_token"])
        .drop_duplicates(subset=["item_id", "category_token"])
    )

    cat_item_count_df = (
        unique_item_cats.groupby("category_token")
        .size()
        .sort_values(ascending=False)
        .rename("item_count")
        .reset_index()
    )

    print("[UNIQUE POS ITEMS] category count (item-based)")
    print(cat_item_count_df.head(20))

    # === ÎπàÎèÑ Í∏∞Î∞ò category Î∂ÑÌè¨ ===
    pos_df = pd.DataFrame({"item_id": pos_items})
    merged = pos_df.merge(cat_long_df, on="item_id", how="left").dropna(subset=["category_token"])

    counts = (
        merged.groupby("category_token")
        .size()
        .sort_values(ascending=False)
        .rename("count")
        .reset_index()
    )

    total = int(counts["count"].sum()) if len(counts) > 0 else 0
    counts["prob"] = counts["count"] / max(1, total)

    if topk is not None:
        counts = counts.head(topk)

    return counts, cat_item_count_df


def plot_bar_vertical(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    title: str,
    topk: int = 10,
    save_path: str | None = None,
):
    if df is None or df.empty:
        print("No data to plot.")
        return

    d = df.sort_values(y_col, ascending=False).head(topk).copy()

    plt.figure(figsize=(11, 6))
    bars = plt.bar(d[x_col], d[y_col])
    plt.title(title)
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    plt.xticks(rotation=45, ha="right")

    for bar in bars:
        h = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            h,
            f"{int(h)}",
            ha="center",
            va="bottom",
            fontsize=10,
        )

    plt.tight_layout()

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=300)
        print(f"saved figure to: {save_path}")

    plt.show()


# =========================
# Run
# =========================
meta_df = load_books_item_meta_jsonl(META_JSONL_PATH)
print("meta_df:", meta_df.shape)
print(meta_df.head(3))

# Ïó¨Í∏∞ modeÎ•º Î∞îÍøîÍ∞ÄÎ©∞ Î≥¥Î©¥ Îê®: "leaf" Ï∂îÏ≤ú
MODE = "leaf"   # "leaf" | "all" | "path"
cat_long_df = build_category_long_table(meta_df, mode=MODE)
print("cat_long_df:", cat_long_df.shape)
print(cat_long_df.head(5))

pos_cat_df, unique_pos_cat_df = analyze_pos_category_distribution(PKL_PATH, cat_long_df, topk=30)

# CSV Ï†ÄÏû•
pos_cat_csv = f"{OUT_DIR}/pos_category_dist_part1_step0_AB_{MODE}.csv"
uniq_cat_csv = f"{OUT_DIR}/unique_pos_item_category_dist_part1_step0_AB_{MODE}.csv"
pos_cat_df.to_csv(pos_cat_csv, index=False)
unique_pos_cat_df.to_csv(uniq_cat_csv, index=False)
print("saved:", pos_cat_csv)
print("saved:", uniq_cat_csv)

print("\n[AB] positive(pos) category distribution (top 30)")
print(pos_cat_df.head(30))

# Plot (frequency-based)
plot_bar_vertical(
    pos_cat_df,
    x_col="category_token",
    y_col="count",
    title=f"[POS] Category distribution (frequency-based, Top 10) | mode={MODE}",
    topk=10,
    save_path=f"{OUT_DIR}/pos_category_bar_part1_step0_AB_{MODE}.png",
)

# Plot (unique-item-based)
unique_plot_df = unique_pos_cat_df.rename(columns={"item_count": "count"})
plot_bar_vertical(
    unique_plot_df,
    x_col="category_token",
    y_col="count",
    title=f"[POS UNIQUE ITEMS] Category distribution (item-based, Top 10) | mode={MODE}",
    topk=10,
    save_path=f"{OUT_DIR}/pos_unique_item_category_bar_part1_step0_AB_{MODE}.png",
)


### Hallucination

In [None]:
import pickle
from collections import Counter, defaultdict

AB_PATH = "Augmentation/data/books/aug_triplets_part1_step0_AB.pkl"
BA_PATH = "Augmentation/data/books/aug_triplets_part1_step0_BA.pkl"

def load_triplets(path: str):
    with open(path, "rb") as f:
        return pickle.load(f)  # List[Tuple[int,int,int]]

def build_pos_counter(triplets):
    """
    key = (u, a, b) where a<b are the unordered pair
    value = Counter({pos_item: count})
    """
    d = defaultdict(Counter)
    for u, pos, neg in triplets:
        a, b = sorted((int(pos), int(neg)))
        key = (int(u), a, b)
        d[key][int(pos)] += 1
    return d

ab = load_triplets(AB_PATH)
ba = load_triplets(BA_PATH)

print("AB triplets:", len(ab))
print("BA triplets:", len(ba))

ab_map = build_pos_counter(ab)
ba_map = build_pos_counter(ba)

common_keys = set(ab_map.keys()) & set(ba_map.keys())
only_ab = set(ab_map.keys()) - set(ba_map.keys())
only_ba = set(ba_map.keys()) - set(ab_map.keys())

if only_ab or only_ba:
    print(f"WARNING: key mismatch. only_ab={len(only_ab)}, only_ba={len(only_ba)}")

mismatch_total = 0
match_total = 0
total_compared = 0

# optional: mismatch examples
examples = []

for key in common_keys:
    # key=(u,a,b)
    u, a, b = key

    ab_counts = ab_map[key]  # Counter(pos)
    ba_counts = ba_map[key]

    total_ab = sum(ab_counts.values())
    total_ba = sum(ba_counts.values())
    if total_ab != total_ba:
        # Í∞ôÏùÄ keyÏóê ÎåÄÌï¥ AB/BA ÏÉòÌîå ÏàòÍ∞Ä Îã§Î•¥Î©¥ ÎπÑÍµêÍ∞Ä Ïï†Îß§Ìï®
        # Í∑∏ÎûòÎèÑ Í∞ÄÎä•Ìïú Î≤îÏúÑ(min)ÍπåÏßÄÎßå ÎπÑÍµêÌïòÍ±∞ÎÇò, Í∑∏ÎÉ• Ïä§ÌÇµÌï† ÏàòÎèÑ ÏûàÏùå.
        # Ïó¨Í∏∞ÏÑ† "Í∞ÄÎä•Ìïú Îß§Ïπ≠" Í∏∞Ï§ÄÏúºÎ°ú Í≥ÑÏÇ∞.
        pass

    # Í∞ôÏùÄ posÎ°ú Îß§Ïπ≠ Í∞ÄÎä•Ìïú Í∞úÏàò
    match_a = min(ab_counts.get(a, 0), ba_counts.get(a, 0))
    match_b = min(ab_counts.get(b, 0), ba_counts.get(b, 0))
    matches = match_a + match_b

    # ÎπÑÍµê Í∞ÄÎä•Ìïú Ï¥ù Í∞úÏàò(Î©ÄÌã∞ÏÖãÏù¥ÎØÄÎ°ú ÏñëÏ™Ω Ï§ë ÏûëÏùÄ Ï™Ω Í∏∞Ï§Ä)
    total = min(total_ab, total_ba)

    mismatches = total - matches

    match_total += matches
    mismatch_total += mismatches
    total_compared += total

    if mismatches > 0 and len(examples) < 20:
        examples.append({
            "user": u, "pair": (a, b),
            "AB": dict(ab_counts), "BA": dict(ba_counts),
            "mismatches": mismatches, "total": total
        })

print("\n===== RESULT =====")
print("Compared pairs(keys):", len(common_keys))
print("Total comparable samples:", total_compared)
print("Match count:", match_total)
print("Mismatch count (pos differs AB vs BA):", mismatch_total)
if total_compared > 0:
    print("Mismatch rate:", mismatch_total / total_compared)

print("\n===== mismatch examples (up to 20) =====")
for ex in examples:
    print(ex)


# RQ2 & RQ3
LLMÏù¥ ÏÉùÏÑ±Ìïú Îç∞Ïù¥ÌÑ∞Î°ú Ï∂îÏ≤úÏùÑ Î∞òÎ≥µÌïú (ÌîºÎìúÎ∞±Î£®ÌîÑ Ïù¥ÌõÑ)Í≤∞Í≥º Î∂ÑÏÑù(Part1: RQ2, Part5: RQ3)

Ìè¨Ìï®: LLMRec, A-LLMRec, Augmentation, TR_CF

## Hallucination

### A-LLMRec

In [None]:
import json
from collections import defaultdict

# 2) Í∞Å ÏãúÎÇòÎ¶¨Ïò§Î≥Ñ pred_count Î∞è user ÏßëÌï© ÏàòÏßë
scenarios = [
    ('case2','part1'),
    ('case2','part5'),
]
pred_counts = {}
pred_user_sets = {}
for case, part in scenarios:
    path = f"{BASE_DIR}/predict_label_{part}.json"
    with open(path) as f:
        raw = json.load(f)
    # Î¨∏ÏûêÏó¥ ÌÇ§ ‚Üí Ï†ïÏàò, ÏòàÏ∏° Ïàò
    cnt_dict = {int(u): len(titles) for u, titles in raw.items()}
    key = f"{case}_{part}"
    pred_counts[key] = cnt_dict
    pred_user_sets[key] = set(cnt_dict.keys())

# 3) Í≥µÌÜµ Ïú†Ï†Ä ÏßëÌï©: train ‚à© label (ÏöîÏ≤≠ÌïòÏã† Î∞©Ïãù), intÌòïÏúºÎ°ú Î≥ÄÌôò 
common_users = sorted(set(train_data) & set(ground_truth))

# 4) Ïã§Ï†ú(interactions) Ï¥ùÌï© Í≥ÑÏÇ∞ (common usersÎßå)
actual_total = sum(len(ground_truth[u]) for u in common_users)
print(f"Actual total interactions (common users): {actual_total}, {len(common_users)}Î™Ö\n")

# 5) ÏãúÎÇòÎ¶¨Ïò§Î≥Ñ Í≥ÑÏÇ∞ Î∞è Ï∂úÎ†• + caseÎ≥Ñ missing user Ïπ¥Ïö¥Ìä∏
case_missing_users = {'case1': set(), 'case2': set()}
common_users = [int(u) for u in common_users]  # Ï†ïÏàòÌòïÏúºÎ°ú Î≥ÄÌôò

for case, part in scenarios:
    key = f"{case}_{part}"
    pred_users = pred_user_sets[key]
    missing_users = set(common_users) - pred_users  # ÏòàÏ∏°Ïóê ÏóÜÎäî common user
    case_missing_users[case].update(missing_users)  # caseÎ≥Ñ ÎàÑÏ†Å

    pred_total = sum(pred_counts[key].get(u, 0) for u in common_users)
    missing    = actual_total - pred_total
    missing_pct= missing / actual_total * 100

    print(f"{key}:")
    print(f"  Predicted total (common users) = {pred_total}")
    print(f"  Missing interactions           = {missing} ({missing_pct:.2f}%)")
    print(f"  Missing users (count)          = {len(missing_users)}\n")


In [None]:
def load_userlist_counts(path: str) -> dict[int, int]:
    if not os.path.exists(path):
        return {}
    with open(path, "r") as f:
        data = json.load(f)
    out = {}
    if isinstance(data, dict):
        for u, arr in data.items():
            try:
                uid = int(u)
            except:
                continue
            if isinstance(arr, list):
                out[uid] = len(arr)
            elif isinstance(arr, int):
                out[uid] = arr
            else:
                out[uid] = 0
    return out

def summarize_case2_part(part: int, max_step: int) -> pd.DataFrame:
    rows, all_users = [], set()

    # Î™®Îì† stepÏóêÏÑú user id Ïú†ÎãàÏò® ÏàòÏßë
    for step in range(1, max_step+1):
        miss_path = os.path.join(BASE_DIR, f"missing_titles_books_case2_part{part}_step{step}.json")
        dup_path  = os.path.join(BASE_DIR, f"skipped_duplicates_books_case2_part{part}_step{step}.json")
        miss = load_userlist_counts(miss_path)
        dup  = load_userlist_counts(dup_path)
        all_users |= set(miss.keys()) | set(dup.keys())

    # stepÎ≥Ñ Ïπ¥Ïö¥Ìä∏(ÏóÜÏúºÎ©¥ 0)
    for step in range(1, max_step+1):
        miss_path = os.path.join(BASE_DIR, f"missing_titles_books_case2_part{part}_step{step}.json")
        dup_path  = os.path.join(BASE_DIR, f"skipped_duplicates_books_case2_part{part}_step{step}.json")
        miss = load_userlist_counts(miss_path)
        dup  = load_userlist_counts(dup_path)
        for u in sorted(all_users):
            rows.append({
                "user_id": u,
                "step": step,
                "missing_title": int(miss.get(u, 0)),
                "skipped_duplicates": int(dup.get(u, 0)),
            })
    return pd.DataFrame(rows).sort_values(["user_id","step"]).reset_index(drop=True)

# ÏöîÏïΩ ÏÉùÏÑ±
df1  = summarize_case2_part(part=1,  max_step=1)
df5  = summarize_case2_part(part=5,  max_step=5)

# (1) part5: stepÎ≥Ñ Ìï©Í≥Ñ Ìëú Ï∂úÎ†•
step_totals1 = (
    df1.groupby("step")[["missing_title","skipped_duplicates"]]
      .sum()
      .reset_index()
      .astype({"step": int, "missing_title": int, "skipped_duplicates": int})
)
print(step_totals1.to_string(index=False))

step_totals5 = (
    df5.groupby("step")[["missing_title","skipped_duplicates"]]
      .sum()
      .reset_index()
      .astype({"step": int, "missing_title": int, "skipped_duplicates": int})
)
print(step_totals5.to_string(index=False))


# (2) parts overall totals (part5/part10 Ï†ÑÏ≤¥ Ìï©Í≥Ñ) Ï∂úÎ†•
overall1  = df1[["missing_title","skipped_duplicates"]].sum().astype(int).to_dict()
overall5  = df5[["missing_title","skipped_duplicates"]].sum().astype(int).to_dict()

parts_overall = pd.DataFrame([
    {"part": 1,  **overall1},
    {"part": 5,  **overall5},
])[["part","missing_title","skipped_duplicates"]]

print("\n[parts overall totals]")
print(parts_overall.to_string(index=False))



In [None]:
import os, json
import pandas as pd

PARTS = 5
base_dir_hallu = "data/books/A-LLMRec_format/A-LLMRec_results"
PREFIX = "predict_label_part5_step"   # ÌïÑÏöîÌïòÎ©¥ predict_label Ï™ΩÏúºÎ°ú Î∞îÍøî
TRY0, TRY1 = 0, 1

def load_json(path):
    with open(path, "r") as f:
        d = json.load(f)
    out = {}
    for k, v in d.items():
        uid = int(k)
        if v is None:
            out[uid] = []
        elif isinstance(v, list):
            out[uid] = [int(x) for x in v]
        else:
            out[uid] = [int(v)]
    return out

def compare_step(d0, d1, step):
    all_users = sorted(set(d0.keys()) | set(d1.keys()))
    rows = []
    changed = 0
    for u in all_users:
        a = d0.get(u, [])
        b = d1.get(u, [])
        if a == b:
            continue
        changed += 1
        sa, sb = set(a), set(b)
        rows.append({
            "step": step,
            "user_id": u,
            "try0": a,
            "try1": b,
            "len_try0": len(a),
            "len_try1": len(b),
            "only_in_try0": sorted(list(sa - sb)),
            "only_in_try1": sorted(list(sb - sa)),
            "set_equal": (sa == sb),
        })
    return changed, pd.DataFrame(rows)

# ---- run ----
summary_rows = []
detail_dfs = []

for step in range(1, PARTS+1):
    p0 = os.path.join(base_dir_hallu, f"{PREFIX}{step}_try{TRY0}.json")
    p1 = os.path.join(base_dir_hallu, f"{PREFIX}{step}_try{TRY1}.json")

    d0 = load_json(p0)
    d1 = load_json(p1)

    changed, df_detail = compare_step(d0, d1, step)
    union_users = len(set(d0.keys()) | set(d1.keys()))
    exact_equal = union_users - changed

    # setÏùÄ Í∞ôÏùÄÎç∞ ÏàúÏÑúÎßå Îã§Î•∏ ÏºÄÏù¥Ïä§
    order_only = int(df_detail["set_equal"].sum()) if len(df_detail) else 0

    summary_rows.append({
        "step": step,
        "users_try0": len(d0),
        "users_try1": len(d1),
        "users_union": union_users,
        "users_changed(+1 per user)": changed,
        "users_exact_equal": exact_equal,
        "changed_set_equal_but_order_diff": order_only,
        "file_try0": p0,
        "file_try1": p1,
    })

    detail_dfs.append(df_detail)

df_summary = pd.DataFrame(summary_rows)
df_detail_all = pd.concat(detail_dfs, ignore_index=True) if detail_dfs else pd.DataFrame()

df_summary


### Augmentation

## Performance

In [None]:
""" 
Performance Check
"""

import json
import os
import numpy as np

def evaluate_prediction(pred_dict, ground_truth, common_users):
    precision_list = []
    recall_list = []
    hit_ratio_list = []
    ndcg_list = []
    hit_item_list = []

    for user_str, pred_items in pred_dict.items():
        if int(user_str) in common_users:
            user = int(user_str)
            if str(user) not in ground_truth:
                continue

            gt_items = set(ground_truth[str(user)])  # flat list of item_ids
            if not gt_items:
                continue

            k = len(gt_items)
            pred_topk = pred_items[:k]
            hit_items = set(pred_topk) & gt_items

            precision = len(hit_items) / k
            recall = len(hit_items) / len(gt_items)
            hit_ratio = 1.0 if len(hit_items) > 0 else 0.0

            hit_item_list.append((user, k, len(hit_items)))

            dcg = sum([1.0 / np.log2(i + 2) for i, item in enumerate(pred_topk) if item in gt_items])
            idcg = sum([1.0 / np.log2(i + 2) for i in range(min(len(gt_items), k))])
            ndcg = dcg / idcg if idcg > 0 else 0.0

            precision_list.append(precision)
            recall_list.append(recall)
            hit_ratio_list.append(hit_ratio)
            ndcg_list.append(ndcg)

    return {
        "precision": np.mean(precision_list),
        "recall": np.mean(recall_list),
        "hit_ratio": np.mean(hit_ratio_list),
        "ndcg": np.mean(ndcg_list),
        "hit_details": hit_item_list
    }

with open(PRED_P1) as f:
    pred_dict  = json.load(f)
print(len(pred_dict))
# ÌèâÍ∞Ä
result = evaluate_prediction(pred_dict, ground_truth, common_users)

print(f"Part1 Results:")
print(f"precision: {result['precision']:.4f}, recall: {result['recall']:.4f}, "
        f"hit_ratio: {result['hit_ratio']:.4f}, ndcg: {result['ndcg']:.4f}")

with open(PRED_P5) as f:
    pred_dict  = json.load(f)
print(len(pred_dict))

# ÌèâÍ∞Ä
result = evaluate_prediction(pred_dict, ground_truth, common_users)

print(f"Part5 Results:")
print(f"precision: {result['precision']:.4f}, recall: {result['recall']:.4f}, "
        f"hit_ratio: {result['hit_ratio']:.4f}, ndcg: {result['ndcg']:.4f}")


## Popularity & Diversity

### LLMRec, A-LLMRec, Augmentation, TR CF

In [None]:
import json
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# ---------------------------
# 1) Popularity from train_data
# ---------------------------
item_pop = Counter()
for items in train_data.values():
    for it in items:
        item_pop[int(it)] += 1

# ---------------------------
# 2) Load predictions
# ---------------------------
with open(PRED_P1) as f:
    pred_p1 = json.load(f)
with open(PRED_P5) as f:
    pred_p5 = json.load(f)

print(len(train_data), "users in train_data")
print(len(ground_truth), "users in ground_truth (Label)")
print(len(pred_p1), "users in pred_p1")
print(len(pred_p5), "users in pred_p5")

# ---------------------------
# 3) Common users
# ---------------------------
common_users_pred  = set(pred_p1.keys()) & set(pred_p5.keys())
common_users_label = set(ground_truth.keys())

# common_usersÎäî (train ‚à© label)Î°ú Ïù¥ÎØ∏ Í≥ÑÏÇ∞ÎêòÏñ¥ ÏûàÎã§Í≥† Í∞ÄÏ†ï
common_users_final = (
    set(map(str, common_users))
    & common_users_pred
    & common_users_label
)

print("len(common_users_final):", len(common_users_final))

# ---------------------------
# 4) Popularity metric
# ---------------------------
def avg_popularity_from_item_ids(user_items: dict[str, list[int]]):
    """
    Return: dict[uid(str)] -> avg_popularity
    """
    out = {}
    for uid, items in user_items.items():
        ids = [int(x) for x in items]
        pops = [item_pop.get(it, 0) for it in ids]
        out[str(uid)] = float(np.mean(pops)) if pops else 0.0
    return out

# Í≥µÌÜµ ÏÇ¨Ïö©ÏûêÎßå ÌïÑÌÑ∞ÎßÅ
label_common = {u: ground_truth[u] for u in common_users_final}
p1_common    = {u: pred_p1[u]     for u in common_users_final}
p5_common    = {u: pred_p5[u]     for u in common_users_final}

m_label = avg_popularity_from_item_ids(label_common)
m_p1    = avg_popularity_from_item_ids(p1_common)
m_p5    = avg_popularity_from_item_ids(p5_common)

# ---------------------------
# 5) Œî transitions (Popularity only)
# ---------------------------
d_label_p1 = []
d_label_p5 = []
d_p1_p5    = []

for u in common_users_final:
    lp = m_label.get(u, 0.0)
    p1 = m_p1.get(u, 0.0)
    p5 = m_p5.get(u, 0.0)

    d_label_p1.append(p1 - lp)
    d_label_p5.append(p5 - lp)
    d_p1_p5.append(p5 - p1)

df_delta = pd.DataFrame({
    "delta": d_label_p1 + d_label_p5 + d_p1_p5,
    "transition": (
        ["Label ‚Üí Part1"] * len(d_label_p1)
        + ["Label ‚Üí Part5"] * len(d_label_p5)
        + ["Part1 ‚Üí Part5"] * len(d_p1_p5)
    ),
    "metric": ["Popularity"] * (len(d_label_p1) + len(d_label_p5) + len(d_p1_p5)),
    "case": ["Books"] * (len(d_label_p1) + len(d_label_p5) + len(d_p1_p5)),
})

# ---------------------------
# 6) Plot
# ---------------------------
sns.set_style("whitegrid")
plt.figure(figsize=(10, 6))

ax = sns.boxplot(
    data=df_delta,
    x="transition",
    y="delta",
    palette="Set2"
)

ax.set_title("Popularity Change ‚Äî Label / Part1 / Part5 (Books)")
ax.set_xlabel("Transition")
ax.set_ylabel("Œî Average Popularity (train-based counts)")
ax.axhline(0, linestyle="--", linewidth=1, alpha=0.6)

plt.tight_layout()
plt.show()

print("Summary (Popularity Œî)")
print(df_delta.groupby("transition")["delta"].describe().round(3))


In [None]:
# ===== ML-1M: Label vs Part1 (4 baselines) Popularity Change Boxplot =====
import json
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------
# 0) Paths (Part1 predictions of 4 baselines)
# ---------------------------
BASELINES_P1 = {
    "A-LLMRec": "A-LLMRec/books_results/predict_label_part1.json",
    "LLMRec": "data/books/books_llmrec_format/predict_label_part1.json",
    "Augmentation": "Augmentation/data/books/predict_label_part1.json",
    #"TraditionalCF": "data/books/traditionalCF/predict_label_part1.json",
}

# ---------------------------
# 1) Popularity from train_data
# ---------------------------
def build_item_popularity(train_data: dict) -> Counter:
    item_pop = Counter()
    for items in train_data.values():
        for it in items:
            item_pop[int(it)] += 1
    return item_pop

item_pop = build_item_popularity(train_data)

# ---------------------------
# 2) Utilities
# ---------------------------
def load_json(path: str) -> dict:
    with open(path, "r") as f:
        return json.load(f)

def avg_popularity_from_item_ids(user_items: dict, item_pop: Counter) -> dict[str, float]:
    """user_items: dict[uid]->list[item_id], return dict[str(uid)]->avg_popularity"""
    out = {}
    for uid, items in user_items.items():
        if items is None:
            out[str(uid)] = 0.0
            continue
        ids = [int(x) for x in items]
        pops = [item_pop.get(it, 0) for it in ids]
        out[str(uid)] = float(np.mean(pops)) if len(pops) > 0 else 0.0
    return out

def filter_dict_by_users(d: dict, users: set[str]) -> dict:
    return {u: d[u] for u in users if u in d}

# ---------------------------
# 3) Common users base
#   - common_users: (train ‚à© gt)ÎùºÍ≥† Í∞ÄÏ†ï(Í∏∞Ï°¥ ÏΩîÎìú Í∑∏ÎåÄÎ°ú ÏÇ¨Ïö©)
# ---------------------------
common_users_base = set(map(str, common_users))  # uidÎ•º strÎ°ú ÌÜµÏùº
label_users = set(map(str, ground_truth.keys()))
common_users_base = common_users_base & label_users

print("len(common_users_base):", len(common_users_base))

# Label popularity (Í≥†Ï†ï)
label_common = filter_dict_by_users({str(k): v for k, v in ground_truth.items()}, common_users_base)
m_label = avg_popularity_from_item_ids(label_common, item_pop)

# ---------------------------
# 4) For each baseline: compute Œî (Part1 - Label)
# ---------------------------
rows = []
baseline_pred = {}

for bname, ppath in BASELINES_P1.items():
    pred = load_json(ppath)
    pred = {str(k): v for k, v in pred.items()}  # uid str ÌÜµÏùº
    baseline_pred[bname] = pred

    # Ïù¥ baselineÏóêÏÑú Ïã§Ï†úÎ°ú predÍ∞Ä ÏûàÎäî ÏÇ¨Ïö©ÏûêÎßå (LabelÎèÑ ÏûàÏñ¥Ïïº Ìï®)
    users_b = common_users_base & set(pred.keys())

    # ÏÇ¨Ïö©Ïûê ÌïÑÌÑ∞
    pred_common = filter_dict_by_users(pred, users_b)

    # popularity metric
    m_pred = avg_popularity_from_item_ids(pred_common, item_pop)

    # delta
    for u in users_b:
        lp = m_label.get(u, 0.0)
        pp = m_pred.get(u, 0.0)
        rows.append({
            "uid": u,
            "baseline": bname,
            "delta": pp - lp,
            "metric": "Popularity",
            "case": "ML-1M",
        })

    print(f"{bname}: users used = {len(users_b)}")

df_delta = pd.DataFrame(rows)
print("df_delta shape:", df_delta.shape)

# ---------------------------
# 5) Plot (single plot comparing 4 baselines)
# ---------------------------
sns.set_style("whitegrid")
plt.figure(figsize=(10, 6))

ax = sns.boxplot(
    data=df_delta,
    x="baseline",
    y="delta",
    order=list(BASELINES_P1.keys()),
)

ax.set_title("Popularity Change ‚Äî GroundTruth ‚Üí Simulation Data (4 baselines, Books)")
ax.set_xlabel("Baseline (Part1 prediction)")
ax.set_ylabel("Œî Average Popularity (Part1 - Label, train-based counts)")
ax.axhline(0, linestyle="--", linewidth=1, alpha=0.6)

plt.tight_layout()
#plt Ï†ÄÏû•
os.makedirs("data/books/plots", exist_ok=True)
plt.savefig("data/books/plots/popularity.png", dpi=300)
plt.show()

# ---------------------------
# 6) Summary
# ---------------------------
print("Summary (Œî Popularity = Part1 - Label)")
print(df_delta.groupby("baseline")["delta"].describe().round(3))
print("plot is saved -> data/books/plots/popularity.png")

In [None]:
# ===== Books: Part1 -> Part5 (4 baselines) Popularity Change Boxplot =====
import os
import json
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------
# 0) Paths (Part1 predictions of 4 baselines)
# ---------------------------
BASELINES_P1 = {
    "A-LLMRec": "A-LLMRec/books_results/predict_label_part1.json",
    "LLMRec": "data/books/books_llmrec_format/predict_label_part1.json",
    "Augmentation": "Augmentation/data/books/predict_label_part1.json",
    #"TraditionalCF": "data/books/traditionalCF/predict_label_part1.json",
}

def part1_to_part5_path(p1_path: str) -> str:
    return p1_path.replace("predict_label_part1.json", "predict_label_part5.json")

BASELINES_P5 = {k: part1_to_part5_path(v) for k, v in BASELINES_P1.items()}

# ---------------------------
# 1) Popularity from train_data
# ---------------------------
def build_item_popularity(train_data: dict) -> Counter:
    item_pop = Counter()
    for items in train_data.values():
        for it in items:
            item_pop[int(it)] += 1
    return item_pop

item_pop = build_item_popularity(train_data)

# ---------------------------
# 2) Utilities
# ---------------------------
def load_json(path: str) -> dict:
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    with open(path, "r") as f:
        return json.load(f)

def avg_popularity_from_item_ids(user_items: dict, item_pop: Counter) -> dict[str, float]:
    """user_items: dict[uid]->list[item_id] (or int), return dict[str(uid)]->avg_popularity"""
    out = {}
    for uid, items in user_items.items():
        if items is None:
            out[str(uid)] = 0.0
            continue
        if isinstance(items, int):
            items = [items]
        ids = [int(x) for x in items]
        pops = [item_pop.get(it, 0) for it in ids]
        out[str(uid)] = float(np.mean(pops)) if len(pops) > 0 else 0.0
    return out

def filter_dict_by_users(d: dict, users: set[str]) -> dict:
    return {u: d[u] for u in users if u in d}

# ---------------------------
# 3) Common users base
#   - common_users: Í∏∞Ï°¥ ÏΩîÎìúÏóêÏÑú Ïì∞Îçò Í∑∏ÎåÄÎ°ú ÏûàÎã§Í≥† Í∞ÄÏ†ï
# ---------------------------
common_users_base = set(map(str, common_users))
print("len(common_users_base):", len(common_users_base))

# ---------------------------
# 4) For each baseline: compute Œî (Part5 - Part1)
# ---------------------------
rows = []

for bname in BASELINES_P1.keys():
    p1_path = BASELINES_P1[bname]
    p5_path = BASELINES_P5[bname]

    pred_p1 = load_json(p1_path)
    pred_p5 = load_json(p5_path)

    pred_p1 = {str(k): v for k, v in pred_p1.items()}
    pred_p5 = {str(k): v for k, v in pred_p5.items()}

    # Ïù¥ baselineÏóêÏÑú Part1 & Part5 Îëò Îã§ ÏûàÎäî ÏÇ¨Ïö©ÏûêÎßå + common_users
    users_b = common_users_base & set(pred_p1.keys()) & set(pred_p5.keys())

    # ÏÇ¨Ïö©Ïûê ÌïÑÌÑ∞
    p1_common = filter_dict_by_users(pred_p1, users_b)
    p5_common = filter_dict_by_users(pred_p5, users_b)

    # popularity metric
    m_p1 = avg_popularity_from_item_ids(p1_common, item_pop)
    m_p5 = avg_popularity_from_item_ids(p5_common, item_pop)

    # delta
    for u in users_b:
        p1v = m_p1.get(u, 0.0)
        p5v = m_p5.get(u, 0.0)
        rows.append({
            "uid": u,
            "baseline": bname,
            "delta": p5v - p1v,
            "metric": "Popularity",
            "case": "Books",
        })

    print(f"{bname}: users used = {len(users_b)}")
    print(f"  - p1: {p1_path}")
    print(f"  - p5: {p5_path}")

df_delta = pd.DataFrame(rows)
print("df_delta shape:", df_delta.shape)

# ---------------------------
# 5) Plot (single plot comparing 4 baselines)
# ---------------------------
sns.set_style("whitegrid")
plt.figure(figsize=(10, 6))

ax = sns.boxplot(
    data=df_delta,
    x="baseline",
    y="delta",
    order=list(BASELINES_P1.keys()),
)

ax.set_title("Popularity Change ‚Äî Recommendation ‚Üí Feedback Loop (3 baselines, Books)")
ax.set_xlabel("Baseline")
ax.set_ylabel("Œî Average Popularity (Part5 - Part1, train-based counts)")
ax.axhline(0, linestyle="--", linewidth=1, alpha=0.6)

plt.tight_layout()
os.makedirs("data/books/plots", exist_ok=True)
out_path = "data/books/plots/popularity_p1_to_p5.png"
plt.savefig(out_path, dpi=300)
plt.show()

# ---------------------------
# 6) Summary
# ---------------------------
print("Summary (Œî Popularity = Part5 - Part1)")
print(df_delta.groupby("baseline")["delta"].describe().round(3))
print("plot is saved ->", out_path)


### Diversity

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, json, math
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
import networkx as nx

# Louvain
import community as community_louvain  # pip install python-louvain

# =========================
# Í≤ΩÎ°ú/ÏÑ§Ï†ï
# =========================
STEPS    = list(range(PARTS))
OUT_DIR  = os.path.join(BASE_DIR, "figs_tsne_cluster_progress")
os.makedirs(OUT_DIR, exist_ok=True)

TSNE_PERPLEXITY      = 30
TSNE_RANDOM_STATE    = 42
KMEANS_RANDOM_STATE  = 42
K_USERS = 2   # placeholder; Louvain Í≤∞Í≥ºÎ°ú Í∞±Ïã†
K_ITEMS = 2

SUBSAMPLE_USERS_FOR_TSNE = None
SUBSAMPLE_ITEMS_FOR_TSNE = None

# ===== Ïª§ÎÆ§ÎãàÌã∞ ÌÉêÏßÄ(Í∑∏ÎûòÌîÑ Íµ¨ÏÑ±) ÏÑ§Ï†ï =====
USE_COSINE       = True        # LouvainÏö© Í∞ÄÏ§ëÏπòÎ°ú cosine Ïú†ÏÇ¨ÎèÑ ÏÇ¨Ïö©
SIM_THRESHOLD    = 0.5         # Í∞ÑÏÑ† ÏÉùÏÑ± ÏûÑÍ≥ÑÍ∞í (cosine sim >= threshold)
MAX_EDGES_PER_NODE = None      # ÎÖ∏ÎìúÎãπ ÏµúÎåÄ Í∞ÑÏÑ† Ïàò Ï†úÌïú(Ïòà: 50). NoneÏù¥Î©¥ Ï†úÌïú ÏóÜÏùå.
EDGE_WEIGHT_NAME = "w"         # Í∞ÄÏ§ëÏπò ÏÜçÏÑ±Î™Ö

# =========================
# Ïú†Ìã∏
# =========================

def _get_items(d, uid):
    """
    d: {user_id: [items,...]} with user_id possibly str or int
    uid: int
    returns list of items (possibly empty)
    """
    if uid in d:
        return d[uid]
    suid = str(uid)
    if suid in d:
        return d[suid]
    return []

def build_user_item_graph_common(train_data, pred_dict, common_users,
                                 alpha: float = 0.3,
                                 drop_isolates: bool = True,
                                 min_degree: int | None = None) -> nx.Graph:
    """
    train_data + pred_dictÎ°ú user-item Ïù¥Î∂Ñ Í∑∏ÎûòÌîÑ ÏÉùÏÑ±ÌïòÎêò,
    userÎäî Ïò§ÏßÅ common_usersÏóê ÌïúÏ†ï.

    - alpha: ÏòàÏ∏° Ïó£ÏßÄ Í∞ÄÏ§ëÏπò
    - drop_isolates: Ï∞®Ïàò 0Ïù∏ ÎÖ∏Îìú Ï†úÍ±∞
    - min_degree: ÏßÄÏ†ï Ïãú, Ï∞®Ïàò < min_degree ÎÖ∏Îìú Ï†úÍ±∞(Î∞òÎ≥µ 1Ìöå)

    return: networkx.Graph (Î¨¥Î∞©Ìñ•, Í∞ÑÏÑ† ÏÜçÏÑ± 'w'Ïóê Í∞ÄÏ§ëÏπò Ï†ÄÏû•)
    """
    # 1) ÏóêÏßÄ Í∞ÄÏ§ëÏπò ÎàÑÏ†Å(Î©îÎ™®Î¶¨ Ï†àÏïΩ + Ìï©ÏÇ∞ Ï†ïÌôïÏÑ±)
    ew = defaultdict(float)  # key=(u_node, i_node) -> weight

    # 2) train Ïó£ÏßÄ
    for u in common_users:
        u = int(u)
        items = _get_items(train_data, u)
        for it in items:
            ew[(f"U{u}", f"I{int(it)}")] += 1.0

    # 3) predict Ïó£ÏßÄ
    for u in common_users:
        u = int(u)
        items = _get_items(pred_dict, u)
        for it in items:
            ew[(f"U{u}", f"I{int(it)}")] += alpha

    # 4) Í∑∏ÎûòÌîÑ Íµ¨ÏÑ±
    G = nx.Graph()
    # add_edges_from with attributes
    G.add_nodes_from([f"U{u}" for u in common_users])  # Ïú†Ï†Ä ÎÖ∏ÎìúÎäî ÎØ∏Î¶¨ Ï∂îÍ∞Ä(Í≥†Ï†ï ÏßëÌï©)
    G.add_edges_from([(u, i, {"w": w}) for (u, i), w in ew.items() if w > 0])

    # 5) Í≥†Î¶Ω/Ï†ÄÏ∞®Ïàò ÎÖ∏Îìú Ï†ïÎ¶¨(ÏòµÏÖò)
    if drop_isolates:
        iso = list(nx.isolates(G))
        if iso:
            G.remove_nodes_from(iso)
    if min_degree is not None and min_degree > 0:
        low = [n for n, d in dict(G.degree()).items() if d < min_degree]
        if low:
            G.remove_nodes_from(low)

    return G


def louvain_detect(G: nx.Graph, weight_name: str = "w"):
    """
    Louvain Ïª§ÎÆ§ÎãàÌã∞ ÌÉêÏßÄ (ÏûÑÏùò ÎÖ∏ÎìúÎ™Ö ÏßÄÏõê)
    return:
      node_order: List[hashable]        # ÎùºÎ≤® Î∞∞Ïó¥Í≥º 1:1Î°ú Îß§ÌïëÎêòÎäî ÎÖ∏Îìú ÏàúÏÑú
      labels_arr: np.ndarray[int]       # node_orderÏôÄ Í∞ôÏùÄ ÏàúÏÑúÏùò Ïª§ÎÆ§ÎãàÌã∞ ÎùºÎ≤®(0..C-1)
      num_comms: int
      part: Dict[node, community_id]    # ÏõêÎ≥∏ Îß§Ìïë(ÎÑ§Í∞Ä ÏßÅÏ†ë Ï°∞ÌöåÌï† Îïå Ìé∏Ìï®)
    """
    part = community_louvain.best_partition(G, weight=weight_name, random_state=42)
    # ÎùºÎ≤®ÏùÑ 0..C-1Î°ú ÏïïÏ∂ï
    uniq = {}
    cur = 0
    for n, c in part.items():
        if c not in uniq:
            uniq[c] = cur
            cur += 1
        part[n] = uniq[c]
    node_order = list(G.nodes())
    labels_arr = np.array([part[n] for n in node_order], dtype=int)
    return node_order, labels_arr, cur, part


def compute_centers(coords: np.ndarray, labels: np.ndarray, K: int):
    centers = {}
    for k in range(K):
        mask = (labels == k)
        if np.any(mask):
            centers[k] = coords[mask].mean(axis=0)
    return centers
# =========================
print("=== Part1 Louvain ===")
# G_users = build_user_item_graph(train_data, pred_p1, alpha=1.0)
G_users = build_user_item_graph_common(
    train_data=train_data,
    pred_dict=pred_p1,
    common_users=common_users,
    alpha=1.0,          # ÏòàÏ∏° Ïó£ÏßÄ ÏòÅÌñ• Í∞ïÌôî
    drop_isolates=True,
    min_degree=None
)


print(f"Graph nodes={G_users.number_of_nodes()}, edges={G_users.number_of_edges()}")

# ---- (ÏàòÏ†ï) Louvain Ìò∏Ï∂ú Î∞è Í≤∞Í≥º ÏöîÏïΩ ----
node_order, node_labels_final, K_USERS, part = louvain_detect(G_users, weight_name=EDGE_WEIGHT_NAME)
print(f"Diversity Score(Detected communities (total nodes) based): {1/K_USERS}")

# Ïª§ÎÆ§ÎãàÌã∞ ÌÅ¨Í∏∞ Ï∂úÎ†•
_, counts = np.unique(node_labels_final, return_counts=True)
print("Community sizes (all nodes U+I):", K_USERS, counts.tolist())


# =========================
print("\n=== Part5 Louvain ===")
#G_users = build_user_item_graph(train_data, pred_p5, alpha=1.0)
G_users = build_user_item_graph_common(
    train_data=train_data,
    pred_dict=pred_p5,
    common_users=common_users,
    alpha=1.0,          # ÏòàÏ∏° Ïó£ÏßÄ ÏòÅÌñ• Í∞ïÌôî
    drop_isolates=True,
    min_degree=None
)

print(f"Graph nodes={G_users.number_of_nodes()}, edges={G_users.number_of_edges()}")

# ---- (ÏàòÏ†ï) Louvain Ìò∏Ï∂ú Î∞è Í≤∞Í≥º ÏöîÏïΩ ----
node_order, node_labels_final, K_USERS, part = louvain_detect(G_users, weight_name=EDGE_WEIGHT_NAME)
print(f"Diversity Score(Detected communities (total nodes) based): {1/K_USERS}")

# Ïª§ÎÆ§ÎãàÌã∞ ÌÅ¨Í∏∞ Ï∂úÎ†•
_, counts = np.unique(node_labels_final, return_counts=True)
print("Community sizes (all nodes U+I):", K_USERS, counts.tolist())


In [None]:
import os
import numpy as np
from tqdm import tqdm

# ---------------------------
# 1) ÏïÑÏù¥ÌÖú ÏûÑÎ≤†Îî© Î°úÎìú
# ---------------------------
def load_item_embeddings(base_dir: str, parts: int, step: int):
    """
    ipath = os.path.join(base_dir, f"item_emb_part{parts}_step{step}.npy")
    Î•º Î°úÎìúÌï¥ÏÑú (num_items, d) Î∞∞Ïó¥ Î∞òÌôò
    """
    ipath = os.path.join(base_dir, f"item_emb_part{parts}_step{step}.npy")
    if not os.path.exists(ipath):
        raise FileNotFoundError(ipath)
    I = np.load(ipath)
    if I.ndim != 2:
        raise ValueError(f"item_emb shape invalid: {I.shape}")
    return I

# ---------------------------
# 2) Ìó¨Ìçº: pred_dictÏóêÏÑú Ïú†Ï†Ä ÌÇ§ ÏïàÏ†Ñ Ï°∞Ìöå
# ---------------------------
def _get_user_items(pred_dict, uid):
    if uid in pred_dict:
        return pred_dict[uid]
    suid = str(uid)
    if suid in pred_dict:
        return pred_dict[suid]
    return []

# ---------------------------
# 3) ÏÇ¨Ïö©ÏûêÎ≥Ñ ÌèâÍ∑† ÏΩîÏÇ¨Ïù∏ Ïú†ÏÇ¨ÎèÑ Í≥ÑÏÇ∞
# ---------------------------
def avg_pairwise_cos_sim_for_user(item_ids, item_emb_norm):
    """
    item_ids: Ï∂îÏ≤úÎêú item id Î¶¨Ïä§Ìä∏(Ï†ïÏàò or Î¨∏ÏûêÏó¥)
    item_emb_norm: (N, d) Ï†ïÍ∑úÌôîÎêú ÏûÑÎ≤†Îî© (Í∞Å Ìñâ L2=1)
    return: ÌèâÍ∑† ÏΩîÏÇ¨Ïù∏ Ïú†ÏÇ¨ÎèÑ (ÏÉÅÏÇºÍ∞Å ÌèâÍ∑†). Ïú†Ìö® ÏïÑÏù¥ÌÖú <2Ïù¥Î©¥ np.nan
    """
    # Ï†ïÏàò Ï∫êÏä§ÌåÖ + Ï§ëÎ≥µ Ï†úÍ±∞(Ï∂îÏ≤ú Î¶¨Ïä§Ìä∏ ÎÇ¥ Ï§ëÎ≥µ ÏïÑÏù¥ÌÖúÏù¥ ÌèâÍ∑†ÏùÑ ÏôúÍ≥°ÌïòÏßÄ ÏïäÎèÑÎ°ù)
    try:
        idx = [int(x) for x in item_ids]
    except Exception:
        idx = []
    # Ïú†Ìö® Î≤îÏúÑÎßå Ïú†ÏßÄ
    N = item_emb_norm.shape[0]
    idx = [i for i in dict.fromkeys(idx) if 0 <= i < N]  # dict.fromkeys: ÏàúÏÑú Î≥¥Ï°¥ Ï§ëÎ≥µ Ï†úÍ±∞

    m = len(idx)
    if m < 2:
        return np.nan

    V = item_emb_norm[idx]            # (m, d) Ïù¥ÎØ∏ Ï†ïÍ∑úÌôîÎê®
    S = V @ V.T                       # (m, m) ÏΩîÏÇ¨Ïù∏ Ïú†ÏÇ¨ÎèÑ ÌñâÎ†¨
    # ÏÉÅÏÇºÍ∞Å(diag Ï†úÏô∏) ÌèâÍ∑†
    triu_idx = np.triu_indices(m, k=1)
    sims = S[triu_idx]
    return float(np.mean(sims)) if sims.size > 0 else np.nan

# ---------------------------
# 4) Î©îÏù∏: ÏÇ¨Ïö©ÏûêÎ≥Ñ/Ï†ÑÏ≤¥ ÌèâÍ∑† Í≥ÑÏÇ∞
# ---------------------------
def compute_rec_list_avg_cos_sim(
    pred_dict: dict,
    base_dir: str,
    parts: int,
    step: int,
    users: list | set | None = None,   # NoneÏù¥Î©¥ pred_dictÏùò Î™®Îì† Ïú†Ï†Ä
    clip01: bool = True,               # ÏàòÏπò Ïò§Ï∞® Î≥¥Ï†ïÏö© [0,1] ÌÅ¥Î¶¨Ìïë
):
    """
    Î∞òÌôò:
      user2avg_sim: {user_id(int): ÌèâÍ∑† ÏΩîÏÇ¨Ïù∏ Ïú†ÏÇ¨ÎèÑ(ÎÇÆÏùÑÏàòÎ°ù Îã§Ïñë)}
      global_mean: Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê ÌèâÍ∑†(na Ï†úÏô∏)
      user2diversity: {user_id: 1 - avg_sim}  # ÎÜíÏùÑÏàòÎ°ù Îã§Ïñë
      global_div_mean: Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê diversity ÌèâÍ∑†
    """
    # 1) ÏûÑÎ≤†Îî© Î°úÎìú & Ï†ïÍ∑úÌôî
    I = load_item_embeddings(base_dir, parts, step)       # (num_items, d)
    norms = np.linalg.norm(I, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    In = I / norms

    # 2) ÎåÄÏÉÅ ÏÇ¨Ïö©Ïûê ÏßëÌï©
    if users is None:
        # pred_dict ÌÇ§Í∞Ä str/int ÏÑûÏó¨ÏûàÏùÑ Ïàò ÏûàÏúºÎØÄÎ°ú Î™®Îëê int Ï∫êÏä§ÌåÖ ÏãúÎèÑ
        cand = []
        for k in pred_dict.keys():
            try:
                cand.append(int(k))
            except Exception:
                # intÎ°ú Ïïà Î∞îÎÄåÎäî ÌÇ§Îäî Ïä§ÌÇµ
                pass
        users = cand
    else:
        users = [int(u) for u in users]

    user2avg_sim = {}
    user2diversity = {}

    for u in tqdm(users, total=len(users)):
        items = _get_user_items(pred_dict, u)
        avg_sim = avg_pairwise_cos_sim_for_user(items, In)
        if np.isnan(avg_sim):
            user2avg_sim[u] = np.nan
            user2diversity[u] = np.nan
        else:
            if clip01:
                avg_sim = max(0.0, min(1.0, avg_sim))
            user2avg_sim[u] = avg_sim
            user2diversity[u] = 1.0 - avg_sim  # Ï†êÏàò ‚Üë = Îçî Îã§Ïñë

    # 3) Ï†ÑÏ≤¥ ÌèâÍ∑†(na Ï†úÏô∏)
    vals = np.array([v for v in user2avg_sim.values() if not np.isnan(v)], dtype=float)
    global_mean = float(vals.mean()) if vals.size > 0 else np.nan

    div_vals = np.array([v for v in user2diversity.values() if not np.isnan(v)], dtype=float)
    global_div_mean = float(div_vals.mean()) if div_vals.size > 0 else np.nan

    return user2avg_sim, global_mean, user2diversity, global_div_mean

# pred_dict: predict_label.jsonÏùÑ Î°úÎìúÌïú dict (user_id -> [item_ids...])
# base_dir, parts, step: ÎÑ§ Ïã§Ìóò Í≤ΩÎ°ú/Ïä§ÌÖù

user2sim, mean_sim, user2div, mean_div = compute_rec_list_avg_cos_sim(
    pred_dict=pred_p1,
    base_dir=BASE_DIR,
    parts=1,
    step=0,      # ÌòπÏùÄ ÌèâÍ∞ÄÌïòÍ≥† Ïã∂ÏùÄ step
    users=common_users,             # ÎòêÎäî common_usersÎ°ú ÌïúÏ†ïÌïòÎ†§Î©¥ users=common_users
    clip01=True
)

print(f"[Avg CosSim] Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê ÌèâÍ∑†: {mean_sim:.6f} (ÎÇÆÏùÑÏàòÎ°ù Îã§Ïñë)")
print(f"[Diversity]  Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê ÌèâÍ∑†: {mean_div:.6f} (ÎÜíÏùÑÏàòÎ°ù Îã§Ïñë)")


user2sim, mean_sim, user2div, mean_div = compute_rec_list_avg_cos_sim(
    pred_dict=pred_p5,
    base_dir=BASE_DIR,
    parts=5,
    step=4,      # ÌòπÏùÄ ÌèâÍ∞ÄÌïòÍ≥† Ïã∂ÏùÄ step
    users=common_users,             # ÎòêÎäî common_usersÎ°ú ÌïúÏ†ïÌïòÎ†§Î©¥ users=common_users
    clip01=True
)

print(f"[Avg CosSim] Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê ÌèâÍ∑†: {mean_sim:.6f} (ÎÇÆÏùÑÏàòÎ°ù Îã§Ïñë)")
print(f"[Diversity]  Ï†ÑÏ≤¥ ÏÇ¨Ïö©Ïûê ÌèâÍ∑†: {mean_div:.6f} (ÎÜíÏùÑÏàòÎ°ù Îã§Ïñë)")

# ÌäπÏ†ï ÏÇ¨Ïö©Ïûê uÏùò Ï†êÏàò ÌôïÏù∏
# print(user2sim[u], user2div[u])


# RQ4
ÌîºÎìúÎ∞± Î£®ÌîÑÎ°ú ÏÉùÏÑ±Îêú ÏûÑÎ≤†Îî©ÏúºÎ°ú Î∂ÑÌÑ∞ Î∞úÏÉùÌï† Ïàò ÏûàÎäî ÏúÑÌóòÏÑ± Î∂ÑÏÑù(ÏñëÍ∑πÌôî, ÌïÑÌÑ∞Î≤ÑÎ∏î Îì±)

Ìè¨Ìï®: LLMRec, A-LLMRec, Augmentation, TR_CF

## Polarization

### 5step ÌïúÎ≤àÏóê

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, json, math
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

# =========================
# Í≤ΩÎ°ú/ÏÑ§Ï†ï
# =========================
STEPS    = list(range(PARTS))
CLUSTER_STEP = STEPS[-1]              # ÏµúÏ¢Ö stepÏóêÏÑú KMeans ÏàòÌñâ
OUT_DIR  = os.path.join(BASE_DIR, "figs_tsne_cluster_progress")
os.makedirs(OUT_DIR, exist_ok=True)

TSNE_PERPLEXITY      = 30
TSNE_RANDOM_STATE    = 42
KMEANS_RANDOM_STATE  = 42
K_USERS = 2
K_ITEMS = 2

# (ÏÑ†ÌÉù) t-SNE Í∞ÄÏÜç ÏÉòÌîåÎßÅ
SUBSAMPLE_USERS_FOR_TSNE = None  # Ïòà: 5000
SUBSAMPLE_ITEMS_FOR_TSNE = None  # Ïòà: 5000

# =========================
# Ïú†Ìã∏
# =========================
def load_users_items_for_step(base_dir: str, parts: int, step: int):
    upath = os.path.join(base_dir, f"user_emb_part{parts}_step{step}.npy")
    ipath = os.path.join(base_dir, f"item_emb_part{parts}_step{step}.npy")
    if not os.path.exists(upath):
        raise FileNotFoundError(upath)
    U = np.load(upath)  # shape: (num_users_step, d)
    I = np.load(ipath) if os.path.exists(ipath) else None  # shape: (num_items_step, d) or None
    if U.ndim != 2:
        raise ValueError(f"user_emb shape invalid at step{step}: {U.shape}")
    if I is not None and I.ndim != 2:
        raise ValueError(f"item_emb shape invalid at step{step}: {I.shape}")
    return U, I


def align_users_for_step(common_users, num_users_step: int) -> np.ndarray:
    """
    common_users: set/list/ndarray Î™®Îëê ÌóàÏö©.
    num_users_step: Ïù¥ stepÏùò user Ï∂ï Í∏∏Ïù¥ (ex. user_by_step[s].shape[0])
    return: [N_u'] Ï†ïÎ†¨Îêú Ïú†Ìö® Ïú†Ï†Ä id Î∞∞Ïó¥
    """
    # 1) ÏûÖÎ†•ÏùÑ ÎÑòÌååÏù¥ int64 Î∞∞Ïó¥Î°ú ÌÜµÏùº
    if isinstance(common_users, (set, list, tuple)):
        cu = np.asarray(sorted(common_users), dtype=np.int64)  # Ïû¨ÌòÑÏÑ± ÏúÑÌï¥ Ï†ïÎ†¨ Í∂åÏû•
    else:
        cu = np.asarray(common_users, dtype=np.int64)

    if cu.size == 0:
        return cu

    # 2) Ïú†Ìö® Î≤îÏúÑ ÎßàÏä§ÌÇπ
    mask = (cu >= 0) & (cu < num_users_step)
    return cu[mask]


# =========================
# 1) Í∞Å step ÏûÑÎ≤†Îî© Î°úÎìú
# =========================
print("Loading augmentation embeddings per step...")
user_by_step = {}
item_by_step = {}
dims = set()

for s in STEPS:
    U, I = load_users_items_for_step(BASE_DIR, PARTS, s)
    user_by_step[s] = U
    item_by_step[s] = I
    dims.add(U.shape[1])
    if I is not None:
        dims.add(I.shape[1])
    print(f"  - step{s}: users={U.shape}, items={None if I is None else I.shape}")

if len(dims) != 1:
    raise RuntimeError(f"ÏûÑÎ≤†Îî© Ï∞®ÏõêÏù¥ stepÎ≥ÑÎ°ú Îã§Î¶ÖÎãàÎã§: {dims}")
D = dims.pop()

# =========================
# 2) train‚à©label Í≥µÌÜµ Ïú†Ï†Ä ÏÇ∞Ï∂ú + stepÎ≥Ñ Ï†ïÎ†¨
# =========================
print(f"common users: {len(common_users)}")

aligned_users_by_step = {s: align_users_for_step(common_users, user_by_step[s].shape[0])
                         for s in STEPS}
print({s: len(aligned_users_by_step[s]) for s in STEPS})

# =========================
# 3) ÏïÑÏù¥ÌÖú Í≥µÌÜµ Î∂ÄÎ∂Ñ Ï†ïÎ†¨(ÏµúÏÜå Í∏∏Ïù¥)
# =========================
item_lengths = [item_by_step[s].shape[0] for s in STEPS if item_by_step[s] is not None]
if len(item_lengths) == 0:
    I_base = 0
else:
    I_base = min(item_lengths)
    uniq = set(item_lengths)
    if len(uniq) != 1:
        print(f"[warn] stepÎ≥Ñ ÏïÑÏù¥ÌÖú ÏàòÍ∞Ä Îã§Î¶ÖÎãàÎã§: {uniq} ‚Üí ÏïûÏóêÏÑú {I_base}Í∞úÎßå Í≥µÌÜµ ÏÇ¨Ïö©")
    for s in STEPS:
        I = item_by_step[s]
        item_by_step[s] = None if I is None or I.shape[0] < I_base else I[:I_base]

rng = np.random.RandomState(0)
if I_base > 0:
    if SUBSAMPLE_ITEMS_FOR_TSNE is not None and SUBSAMPLE_ITEMS_FOR_TSNE < I_base:
        keep_items = np.sort(rng.choice(I_base, size=SUBSAMPLE_ITEMS_FOR_TSNE, replace=False))
    else:
        keep_items = np.arange(I_base)
else:
    keep_items = np.array([], dtype=int)

# =========================
# 4) ÏµúÏ¢Ö stepÏóêÏÑú KMeans ‚Üí Î†àÏù¥Î∏î ÏÉùÏÑ±
# =========================
print(f"KMeans on final step: step{CLUSTER_STEP}")
U_final = user_by_step[CLUSTER_STEP]
ids_final = aligned_users_by_step[CLUSTER_STEP]
if len(ids_final) < K_USERS:
    raise RuntimeError(f"ÏµúÏ¢Ö stepÏóêÏÑú ÌÅ¥Îü¨Ïä§ÌÑ∞ÎßÅ Í∞ÄÎä•Ìïú Í≥µÌÜµ Ïú†Ï†ÄÍ∞Ä Î∂ÄÏ°±Ìï©ÎãàÎã§: {len(ids_final)}")

emb_final_users = U_final[ids_final, :]  # (|ids_final|, D)
user_kmeans = KMeans(n_clusters=K_USERS, random_state=KMEANS_RANDOM_STATE, n_init=10)
user_labels_final = user_kmeans.fit_predict(emb_final_users)
user_center_idx = {k: user_labels_final == k for k in range(K_USERS)}

# ÏïÑÏù¥ÌÖú KMeans (ÏµúÏ¢Ö step Í≥µÌÜµ ÏïÑÏù¥ÌÖú)
if item_by_step[CLUSTER_STEP] is not None and len(keep_items) > 0:
    I_final = item_by_step[CLUSTER_STEP][keep_items]  # (|keep_items|, D)
    item_kmeans = KMeans(n_clusters=K_ITEMS, random_state=KMEANS_RANDOM_STATE, n_init=10)
    item_labels_final = item_kmeans.fit_predict(I_final)
else:
    item_kmeans = None
    item_labels_final = None
    print("[warn] ÏµúÏ¢Ö step ÏïÑÏù¥ÌÖúÏù¥ ÏóÜÏñ¥ item clustering ÏÉùÎûµ")

# =========================
# 5) t-SNEÏóê ÎÑ£ÏùÑ Îç∞Ïù¥ÌÑ∞ Î∏îÎ°ù Íµ¨ÏÑ± (users + items across steps)
# =========================
print("Preparing t-SNE blocks...")
ids_for_tsne = {}      # step -> user_id Î∞∞Ïó¥(ÏÉòÌîåÎßÅ Î∞òÏòÅ)
users_for_tsne = {}    # step -> user_emb Î∞∞Ïó¥
item_for_tsne  = {}    # step -> item_emb Î∞∞Ïó¥ (Í≥µÌÜµ keep_itemsÏóê Ìï¥Îãπ)

# Ïú†Ï†Ä ÏÉòÌîåÎßÅ Í∏∞Ï§Ä: ÏµúÏ¢Ö stepÏùò ids_final ÏßëÌï©
if SUBSAMPLE_USERS_FOR_TSNE is not None and SUBSAMPLE_USERS_FOR_TSNE < len(ids_final):
    sample_users = np.sort(rng.choice(ids_final, size=SUBSAMPLE_USERS_FOR_TSNE, replace=False))
else:
    sample_users = ids_final

for s in STEPS:
    ids_s = aligned_users_by_step[s]
    if len(ids_s) == 0:
        ids_for_tsne[s] = np.array([], dtype=np.int64)
        users_for_tsne[s] = np.empty((0, D), dtype=np.float32)
    else:
        mask = np.isin(ids_s, sample_users)
        ids_for_tsne[s] = ids_s[mask]
        users_for_tsne[s] = user_by_step[s][ids_for_tsne[s], :]

    I = item_by_step[s]
    item_for_tsne[s] = None if I is None or len(keep_items) == 0 else I[keep_items]

# =========================
# 6) t-SNE ÏàòÌñâ (Î™®Îì† stepÏùò Ïú†Ï†Ä/ÏïÑÏù¥ÌÖú Ìï©Ï≥êÏÑú)
# =========================
print("Fitting joint t-SNE on users + items across steps")
blocks = []
slices = {}  # step -> {"user": (st,en), "item": (st,en) or None}
cursor = 0

for s in STEPS:
    Ublk = users_for_tsne[s]
    blocks.append(Ublk)
    u_st, u_en = cursor, cursor + len(Ublk)
    cursor = u_en

    Iblk = item_for_tsne[s]
    if Iblk is not None and len(Iblk) > 0:
        blocks.append(Iblk)
        i_st, i_en = cursor, cursor + len(Iblk)
        cursor = i_en
        slices[s] = {"user": (u_st, u_en), "item": (i_st, i_en)}
    else:
        slices[s] = {"user": (u_st, u_en), "item": None}

if len(blocks) == 0 or sum(len(b) for b in blocks) < 3:
    raise RuntimeError("t-SNEÏóê ÏÇ¨Ïö©Ìï† ÌëúÎ≥∏Ïù¥ ÎÑàÎ¨¥ Ï†ÅÏäµÎãàÎã§.")

X_all = np.vstack(blocks)
perp = min(TSNE_PERPLEXITY, max(5, (len(X_all) - 1)//3))
tsne = TSNE(
    n_components=2,
    perplexity=perp,
    init="pca",
    learning_rate="auto",
    random_state=TSNE_RANDOM_STATE,
    max_iter=1000,
    verbose=1
)
X_tsne = tsne.fit_transform(X_all)

# Ï¢åÌëú Î≤îÏúÑ(Í≥µÌÜµ Ï∂ï)
xmin, ymin = X_tsne.min(axis=0)
xmax, ymax = X_tsne.max(axis=0)
xpad = 0.05 * (xmax - xmin) if xmax > xmin else 0.5
ypad = 0.05 * (ymax - ymin) if ymax > ymin else 0.5

# ÏµúÏ¢Ö step Ï§ëÏã¨ Ï¢åÌëú(ÏÇ¨Ïö©Ïûê/ÏïÑÏù¥ÌÖú)
palette = ["tab:blue","tab:orange","tab:green","tab:red","tab:purple",
           "tab:brown","tab:pink","tab:gray","tab:olive","tab:cyan"]

def compute_centers(coords: np.ndarray, labels: np.ndarray, K: int):
    centers = {}
    for k in range(K):
        mask = (labels == k)
        if np.any(mask):
            centers[k] = coords[mask].mean(axis=0)
    return centers

# ÏµúÏ¢Ö step ÏÇ¨Ïö©Ïûê/ÏïÑÏù¥ÌÖú Ï¢åÌëú Ïä¨ÎùºÏù¥Ïä§
u_st_f, u_en_f = slices[CLUSTER_STEP]["user"]
coords_u_final = X_tsne[u_st_f:u_en_f]
# user_labels_finalÎäî ids_final(=ÏÉòÌîåÎßÅ Ï†Ñ ÏµúÏ¢Ö step Í≥µÌÜµÏú†Ï†Ä)Ïóê ÎåÄÌïú Î†àÏù¥Î∏î
# t-SNEÏóêÎäî sample_usersÎßå Ìè¨Ìï®ÎêêÏùÑ Ïàò ÏûàÏúºÎØÄÎ°ú, ids_for_tsne[CLUSTER_STEP] ÏàúÏÑúÏóê ÎßûÏ∂∞ Î†àÏù¥Î∏î Ïû¨Ï†ïÎ†¨
ids_tsne_final = ids_for_tsne[CLUSTER_STEP]
# ids_tsne_finalÏùò Í∞Å user_idÍ∞Ä ids_final ÎÇ¥ÏóêÏÑúÏùò indexÎ•º Ï∞æÏïÑ Îß§Ìïë
idx_map_final = {int(u): i for i, u in enumerate(ids_final.tolist())}
labels_u_for_plot = np.array([user_labels_final[idx_map_final[int(u)]] for u in ids_tsne_final], dtype=int)
user_centers = compute_centers(coords_u_final, labels_u_for_plot, K_USERS)

if slices[CLUSTER_STEP]["item"] is not None and item_labels_final is not None:
    i_st_f, i_en_f = slices[CLUSTER_STEP]["item"]
    coords_i_final = X_tsne[i_st_f:i_en_f]
    item_centers = compute_centers(coords_i_final, item_labels_final, K_ITEMS)
else:
    item_centers = {}

# =========================
# 7) Í∑∏Î¶º Ï†ÄÏû• ‚Äî Users only / Items only / Users+Items
# =========================
def plot_users_only():
    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))
    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)
        u_st, u_en = slices[s]["user"]
        coords_u = X_tsne[u_st:u_en]
        # ÏµúÏ¢Ö step Í∏∞Ï§Ä ÎùºÎ≤®Î°ú ÏÉâÏπ†
        if s == CLUSTER_STEP:
            labels_plot = labels_u_for_plot
        else:
            # Îã§Î•∏ stepÏùÄ Í∞ôÏùÄ ÏÇ¨Ïö©Ïûê(ÏÉòÌîåÎßÅÎêú ids)Ïóê ÎåÄÌï¥ ÎèôÏùºÌïú ÎùºÎ≤®ÏùÑ ÏÇ¨Ïö©
            # ids_for_tsne[s] -> ids_tsne_finalÎ°ú Îß§Ìïë
            ids_tsne_s = ids_for_tsne[s]
            labels_plot = np.empty((len(ids_tsne_s),), dtype=int)
            for i, uid in enumerate(ids_tsne_s):
                labels_plot[i] = labels_u_for_plot[np.where(ids_tsne_final == uid)[0][0]]
        for k in range(K_USERS):
            mk = (labels_plot == k)
            if np.any(mk):
                ax.scatter(coords_u[mk,0], coords_u[mk,1], s=8, alpha=0.7,
                           c=palette[k % len(palette)], label=f"Users: C{k}")
        if s == CLUSTER_STEP:
            for k, ctr in user_centers.items():
                ax.scatter(ctr[0], ctr[1], s=120, marker="X",
                           c=palette[k % len(palette)], edgecolor="k")
        ax.set_title(f"step{s} (Users only, colored by step{CLUSTER_STEP})")
        ax.set_xticks([]); ax.set_yticks([])
        # Í≥†Ï†ï Ï∂ïÏùÑ Ïì∞Í≥† Ïã∂ÏúºÎ©¥ Ï£ºÏÑù Ìï¥Ï†ú
        # ax.set_xlim(xmin - xpad, xmax + xpad)
        # ax.set_ylim(ymin - ypad, ymax + ypad)
        if idx == 1:
            ax.legend(loc="best", frameon=True)
    plt.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_USERS_ONLY.png")
    plt.savefig(out, dpi=220); plt.close()
    print("  - saved:", out)

def plot_items_only():
    has_any = any(item_for_tsne[s] is not None for s in STEPS) and (item_labels_final is not None)
    if not has_any:
        print("  - Item only plot ÏÉùÎûµ(ÏïÑÏù¥ÌÖú ÏûÑÎ≤†Îî©/ÎùºÎ≤® ÏóÜÏùå)")
        return
    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))
    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)
        if slices[s]["item"] is not None:
            i_st, i_en = slices[s]["item"]
            coords_i = X_tsne[i_st:i_en]
            # ÏµúÏ¢Ö stepÏùò item_labels_finalÏùÑ Í∑∏ÎåÄÎ°ú ÏÇ¨Ïö© (Í≥µÌÜµ keep_itemsÏóê ÎåÄÏùë)
            for k in range(K_ITEMS):
                mk = (item_labels_final == k)
                if np.any(mk):
                    ax.scatter(coords_i[mk,0], coords_i[mk,1], s=5, alpha=0.35,
                               c=palette[(k+2) % len(palette)], label=f"Items: C{k}")
            if s == CLUSTER_STEP and len(item_centers) > 0:
                for k, ctr in item_centers.items():
                    ax.scatter(ctr[0], ctr[1], s=110, marker="D",
                               c=palette[(k+2) % len(palette)], edgecolor="k")
        ax.set_title(f"step{s} (Items only, colored by step{CLUSTER_STEP})")
        ax.set_xticks([]); ax.set_yticks([])
        if idx == 1:
            ax.legend(loc="best", frameon=True)
    plt.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_ITEMS_ONLY.png")
    plt.savefig(out, dpi=220); plt.close()
    print("  - saved:", out)

def plot_users_items():
    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))
    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)
        # ÏïÑÏù¥ÌÖú
        if slices[s]["item"] is not None and item_labels_final is not None:
            i_st, i_en = slices[s]["item"]
            coords_i = X_tsne[i_st:i_en]
            for k in range(K_ITEMS):
                mk = (item_labels_final == k)
                if np.any(mk):
                    ax.scatter(coords_i[mk,0], coords_i[mk,1], s=5, alpha=0.35,
                               c=palette[(k+2) % len(palette)], label=f"Items: C{k}")
            if s == CLUSTER_STEP and len(item_centers) > 0:
                for k, ctr in item_centers.items():
                    ax.scatter(ctr[0], ctr[1], s=110, marker="D",
                               c=palette[(k+2) % len(palette)], edgecolor="k")
        # Ïú†Ï†Ä
        u_st, u_en = slices[s]["user"]
        coords_u = X_tsne[u_st:u_en]
        # stepÎ≥Ñ Ïú†Ï†Ä ÎùºÎ≤® Íµ¨ÏÑ±
        if s == CLUSTER_STEP:
            labels_plot = labels_u_for_plot
        else:
            ids_tsne_s = ids_for_tsne[s]
            labels_plot = np.empty((len(ids_tsne_s),), dtype=int)
            for i, uid in enumerate(ids_tsne_s):
                labels_plot[i] = labels_u_for_plot[np.where(ids_tsne_final == uid)[0][0]]
        for k in range(K_USERS):
            mk = (labels_plot == k)
            if np.any(mk):
                ax.scatter(coords_u[mk,0], coords_u[mk,1], s=8, alpha=0.7,
                           c=palette[k % len(palette)], label=f"Users: C{k}")
        if s == CLUSTER_STEP:
            for k, ctr in user_centers.items():
                ax.scatter(ctr[0], ctr[1], s=120, marker="X",
                           c=palette[k % len(palette)], edgecolor="k")

        ax.set_title(f"step{s} (Users & Items, colored by step{CLUSTER_STEP})")
        ax.set_xlim(xmin - xpad, xmax + xpad)
        ax.set_ylim(ymin - ypad, ymax + ypad)
        ax.set_xticks([]); ax.set_yticks([])
        if idx == 1:
            ax.legend(loc="best", frameon=True)
    plt.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_USERS_ITEMS.png")
    plt.savefig(out, dpi=220); plt.close()
    print("  - saved:", out)

print("Plotting...")
plot_users_only()
plot_items_only()
plot_users_items()
print("Done.")


In [None]:
# ===== 9) Raw: ÌÅ¥Îü¨Ïä§ÌÑ∞ Ï§ëÏã¨ Í∞Ñ Í±∞Î¶¨(stepÎ≥Ñ) Í≥ÑÏÇ∞ & ÏãúÍ∞ÅÌôî (t-SNE Ï†úÍ±∞, 2row√ó1col ÌïòÎÇòÏùò figure) =====
import csv
import numpy as np
import matplotlib.pyplot as plt
import os
from itertools import combinations

# Ï†ÑÏ†ú: ÏúÑÏóêÏÑú Îã§Ïùå Î≥ÄÏàòÎì§Ïù¥ Ïù¥ÎØ∏ Ï†ïÏùòÎêòÏñ¥ ÏûàÏùå
# - STEPS, CLUSTER_STEP, OUT_DIR
# - aligned_users_by_step, user_by_step
# - item_for_tsne (raw item embedding common subset), item_labels_final ÎòêÎäî None
# - ids_final, user_labels_final
# - K_USERS, K_ITEMS

_user_label_map_final = {int(u): int(l) for u, l in zip(ids_final.tolist(), user_labels_final.tolist())}

def _labels_for_user_ids(user_ids: np.ndarray) -> np.ndarray:
    if len(user_ids) == 0:
        return np.empty((0,), dtype=int)
    out = np.full((len(user_ids),), -1, dtype=int)
    for i, u in enumerate(user_ids.tolist()):
        out[i] = _user_label_map_final.get(int(u), -1)
    return out

def _centers_from_points(points: np.ndarray, labels: np.ndarray, K: int) -> dict:
    centers = {}
    for k in range(K):
        m = (labels == k)
        centers[k] = points[m].mean(axis=0) if np.any(m) else None
    return centers

def _dist(a, b):
    if a is None or b is None:
        return np.nan
    return float(np.linalg.norm(a - b))

def _pairwise_center_stats(centers: dict):
    vecs = [v for v in centers.values() if v is not None]
    if len(vecs) < 2:
        return (np.nan, np.nan, np.nan)
    dists = [np.linalg.norm(a - b) for a, b in combinations(vecs, 2)]
    return (float(np.mean(dists)), float(np.min(dists)), float(np.max(dists)))

def _ui_stats(centers_u: dict, centers_i: dict):
    vec_u = [v for v in centers_u.values() if v is not None]
    vec_i = [v for v in centers_i.values() if v is not None]
    if len(vec_u) == 0 or len(vec_i) == 0:
        return (np.nan, np.nan, np.nan)
    d = []
    for u in vec_u:
        for i in vec_i:
            d.append(np.linalg.norm(u - i))
    return (float(np.mean(d)), float(np.min(d)), float(np.max(d)))

rows = []
for s in STEPS:
    # ---- raw space: user centers ----
    ids_u_raw = aligned_users_by_step[s]  # (Nr_s,)
    if len(ids_u_raw) > 0:
        emb_u_raw = user_by_step[s][ids_u_raw, :]
    else:
        emb_u_raw = np.empty((0, user_by_step[s].shape[1]))
    labels_u_raw = _labels_for_user_ids(ids_u_raw)
    centers_u_raw = _centers_from_points(emb_u_raw, labels_u_raw, K_USERS)

    # ---- raw space: item centers (keep_items common subset) ----
    if item_for_tsne[s] is not None and item_labels_final is not None:
        emb_i_raw = item_for_tsne[s]          # (Ni_common, d)
        labels_i_raw = item_labels_final      # (Ni_common,)
        centers_i_raw = _centers_from_points(emb_i_raw, labels_i_raw, K_ITEMS)
    else:
        centers_i_raw = {k: None for k in range(K_ITEMS)}

    # ---- summary stats (raw) ----
    uu_raw_mean, uu_raw_min, uu_raw_max = _pairwise_center_stats(centers_u_raw)
    ii_raw_mean, ii_raw_min, ii_raw_max = _pairwise_center_stats(centers_i_raw)

    row = {
        "step": s,
        "user_dist_raw_mean": uu_raw_mean,
        "user_dist_raw_min":  uu_raw_min,
        "user_dist_raw_max":  uu_raw_max,
        "item_dist_raw_mean": ii_raw_mean,
        "item_dist_raw_min":  ii_raw_min,
        "item_dist_raw_max":  ii_raw_max,
    }

    # K==2 ÏÉÅÏÑ∏
    if K_USERS == 2:
        row["user_dist_raw_U0U1"] = _dist(centers_u_raw.get(0), centers_u_raw.get(1))
    if K_ITEMS == 2:
        row["item_dist_raw_I0I1"] = _dist(centers_i_raw.get(0), centers_i_raw.get(1))

    # user-item distances (raw)
    if K_USERS == 2 and K_ITEMS == 2:
        row.update({
            "ui_raw_U0_I0": _dist(centers_u_raw.get(0), centers_i_raw.get(0)),
            "ui_raw_U0_I1": _dist(centers_u_raw.get(0), centers_i_raw.get(1)),
            "ui_raw_U1_I0": _dist(centers_u_raw.get(1), centers_i_raw.get(0)),
            "ui_raw_U1_I1": _dist(centers_u_raw.get(1), centers_i_raw.get(1)),
        })
    else:
        m, mn, mx = _ui_stats(centers_u_raw, centers_i_raw)
        row.update({
            "ui_raw_mean": m,
            "ui_raw_min":  mn,
            "ui_raw_max":  mx,
        })

    rows.append(row)

# ---- CSV Ï†ÄÏû• ----
dist_csv = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_center_distances_over_steps_RAW_ONLY.csv")
with open(dist_csv, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
    w.writeheader()
    w.writerows(rows)
print(f"  - saved: {dist_csv}")

# ---- Plot: 2row√ó1col in ONE figure (raw only) ----
steps_ax = [r["step"] for r in rows]

ud_r_mean = [r["user_dist_raw_mean"] for r in rows]
id_r_mean = [r["item_dist_raw_mean"] for r in rows]

fig, axes = plt.subplots(2, 1, figsize=(9, 7), sharex=True)

# (row 1) user-user / item-item raw mean
axes[0].plot(steps_ax, ud_r_mean, marker="o", label="Users (raw, mean)")
axes[0].plot(steps_ax, id_r_mean, marker="o", label="Items (raw, mean)")
axes[0].set_ylabel("Center distance")
axes[0].set_title("User/User & Item/Item center distance over steps (raw)")
axes[0].legend()

# (row 2) user-item raw
if K_USERS == 2 and K_ITEMS == 2:
    ui_raw_u0i0 = [r.get("ui_raw_U0_I0", np.nan) for r in rows]
    ui_raw_u0i1 = [r.get("ui_raw_U0_I1", np.nan) for r in rows]
    ui_raw_u1i0 = [r.get("ui_raw_U1_I0", np.nan) for r in rows]
    ui_raw_u1i1 = [r.get("ui_raw_U1_I1", np.nan) for r in rows]

    axes[1].plot(steps_ax, ui_raw_u0i0, marker="s", label="U0‚ÄìI0 (raw)")
    axes[1].plot(steps_ax, ui_raw_u0i1, marker="s", label="U0‚ÄìI1 (raw)")
    axes[1].plot(steps_ax, ui_raw_u1i0, marker="s", label="U1‚ÄìI0 (raw)")
    axes[1].plot(steps_ax, ui_raw_u1i1, marker="s", label="U1‚ÄìI1 (raw)")
    axes[1].legend(ncol=2)
else:
    ui_raw_mean = [r.get("ui_raw_mean", np.nan) for r in rows]
    ui_raw_min  = [r.get("ui_raw_min",  np.nan) for r in rows]
    ui_raw_max  = [r.get("ui_raw_max",  np.nan) for r in rows]

    axes[1].plot(steps_ax, ui_raw_mean, marker="s", label="mean (raw)")
    axes[1].plot(steps_ax, ui_raw_min,  marker="s", label="min (raw)")
    axes[1].plot(steps_ax, ui_raw_max,  marker="s", label="max (raw)")
    axes[1].legend()

axes[1].set_xlabel("Step")
axes[1].set_ylabel("Center distance")
axes[1].set_title("User‚ÄìItem center distances over steps (raw)")

plt.tight_layout()
dist_png = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_center_distances_over_steps_RAW_ONLY.png")
plt.savefig(dist_png, dpi=220)
plt.close()
print(f"  - saved: {dist_png}")


### 1Ïä§ÌÖùÎ≥ÑÎ°ú

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

# ============================================================
# ÏÇ¨Ïö©Ïûê ÌôòÍ≤ΩÏóêÏÑú ÎØ∏Î¶¨ Ï†ïÏùòÎêòÏñ¥ ÏûàÏñ¥Ïïº ÌïòÎäî Í∞íÎì§
#   - BASE_DIR: ÏûÑÎ≤†Îî©(.npy)Îì§Ïù¥ ÏûàÎäî ÎîîÎ†âÌÜ†Î¶¨
#   - PARTS: step Í∞úÏàò (Ïòà: 5)
#   - common_users: train‚à©label Í≥µÌÜµ Ïú†Ï†Ä id ÏßëÌï©/Î¶¨Ïä§Ìä∏/ndarray
# ============================================================
# BASE_DIR = "..."
# PARTS = 5
# common_users = ...

# =========================
# Í≤ΩÎ°ú/ÏÑ§Ï†ï
# =========================
STEPS = list(range(PARTS))
CLUSTER_STEP = STEPS[-1]  # ÏµúÏ¢Ö stepÏóêÏÑú KMeans ÏàòÌñâ
OUT_DIR = os.path.join(BASE_DIR, "figs_tsne_cluster_progress")
os.makedirs(OUT_DIR, exist_ok=True)

TSNE_PERPLEXITY = 30
TSNE_RANDOM_STATE = 42
KMEANS_RANDOM_STATE = 42
K_USERS = 2
K_ITEMS = 2

# (ÏÑ†ÌÉù) t-SNE Í∞ÄÏÜç ÏÉòÌîåÎßÅ
SUBSAMPLE_USERS_FOR_TSNE = None  # Ïòà: 5000
SUBSAMPLE_ITEMS_FOR_TSNE = None  # Ïòà: 5000

# (Ï†ÑÏó≠) t-SNE 2DÏóêÏÑú Ï§ëÏã¨ÏúºÎ°úÎ∂ÄÌÑ∞ Í±∞Î¶¨ ÏÉÅÏúÑ p%Î•º "ÌîåÎ°úÌåÖÏóêÏÑúÎßå" Ï†úÍ±∞
GLOBAL_TRIM_TOP_PCT = 0.0   # Ï†ÑÏó≠ trimming ÏõêÌïòÎ©¥ 0.005 Îì±ÏúºÎ°ú (0Ïù¥Î©¥ Ï†ÑÏó≠ trim off)

# (ÌïµÏã¨) stepÎ≥Ñ(local) trimming: step1Ïùò Ïô∏Îî¥ Ï†ê Í∞ôÏùÄ Í±∏ Ï†úÍ±∞
STEP_TRIM_TOP_PCT = 0.01        # Í∏∞Î≥∏: Í∞Å stepÏóêÏÑú ÏÉÅÏúÑ 0.2% Ï†úÍ±∞
STEP_TRIM_TOP_PCT_STEP1 = 0.01  # step==1Îßå Îçî ÏÑ∏Í≤å(0.5%) Ï†úÍ±∞

# =========================
# Ïú†Ìã∏
# =========================
def load_users_items_for_step(base_dir: str, parts: int, step: int):
    upath = os.path.join(base_dir, f"user_emb_part{parts}_step{step}.npy")
    ipath = os.path.join(base_dir, f"item_emb_part{parts}_step{step}.npy")
    if not os.path.exists(upath):
        raise FileNotFoundError(upath)
    U = np.load(upath)  # (num_users_step, d)
    I = np.load(ipath) if os.path.exists(ipath) else None  # (num_items_step, d) or None
    if U.ndim != 2:
        raise ValueError(f"user_emb shape invalid at step{step}: {U.shape}")
    if I is not None and I.ndim != 2:
        raise ValueError(f"item_emb shape invalid at step{step}: {I.shape}")
    return U, I


def align_users_for_step(common_users, num_users_step: int) -> np.ndarray:
    if isinstance(common_users, (set, list, tuple)):
        cu = np.asarray(sorted(common_users), dtype=np.int64)
    else:
        cu = np.asarray(common_users, dtype=np.int64)

    if cu.size == 0:
        return cu

    mask = (cu >= 0) & (cu < num_users_step)
    return cu[mask]


def compute_centers(coords: np.ndarray, labels: np.ndarray, K: int):
    centers = {}
    for k in range(K):
        mk = (labels == k)
        if np.any(mk):
            centers[k] = coords[mk].mean(axis=0)
    return centers


def stepwise_trim_mask(coords2d: np.ndarray, top_pct: float) -> np.ndarray:
    """
    coords2d: (N,2)
    top_pct: Ïù¥ stepÏóêÏÑú Ï§ëÏã¨ÏúºÎ°úÎ∂ÄÌÑ∞ Í±∞Î¶¨ ÏÉÅÏúÑ top_pct Ï†úÍ±∞ (plot-only)
    return: keep mask (True=keep)
    """
    n = 0 if coords2d is None else len(coords2d)
    if n == 0:
        return np.zeros((0,), dtype=bool)
    if top_pct is None or top_pct <= 0:
        return np.ones((n,), dtype=bool)
    if top_pct >= 1.0:
        return np.zeros((n,), dtype=bool)

    ctr = np.median(coords2d, axis=0)
    dist = np.linalg.norm(coords2d - ctr, axis=1)

    q = 1.0 - top_pct
    q = min(max(q, 0.0), 1.0)
    thr = np.quantile(dist, q)
    return dist <= thr


# =========================
# 1) Í∞Å step ÏûÑÎ≤†Îî© Î°úÎìú
# =========================
print("Loading augmentation embeddings per step...")
user_by_step = {}
item_by_step = {}
dims = set()

for s in STEPS:
    U, I = load_users_items_for_step(BASE_DIR, PARTS, s)
    user_by_step[s] = U
    item_by_step[s] = I
    dims.add(U.shape[1])
    if I is not None:
        dims.add(I.shape[1])
    print(f"  - step{s}: users={U.shape}, items={None if I is None else I.shape}")

if len(dims) != 1:
    raise RuntimeError(f"ÏûÑÎ≤†Îî© Ï∞®ÏõêÏù¥ stepÎ≥ÑÎ°ú Îã§Î¶ÖÎãàÎã§: {dims}")
D = dims.pop()

# =========================
# 2) train‚à©label Í≥µÌÜµ Ïú†Ï†Ä ÏÇ∞Ï∂ú + stepÎ≥Ñ Ï†ïÎ†¨
# =========================
print(f"common users: {len(common_users)}")
aligned_users_by_step = {
    s: align_users_for_step(common_users, user_by_step[s].shape[0])
    for s in STEPS
}
print({s: len(aligned_users_by_step[s]) for s in STEPS})

# =========================
# 3) ÏïÑÏù¥ÌÖú Í≥µÌÜµ Î∂ÄÎ∂Ñ Ï†ïÎ†¨(ÏµúÏÜå Í∏∏Ïù¥)
# =========================
item_lengths = [item_by_step[s].shape[0] for s in STEPS if item_by_step[s] is not None]
if len(item_lengths) == 0:
    I_base = 0
else:
    I_base = min(item_lengths)
    uniq = set(item_lengths)
    if len(uniq) != 1:
        print(f"[warn] stepÎ≥Ñ ÏïÑÏù¥ÌÖú ÏàòÍ∞Ä Îã§Î¶ÖÎãàÎã§: {uniq} ‚Üí ÏïûÏóêÏÑú {I_base}Í∞úÎßå Í≥µÌÜµ ÏÇ¨Ïö©")
    for s in STEPS:
        I = item_by_step[s]
        item_by_step[s] = None if (I is None or I.shape[0] < I_base) else I[:I_base]

rng = np.random.RandomState(0)
if I_base > 0:
    if SUBSAMPLE_ITEMS_FOR_TSNE is not None and SUBSAMPLE_ITEMS_FOR_TSNE < I_base:
        keep_items = np.sort(rng.choice(I_base, size=SUBSAMPLE_ITEMS_FOR_TSNE, replace=False))
    else:
        keep_items = np.arange(I_base, dtype=int)
else:
    keep_items = np.array([], dtype=int)

# =========================
# 4) ÏµúÏ¢Ö stepÏóêÏÑú KMeans ‚Üí Î†àÏù¥Î∏î ÏÉùÏÑ±
# =========================
print(f"KMeans on final step: step{CLUSTER_STEP}")
U_final = user_by_step[CLUSTER_STEP]
ids_final = aligned_users_by_step[CLUSTER_STEP]
if len(ids_final) < K_USERS:
    raise RuntimeError(f"ÏµúÏ¢Ö stepÏóêÏÑú ÌÅ¥Îü¨Ïä§ÌÑ∞ÎßÅ Í∞ÄÎä•Ìïú Í≥µÌÜµ Ïú†Ï†ÄÍ∞Ä Î∂ÄÏ°±Ìï©ÎãàÎã§: {len(ids_final)}")

emb_final_users = U_final[ids_final, :]  # (|ids_final|, D)
user_kmeans = KMeans(n_clusters=K_USERS, random_state=KMEANS_RANDOM_STATE, n_init=10)
user_labels_final_full = user_kmeans.fit_predict(emb_final_users)  # ids_final ÏàúÏÑú Í∏∞Ï§Ä

# ÏïÑÏù¥ÌÖú KMeans (ÏµúÏ¢Ö step Í≥µÌÜµ ÏïÑÏù¥ÌÖú)
if item_by_step[CLUSTER_STEP] is not None and len(keep_items) > 0:
    I_final = item_by_step[CLUSTER_STEP][keep_items, :]  # (|keep_items|, D)
    item_kmeans = KMeans(n_clusters=K_ITEMS, random_state=KMEANS_RANDOM_STATE, n_init=10)
    item_labels_final_full = item_kmeans.fit_predict(I_final)  # keep_items ÏàúÏÑú Í∏∞Ï§Ä
else:
    item_kmeans = None
    item_labels_final_full = None
    print("[warn] ÏµúÏ¢Ö step ÏïÑÏù¥ÌÖúÏù¥ ÏóÜÏñ¥ item clustering ÏÉùÎûµ")

# ids_final uid -> label Îß§Ìïë (ÏµúÏ¢Ö step Í∏∞Ï§Ä)
uid2label = {int(uid): int(lb) for uid, lb in zip(ids_final.tolist(), user_labels_final_full.tolist())}

# =========================
# 5) t-SNEÏóê ÎÑ£ÏùÑ Îç∞Ïù¥ÌÑ∞ Î∏îÎ°ù Íµ¨ÏÑ± (users + items across steps)
# =========================
print("Preparing t-SNE blocks...")
ids_for_tsne = {}      # step -> user_id Î∞∞Ïó¥(ÏÉòÌîåÎßÅ Î∞òÏòÅ)
users_for_tsne = {}    # step -> user_emb Î∞∞Ïó¥
item_for_tsne  = {}    # step -> item_emb Î∞∞Ïó¥ (Í≥µÌÜµ keep_itemsÏóê Ìï¥Îãπ)

# Ïú†Ï†Ä ÏÉòÌîåÎßÅ Í∏∞Ï§Ä: ÏµúÏ¢Ö stepÏùò ids_final ÏßëÌï©
if SUBSAMPLE_USERS_FOR_TSNE is not None and SUBSAMPLE_USERS_FOR_TSNE < len(ids_final):
    sample_users = np.sort(rng.choice(ids_final, size=SUBSAMPLE_USERS_FOR_TSNE, replace=False))
else:
    sample_users = ids_final

for s in STEPS:
    ids_s = aligned_users_by_step[s]
    if len(ids_s) == 0:
        ids_for_tsne[s] = np.array([], dtype=np.int64)
        users_for_tsne[s] = np.empty((0, D), dtype=np.float32)
    else:
        mask = np.isin(ids_s, sample_users)
        ids_for_tsne[s] = ids_s[mask]
        users_for_tsne[s] = user_by_step[s][ids_for_tsne[s], :]

    I = item_by_step[s]
    item_for_tsne[s] = None if (I is None or len(keep_items) == 0) else I[keep_items, :]

# =========================
# 6) t-SNE ÏàòÌñâ (Î™®Îì† stepÏùò Ïú†Ï†Ä/ÏïÑÏù¥ÌÖú Ìï©Ï≥êÏÑú)
# =========================
print("Fitting joint t-SNE on users + items across steps")
blocks = []
slices = {}  # step -> {"user": (st,en), "item": (st,en) or None}
cursor = 0

for s in STEPS:
    Ublk = users_for_tsne[s]
    blocks.append(Ublk)
    u_st, u_en = cursor, cursor + len(Ublk)
    cursor = u_en

    Iblk = item_for_tsne[s]
    if Iblk is not None and len(Iblk) > 0:
        blocks.append(Iblk)
        i_st, i_en = cursor, cursor + len(Iblk)
        cursor = i_en
        slices[s] = {"user": (u_st, u_en), "item": (i_st, i_en)}
    else:
        slices[s] = {"user": (u_st, u_en), "item": None}

if len(blocks) == 0 or sum(len(b) for b in blocks) < 3:
    raise RuntimeError("t-SNEÏóê ÏÇ¨Ïö©Ìï† ÌëúÎ≥∏Ïù¥ ÎÑàÎ¨¥ Ï†ÅÏäµÎãàÎã§.")

X_all = np.vstack(blocks)
perp = min(TSNE_PERPLEXITY, max(5, (len(X_all) - 1)//3))
tsne = TSNE(
    n_components=2,
    perplexity=perp,
    init="pca",
    learning_rate="auto",
    random_state=TSNE_RANDOM_STATE,
    max_iter=1000,
    verbose=1
)
X_tsne = tsne.fit_transform(X_all)

# =========================
# (ÏÑ†ÌÉù) Ï†ÑÏó≠ trimming (plot-only)
# =========================
if GLOBAL_TRIM_TOP_PCT is not None and GLOBAL_TRIM_TOP_PCT > 0:
    ctr2d = np.median(X_tsne, axis=0)
    dist2d = np.linalg.norm(X_tsne - ctr2d, axis=1)
    thr = np.quantile(dist2d, 1.0 - GLOBAL_TRIM_TOP_PCT)
    global_keep_mask = dist2d <= thr
    print(f"[global-trim] keep {global_keep_mask.sum()}/{len(global_keep_mask)} "
          f"({100.0 * global_keep_mask.mean():.2f}%) points")
else:
    global_keep_mask = np.ones((X_tsne.shape[0],), dtype=bool)

# Ï∂ï Î≤îÏúÑ: Ï†ÑÏó≠ keep Í∏∞Ï§Ä(ÏóÜÏúºÎ©¥ Ï†ÑÏ≤¥)
X_plot = X_tsne[global_keep_mask]
xmin, ymin = X_plot.min(axis=0)
xmax, ymax = X_plot.max(axis=0)
xpad = 0.05 * (xmax - xmin) if xmax > xmin else 0.5
ypad = 0.05 * (ymax - ymin) if ymax > ymin else 0.5

palette = ["tab:blue","tab:orange","tab:green","tab:red","tab:purple",
           "tab:brown","tab:pink","tab:gray","tab:olive","tab:cyan"]

# =========================
# 6.5) ÏµúÏ¢Ö step ÏÇ¨Ïö©Ïûê/ÏïÑÏù¥ÌÖúÏùò (plotÏö©) ÎùºÎ≤®/ÏÑºÌÑ∞ Ï§ÄÎπÑ
#     - center Í≥ÑÏÇ∞ÎèÑ trimming(Ï†ÑÏó≠/Î°úÏª¨) Ï†ÅÏö©Ìï¥ÏÑú ÏùºÍ¥ÄÎêòÍ≤å
# =========================
# final users slice
u_st_f, u_en_f = slices[CLUSTER_STEP]["user"]
coords_u_final_all = X_tsne[u_st_f:u_en_f]
ids_tsne_final_all = ids_for_tsne[CLUSTER_STEP]
labels_u_final_all = np.array([uid2label[int(uid)] for uid in ids_tsne_final_all], dtype=int)

# Ï†ÑÏó≠ trim Ï†ÅÏö©
g_keep_u_f = global_keep_mask[u_st_f:u_en_f]
coords_u_final_g = coords_u_final_all[g_keep_u_f]
ids_tsne_final_g = ids_tsne_final_all[g_keep_u_f]
labels_u_final_g = labels_u_final_all[g_keep_u_f]

# ÏµúÏ¢Ö stepÏóê ÎåÄÌï¥ Î°úÏª¨ trimÎèÑ Ï†ÅÏö© (Í∏∞Î≥∏ step Í∑úÏπô)
local_pct_f = STEP_TRIM_TOP_PCT_STEP1 if CLUSTER_STEP == 1 else STEP_TRIM_TOP_PCT
l_keep_u_f = stepwise_trim_mask(coords_u_final_g, top_pct=local_pct_f)
coords_u_final = coords_u_final_g[l_keep_u_f]
ids_tsne_final = ids_tsne_final_g[l_keep_u_f]
labels_u_for_plot = labels_u_final_g[l_keep_u_f]

user_centers = compute_centers(coords_u_final, labels_u_for_plot, K_USERS)

# final items slice
if slices[CLUSTER_STEP]["item"] is not None and item_labels_final_full is not None:
    i_st_f, i_en_f = slices[CLUSTER_STEP]["item"]
    coords_i_final_all = X_tsne[i_st_f:i_en_f]
    labels_i_final_all = item_labels_final_full

    # Ï†ÑÏó≠ trim
    g_keep_i_f = global_keep_mask[i_st_f:i_en_f]
    coords_i_final_g = coords_i_final_all[g_keep_i_f]
    labels_i_final_g = labels_i_final_all[g_keep_i_f]

    # Î°úÏª¨ trim(ÏµúÏ¢Ö step)
    local_pct_i_f = STEP_TRIM_TOP_PCT_STEP1 if CLUSTER_STEP == 1 else STEP_TRIM_TOP_PCT
    l_keep_i_f = stepwise_trim_mask(coords_i_final_g, top_pct=local_pct_i_f)
    coords_i_final = coords_i_final_g[l_keep_i_f]
    labels_i_final = labels_i_final_g[l_keep_i_f]

    item_centers = compute_centers(coords_i_final, labels_i_final, K_ITEMS)
else:
    item_centers = {}

# final uid -> position (trimmed final Í∏∞Ï§Ä)
pos_map_final = {int(uid): i for i, uid in enumerate(ids_tsne_final.tolist())}

# =========================
# 7) Í∑∏Î¶º Ï†ÄÏû• ‚Äî Users only / Items only / Users+Items
# =========================
def plot_users_only():
    step_dir = os.path.join(OUT_DIR, "step_by_step")
    os.makedirs(step_dir, exist_ok=True)

    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))

    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)
        fig_s, ax_s = plt.subplots(figsize=(6, 5))

        u_st, u_en = slices[s]["user"]
        coords_u_all = X_tsne[u_st:u_en]
        ids_tsne_s_all = ids_for_tsne[s]

        if len(ids_tsne_s_all) == 0:
            ax.set_xticks([]); ax.set_yticks([])
            plt.close(fig_s)
            continue

        # (1) Ï†ÑÏó≠ trim
        g_keep = global_keep_mask[u_st:u_en]
        coords_u_g = coords_u_all[g_keep]
        ids_u_g = ids_tsne_s_all[g_keep]

        if len(ids_u_g) == 0:
            ax.set_xticks([]); ax.set_yticks([])
            plt.close(fig_s)
            continue

        # (2) stepÎ≥Ñ Î°úÏª¨ trim
        local_pct = STEP_TRIM_TOP_PCT_STEP1 if s == 1 else STEP_TRIM_TOP_PCT
        l_keep = stepwise_trim_mask(coords_u_g, top_pct=local_pct)
        coords_u = coords_u_g[l_keep]
        ids_tsne_s = ids_u_g[l_keep]

        if len(ids_tsne_s) == 0:
            ax.set_xticks([]); ax.set_yticks([])
            plt.close(fig_s)
            continue

        # final Í∏∞Ï§Ä ÎùºÎ≤®Î°ú ÏÉâÏπ†: finalÏóê ÏóÜÎäî uidÎäî Î≤ÑÎ¶º
        valid_mask = np.array([int(uid) in pos_map_final for uid in ids_tsne_s], dtype=bool)
        coords_valid = coords_u[valid_mask]
        ids_valid = ids_tsne_s[valid_mask]

        if len(ids_valid) == 0:
            ax.set_xticks([]); ax.set_yticks([])
            plt.close(fig_s)
            continue

        labels_plot = np.array([labels_u_for_plot[pos_map_final[int(uid)]] for uid in ids_valid], dtype=int)

        for k in range(K_USERS):
            mk = (labels_plot == k)
            if np.any(mk):
                ax.scatter(coords_valid[mk, 0], coords_valid[mk, 1],
                           s=8, alpha=0.7, c=palette[k % len(palette)],
                           label=f"Users: C{k}")
                ax_s.scatter(coords_valid[mk, 0], coords_valid[mk, 1],
                             s=10, alpha=0.7, c=palette[k % len(palette)],
                             label=f"Users: C{k}")

        if s == CLUSTER_STEP:
            for k, ctr in user_centers.items():
                ax.scatter(ctr[0], ctr[1], s=120, marker="X",
                           c=palette[k % len(palette)], edgecolor="k")

        for a in (ax, ax_s):
            #a.set_xlim(xmin - xpad, xmax + xpad)
            #a.set_ylim(ymin - ypad, ymax + ypad)
            a.set_xticks([]); a.set_yticks([])

        if idx == 1:
            ax.legend(loc="best", frameon=True)
        ax_s.legend(loc="best", frameon=True)

        fig_s.tight_layout()
        out_s = os.path.join(step_dir, f"step{s}_tsne2d_USERS_ONLY.png")
        fig_s.savefig(out_s, dpi=220)
        plt.close(fig_s)
        print("  - saved:", out_s)

    fig.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_USERS_ONLY.png")
    fig.savefig(out, dpi=220)
    plt.close(fig)
    print("  - saved:", out)


def plot_items_only():
    has_any = any(item_for_tsne[s] is not None for s in STEPS) and (item_labels_final_full is not None)
    if not has_any:
        print("  - Item only plot ÏÉùÎûµ(ÏïÑÏù¥ÌÖú ÏûÑÎ≤†Îî©/ÎùºÎ≤® ÏóÜÏùå)")
        return

    step_dir = os.path.join(OUT_DIR, "step_by_step")
    os.makedirs(step_dir, exist_ok=True)

    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))

    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)
        fig_s, ax_s = plt.subplots(figsize=(5, 4))

        if slices[s]["item"] is not None:
            i_st, i_en = slices[s]["item"]
            coords_i_all = X_tsne[i_st:i_en]
            labels_i_all = item_labels_final_full

            # (1) Ï†ÑÏó≠ trim
            g_keep = global_keep_mask[i_st:i_en]
            coords_i_g = coords_i_all[g_keep]
            labels_i_g = labels_i_all[g_keep]

            # (2) stepÎ≥Ñ Î°úÏª¨ trim
            local_pct = STEP_TRIM_TOP_PCT_STEP1 if s == 1 else STEP_TRIM_TOP_PCT
            l_keep = stepwise_trim_mask(coords_i_g, top_pct=local_pct)
            coords_i = coords_i_g[l_keep]
            labels_i = labels_i_g[l_keep]

            if len(coords_i) > 0:
                for k in range(K_ITEMS):
                    mk = (labels_i == k)
                    if np.any(mk):
                        ax.scatter(coords_i[mk, 0], coords_i[mk, 1],
                                   s=5, alpha=0.35, c=palette[(k+2) % len(palette)],
                                   label=f"Items: C{k}")
                        ax_s.scatter(coords_i[mk, 0], coords_i[mk, 1],
                                     s=5, alpha=0.35, c=palette[(k+2) % len(palette)],
                                     label=f"Items: C{k}")

                if s == CLUSTER_STEP and len(item_centers) > 0:
                    for k, ctr in item_centers.items():
                        ax.scatter(ctr[0], ctr[1], s=110, marker="D",
                                   c=palette[(k+2) % len(palette)], edgecolor="k")

        for a in (ax, ax_s):
            #a.set_xlim(xmin - xpad, xmax + xpad)
            #a.set_ylim(ymin - ypad, ymax + ypad)
            a.set_xticks([]); a.set_yticks([])

        if idx == 1:
            ax.legend(loc="best", frameon=True)
        ax_s.legend(loc="best", frameon=True)

        fig_s.tight_layout()
        out_s = os.path.join(step_dir, f"step{s}_tsne2d_ITEMS_ONLY.png")
        fig_s.savefig(out_s, dpi=220)
        plt.close(fig_s)
        print("  - saved:", out_s)

    fig.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_ITEMS_ONLY.png")
    fig.savefig(out, dpi=220)
    plt.close(fig)
    print("  - saved:", out)


def plot_users_items():
    step_dir = os.path.join(OUT_DIR, "step_by_step")
    os.makedirs(step_dir, exist_ok=True)

    S = len(STEPS)
    cols = min(3, S)
    rows = math.ceil(S / cols)
    fig = plt.figure(figsize=(5*cols, 4*rows))

    for idx, s in enumerate(STEPS, start=1):
        ax = fig.add_subplot(rows, cols, idx)

        # ÏïÑÏù¥ÌÖú
        if slices[s]["item"] is not None and item_labels_final_full is not None:
            i_st, i_en = slices[s]["item"]
            coords_i_all = X_tsne[i_st:i_en]
            labels_i_all = item_labels_final_full

            # (1) Ï†ÑÏó≠ trim
            g_keep = global_keep_mask[i_st:i_en]
            coords_i_g = coords_i_all[g_keep]
            labels_i_g = labels_i_all[g_keep]

            # (2) stepÎ≥Ñ Î°úÏª¨ trim
            local_pct = STEP_TRIM_TOP_PCT_STEP1 if s == 1 else STEP_TRIM_TOP_PCT
            l_keep = stepwise_trim_mask(coords_i_g, top_pct=local_pct)
            coords_i = coords_i_g[l_keep]
            labels_i = labels_i_g[l_keep]

            if len(coords_i) > 0:
                for k in range(K_ITEMS):
                    mk = (labels_i == k)
                    if np.any(mk):
                        ax.scatter(coords_i[mk, 0], coords_i[mk, 1],
                                   s=5, alpha=0.35, c=palette[(k+2) % len(palette)],
                                   label=f"Items: C{k}")
                if s == CLUSTER_STEP and len(item_centers) > 0:
                    for k, ctr in item_centers.items():
                        ax.scatter(ctr[0], ctr[1], s=110, marker="D",
                                   c=palette[(k+2) % len(palette)], edgecolor="k")

        # Ïú†Ï†Ä
        u_st, u_en = slices[s]["user"]
        coords_u_all = X_tsne[u_st:u_en]
        ids_tsne_s_all = ids_for_tsne[s]

        if len(ids_tsne_s_all) > 0:
            # (1) Ï†ÑÏó≠ trim
            g_keep = global_keep_mask[u_st:u_en]
            coords_u_g = coords_u_all[g_keep]
            ids_u_g = ids_tsne_s_all[g_keep]

            # (2) stepÎ≥Ñ Î°úÏª¨ trim
            local_pct = STEP_TRIM_TOP_PCT_STEP1 if s == 1 else STEP_TRIM_TOP_PCT
            l_keep = stepwise_trim_mask(coords_u_g, top_pct=local_pct)
            coords_u = coords_u_g[l_keep]
            ids_tsne_s = ids_u_g[l_keep]

            # final Í∏∞Ï§Ä ÎùºÎ≤®Î°ú ÏÉâÏπ† (finalÏóê ÏóÜÎäî uidÎäî Î≤ÑÎ¶º)
            valid_mask = np.array([int(uid) in pos_map_final for uid in ids_tsne_s], dtype=bool)
            coords_valid = coords_u[valid_mask]
            ids_valid = ids_tsne_s[valid_mask]

            if len(ids_valid) > 0:
                labels_plot = np.array(
                    [labels_u_for_plot[pos_map_final[int(uid)]] for uid in ids_valid],
                    dtype=int
                )

                for k in range(K_USERS):
                    mk = (labels_plot == k)
                    if np.any(mk):
                        ax.scatter(coords_valid[mk, 0], coords_valid[mk, 1],
                                   s=8, alpha=0.7, c=palette[k % len(palette)],
                                   label=f"Users: C{k}")

                if s == CLUSTER_STEP:
                    for k, ctr in user_centers.items():
                        ax.scatter(ctr[0], ctr[1], s=120, marker="X",
                                   c=palette[k % len(palette)], edgecolor="k")

        ax.set_title(f"step{s} (Users & Items, colored by step{CLUSTER_STEP})")
        #ax.set_xlim(xmin - xpad, xmax + xpad)
        #ax.set_ylim(ymin - ypad, ymax + ypad)
        ax.set_xticks([]); ax.set_yticks([])

        if idx == 1:
            ax.legend(loc="best", frameon=True)

    fig.tight_layout()
    out = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_tsne2d_USERS_ITEMS.png")
    fig.savefig(out, dpi=220)
    plt.close(fig)
    print("  - saved:", out)


print("Plotting...")
plot_users_only()
plot_items_only()
plot_users_items()
print("Done.")


### Distance: Step1, Step5

In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt
import os
from itertools import combinations
from typing import Dict, Tuple, Optional

# --- Helper Functions ---

def get_user_label_map(user_ids: np.ndarray, user_labels: np.ndarray) -> Dict[int, int]:
    return {int(u): int(l) for u, l in zip(user_ids.tolist(), user_labels.tolist())}

def get_labels_for_ids(target_ids: np.ndarray, label_map: Dict[int, int]) -> np.ndarray:
    if len(target_ids) == 0:
        return np.empty((0,), dtype=int)
    return np.array([label_map.get(int(u), -1) for u in target_ids.tolist()], dtype=int)

def calculate_centers(points: np.ndarray, labels: np.ndarray, k: int) -> Dict[int, Optional[np.ndarray]]:
    centers = {}
    for cluster_id in range(k):
        mask = (labels == cluster_id)
        centers[cluster_id] = points[mask].mean(axis=0) if np.any(mask) else None
    return centers

def calculate_distance(point_a: Optional[np.ndarray], point_b: Optional[np.ndarray]) -> float:
    if point_a is None or point_b is None:
        return np.nan
    return float(np.linalg.norm(point_a - point_b))

def get_pairwise_stats(centers: Dict[int, Optional[np.ndarray]]) -> Tuple[float, float, float]:
    valid = [v for v in centers.values() if v is not None]
    if len(valid) < 2:
        return np.nan, np.nan, np.nan
    dists = [np.linalg.norm(a - b) for a, b in combinations(valid, 2)]
    return float(np.mean(dists)), float(np.min(dists)), float(np.max(dists))

def get_user_item_stats(centers_u: Dict[int, Optional[np.ndarray]],
                        centers_i: Dict[int, Optional[np.ndarray]]) -> Tuple[float, float, float]:
    vec_u = [v for v in centers_u.values() if v is not None]
    vec_i = [v for v in centers_i.values() if v is not None]
    if not vec_u or not vec_i:
        return np.nan, np.nan, np.nan
    dists = [np.linalg.norm(u - i) for u in vec_u for i in vec_i]
    return float(np.mean(dists)), float(np.min(dists)), float(np.max(dists))

# =========================
# Main (RAW ONLY) - Only first & last step
# =========================

# 0) ÏÇ¨Ïö©Ìï† step: Ï≤´Î≤àÏß∏ÏôÄ ÎßàÏßÄÎßâÎßå
FIRST_STEP = STEPS[0]
LAST_STEP = STEPS[-1]
TARGET_STEPS = [FIRST_STEP, LAST_STEP]

# 1) Prepare label map from final step clustering
user_label_map_final = get_user_label_map(ids_final, user_labels_final)

rows = []

for s in TARGET_STEPS:
    step_data = {"step": s}

    # --- A) RAW: User centers ---
    ids_u_raw = aligned_users_by_step[s]
    emb_u_raw = user_by_step[s][ids_u_raw, :] if len(ids_u_raw) > 0 else np.empty((0, user_by_step[s].shape[1]))
    labels_u_raw = get_labels_for_ids(ids_u_raw, user_label_map_final)
    centers_u_raw = calculate_centers(emb_u_raw, labels_u_raw, K_USERS)

    # --- B) RAW: Item centers ---
    # item_for_tsne[s]Î•º raw item embeddingÏúºÎ°ú Ïì∞Í≥† ÏûàÏóàÏúºÎØÄÎ°ú Í∑∏ÎåÄÎ°ú ÏÇ¨Ïö©
    if item_for_tsne[s] is not None and item_labels_final is not None:
        emb_i_raw = item_for_tsne[s]
        centers_i_raw = calculate_centers(emb_i_raw, item_labels_final, K_ITEMS)
    else:
        centers_i_raw = {k: None for k in range(K_ITEMS)}

    # --- C) Stats (RAW) ---
    # User-User / Item-Item (raw)
    step_data["user_dist_raw_mean"], step_data["user_dist_raw_min"], step_data["user_dist_raw_max"] = get_pairwise_stats(centers_u_raw)
    step_data["item_dist_raw_mean"], step_data["item_dist_raw_min"], step_data["item_dist_raw_max"] = get_pairwise_stats(centers_i_raw)

    # Detailed stats when K=2
    if K_USERS == 2:
        step_data["user_dist_raw_U0U1"] = calculate_distance(centers_u_raw.get(0), centers_u_raw.get(1))
    if K_ITEMS == 2:
        step_data["item_dist_raw_I0I1"] = calculate_distance(centers_i_raw.get(0), centers_i_raw.get(1))

    # User-Item (raw)
    if K_USERS == 2 and K_ITEMS == 2:
        for u_idx in range(2):
            for i_idx in range(2):
                key_suffix = f"U{u_idx}_I{i_idx}"
                step_data[f"ui_raw_{key_suffix}"] = calculate_distance(
                    centers_u_raw.get(u_idx), centers_i_raw.get(i_idx)
                )
    else:
        step_data["ui_raw_mean"], step_data["ui_raw_min"], step_data["ui_raw_max"] = get_user_item_stats(centers_u_raw, centers_i_raw)

    rows.append(step_data)

# --- Save CSV (only 2 rows) ---
csv_filename = os.path.join(OUT_DIR, f"step{CLUSTER_STEP}_RAW_center_distances_first_last.csv")
if rows:
    with open(csv_filename, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)
    print(f"  - saved: {csv_filename}")

# =========================
# Visualization (RAW, only first & last)
# =========================

steps_ax = [r["step"] for r in rows]

fig, (ax_top, ax_bottom) = plt.subplots(
    nrows=2, ncols=1, figsize=(8, 8), sharex=True
)

# -------------------------
# (Top) User-User & Item-Item (RAW)
# -------------------------
ax_top.plot(
    steps_ax,
    [r["user_dist_raw_mean"] for r in rows],
    marker="o",
    label="Users (raw, mean)"
)
ax_top.plot(
    steps_ax,
    [r["item_dist_raw_mean"] for r in rows],
    marker="s",
    label="Items (raw, mean)"
)
ax_top.set_ylabel("Center distance")
ax_top.set_title("User/User & Item/Item Center Distance (RAW) - First vs Last")
ax_top.legend()

# -------------------------
# (Bottom) User-Item (RAW)
# -------------------------
if K_USERS == 2 and K_ITEMS == 2:
    keys = ["U0_I0", "U0_I1", "U1_I0", "U1_I1"]
    for key_suffix in keys:
        ax_bottom.plot(
            steps_ax,
            [r.get(f"ui_raw_{key_suffix}", np.nan) for r in rows],
            marker="o",
            label=key_suffix.replace("_", "‚Äì")
        )
    ax_bottom.legend(ncol=2)
    title_suffix = "(Detailed K=2)"
else:
    for stat, mk in [("mean", "o"), ("min", "s"), ("max", "^")]:
        ax_bottom.plot(
            steps_ax,
            [r.get(f"ui_raw_{stat}", np.nan) for r in rows],
            marker=mk,
            label=stat
        )
    ax_bottom.legend()
    title_suffix = "(Summary)"

ax_bottom.set_xlabel("Step")
ax_bottom.set_ylabel("Center distance")
ax_bottom.set_title(f"User‚ÄìItem Center Distances (RAW) - First vs Last {title_suffix}")

plt.tight_layout()

out = os.path.join(
    OUT_DIR,
    f"step{CLUSTER_STEP}_RAW_center_distances_first_last.png"
)
plt.savefig(out, dpi=220)
plt.close()

print(f"  - saved: {out}")


## FilterBubble

In [None]:
csv_path = f"{BASE_DIR}/figs_tsne_cluster_progress/step4_center_distances_over_steps_RAW_ONLY.csv"

df = pd.read_csv(csv_path)

# step0, step4Í∞Ä Ï†ïÌôïÌûà ÏûàÎäîÏßÄ ÌôïÏù∏ (ÏóÜÏúºÎ©¥ Í∞ÄÏû• ÏûëÏùÄ/ÌÅ∞ step ÏÇ¨Ïö©)
step_min = int(df["step"].min())
step_max = int(df["step"].max())

# ÎÑ§Í∞Ä ÏõêÌïòÎäî Í≤å step0/step4 Í≥†Ï†ïÏù¥Î©¥ ÏïÑÎûò Îëê Ï§ÑÏùÑ Ïì∞Í≥†,
# CSVÏóê ÏóÜÏúºÎ©¥ KeyError/ÎπàÍ≤∞Í≥º ÎÇ† Ïàò ÏûàÏùå.
s0 = 0
s4 = 4
if not ((df["step"] == s0).any() and (df["step"] == s4).any()):
    print(f"[warn] step {s0}/{s4}Í∞Ä CSVÏóê ÏóÜÏñ¥ÏÑú {step_min}/{step_max}Î°ú ÎåÄÏ≤¥Ìï©ÎãàÎã§.")
    s0, s4 = step_min, step_max

row0 = df.loc[df["step"] == s0].iloc[0]
row4 = df.loc[df["step"] == s4].iloc[0]

pairs = [
    ("U0", "I0", "ui_raw_U0_I0"),
    ("U0", "I1", "ui_raw_U0_I1"),
    ("U1", "I0", "ui_raw_U1_I0"),
    ("U1", "I1", "ui_raw_U1_I1"),
]

results = []
for u, i, col in pairs:
    d0 = float(row0[col])
    d4 = float(row4[col])
    dec = d0 - d4   # +Î©¥ Í∞ÄÍπåÏõåÏßê, -Î©¥ Î©ÄÏñ¥Ïßê
    results.append({"U": u, "I": i, "d_step0": d0, "d_step4": d4, "decrease": dec})

res = pd.DataFrame(results)

# UÎ≥Ñ ÏöîÏïΩ: I0/I1 Í∞êÏÜåÎüâ ÎπÑÍµê Î∞è "I0Ï™ΩÏúºÎ°ú Îçî Í∞ÄÍπåÏõåÏ°åÎäîÏßÄ" ÌåêÎã®
summary_rows = []
for u in ["U0", "U1"]:
    dec_i0 = float(res[(res["U"] == u) & (res["I"] == "I0")]["decrease"].iloc[0])
    dec_i1 = float(res[(res["U"] == u) & (res["I"] == "I1")]["decrease"].iloc[0])
    # Í∞êÏÜåÎüâ Ï∞®Ïù¥(ÏñëÏàòÎ©¥ I0Ï™ΩÏúºÎ°ú Îçî Í∞ÄÍπåÏõåÏßê)
    diff = dec_i0 - dec_i1
    summary_rows.append({
        "U": u,
        "step_start": s0,
        "step_end": s4,
        "decrease_to_I0": dec_i0,
        "decrease_to_I1": dec_i1,
        "diff(I0_minus_I1)": diff,
        "moved_more_toward": "I0" if diff > 0 else ("I1" if diff < 0 else "Same")
    })

summary = pd.DataFrame(summary_rows)

print("\n[raw distances + decreases]")
print(res.to_string(index=False))

print("\n[summary: compare decreases]")
print(summary.to_string(index=False))

# ---------------------------
# Visualization
# - U0, U1 Í∞ÅÍ∞ÅÏóêÏÑú (I0 Í∞êÏÜåÎüâ, I1 Í∞êÏÜåÎüâ) ÎßâÎåÄ ÎπÑÍµê
# ---------------------------
fig, ax = plt.subplots(figsize=(8, 4))

x = np.arange(2)  # U0, U1
width = 0.35

dec_i0s = summary["decrease_to_I0"].values
dec_i1s = summary["decrease_to_I1"].values

ax.bar(x - width/2, dec_i0s, width, label="Decrease to I0 (d_step0 - d_step4)")
ax.bar(x + width/2, dec_i1s, width, label="Decrease to I1 (d_step0 - d_step4)")

ax.set_xticks(x)
ax.set_xticklabels(summary["U"].values)
ax.set_ylabel("Distance decrease (+ means closer)")
ax.set_title(f"U‚ÄìI distance decrease from step{s0} to step{s4}")
ax.axhline(0, linestyle="--", linewidth=1, alpha=0.6)
ax.legend()

plt.tight_layout()

out_png = os.path.join(os.path.dirname(csv_path), f"p{s0}_to_p{s4}_U_to_I_distance_decrease.png")
plt.savefig(out_png, dpi=220)
plt.show()

print("\nSaved plot ->", out_png)
