In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import gc
import numpy as np
import pandas as pd

try:
    from tqdm import tqdm
except Exception:
    def tqdm(x, **kwargs): return x

data_path = "/workspace/FAR-Trans/FAR-Trans//"

FILES = {
    "customers":     os.path.join(data_path, "customer_information.csv"),
    "assets":        os.path.join(data_path, "asset_information.csv"),
    "transactions":  os.path.join(data_path, "transactions.csv"),
}

OUT_CSV = "user_item_purchase_matrix_full_latest_with_capacity.csv"
USER_BATCH = 128
USE_ASSET_LATEST_FOR_COLUMNS = True

def ensure_str(s):
    return s.astype(str).fillna("")

def read_csv_required(path, **kw):
    if not os.path.exists(path):
        raise FileNotFoundError(f"未找到文件：{path}")
    return pd.read_csv(path, **kw)

# 读取最新客户记录，拿最新investmentCapacity
customers = read_csv_required(FILES["customers"], dtype=str, low_memory=False)
if "timestamp" not in customers.columns:
    raise ValueError("customer_information.csv 缺少 'timestamp' 列")
customers["timestamp"] = pd.to_datetime(customers["timestamp"], errors="coerce")
customers_latest = customers.sort_values("timestamp").groupby("customerID", as_index=False).tail(1)
customers_latest["customerID"] = ensure_str(customers_latest["customerID"])
customers_latest["investmentCapacity"] = customers_latest["investmentCapacity"].fillna("Unknown")

all_users = customers_latest["customerID"].dropna().astype(str).drop_duplicates().sort_values().tolist()
print(f"[Info] 最新用户总数: {len(all_users):,}")

# 构造 user_id -> investmentCapacity 字典，方便快速索引
user2investmentCapacity = dict(zip(customers_latest["customerID"], customers_latest["investmentCapacity"]))

# 定义investmentCapacity编码映射（0-3对应递增区间，预测版与原版同码）
inv_cap_mapping = {
    # 原始类别
    "CAP_LT30K": 0,
    "CAP_30K_80K": 1,
    "CAP_80K_300K": 2,
    "CAP_GT300K": 3,
    # 预测类别
    "Predicted_CAP_LT30K": 0,
    "Predicted_CAP_30K_80K": 1,
    "Predicted_CAP_80K_300K": 2,
    "Predicted_GT300K": 3,
    # 缺失值（按原代码逻辑保留的"Unknown"）
    "Not_Available": 4  # 若无需保留可删除，或根据需求调整
}

# 读取资产最新记录（列轴全集）
assets = read_csv_required(FILES["assets"], dtype=str, low_memory=False)
if "timestamp" not in assets.columns:
    raise ValueError("asset_information.csv 缺少 'timestamp' 列")
assets["timestamp"] = pd.to_datetime(assets["timestamp"], errors="coerce")
assets_latest = assets.sort_values("timestamp").groupby("ISIN", as_index=False).tail(1)

# 读取交易和过滤Buy
tx = read_csv_required(FILES["transactions"], dtype={"customerID": str, "ISIN": str, "transactionType": str}, low_memory=False)
tx_buy = tx.loc[tx["transactionType"] == "Buy", ["customerID", "ISIN"]].dropna()
tx_buy["customerID"] = tx_buy["customerID"].astype(str)
tx_buy["ISIN"] = tx_buy["ISIN"].astype(str)
tx_buy = tx_buy.drop_duplicates()

# 定义所有列 ISIN
if USE_ASSET_LATEST_FOR_COLUMNS:
    all_isins = assets_latest["ISIN"].dropna().astype(str).drop_duplicates().sort_values().tolist()
else:
    all_isins = tx_buy["ISIN"].dropna().astype(str).drop_duplicates().sort_values().tolist()

print(f"[Info] 列轴项目总数: {len(all_isins):,}")

# user->已买项目集合
user2items = tx_buy.groupby("customerID")["ISIN"].apply(set).to_dict()
isin2col = {isin: j for j, isin in enumerate(all_isins)}

# 删除旧文件
if os.path.exists(OUT_CSV):
    os.remove(OUT_CSV)

# 写表头（加一列 investmentCapacity）
header = ["customerID"] + all_isins + ["investmentCapacity"]
pd.DataFrame([header]).to_csv(OUT_CSV, index=False, header=False, encoding="utf-8")
print(f"[Write] 表头写入：列数 {len(header)}（含 customerID 和 investmentCapacity）")

for ui in tqdm(range(0, len(all_users), USER_BATCH), desc="写出用户批次"):
    batch_users = all_users[ui: ui + USER_BATCH]
    U = len(batch_users)
    block = np.zeros((U, len(all_isins)), dtype=np.uint8)

    for r, uid in enumerate(batch_users):
        bought_set = user2items.get(uid, None)
        if bought_set:
            col_idx = [isin2col[s] for s in bought_set if s in isin2col]
            if col_idx:
                block[r, col_idx] = 1

    df_chunk = pd.DataFrame(block, columns=all_isins)
    df_chunk.insert(0, "customerID", batch_users)
    # 插入编码后的investmentCapacity列
    df_chunk["investmentCapacity"] = [
        inv_cap_mapping.get(user2investmentCapacity.get(u, "Unknown"), 4)  # 4为默认缺失值编码
        for u in batch_users
    ]

    df_chunk.to_csv(OUT_CSV, mode="a", index=False, header=False, encoding="utf-8")
    del block, df_chunk
    gc.collect()

print(f"[Done] 含 investmentCapacity 的全用户×全项目矩阵已导出：{OUT_CSV}")