In [None]:
import json
import os
import pandas as pd
import numpy as np
import random

In [None]:
class Raw_Data_Processor:
    def __init__(self, raw_meta_data_path, raw_review_data_path, raw_interaction_data_path):
        self.raw_meta_data_path = raw_meta_data_path  # jsonl file
        self.raw_review_data_path = raw_review_data_path  # jsonl file
        self.raw_interaction_data_path = raw_interaction_data_path  # csv file

    def generate_mapping(self, save_dir):
        """
        Generate mapping about asin, user_id, parent_asin, etc.
        :param save_dir: str, the directory to save the mapping files
        """
        data = pd.read_csv(self.raw_interaction_data_path)
        user_id = data['user_id'].unique()
        parent_asin = data['parent_asin'].unique()
        user_id_2_index = {user_id[i]: i for i in range(len(user_id))}
        parent_asin_2_index = {parent_asin[i]: i for i in range(len(parent_asin))}
        index_2_user_id = {i: user_id[i] for i in range(len(user_id))}
        index_2_parent_asin = {i: parent_asin[i] for i in range(len(parent_asin))}
        with open(os.path.join(save_dir, 'user_id_2_index.json'), 'w') as f:
            json.dump(user_id_2_index, f)
        with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
            json.dump(parent_asin_2_index, f)
        with open(os.path.join(save_dir, 'index_2_user_id.json'), 'w') as f:
            json.dump(index_2_user_id, f)
        with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
            json.dump(index_2_parent_asin, f)

    def print_jsonl_data_keys(self, path):
        """
        Print the keys of jsonl data
        :param path: str, the path of the jsonl file
        """
        with open(path, 'r') as f:
            for line in f:
                data = json.loads(line)
                print(list(data.keys()))
                break

    def transform_meta_jsonl_to_parquet(self, mapping_dir, save_dir):
        parent_asin_2_index = json.load(open(os.path.join(mapping_dir, 'parent_asin_2_index.json'), 'r'))

        data = []
        parent_asin = parent_asin_2_index.keys()
        # keep the columns we need
        # all: ['main_category', 'title', 'average_rating', 'rating_number', 'features', 'description', 'price', 'images', 'videos', 'store', 'categories', 'details', 'parent_asin', 'bought_together']
        # needs: ['main_category', 'title', 'average_rating', 'rating_number', 'price', 'images', 'store', 'categories', 'details', 'parent_asin']
        with open(self.raw_meta_data_path, 'r') as f:
            for line in f:
                if json.loads(line)['parent_asin'] in parent_asin:
                    data.append(json.loads(line))
        data = pd.DataFrame(data)
        data = data[['main_category', 'title', 'average_rating', 'rating_number', 'price', 'images', 'store', 'categories', 'details', 'parent_asin']]
        # keep the data whose parent_asin is in the interaction data
        data = data[data['parent_asin'].isin(parent_asin)]
        data['parent_asin'] = data['parent_asin'].apply(lambda x: parent_asin_2_index[x])
        data.to_parquet(os.path.join(save_dir, 'meta_data.parquet'))

    def transform_review_jsonl_to_parquet(self, mapping_dir, save_dir):
        user_id_2_index = json.load(open(os.path.join(mapping_dir, 'user_id_2_index.json'), 'r'))
        parent_asin_2_index = json.load(open(os.path.join(mapping_dir, 'parent_asin_2_index.json'), 'r'))
        data = []
        user_id = set(user_id_2_index.keys())
        parent_asin = set(parent_asin_2_index.keys())
        # keep the columns we need
        # all: ['rating', 'title', 'text', 'images', 'asin', 'parent_asin', 'user_id', 'timestamp', 'helpful_vote', 'verified_purchase']
        # needs: ['parent_asin', 'user_id', 'text', 'timestamp', 'rating']
        with open(self.raw_review_data_path, 'r') as f:
            for line in f:
                if json.loads(line)['user_id'] in user_id and json.loads(line)['parent_asin'] in parent_asin:
                    data.append(json.loads(line))
        data = pd.DataFrame(data)
        data = data[['parent_asin', 'user_id', 'text', 'timestamp', 'rating']]
        data['parent_asin'] = data['parent_asin'].apply(lambda x: parent_asin_2_index[x])
        data['user_id'] = data['user_id'].apply(lambda x: user_id_2_index[x])
        data.to_parquet(os.path.join(save_dir, 'review_data.parquet'))

    def generate_test_data(self, save_dir):
        user_id_2_index = json.load(open(os.path.join(save_dir, 'user_id_2_index.json'), 'r'))
        parent_asin_2_index = json.load(open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'r'))
        data = pd.read_csv(self.raw_interaction_data_path)
        data['user_id'] = data['user_id'].apply(lambda x: user_id_2_index[x])
        data['parent_asin'] = data['parent_asin'].apply(lambda x: parent_asin_2_index[x])
        data['history'] = data['history'].apply(lambda x: [parent_asin_2_index[i] for i in x.split()])
        data.to_parquet(os.path.join(save_dir, 'test_data.parquet'))


