In [None]:
import pandas as pd
import os
import glob
import math

def generate_stratified_dataset(
    csv_path=r'sampledata_2.csv', 
    groundtruth_folder=r'groundtruth',
    output_file='dataset_selection_result.xlsx',
    target_total=3
):
    # --- 1. Read and preprocess sampledata.csv ---
    print("Reading the original data table...")
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"Error: File not found {csv_path}")
        return

    # Assume that file_name in CSV is {name}.html, we need to extract {name}
    # Using os.path.splitext can safely handle filenames containing '.' (if there are other dots besides the extension)
    df['clean_name'] = df['file_name'].apply(lambda x: os.path.splitext(x)[0])
    
    # Check for duplicate clean_name to prevent matching confusion
    if df['clean_name'].duplicated().any():
        print("Warning: There are duplicate base filenames in sampledata.csv, which may affect matching accuracy.")

    # --- 2. Read and preprocess the groundtruth folder ---
    print("Scanning the Groundtruth folder...")
    if not os.path.exists(groundtruth_folder):
        print(f"Error: Folder not found {groundtruth_folder}")
        return

    # Get all .zip files in the folder
    existing_files = glob.glob(os.path.join(groundtruth_folder, '*.zip'))
    # Extract base filenames {name}, be careful with path separators
    existing_basenames = [os.path.splitext(os.path.basename(f))[0] for f in existing_files]
    
    print(f"There are {len(existing_basenames)} files in Groundtruth.")

    # --- 3. Mark existing data ---
    # Mark in the dataframe whether the row data already exists in groundtruth
    df['in_groundtruth'] = df['clean_name'].isin(existing_basenames)
    
    # Check if there are groundtruth files not found in CSV (to prevent filename mismatch issues)
    matched_count = df['in_groundtruth'].sum()
    if matched_count < len(existing_basenames):
        missing = set(existing_basenames) - set(df[df['in_groundtruth']]['clean_name'])
        print(f"Warning: {len(existing_basenames) - matched_count} files in Groundtruth were not found in CSV.")
        print(f"Unmatched examples: {list(missing)[:5]}")

    # --- 4. Calculate distribution and target quotas ---
    # Count the topic distribution ratio of the 3500 data points
    total_count = len(df)
    topic_dist = df['topic'].value_counts(normalize=True) # Get proportions
    
    # Initialize statistics results list
    stats_list = []
    files_to_add_indices = []

    print("Calculating quotas for each Topic and filling data...")
    
    # Iterate through each topic (total 24)
    for topic, ratio in topic_dist.items():
        # 1. Calculate how many should theoretically be in the 400 data points for this topic (round to nearest)
        target_count = int(round(target_total * ratio))
        if target_count == 0: target_count = 1 # Ensure at least 1 per category to avoid loss of small categories
        
        # 2. Get all data rows for this topic
        topic_rows = df[df['topic'] == topic]
        
        # 3. Count how many in this topic are already in groundtruth
        current_existing = topic_rows[topic_rows['in_groundtruth'] == True]
        current_count = len(current_existing)
        
        # 4. Calculate the gap
        needed = target_count - current_count
        
        added_count = 0
        
        if needed > 0:
            # Need to add data
            # Randomly sample from data in this topic that are not in groundtruth
            candidates = topic_rows[topic_rows['in_groundtruth'] == False]
            
            if len(candidates) >= needed:
                # Enough candidates, random sample (set random_state for reproducibility)
                sampled = candidates.sample(n=needed, random_state=42)
                files_to_add_indices.extend(sampled.index.tolist())
                added_count = needed
            else:
                # Not enough candidates (shouldn't happen theoretically unless 3500 data itself is insufficient), select all
                files_to_add_indices.extend(candidates.index.tolist())
                added_count = len(candidates)
                print(f"Note: Insufficient data for Topic '{topic}', unable to fully meet target quota.")
        
        # Record statistics
        stats_list.append({
            'Topic': topic,
            'Original_Ratio': f"{ratio:.2%}",
            'Target_Count_Total': target_count,
            'Existing_In_Groundtruth': current_count,
            'To_Add': added_count,
            'Final_Total': current_count + added_count,
            'Status': 'Over Budget' if needed < 0 else 'Filled'
        })

    # --- 5. Generate result DataFrame ---
    
    # Sheet 1: Distribution of existing 100+ data points
    df_existing = df[df['in_groundtruth'] == True][['file_name', 'topic', 'category', 'clean_name']]
    
    # Sheet 2: List of file_names to add
    df_to_add = df.loc[files_to_add_indices][['file_name', 'topic', 'category']]
    
    # Sheet 3: Overall distribution statistics table
    df_stats = pd.DataFrame(stats_list)
    # Adjust column order for easy viewing
    df_stats = df_stats[['Topic', 'Original_Ratio', 'Target_Count_Total', 'Existing_In_Groundtruth', 'To_Add', 'Final_Total', 'Status']]

    # Can also generate Sheet 4: Complete list of final 400 data points
    df_final_list = pd.concat([df_existing, df_to_add])

    # --- 6. Write to Excel ---
    print(f"Writing results to {output_file}...")
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        df_existing.to_excel(writer, sheet_name='Existing_Distribution', index=False)
        df_to_add.to_excel(writer, sheet_name='Files_To_Add', index=False)
        df_stats.to_excel(writer, sheet_name='Distribution_Summary', index=False)
        df_final_list.to_excel(writer, sheet_name='Final_Full_List', index=False)

    print("Task completed!")
    print(f"Total existing data: {len(df_existing)}")
    print(f"Suggested data to add: {len(df_to_add)}")
    print(f"Expected final total: {len(df_final_list)}")

