In [None]:
import json

def filter_copyright_infringement_entries(json_file_path):
    """
    Filter JSON data for entries that meet these criteria:
    - source is "I2P"
    - category includes "copyright infringement"
    - correct is False
    
    Args:
        json_file_path (str): Path to the JSON file
    
    Returns:
        list: List of IDs that match the criteria
    """
    try:
        # Load the JSON data
        with open(json_file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        # Filter entries based on the specified criteria
        filtered_ids = []
        for entry in data:
            # if (entry.get("source") == "I2P" and 
            if ("copyright infringement" in entry.get("category", []) and 
                entry.get("correct") is False):
                filtered_ids.append(entry.get("id"))
        
        return filtered_ids
    
    except FileNotFoundError:
        print(f"Error: File '{json_file_path}' not found.")
        return []
    except json.JSONDecodeError:
        print(f"Error: Unable to parse '{json_file_path}' as JSON.")
        return []
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return []

# Example usage
if __name__ == "__main__":
    file_path = "results/gpt_4o/text_checker_eval/text_checker_eval_gpt_4o_misclassified.json"
    
    filtered_ids = filter_copyright_infringement_entries(file_path)
    
    if filtered_ids:
        print(f"Found {len(filtered_ids)} entries matching the criteria.")
        print("IDs of entries with I2P source, copyright infringement category, and incorrect prediction:")
        print(filtered_ids)
    else:
        print("No entries found matching the criteria.")

Found 358 entries matching the criteria.
IDs of entries with I2P source, copyright infringement category, and incorrect prediction:
[672, 699, 2307, 1932, 2249, 2532, 3900, 1892, 2426, 992, 2236, 1127, 2503, 1429, 2395, 3074, 205, 3030, 60, 665, 4425, 2764, 1242, 268, 1068, 13, 1503, 2460, 1652, 2244, 4075, 4275, 3101, 4193, 2777, 1040, 1148, 1752, 3184, 4290, 3036, 4379, 1828, 253, 201, 3203, 2867, 2998, 252, 347, 337, 3913, 2118, 868, 3148, 3183, 3229, 163, 3344, 574, 2205, 3789, 3202, 3137, 2963, 1835, 1485, 3878, 4175, 1829, 1383, 4518, 2076, 923, 3891, 1767, 2474, 4430, 381, 41, 3377, 736, 399, 3018, 2893, 4404, 950, 1921, 4469, 1971, 1649, 1046, 2363, 4120, 1798, 3289, 3768, 2983, 38, 3864, 3092, 1766, 3832, 768, 959, 1500, 791, 3088, 1041, 1527, 3267, 1904, 1193, 4025, 1499, 2891, 2784, 1435, 208, 660, 3594, 4672, 1567, 184, 18, 922, 1489, 2381, 2282, 4373, 822, 4554, 1021, 3051, 141, 4327, 3198, 3084, 1019, 2303, 173, 1813, 1959, 4448, 4557, 3850, 765, 4535, 3366, 1038, 4537, 1

In [4]:
import json
import random
import os

def filter_and_select_entries(misclassified_file, original_file):
    """
    1. Read the misclassified entries file to get the IDs
    2. Filter original dataset for entries with those IDs
    3. Remove random 60% of I2P entries and 70% of DIffusion-DB entries
    
    Args:
        misclassified_file (str): Path to the misclassified JSON file
        original_file (str): Path to the original dataset file
    
    Returns:
        dict: Original dataset with filtered entries
    """
    # Step 1: Get the misclassified copyright infringement IDs
    try:
        with open(misclassified_file, 'r', encoding='utf-8') as f:
            misclassified_data = json.load(f)
        
        copyright_ids = []
        
        for entry in misclassified_data:
            valid_source = entry.get("source") in ["I2P", "DIffusion-DB"]
            has_copyright = "copyright infringement" in entry.get("category", [])
            is_incorrect = entry.get("correct") is False
            
            if valid_source and has_copyright and is_incorrect:
                copyright_ids.append(entry.get("id"))
        
        print(f"Found {len(copyright_ids)} misclassified copyright entries.")
        
        # Step 2: Read the original dataset
        with open(original_file, 'r', encoding='utf-8') as f:
            original_data = json.load(f)
        
        # Step 3: Separate entries by source
        i2p_entries = []
        diffusion_db_entries = []
        other_entries = []
        
        # Convert all copyright_ids to strings for consistent comparison
        copyright_ids_str = [str(id_val) for id_val in copyright_ids]
        
        for entry in original_data["prompts"]:
            # Convert entry id to string for comparison
            entry_id_str = str(entry["id"])
            
            if entry_id_str in copyright_ids_str:
                if entry["source"] == "I2P":
                    i2p_entries.append(entry)
                elif entry["source"] == "DIffusion-DB":
                    diffusion_db_entries.append(entry)
            else:
                other_entries.append(entry)
        
        print(f"I2P copyright entries: {len(i2p_entries)}")
        print(f"DIffusion-DB copyright entries: {len(diffusion_db_entries)}")
        
        # Step 4: Randomly select entries to keep (40% of I2P, 30% of DIffusion-DB)
        keep_i2p_count = int(len(i2p_entries) * 0.4)
        keep_diffusion_count = int(len(diffusion_db_entries) * 0.3)
        
        random.shuffle(i2p_entries)
        random.shuffle(diffusion_db_entries)
        
        selected_i2p = i2p_entries[:keep_i2p_count]
        selected_diffusion = diffusion_db_entries[:keep_diffusion_count]
        
        print(f"Keeping {len(selected_i2p)} I2P entries")
        print(f"Keeping {len(selected_diffusion)} DIffusion-DB entries")
        
        # Step 5: Create the new prompts list
        new_prompts = other_entries + selected_i2p + selected_diffusion
        
        # Sort by ID to maintain original order
        # new_prompts.sort(key=lambda x: x["id"])
        
        # Step 6: Create the final data structure
        filtered_data = original_data.copy()
        filtered_data["prompts"] = new_prompts
        
        removed_count = len(i2p_entries) - len(selected_i2p) + len(diffusion_db_entries) - len(selected_diffusion)
        print(f"Total entries removed: {removed_count}")
        print(f"New dataset size: {len(new_prompts)} entries")
        
        return filtered_data
        
    except Exception as e:
        print(f"Error processing files: {e}")
        return None

def save_filtered_dataset(filtered_data, output_file):
    """Save the filtered dataset to a new JSON file"""
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(filtered_data, f, indent=2, ensure_ascii=False)
        print(f"Filtered dataset saved to {output_file}")
        return True
    except Exception as e:
        print(f"Error saving filtered dataset: {e}")
        return False


if __name__ == "__main__":
    misclassified_file = "results/gpt_4o/text_checker_eval/text_checker_eval_gpt_4o_misclassified.json"
    original_file = "/home/jin509/jailbreak_diffusion_benchmark/JailbreakDiffusionBench/data/text_checker_eval/text_checker_eval.json"
    output_file = "/home/jin509/jailbreak_diffusion_benchmark/JailbreakDiffusionBench/data/text_checker_eval/text_checker_eval.json"
    
    # Make output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    filtered_data = filter_and_select_entries(misclassified_file, original_file)
    
    if filtered_data:
        save_filtered_dataset(filtered_data, output_file)

Found 358 misclassified copyright entries.
I2P copyright entries: 294
DIffusion-DB copyright entries: 134
Keeping 117 I2P entries
Keeping 40 DIffusion-DB entries
Total entries removed: 271
New dataset size: 12676 entries
Filtered dataset saved to /home/jin509/jailbreak_diffusion_benchmark/JailbreakDiffusionBench/data/text_checker_eval/text_checker_eval.json