raw_data_processor = Raw_Data_Processor(
    raw_meta_data_path='raw_data/meta_Baby_Products.jsonl',
    raw_review_data_path='raw_data/Baby_Products.jsonl',
    raw_interaction_data_path='raw_data/Baby_Products.csv'
)

In [None]:
class Raw_Data_Processor_to_100user:
    def __init__(self, raw_meta_data_path, raw_review_data_path, test_interaction_data_path):
        self.raw_meta_data_path = raw_meta_data_path  # jsonl file
        self.raw_review_data_path = raw_review_data_path  # jsonl file
        self.test_interaction_data_path = test_interaction_data_path  # csv file

    def generate_mapping(self, save_dir):
        """
        Generate mapping about asin, user_id, parent_asin, etc.
        Here we only keep the data of random 100 users
        :param save_dir: str, the directory to save the mapping files
        """
        data = pd.read_csv(self.test_interaction_data_path)
        user_id = data['user_id'].unique()
        final_user_id = random.sample(list(user_id), 100)
        final_parent_asin = []
        for i in range(100):
            final_parent_asin += [data[data['user_id'] == final_user_id[i]]['parent_asin'].values[0]] + str(data[data['user_id'] == final_user_id[i]]['history'].values[0]).split()
        final_parent_asin = list(set(final_parent_asin))
        final_user_id_2_index = {final_user_id[i]: i for i in range(len(final_user_id))}
        final_parent_asin_2_index = {final_parent_asin[i]: i for i in range(len(final_parent_asin))}
        index_2_final_user_id = {i: final_user_id[i] for i in range(len(final_user_id))}
        index_2_final_parent_asin = {i: final_parent_asin[i] for i in range(len(final_parent_asin))}
        with open(os.path.join(save_dir, 'user_id_2_index.json'), 'w') as f:
            json.dump(final_user_id_2_index, f)
        with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
            json.dump(final_parent_asin_2_index, f)
        with open(os.path.join(save_dir, 'index_2_user_id.json'), 'w') as f:
            json.dump(index_2_final_user_id, f)
        with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
            json.dump(index_2_final_parent_asin, f)

    def generate_test_mapping(self, save_dir, num, train_user):
        # generate the mapping of the test data, the test data is the data of the users who are not in the train data
        final_user_id_2_index = train_user
        data = pd.read_csv(self.test_interaction_data_path)
        user_id = data['user_id'].unique()
        final_user_id = [i for i in user_id if i not in final_user_id_2_index.keys()]
        final_parent_asin = []
        final_user_id = random.sample(final_user_id, num)
        print('取出的测试用户数：', len(final_user_id))
        for i in range(num):
            final_parent_asin += [data[data['user_id'] == final_user_id[i]]['parent_asin'].values[0]] + str(data[data['user_id'] == final_user_id[i]]['history'].values[0]).split()
        final_parent_asin = list(set(final_parent_asin))
        final_user_id_2_index = {final_user_id[i]: i for i in range(len(final_user_id))}
        final_parent_asin_2_index = {final_parent_asin[i]: i for i in range(len(final_parent_asin))}
        index_2_final_user_id = {i: final_user_id[i] for i in range(len(final_user_id))}
        index_2_final_parent_asin = {i: final_parent_asin[i] for i in range(len(final_parent_asin))}
        print('物品数：', len(final_parent_asin))
        with open(os.path.join(save_dir, 'user_id_2_index.json'), 'w') as f:
            json.dump(final_user_id_2_index, f)
        with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
            json.dump(final_parent_asin_2_index, f)
        with open(os.path.join(save_dir, 'index_2_user_id.json'), 'w') as f:
            json.dump(index_2_final_user_id, f)
        with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
            json.dump(index_2_final_parent_asin, f)

    def generate_dense_test_mapping(self, save_dir, num_users=500, min_interactions=5, train_user=None):
        """
        筛选出具有密集交互模式的测试用户集合

        Args:
            save_dir: 保存映射文件的目录
            num_users: 需要筛选的用户数量
            min_interactions: 每个用户至少要有的交互次数
            train_user: 训练集中的用户ID映射，这些用户将被排除

        Returns:
            筛选出的用户ID列表
        """
        print(f"开始筛选密集交互的{num_users}个测试用户...")

        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)

        # 读取数据
        data = pd.read_csv(self.test_interaction_data_path)

        # 排除训练集中的用户
        if train_user is not None:
            data = data[~data['user_id'].isin(train_user.keys())]
            print(f"排除训练集用户后剩余{len(data['user_id'].unique())}个用户")

        # 1. 统计每个用户的总交互次数（当前交互 + 历史交互）
        user_interaction_count = {}
        for _, row in data.iterrows():
            user_id = row['user_id']
            history = str(row['history']).split()  # 确保历史是字符串并分割

            # 计算总交互次数：历史交互数量 + 当前交互(1)
            total_interactions = len(history) + 1

            if user_id in user_interaction_count:
                user_interaction_count[user_id] = max(user_interaction_count[user_id], total_interactions)
            else:
                user_interaction_count[user_id] = total_interactions

        # 2. 筛选出交互次数达到阈值的用户
        active_users = [user for user, count in user_interaction_count.items() if count >= min_interactions]
        print(f"交互次数>={min_interactions}的用户有{len(active_users)}个")

        if len(active_users) < num_users:
            print(f"警告: 符合条件的用户不足{num_users}个，将使用所有{len(active_users)}个符合条件的用户")
            num_users = len(active_users)

        # 3. 构建物品流行度映射，便于后续分析
        all_items = []
        for _, row in data.iterrows():
            if row['user_id'] in active_users:
                all_items.append(row['parent_asin'])  # 添加当前交互物品
                history = str(row['history']).split()
                all_items.extend(history)  # 添加历史交互物品

        item_popularity = {}
        for item in all_items:
            item_popularity[item] = item_popularity.get(item, 0) + 1

        # 4. 为每个用户计算交互的物品集合
        user_item_sets = {}
        for _, row in data.iterrows():
            user_id = row['user_id']
            if user_id in active_users:
                if user_id not in user_item_sets:
                    user_item_sets[user_id] = set()

                user_item_sets[user_id].add(row['parent_asin'])  # 添加当前交互

                # 添加历史交互
                history = str(row['history']).split()
                user_item_sets[user_id].update(history)

        # 5. 计算用户间的物品重叠度
        user_overlap_scores = {}
        for user in active_users:
            overlap_score = 0
            user_items = user_item_sets.get(user, set())

            for other_user in active_users:
                if other_user != user:
                    other_items = user_item_sets.get(other_user, set())
                    # 计算Jaccard相似度
                    intersection = len(user_items.intersection(other_items))
                    union = len(user_items.union(other_items))
                    if union > 0:
                        overlap_score += intersection / union

            # 平均重叠度
            user_overlap_scores[user] = overlap_score / (len(active_users) - 1) if len(active_users) > 1 else 0

        # 6. 根据重叠度排序用户
        sorted_users = sorted(user_overlap_scores.items(), key=lambda x: x[1], reverse=True)

        # 7. 使用改进的贪心算法选择用户
        selected_users = []
        all_selected_items = set()
        candidate_items_count = {}  # 统计每个物品被多少候选用户交互

        # 首先统计每个物品被多少活跃用户交互
        for user in active_users:
            for item in user_item_sets.get(user, set()):
                candidate_items_count[item] = candidate_items_count.get(item, 0) + 1

        # 选择一个种子用户，优先选择与高重叠度的用户
        seed_user = sorted_users[0][0]
        selected_users.append(seed_user)
        all_selected_items.update(user_item_sets[seed_user])

        # 定义一个函数计算用户的贡献度
        def calculate_contribution(user, selected_items, user_items):
            # 计算两个重要因素：
            # 1. 与已选物品的重叠度
            overlap_count = len(user_items.intersection(selected_items))
            # 2. 物品的流行度加权分数
            popularity_score = sum(item_popularity.get(item, 0) for item in user_items)
            # 3. 交互的物品数量
            item_count = len(user_items)

            # 结合这些因素计算贡献度
            if len(selected_items) == 0:
                overlap_ratio = 0
            else:
                overlap_ratio = overlap_count / len(selected_items)

            # 调整权重以控制密集度
            overlap_weight = 0.7  # 重叠度权重
            popularity_weight = 0.2  # 流行度权重
            count_weight = 0.1  # 物品数量权重

            return (overlap_ratio * overlap_weight +
                    (popularity_score / (item_count or 1)) * popularity_weight +
                    item_count * count_weight)

        # 贪心选择剩余用户
        remaining_users = [user for user, _ in sorted_users if user != seed_user]

        while len(selected_users) < num_users and remaining_users:
            max_contribution = -1
            best_user = None

            for user in remaining_users:
                user_items = user_item_sets.get(user, set())
                contribution = calculate_contribution(user, all_selected_items, user_items)

                if contribution > max_contribution:
                    max_contribution = contribution
                    best_user = user

            if best_user:
                selected_users.append(best_user)
                all_selected_items.update(user_item_sets[best_user])
                remaining_users.remove(best_user)
            else:
                break

        print(f"成功选择{len(selected_users)}个用户")

        # 8. 获取这些用户交互的所有物品
        final_parent_asin = set()
        user_item_count = {}  # 记录每个用户交互的物品数量

        for user in selected_users:
            user_items = set()
            user_rows = data[data['user_id'] == user]

            for _, row in user_rows.iterrows():
                # 添加当前交互物品
                current_item = row['parent_asin']
                final_parent_asin.add(current_item)
                user_items.add(current_item)

                # 添加历史交互物品
                history = str(row['history']).split()
                final_parent_asin.update(history)
                user_items.update(history)

            user_item_count[user] = len(user_items)

        final_parent_asin = list(final_parent_asin)

        # 9. 创建并保存映射
        final_user_id_2_index = {selected_users[i]: i for i in range(len(selected_users))}
        final_parent_asin_2_index = {final_parent_asin[i]: i for i in range(len(final_parent_asin))}
        index_2_final_user_id = {i: selected_users[i] for i in range(len(selected_users))}
        index_2_final_parent_asin = {i: final_parent_asin[i] for i in range(len(final_parent_asin))}

        # 10. 计算统计信息
        total_interactions = sum(user_item_count.values())
        distinct_interactions = len(final_parent_asin)
        density = total_interactions / (len(selected_users) * distinct_interactions) if distinct_interactions > 0 else 0

        print("\n=== 数据集统计 ===")
        print(f"用户数: {len(selected_users)}")
        print(f"物品数: {distinct_interactions}")
        print(f"总交互数: {total_interactions}")
        print(f"交互密度: {density:.4f}")
        print(f"平均每用户交互物品数: {total_interactions/len(selected_users):.2f}")
        print(f"物品被交互的平均次数: {total_interactions/distinct_interactions:.2f}")

        # 绘制交互分布直方图
        try:
            import matplotlib.pyplot as plt
            plt.figure(figsize=(10, 6))
            plt.hist(list(user_item_count.values()), bins=20)
            plt.title('用户交互数分布')
            plt.xlabel('交互数')
            plt.ylabel('用户数')
            plt.savefig(os.path.join(save_dir, 'user_interaction_distribution.png'))
            plt.close()
        except:
            print("无法生成分布图")

        # 11. 保存映射文件
        with open(os.path.join(save_dir, 'user_id_2_index.json'), 'w') as f:
            json.dump(final_user_id_2_index, f)
        with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
            json.dump(final_parent_asin_2_index, f)
        with open(os.path.join(save_dir, 'index_2_user_id.json'), 'w') as f:
            json.dump(index_2_final_user_id, f)
        with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
            json.dump(index_2_final_parent_asin, f)

        return selected_users

    def print_jsonl_data_keys(self, path):
        """
        Print the keys of jsonl data
        :param path: str, the path of the jsonl file
        """
        with open(path, 'r') as f:
            for line in f:
                data = json.loads(line)
                print(list(data.keys()))
                break

    def transform_meta_jsonl_to_parquet(self, mapping_dir, save_dir):
        final_parent_asin_2_index = json.load(open(os.path.join(mapping_dir, 'parent_asin_2_index.json'), 'r'))

        data = []
        final_parent_asin = final_parent_asin_2_index.keys()
        # keep the columns we need
        # all: ['main_category', 'title', 'average_rating', 'rating_number', 'features', 'description', 'price', 'images', 'videos', 'store', 'categories', 'details', 'parent_asin', 'bought_together']
        # needs: ['main_category', 'title', 'average_rating', 'rating_number', 'price', 'images', 'store', 'categories', 'details', 'parent_asin']
        with open(self.raw_meta_data_path, 'r') as f:
            for line in f:
                if json.loads(line)['parent_asin'] in final_parent_asin:
                    data.append(json.loads(line))
        data = pd.DataFrame(data)
        data = data[['main_category', 'title', 'average_rating', 'rating_number', 'price', 'images', 'store', 'categories', 'details', 'parent_asin']]
        # keep the data whose parent_asin is in the interaction data
        data = data[data['parent_asin'].isin(final_parent_asin)]
        data['parent_asin'] = data['parent_asin'].apply(lambda x: final_parent_asin_2_index[x])
        data.to_parquet(os.path.join(save_dir, 'meta_data.parquet'))

    def transform_review_jsonl_to_parquet(self, mapping_dir, save_dir):
        final_user_id_2_index = json.load(open(os.path.join(mapping_dir, 'user_id_2_index.json'), 'r'))
        final_parent_asin_2_index = json.load(open(os.path.join(mapping_dir, 'parent_asin_2_index.json'), 'r'))
        data = []
        user_id = set(final_user_id_2_index.keys())
        parent_asin = set(final_parent_asin_2_index.keys())
        # keep the columns we need
        # all: ['rating', 'title', 'text', 'images', 'asin', 'parent_asin', 'user_id', 'timestamp', 'helpful_vote', 'verified_purchase']
        # needs: ['parent_asin', 'user_id', 'text', 'timestamp', 'rating']
        with open(self.raw_review_data_path, 'r') as f:
            for line in f:
                if json.loads(line)['parent_asin'] in parent_asin and json.loads(line)['user_id'] in user_id:
                    data.append(json.loads(line))
        data = pd.DataFrame(data)
        data = data[['parent_asin', 'user_id', 'text', 'timestamp', 'rating']]
        data['parent_asin'] = data['parent_asin'].apply(lambda x: final_parent_asin_2_index[x])
        data['user_id'] = data['user_id'].apply(lambda x: final_user_id_2_index[x] if x in final_user_id_2_index else -1)
        data.to_parquet(os.path.join(save_dir, 'review_data.parquet'))

    def generate_test_data(self, save_dir):
        final_user_id_2_index = json.load(open(os.path.join(save_dir, 'user_id_2_index.json'), 'r'))
        final_parent_asin_2_index = json.load(open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'r'))
        data = pd.read_csv(self.test_interaction_data_path)
        # keep the data whose user_id is in the 100 users
        data = data[data['user_id'].isin(final_user_id_2_index.keys())]
        data['user_id'] = data['user_id'].apply(lambda x: final_user_id_2_index[x])
        data['parent_asin'] = data['parent_asin'].apply(lambda x: final_parent_asin_2_index[x])
        data['history'] = data['history'].apply(lambda x: [final_parent_asin_2_index[i] for i in x.split()])
        # data.to_parquet(os.path.join(save_dir, 'test_data.parquet'))
        data.to_parquet('baby_data/test_data.parquet')

    def generate_dense_test_mapping_fast(self, save_dir, num_users=500, min_interactions=5, train_user=None, sampling_ratio=0.2, max_items=2000, min_item_users=5):
        """
        使用更高效的方法筛选出具有密集交互模式的测试用户集合，并限制物品数量

        Args:
            save_dir: 保存映射文件的目录
            num_users: 需要筛选的用户数量
            min_interactions: 每个用户至少要有的交互次数
            train_user: 训练集中的用户ID映射，这些用户将被排除
            sampling_ratio: 如果用户数量巨大，使用此比例进行抽样处理
            max_items: 最大物品数量限制
            min_item_users: 物品至少被多少用户交互才保留

        Returns:
            筛选出的用户ID列表
        """
        import numpy as np
        from collections import Counter
        from sklearn.cluster import MiniBatchKMeans
        from sklearn.feature_extraction.text import TfidfTransformer

        print(f"开始高效筛选密集交互的{num_users}个测试用户...")

        # 创建保存目录
        os.makedirs(save_dir, exist_ok=True)

        # 读取数据
        data = pd.read_csv(self.test_interaction_data_path)
        total_users = len(data['user_id'].unique())

        # 排除训练集中的用户
        if train_user is not None:
            data = data[~data['user_id'].isin(train_user.keys())]
            print(f"排除训练集用户后剩余{len(data['user_id'].unique())}/{total_users}个用户")

        # 如果用户数量太多，考虑随机抽样
        if len(data['user_id'].unique()) > 10000 and sampling_ratio < 1.0:
            sample_users = np.random.choice(
                data['user_id'].unique(),
                size=int(len(data['user_id'].unique()) * sampling_ratio),
                replace=False
            )
            data = data[data['user_id'].isin(sample_users)]
            print(f"随机抽样{sampling_ratio:.1%}的用户，现在有{len(data['user_id'].unique())}个用户")

        # 1. 创建用户-物品交互字典和物品流行度计数器
        user_items = {}
        item_users = Counter()  # 记录每个物品被多少用户交互
        item_count = Counter()  # 记录每个物品的总交互次数

        # 一次性处理所有行，提取交互信息
        print("处理用户交互数据...")
        for _, row in data.iterrows():
            user_id = row['user_id']
            if user_id not in user_items:
                user_items[user_id] = set()

            # 添加当前交互物品
            current_item = row['parent_asin']
            user_items[user_id].add(current_item)
            item_count[current_item] += 1

            # 添加历史交互物品
            history = str(row['history']).split()
            for item in history:
                user_items[user_id].add(item)
                item_count[item] += 1

        # 更新物品被多少用户交互的计数
        for user, items in user_items.items():
            for item in items:
                item_users[item] += 1

        # 2. 筛选出交互次数达到阈值的用户
        active_users = [user for user, items in user_items.items() if len(items) >= min_interactions]
        print(f"交互次数>={min_interactions}的用户有{len(active_users)}个")

        if len(active_users) < num_users:
            print(f"警告: 符合条件的用户不足{num_users}个，将使用所有{len(active_users)}个符合条件的用户")
            num_users = len(active_users)

        # 3. 筛选物品 - 保留被足够多用户交互且流行度高的物品
        print(f"原始物品总数: {len(item_users)}")

        # 先筛选掉被交互用户数少于阈值的物品
        popular_items = {item: count for item, count in item_users.items() if count >= min_item_users}
        print(f"被至少{min_item_users}个用户交互的物品数: {len(popular_items)}")

        # 如果物品仍然过多，按流行度选择top物品
        if len(popular_items) > max_items:
            # 按交互用户数排序
            sorted_items = sorted(popular_items.items(), key=lambda x: (x[1], item_count[x[0]]), reverse=True)
            popular_items = {item: count for item, count in sorted_items[:max_items]}
            print(f"限制为流行度最高的{max_items}个物品")

        # 4. 更新用户的物品集合，只保留筛选后的物品
        filtered_user_items = {}
        for user, items in user_items.items():
            filtered_items = {item for item in items if item in popular_items}
            if len(filtered_items) >= min_interactions:  # 确保用户仍有足够的交互
                filtered_user_items[user] = filtered_items

        # 更新活跃用户列表
        active_users = list(filtered_user_items.keys())
        print(f"筛选物品后，仍有{len(active_users)}个用户有足够交互")

        if len(active_users) < num_users:
            print(f"警告: 筛选物品后，符合条件的用户不足{num_users}个")
            num_users = len(active_users)

        # 5. 映射物品ID到数字索引
        item_id_map = {}
        next_item_id = 0
        for items in filtered_user_items.values():
            for item_id in items:
                if item_id not in item_id_map:
                    item_id_map[item_id] = next_item_id
                    next_item_id += 1

        # 6. 构建用户-物品矩阵 (稀疏表示)
        print("构建用户-物品矩阵...")
        user_to_idx = {user: i for i, user in enumerate(active_users)}

        # 创建稀疏矩阵的行、列和值列表
        rows, cols, data_values = [], [], []

        for user, items in filtered_user_items.items():
            if user in user_to_idx:  # 只考虑活跃用户
                u_idx = user_to_idx[user]
                for item in items:
                    i_idx = item_id_map[item]
                    rows.append(u_idx)
                    cols.append(i_idx)
                    data_values.append(1)  # 二元交互

        # 创建稀疏矩阵
        from scipy.sparse import csr_matrix
        user_item_matrix = csr_matrix((data_values, (rows, cols)),
                                      shape=(len(active_users), len(item_id_map)))

        # 7. 使用TF-IDF变换矩阵，突出重要物品
        print("应用TF-IDF变换...")
        transformer = TfidfTransformer()
        user_item_tfidf = transformer.fit_transform(user_item_matrix)

        # 8. 使用MiniBatchKMeans进行快速聚类
        print("聚类用户...")
        # 尝试将用户聚为多个类别，选择一些最大的簇
        n_clusters = min(int(num_users / 5), len(active_users) // 10)
        n_clusters = max(n_clusters, 5)  # 至少5个簇

        kmeans = MiniBatchKMeans(n_clusters=n_clusters, batch_size=1000, random_state=42)
        user_clusters = kmeans.fit_predict(user_item_tfidf)

        # 9. 计算每个簇的大小
        cluster_sizes = Counter(user_clusters)

        # 10. 从每个簇中选择具有典型交互模式的用户
        print("从聚类中选择用户...")
        selected_users = []

        # 按簇大小降序排列
        sorted_clusters = sorted(cluster_sizes.items(), key=lambda x: x[1], reverse=True)

        # 从每个簇中选择用户数量与簇大小成比例
        remaining = num_users
        for cluster_id, cluster_size in sorted_clusters:
            # 确定从该簇中选择的用户数
            to_select = max(1, int(remaining * (cluster_size / sum([s for _, s in sorted_clusters]))))
            to_select = min(to_select, cluster_size, remaining)

            # 获取该簇的所有用户
            cluster_users = [active_users[i] for i, c in enumerate(user_clusters) if c == cluster_id]

            # 选择该簇中交互物品数量最多的用户
            users_with_counts = [(user, len(filtered_user_items[user])) for user in cluster_users]
            users_with_counts.sort(key=lambda x: x[1], reverse=True)

            # 选择前to_select个用户
            selected_from_cluster = [user for user, _ in users_with_counts[:to_select]]
            selected_users.extend(selected_from_cluster)

            remaining -= len(selected_from_cluster)
            if remaining <= 0:
                break

        # 如果还需要更多用户，从剩余的活跃用户中选择
        if len(selected_users) < num_users:
            remaining_users = [u for u in active_users if u not in selected_users]
            # 按交互物品数量排序
            remaining_with_counts = [(user, len(filtered_user_items[user])) for user in remaining_users]
            remaining_with_counts.sort(key=lambda x: x[1], reverse=True)

            additional = min(num_users - len(selected_users), len(remaining_users))
            selected_users.extend([user for user, _ in remaining_with_counts[:additional]])

        print(f"成功选择{len(selected_users)}个用户")

        # 11. 获取这些用户交互的所有物品
        final_parent_asin = set()
        user_item_count = {}  # 记录每个用户交互的物品数量

        for user in selected_users:
            user_items_set = filtered_user_items[user]
            final_parent_asin.update(user_items_set)
            user_item_count[user] = len(user_items_set)

        final_parent_asin = list(final_parent_asin)

        # 12. 创建并保存映射
        final_user_id_2_index = {selected_users[i]: i for i in range(len(selected_users))}
        final_parent_asin_2_index = {final_parent_asin[i]: i for i in range(len(final_parent_asin))}
        index_2_final_user_id = {i: selected_users[i] for i in range(len(selected_users))}
        index_2_final_parent_asin = {i: final_parent_asin[i] for i in range(len(final_parent_asin))}

        # 13. 计算统计信息
        total_interactions = sum(user_item_count.values())
        distinct_interactions = len(final_parent_asin)
        density = total_interactions / (len(selected_users) * distinct_interactions) if distinct_interactions > 0 else 0

        print("\n=== 数据集统计 ===")
        print(f"用户数: {len(selected_users)}")
        print(f"物品数: {distinct_interactions}")
        print(f"总交互数: {total_interactions}")
        print(f"交互密度: {density:.4f}")
        print(f"平均每用户交互物品数: {total_interactions/len(selected_users):.2f}")
        print(f"物品被交互的平均次数: {total_interactions/distinct_interactions:.2f}")

        # 13. 保存映射文件
        with open(os.path.join(save_dir, 'user_id_2_index.json'), 'w') as f:
            json.dump(final_user_id_2_index, f)
        with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
            json.dump(final_parent_asin_2_index, f)
        with open(os.path.join(save_dir, 'index_2_user_id.json'), 'w') as f:
            json.dump(index_2_final_user_id, f)
        with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
            json.dump(index_2_final_parent_asin, f)

        return selected_users

    def generate_test_data_new(self, save_dir):
        """
        从交互数据中提取测试集，确保所有映射中的用户都被包含，
        如果用户当前交互的物品(parent_asin)不在映射中，则将其添加到映射中
        
        Args:
            save_dir: 保存目录，应包含user_id_2_index.json和parent_asin_2_index.json文件
        """
        # 加载映射文件
        final_user_id_2_index = json.load(open(os.path.join(save_dir, 'user_id_2_index.json'), 'r'))
        final_parent_asin_2_index = json.load(open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'r'))
        index_2_final_parent_asin = json.load(open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'r'))
        
        # 输出一共多少个用户和物品
        print(f"原始映射中的用户数量: {len(set(final_user_id_2_index.keys()))}")
        print(f"原始映射中的物品数量: {len(set(final_parent_asin_2_index.keys()))}")
        
        # 创建物品ID集合，用于快速检查
        valid_items = set(final_parent_asin_2_index.keys())
        
        # 读取交互数据
        data = pd.read_csv(self.test_interaction_data_path)
        
        # 只保留映射中的用户
        data = data[data['user_id'].isin(final_user_id_2_index.keys())]
        print(f"映射中的用户交互数据: {len(data)} 条记录")
        
        # 检查是否有用户的当前交互物品不在映射中
        missing_items = set()
        for _, row in data.iterrows():
            if row['parent_asin'] not in valid_items:
                missing_items.add(row['parent_asin'])
        
        # 将缺失的物品添加到映射中
        if missing_items:
            print(f"发现 {len(missing_items)} 个用户当前交互物品不在映射中，添加这些物品")
            
            # 获取当前最大索引
            max_idx = max([int(idx) for idx in index_2_final_parent_asin.keys()])
            
            # 为缺失物品分配新索引
            new_idx = max_idx + 1
            for item in missing_items:
                final_parent_asin_2_index[item] = new_idx
                index_2_final_parent_asin[str(new_idx)] = item  # 注意这里的键需要是字符串
                new_idx += 1
            
            # 更新有效物品集合
            valid_items = set(final_parent_asin_2_index.keys())
            
            # 保存更新后的映射
            with open(os.path.join(save_dir, 'parent_asin_2_index.json'), 'w') as f:
                json.dump(final_parent_asin_2_index, f)
            with open(os.path.join(save_dir, 'index_2_parent_asin.json'), 'w') as f:
                json.dump(index_2_final_parent_asin, f)
            
            print(f"更新后的物品数量: {len(valid_items)}")
        
        # 将user_id映射到索引
        data['user_id'] = data['user_id'].apply(lambda x: final_user_id_2_index[x])
        
        # 将parent_asin映射到索引
        data['parent_asin'] = data['parent_asin'].apply(lambda x: final_parent_asin_2_index[x])
        
        # 处理历史记录，过滤掉不在映射中的物品
        def filter_history(history_str):
            history_items = str(history_str).split()
            filtered_history = [final_parent_asin_2_index[item] for item in history_items 
                            if item in valid_items]
            return filtered_history
        
        # 应用历史记录过滤
        data['history'] = data['history'].apply(filter_history)
        
        # 检查处理后的数据
        print(f"处理后的测试集大小: {len(data)} 条记录")
        print(f"包含的用户数量: {data['user_id'].nunique()}")
        print(f"包含的物品数量: {data['parent_asin'].nunique()}")
        
        # 计算历史记录中的物品数量
        history_items = set()
        for hist in data['history']:
            history_items.update(hist)
        print(f"历史记录中的物品索引数量: {len(history_items)}")
        
        # 保存处理后的数据
        output_path = os.path.join(save_dir, 'test_data.parquet')
        data.to_parquet(output_path)
        print(f"测试数据已保存到: {output_path}")
        
        # return data