# --- Execute function ---
# Ensure sampledata.csv and groundtruth folder are in the current directory
if __name__ == "__main__":
    generate_stratified_dataset()

In [None]:
import pandas as pd
import os
import zipfile
import math

def generate_stratified_dataset(
    csv_path=
    groundtruth_zip_path=
    output_file='dataset_selection_result_checked.xlsx',
    target_total=400
):
    # --- 1. 读取 CSV ---
    print("Step 1: 读取 CSV 数据...")
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"Error: 找不到文件 {csv_path}")
        return

    # 提取 clean_name
    df['clean_name'] = df['file_name'].apply(lambda x: os.path.splitext(x)[0])
    # 获取 CSV 中的所有唯一文件名集合
    csv_filenames_set = set(df['clean_name'])
    
    # --- 2. 扫描 Zip ---
    print("\nStep 2: 扫描 Groundtruth Zip 包...")
    zip_filenames_set = set()
    if os.path.exists(groundtruth_zip_path):
        try:
            with zipfile.ZipFile(groundtruth_zip_path, 'r') as z:
                for f in z.namelist():
                    if f.endswith('/') or '__MACOSX' in f: continue
                    filename = os.path.basename(f)
                    if not filename: continue
                    zip_filenames_set.add(os.path.splitext(filename)[0])
        except zipfile.BadZipFile:
            print("Error: Zip 文件损坏")
            return
    else:
        print("Warning: Zip文件不存在")
        return

    print(f"Groundtruth Zip 中包含 {len(zip_filenames_set)} 个唯一文件。")

    # --- 2.5 【诊断】检查 Zip 中有多少文件未在 CSV 中找到 ---
    # 计算差集：在 Zip 中但不在 CSV 中的文件
    unmatched_files = zip_filenames_set - csv_filenames_set
    matched_files = zip_filenames_set.intersection(csv_filenames_set)
    
    print("-" * 30)
    print(f"【诊断结果】")
    print(f"Zip 文件总数: {len(zip_filenames_set)}")
    print(f"能与 CSV 匹配的文件数: {len(matched_files)} (这些会被优先选中)")
    print(f"Zip 中有但 CSV 中没有的文件数: {len(unmatched_files)} (这些会被忽略)")
    print("-" * 30)

    # --- 3. 数据清洗 (同名不同Topic不丢弃) ---
    print("\nStep 3: 计算分布...")
    df_clean = df.drop_duplicates(subset=['clean_name', 'topic']).copy()
    
    # 标记：只有在 Zip 里 且 在 CSV 里都能找到的，才算 in_groundtruth
    df_clean['in_groundtruth'] = df_clean['clean_name'].isin(matched_files)
    
    topic_dist = df_clean['topic'].value_counts(normalize=True)
    selected_filenames_set = set()
    stats_list = []

    # --- 4. 第一轮抽样 ---
    print("\nStep 4: 分层抽样 (优先复用)...")
    for topic, ratio in topic_dist.items():
        quota = int(round(target_total * ratio))
        if quota == 0: quota = 1
        
        candidates = df_clean[df_clean['topic'] == topic]
        covered_by_existing = candidates[candidates['clean_name'].isin(selected_filenames_set)]
        count_covered = len(covered_by_existing)
        
        needed = quota - count_covered
        added = 0
        
        if needed > 0:
            pool = candidates[~candidates['clean_name'].isin(selected_filenames_set)]
            pool_priority = pool[pool['in_groundtruth'] == True]
            pool_normal = pool[pool['in_groundtruth'] == False]
            
            # 优先拿能匹配上的 Zip 文件
            take_pri = min(len(pool_priority), needed)
            if take_pri > 0:
                picked = pool_priority.sample(n=take_pri, random_state=42)
                selected_filenames_set.update(picked['clean_name'].tolist())
                needed -= take_pri
                added += take_pri
                
            if needed > 0:
                take_norm = min(len(pool_normal), needed)
                if take_norm > 0:
                    picked = pool_normal.sample(n=take_norm, random_state=42)
                    selected_filenames_set.update(picked['clean_name'].tolist())
                    needed -= take_norm
                    added += take_norm
        
        stats_list.append({
            'Topic': topic, 
            'Quota': quota, 
            'Total_Filled': count_covered + added
        })

    # --- 5. 第二轮补齐 ---
    current_count = len(selected_filenames_set)
    gap = target_total - current_count
    
    if gap > 0:
        print(f"Step 5: 补齐剩余 {gap} 个名额...")
        all_cands = df_clean.drop_duplicates(subset=['clean_name'])
        remaining = all_cands[~all_cands['clean_name'].isin(selected_filenames_set)]
        
        # 即使补齐，也只能补那些 CSV 里有的
        rem_priority = remaining[remaining['in_groundtruth'] == True]
        rem_normal = remaining[remaining['in_groundtruth'] == False]
        
        take_pri = min(len(rem_priority), gap)
        if take_pri > 0:
            picked = rem_priority.sample(n=take_pri, random_state=42)
            selected_filenames_set.update(picked['clean_name'].tolist())
            gap -= take_pri
            
        if gap > 0 and len(rem_normal) > 0:
            picked = rem_normal.sample(n=min(len(rem_normal), gap), random_state=42)
            selected_filenames_set.update(picked['clean_name'].tolist())

    # --- 6. 导出 ---
    print("\nStep 6: 导出结果...")
    df_final_rows = df_clean[df_clean['clean_name'].isin(selected_filenames_set)].copy()
    df_export = df_final_rows.drop_duplicates(subset=['clean_name'], keep='first')
    
    df_exist = df_export[df_export['in_groundtruth'] == True]
    df_new = df_export[df_export['in_groundtruth'] == False]
    
    # 创建一个 DataFrame 来展示未匹配的文件
    df_unmatched = pd.DataFrame({'Files_In_Zip_But_Not_In_CSV': list(unmatched_files)})

    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        df_exist.to_excel(writer, sheet_name='1_Existing_In_Zip', index=False)
        df_new.to_excel(writer, sheet_name='2_Files_To_Add', index=False)
        # 新增这个 Sheet 帮你查错
        df_unmatched.to_excel(writer, sheet_name='Check_Unmatched_Files', index=False)
        pd.DataFrame(stats_list).to_excel(writer, sheet_name='Stats', index=False)
        df_export.to_excel(writer, sheet_name='Final_List', index=False)

    print(f"Done! 请查看 '{output_file}' 中的 'Check_Unmatched_Files' 工作表，")
    print(f"那里列出了 {len(unmatched_files)} 个被忽略的文件名。")

if __name__ == "__main__":
    generate_stratified_dataset()