# raw_data_processor_to_100user = Raw_Data_Processor_to_100user(
#     raw_meta_data_path='raw_data/meta_Baby_Products.jsonl',
#     raw_review_data_path='raw_data/Baby_Products.jsonl',
#     test_interaction_data_path='raw_data/Baby_Products.test.csv'
# )
# raw_data_processor_to_100user = Raw_Data_Processor_to_100user(
#     raw_meta_data_path='raw_data/meta_Video_Games.jsonl',
#     raw_review_data_path='raw_data/Video_Games.jsonl',
#     test_interaction_data_path='raw_data/Video_Games.test.csv'
# )
raw_data_processor_to_100user = Raw_Data_Processor_to_100user(
    raw_meta_data_path='raw_data/meta_CDs_and_Vinyl.jsonl',
    raw_review_data_path='raw_data/CDs_and_Vinyl.jsonl',
    test_interaction_data_path='raw_data/CDs_and_Vinyl.test.csv'
)

In [None]:
raw_data_processor_to_100user.generate_mapping('cd_data/train')
raw_data_processor_to_100user.generate_test_mapping('cd_data/test', 500, json.load(open('cd_data/train/user_id_2_index.json', 'r')))

In [None]:
raw_data_processor_to_100user.print_jsonl_data_keys('raw_data/meta_Video_Games.jsonl')

In [None]:
raw_data_processor_to_100user.transform_meta_jsonl_to_parquet('cd_data/test', 'cd_data/test')
raw_data_processor_to_100user.transform_review_jsonl_to_parquet('cd_data/test', 'cd_data/test')
raw_data_processor_to_100user.generate_test_data('cd_data/test')