<a href="https://colab.research.google.com/github/MrinalA2009/ZEDD/blob/main/Zero_Shot_Embedding_Drift_Detection_A_Lightweight_Defense_Against_Prompt_Injection_in_Instruction_Following_LLMS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%autosave 30

# Install Necessary Modules and Mount to Drive

In [None]:
!pip install datasets==2.15.0 openai fasttext tqdm numpy==1.25.0 tiktoken transformers==4.44.0 sentence-transformers==3.0.1
!rm -rf ~/.cache/huggingface/datasets
!rm -rf /root/.cache/huggingface/datasets

In [None]:
from google.colab import drive

drive.mount("/content/drive")

#Load and Filter Data

The code in this section processes the `microsoft/llmail-inject-challenge` dataset to clean and prepare it for analysis. It begins by removing duplicate entries based on normalized text content, then filters the dataset to keep only English-language entries (plus some system messages) using Facebook’s FastText language detection model. The code also calculates the average character length of text content across different phases of the dataset. Finally, it splits the cleaned dataset into four equal quarters, creating separate dataset objects for each portion.

In [None]:
def calculate_avg_length(dataset):

    """calculate_avg_length(dataset) calculates the average character length of text content across
    two phases of a dataset. It sums the character counts of all "body" entries
    from both Phase1 and Phase2, then divides by the total number of entries to
    return the mean length."""

    phase1 = dataset["Phase1"]
    phase2 = dataset["Phase2"]

    phase1_total = sum(len(body.strip()) for body in phase1["body"])
    phase2_total = sum(len(body.strip()) for body in phase2["body"])

    total_rows = len(phase1["body"]) + len(phase2["body"])
    return (phase1_total + phase2_total) / total_rows

In [None]:
from openai import OpenAI
from google.colab import userdata
client = OpenAI(api_key=userdata.get("OPENAI_KEY"))

In [None]:
from datasets import load_dataset, Dataset

try:
    dataset_injected = load_dataset("microsoft/llmail-inject-challenge")
except NotImplementedError:
    print("Loading from cache failed, attempting to force download.")
    dataset_injected = load_dataset("microsoft/llmail-inject-challenge", download_mode="force_redownload")

In [None]:
print(calculate_avg_length(dataset_injected))

In [None]:
from datasets import Dataset, DatasetDict
import pandas as pd

def remove_duplicates_from_dataset_dict(dataset_dict):

    cleaned_dict = {}
    seen = set()

    for phase_name, dataset in dataset_dict.items():
        print(f"\n=== Processing {phase_name} ===")
        print(f"Original size: {len(dataset)}")


        df = dataset.to_pandas()
        df['body_normalized'] = df['body'].str.replace(r'\s+', '', regex=True).str.lower()

        mask = ~df['body_normalized'].isin(seen)
        df_filtered = df[mask].copy()


        df_deduped = df_filtered.drop_duplicates(subset=['body_normalized'], keep='first')


        seen.update(df_deduped['body_normalized'].tolist())


        df_deduped = df_deduped.drop('body_normalized', axis=1)

        print(f"After deduplication: {len(df_deduped)} (removed {len(df) - len(df_deduped)} duplicates)")


        cleaned_dict[phase_name] = Dataset.from_pandas(df_deduped)

    return DatasetDict(cleaned_dict)

dataset_injected = remove_duplicates_from_dataset_dict(dataset_injected)

In [None]:
print(calculate_avg_length(dataset_injected))

In [None]:
!wget -q https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz


import fasttext
from datasets import DatasetDict
from collections import defaultdict

ft_model = fasttext.load_model("lid.176.ftz")

def fasttext_detect_language(text, threshold=0.05):
    try:
        if not text or not isinstance(text, str):
            return 'unknown'

        clean_text = text.strip()
        if not clean_text:
            return 'unknown'


        clean_text = ' '.join(clean_text.split())

        prediction = ft_model.predict(clean_text, k=1)
        label, prob = prediction[0][0], prediction[1][0]

        detected_lang = label.replace('__label__', '')

        return detected_lang
    except Exception as e:
        return 'unknown'


def make_filter_fn():
    local_counter = defaultdict(int)
    def _filter(entry):
        body = entry.get("body", "")
        if not isinstance(body, str) or len(body.strip()) == 0:
            local_counter["removed"] += 1
            return False
        if 'system' in body.lower() or '<<' in body.lower():
            local_counter["kept_system"] += 1
            return True
        lang = fasttext_detect_language(body)
        if lang == 'en':
            local_counter["kept_en"] += 1
            return True
        else:
            local_counter["removed"] += 1
            return False
    return _filter, local_counter

def filter_english(dataset_dict):
    filtered_dict = {}
    for phase, dataset in dataset_dict.items():
        print(f"\nFiltering {phase}...")
        filter_fn, counter = make_filter_fn()
        filtered_dataset = dataset.filter(
            filter_fn,
            desc=f"Filtering {phase}",
            num_proc=1,
            with_indices=False
        )
        total = len(dataset)
        kept = counter["kept_en"] + counter["kept_system"]
        removed = counter["removed"]
        print(f"{phase} Summary:")
        print(f"  Total entries:        {total:,}")
        print(f"  Kept (English):       {counter['kept_en']:,}")
        print(f"  Kept (system):        {counter['kept_system']:,}")
        print(f"  Removed (non-English): {removed:,} ({(removed/total)*100:.2f}%)")
        filtered_dict[phase] = filtered_dataset
    return DatasetDict(filtered_dict)

In [None]:
dataset_fasttext = filter_english(dataset_injected)

In [None]:
dataset_injected = dataset_fasttext.remove_columns('__index_level_0__')

In [None]:
dataset_injected

In [None]:
print(calculate_avg_length(dataset_injected))

In [None]:
def split_dataset_dict(dataset_dict):

    dataset_injected_first = {}
    dataset_injected_second = {}
    dataset_injected_third = {}
    dataset_injected_fourth = {}

    for phase_name, dataset in dataset_dict.items():
        # Calculate split points (divide into 4 equal parts)
        total_rows = len(dataset)
        quarter = total_rows // 4

        # Calculate split indices
        split_1 = quarter
        split_2 = quarter * 2
        split_3 = quarter * 3

        # Split the dataset into 4 parts
        first_quarter = dataset.select(range(0, split_1))
        second_quarter = dataset.select(range(split_1, split_2))
        third_quarter = dataset.select(range(split_2, split_3))
        fourth_quarter = dataset.select(range(split_3, total_rows))

        # Add to respective dictionaries
        dataset_injected_first[phase_name] = first_quarter
        dataset_injected_second[phase_name] = second_quarter
        dataset_injected_third[phase_name] = third_quarter
        dataset_injected_fourth[phase_name] = fourth_quarter

        print(f"{phase_name}: Split {total_rows} rows into {len(first_quarter)} + {len(second_quarter)} + {len(third_quarter)} + {len(fourth_quarter)} rows")

    # Convert dictionaries back to DatasetDict objects
    dataset_injected_first = DatasetDict(dataset_injected_first)
    dataset_injected_second = DatasetDict(dataset_injected_second)
    dataset_injected_third = DatasetDict(dataset_injected_third)
    dataset_injected_fourth = DatasetDict(dataset_injected_fourth)

    return dataset_injected_first, dataset_injected_second, dataset_injected_third, dataset_injected_fourth

In [None]:
dataset_injected_first, dataset_injected_second, dataset_injected_third, dataset_injected_fourth = split_dataset_dict(dataset_injected)

# Create Categorized Dataset

## Create JSONL File for OpenAI Batch API

The code in this section creates a batch processing system for classifying prompt injection attacks using OpenAI's API. It defines a function `create_jsonl_file()` that formats individual text samples into JSON requests for GPT-3.5-turbo, where each request asks the model to classify prompts into one of five cybersecurity attack categories: jailbreak, system leak, task override, encoding manipulation, or prompt confusion. The system uses a detailed prompt that instructs the AI to act as a cybersecurity expert and classify injection attempts based on specific definitions for each attack type.

In [None]:
from datasets import Dataset, DatasetDict
import time
import json
counter = 0
def create_jsonl_file(batch, field, output_file="batch_requests.jsonl", model="gpt-3.5-turbo-0125"):
    global counter
    text = batch[field]
    if not text:
        return "Error with Texts"

    try:
        counter += 1
        custom_id = f"batch-{counter}"

        request_data = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": model,
                "messages": [
                    {
                        "role": "system",
                        "content": """You are a cybersecurity expert specializing in prompt injection classification.

                                      You will receive injected prompts and must classify them into one of these attack categories:

                                      1. **jailbreak** — Only prompts designed to bypass the model's safety restrictions, override ethical constraints, or make the AI ignore its built-in instructions. Often involves pretending to be someone else, using reverse psychology, or roleplay.
                                      2. **system leak** — Only prompts aimed at exposing hidden system messages, configurations, internal logic, or developer instructions that are not meant for the user.
                                      3. **task override** — Only prompts that attempt to change or override the original objective assigned to the model. The attacker tries to reprogram the task, shift the topic, or redefine the goal.
                                      4. **encoding manipulation** — Only prompts that hide malicious instructions using THESE ENCODED FORMATS (base64, hex, Unicode, URL encoding) to bypass filters or detection systems.
                                      5. **prompt confusion** — Only prompts that use misleading, contradictory, repetitive wording, or ambiguous instructions to confuse the model's behavior. These attacks often involve multiple, conflicting tasks or embedded triggers.

                                      Respond with only one category name (e.g., "jailbreak", "system_leak", "task_override", "encoding_manipulation", "prompt_confusion") with no markdown or any symbols before the category name."""
                    },
                    {
                        "role": "user",
                        "content": f"You will receive a prompt. For that prompt, respond with the category (one of: jailbreak, system_leak, task_override, encoding_manipulation, prompt_confusion) in that format. Take into account the exact definitions for each type of injection and do not return that the prompt isn't injected. Here is the prompt:\n\n {text}"
                    }
                ],
                "max_tokens": 20,
                "temperature": 0.1
            }
        }

        with open(output_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(request_data) + '\n')


    except Exception as e:
        print(f"Error creating batch file: {e}")

def allocate_dataset(dataset_injected, field, batch_size=1, output_file="batch_requests.jsonl"):

    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("")

    for split_name, dataset in dataset_injected.items():
        print(f"Processing {split_name} split with {len(dataset)} samples...")

        dataset.map (
            lambda batch: create_jsonl_file(batch, field, output_file),
            batched=True,
            batch_size=batch_size,
            desc=f"Creating batch requests for {split_name}"
        )

    print(f"All batch requests written to {output_file}")

## Use the OpenAI Batch API to allocate categories

The code in this section processes large datasets for prompt injection classification by splitting them into chunks and submitting each chunk as a separate batch job to OpenAI's API. The workflow follows a consistent pattern: generate JSONL request files, upload them to OpenAI, create batch jobs with 24-hour completion windows, and retrieve the classification results. By dividing the work into multiple batches (first, second1, second2, third, fourth), it efficiently handles large-scale data processing while staying within API limits. Each batch is tracked with descriptive metadata to manage the multiple concurrent operations systematically.

In [None]:
allocate_dataset(dataset_injected_first, "body")

In [None]:
batch_input_file = client.files.create(
    file=open("batch_requests.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file)

In [None]:
batch_input_file_id = batch_input_file.id
batch_val = client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "first quarter"
    }
)

In [None]:
batch = client.batches.retrieve(batch_val.id)
print(batch)
batch_output_file_id = batch.output_file_id


In [None]:
file_response = client.files.content(batch_output_file_id)

In [None]:
print(file_response.text)

In [None]:
# Get the current batch status
batch = client.batches.retrieve(batch.id)
print(f"Batch status: {batch.status}")

# Only try to access error file if batch failed or completed with errors
if batch.status in ["failed", "completed"] and batch.error_file_id:
    error_file_response = client.files.content(batch.error_file_id)
    print(error_file_response.text)
elif batch.status == "completed":
    print("Batch completed successfully - no errors to display")
else:
    print(f"Batch is {batch.status} - error file not yet available")

In [None]:
def split_dataset_dict_half(dataset_dict):

    dataset_injected_first = {}
    dataset_injected_second = {}

    for phase_name, dataset in dataset_dict.items():
        # Calculate split point (divide into 2 equal parts)
        total_rows = len(dataset)
        half = total_rows // 2

        # Split the dataset into 2 parts
        first_half = dataset.select(range(0, half))
        second_half = dataset.select(range(half, total_rows))

        # Add to respective dictionaries
        dataset_injected_first[phase_name] = first_half
        dataset_injected_second[phase_name] = second_half

        print(f"{phase_name}: Split {total_rows} rows into {len(first_half)} + {len(second_half)} rows")

    # Convert dictionaries back to DatasetDict objects
    dataset_injected_first = DatasetDict(dataset_injected_first)
    dataset_injected_second = DatasetDict(dataset_injected_second)

    return dataset_injected_first, dataset_injected_second

In [None]:
dataset_injected_second1, dataset_injected_second2 = split_dataset_dict_half(dataset_injected_second)

In [None]:
allocate_dataset(dataset_injected_second1, "body", output_file="batch_requests_second1.jsonl")

In [None]:
batch_input_file_second1 = client.files.create(
    file=open("batch_requests_second1.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_second1)

In [None]:
batch_input_file_id_second1 = batch_input_file_second1.id
batch_val_second1 = client.batches.create(
    input_file_id=batch_input_file_id_second1,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "second part 1"
    }
)

In [None]:
batch_second1 = client.batches.retrieve(batch_val_second1.id)
print(batch_second1)
batch_output_file_id_second1 = batch_second1.output_file_id

In [None]:
file_response_second1 = client.files.content(batch_output_file_id_second1)


In [None]:
file_response_second1.text

In [None]:
allocate_dataset(dataset_injected_second2, "body", output_file="batch_requests_second2.jsonl")

In [None]:
batch_input_file_second2 = client.files.create(
    file=open("batch_requests_second2.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_second2)

In [None]:
batch_input_file_id_second2 = batch_input_file_second2.id
batch_val_second2 = client.batches.create(
    input_file_id=batch_input_file_id_second2,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "second part 2"
    }
)

In [None]:
batch_second2 = client.batches.retrieve(batch_val_second2.id)
print(batch_second2)
batch_output_file_id_second2 = batch_second2.output_file_id

In [None]:
file_response_second2 = client.files.content(batch_output_file_id_second2)

In [None]:
file_response_second2.text

In [None]:
allocate_dataset(dataset_injected_third, "body", output_file="batch_requests_third.jsonl")

In [None]:
batch_input_file_third = client.files.create(
    file=open("batch_requests_third.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_third)

In [None]:
batch_input_file_id_third = batch_input_file_third.id
batch_val_third = client.batches.create(
    input_file_id=batch_input_file_id_third,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "third part"
    }
)

In [None]:
batch_third = client.batches.retrieve(batch_val_third.id)
print(batch_third)
batch_output_file_id_third = batch_third.output_file_id

In [None]:
file_response_third = client.files.content(batch_output_file_id_third)

In [None]:
allocate_dataset(dataset_injected_fourth, "body", output_file="batch_requests_fourth.jsonl")

In [None]:
batch_input_file_fourth = client.files.create(
    file=open("batch_requests_fourth.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_fourth)

In [None]:
batch_input_file_id_fourth = batch_input_file_fourth.id
batch_val_fourth = client.batches.create(
    input_file_id=batch_input_file_id_fourth,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "fourth part"
    }
)

In [None]:
batch_fourth = client.batches.retrieve(batch_val_fourth.id)
print(batch_fourth)
batch_output_file_id_fourth = batch_fourth.output_file_id

In [None]:
file_response_fourth = client.files.content(batch_output_file_id_fourth)

In [None]:
print(file_response_fourth.text)

## Process Batch

The code in this section processes OpenAI batch API responses to add prompt injection classification categories to datasets. The `process_batch_and_add_categories` function parses batch responses, extracts categories (jailbreak, system leak, etc.), and adds them as new columns to the original datasets. It handles multiple dataset chunks, concatenates them back into Phase1/Phase2 splits, and includes error handling for failed classifications. The processed labeled dataset is then saved to Google Drive as JSON and can be reloaded as a DatasetDict for further analysis.

In [None]:
def process_batch_and_add_categories(original_dataset, batch_content, batch_size=1, filter_failed=True):
    """
    Process batch responses and add categories to dataset

    Args:
        original_dataset: HuggingFace DatasetDict
        batch_content: String containing batch responses (JSON/JSONL format)
        batch_size: Size of each batch (default=1)
        filter_failed: Whether to filter out failed entries (default=True)
    """
    from datasets import DatasetDict, Dataset

    print("Starting batch processing...")
    print(f"Batch size: {batch_size}")
    print(f"Filter failed: {filter_failed}")


    # Step 1: Parse the batch content
    print("Parsing batch responses...")
    batch_responses = {}

    # Parse the batch content
    batch_data = parse_batch_content(batch_content)

    if not batch_data:
        print("ERROR: No batch data could be parsed from the input!")
        print(f"Input content preview: {repr(batch_content[:500] if batch_content else 'None')}...")
        return original_dataset

    print(f"Successfully parsed {len(batch_data)} batch responses")

    # Process each response
    for response_data in batch_data:
        try:
            custom_id = response_data['custom_id']

            # Extract content from response
            content = response_data['response']['body']['choices'][0]['message']['content']

            # Parse categories from content
            categories = []

            # For batch_size=1, we expect a single category
            if batch_size == 1:
                category = extract_category_from_text(content)
                if category and is_valid_category(category):
                    categories = [category]
                else:
                    print(f"Invalid category extracted from {custom_id}: '{category}' from content: '{content}'")
                    categories = ['failed_parsing']
            else:
                # Handle multiple categories (your existing logic)
                content_lines = content.split('\n')

                for line_content in content_lines:
                    line_content = line_content.strip()
                    if not line_content:
                        continue

                    # Handle multiple categories in one line
                    if ',' in line_content or '=' in line_content:
                        potential_parts = []
                        for sep in [',', '=', ';', '|']:
                            if sep in line_content:
                                potential_parts = line_content.split(sep)
                                break

                        if potential_parts:
                            for part in potential_parts:
                                part = part.strip()
                                if part:
                                    category = extract_category_from_text(part)
                                    if category and is_valid_category(category):
                                        categories.append(category)
                        continue

                    # Single category per line
                    category = extract_category_from_text(line_content)
                    if category and is_valid_category(category):
                        categories.append(category)

                # Validate we have the expected number of categories
                if len(categories) != batch_size:
                    print(f"Expected {batch_size} categories for {custom_id}, but found {len(categories)}")
                    print(f"Categories found: {categories}")
                    print(f"Raw content: {repr(content)}")

                    # Adjust categories list
                    if len(categories) < batch_size:
                        missing = batch_size - len(categories)
                        categories.extend(['failed_parsing'] * missing)
                    else:
                        categories = categories[:batch_size]

            batch_responses[custom_id] = categories
            print(f"Batch {custom_id}: Found {len(categories)} categories: {categories}")

        except Exception as e:
            print(f"Error parsing batch response: {e}")
            print(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}")
            if isinstance(response_data, dict):
                print(f"Custom ID: {response_data.get('custom_id', 'MISSING')}")
            continue

    print(f"Found {len(batch_responses)} successful batches")

    # Step 2: Process each split
    updated_splits = {}
    failed_batches = set()

    # First, let's see what custom_ids we actually have
    available_custom_ids = sorted(batch_responses.keys())
    print(f"Available custom_ids: {available_custom_ids[:10]}..." if len(available_custom_ids) > 10 else f"Available custom_ids: {available_custom_ids}")

    # Extract just the batch numbers to understand the sequence
    batch_numbers = []
    for custom_id in available_custom_ids:
        try:
            parts = custom_id.split('-')
            if len(parts) >= 2:
                batch_num = int(parts[1])
                batch_numbers.append(batch_num)
        except ValueError:
            continue

    if batch_numbers:
        print(f"Batch number range: {min(batch_numbers)} to {max(batch_numbers)} ({len(batch_numbers)} total)")

    for split_name, dataset in original_dataset.items():
        print(f"\nProcessing {split_name}...")

        split_categories = []
        num_samples = len(dataset)

        print(f"Dataset has {num_samples} samples")

        # Create a mapping based on available custom_ids
        # We'll match them in order to the dataset samples
        sorted_custom_ids = sorted(batch_responses.keys(), key=lambda x: int(x.split('-')[1]) if len(x.split('-')) > 1 and x.split('-')[1].isdigit() else 0)

        samples_processed = 0
        custom_id_index = 0

        # Go through samples and match with available batch responses
        for i in range(0, num_samples, batch_size):
            current_batch_size = min(batch_size, num_samples - i)

            if custom_id_index < len(sorted_custom_ids):
                # Use the next available custom_id
                custom_id = sorted_custom_ids[custom_id_index]
                custom_id_index += 1

                if custom_id in batch_responses:
                    batch_cats = batch_responses[custom_id]

                    # Ensure we have the right number of categories
                    if len(batch_cats) == current_batch_size:
                        split_categories.extend(batch_cats)
                        print(f"✓ {custom_id}: Added {len(batch_cats)} categories")
                    else:
                        print(f"⚠ {custom_id}: Expected {current_batch_size} categories, got {len(batch_cats)}")
                        # Take what we have and fill the rest
                        split_categories.extend(batch_cats[:current_batch_size])
                        if len(batch_cats) < current_batch_size:
                            missing = current_batch_size - len(batch_cats)
                            split_categories.extend(['failed_parsing'] * missing)
                else:
                    print(f"✗ {custom_id}: Not found in responses")
                    split_categories.extend(['failed'] * current_batch_size)
                    failed_batches.add(custom_id)
            else:
                # No more custom_ids available
                print(f"✗ No more batch responses available - adding {current_batch_size} 'failed' entries")
                split_categories.extend(['failed'] * current_batch_size)

            samples_processed += current_batch_size

        # Validation
        print(f"Generated {len(split_categories)} categories for {num_samples} samples")

        if len(split_categories) != num_samples:
            print(f"ERROR: Category count mismatch!")
            print(f"Expected: {num_samples}, Got: {len(split_categories)}")

            if len(split_categories) < num_samples:
                missing = num_samples - len(split_categories)
                print(f"Adding {missing} 'missing' entries")
                split_categories.extend(['missing'] * missing)
            elif len(split_categories) > num_samples:
                print(f"Truncating to {num_samples} entries")
                split_categories = split_categories[:num_samples]

        # Count categories
        category_counts = {}
        for cat in split_categories:
            category_counts[cat] = category_counts.get(cat, 0) + 1

        print(f"Category distribution for {split_name}:")
        for cat, count in sorted(category_counts.items()):
            print(f"  {cat}: {count}")

        # Add category column
        try:
            dataset_with_categories = dataset.add_column('category', split_categories)
            updated_splits[split_name] = dataset_with_categories
            print(f"✓ Successfully added categories to {split_name}")
        except Exception as e:
            print(f"✗ Error adding categories to {split_name}: {e}")
            raise

    # Create DatasetDict
    dataset_dict = DatasetDict(updated_splits)

    # Step 3: Apply filtering if requested
    if filter_failed:
        print("\nApplying filtering to remove failed entries...")
        dataset_dict = filter_failed_parsing_datasetdict(dataset_dict)

    # Final summary
    print(f"\nSUMMARY:")
    for split_name, split_dataset in dataset_dict.items():
        print(f"{split_name}: {len(split_dataset)} samples")

    if failed_batches:
        print(f"Failed batches: {sorted(failed_batches)}")

    return dataset_dict


def parse_batch_content(batch_content):
    """
    Robust parser for batch content (JSON/JSONL format)
    """
    import json

    if not batch_content:
        print("Empty batch content received")
        return []

    # Handle different input types
    if hasattr(batch_content, 'text'):
        batch_content = batch_content.text
    elif not isinstance(batch_content, str):
        batch_content = str(batch_content)

    batch_content = batch_content.strip()

    if not batch_content:
        print("Empty batch content after processing")
        return []

    print(f"Content length: {len(batch_content)} characters")
    print(f"Content starts with: {repr(batch_content[:100])}")
    print(f"Content ends with: {repr(batch_content[-100:])}")

    responses = []

    try:
        # Method 1: Try JSONL format (most common)
        print("Attempting JSONL parsing...")
        lines = batch_content.split('\n')
        print(f"Found {len(lines)} lines")

        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                continue

            try:
                response = json.loads(line)

                # Validate response structure
                if 'custom_id' not in response:
                    print(f"Line {i+1}: Missing custom_id")
                    continue

                if 'response' not in response:
                    print(f"Line {i+1}: Missing response field")
                    continue

                # Check for error field
                if response.get('error'):
                    print(f"Line {i+1}: Response has error: {response['error']}")
                    continue

                # Validate nested structure
                try:
                    content = response['response']['body']['choices'][0]['message']['content']
                    responses.append(response)

                    if i < 5:  # Show first few for debugging
                        print(f"✓ Line {i+1}: {response['custom_id']} -> '{content}'")

                except (KeyError, IndexError, TypeError) as e:
                    print(f"Line {i+1}: Invalid response structure: {e}")
                    continue

            except json.JSONDecodeError as e:
                print(f"Line {i+1}: JSON decode error: {e}")
                if len(line) < 200:
                    print(f"  Full line: {repr(line)}")
                else:
                    print(f"  Line preview: {repr(line[:100])}...{repr(line[-100:])}")
                continue

        if responses:
            print(f"Successfully parsed {len(responses)} responses from JSONL")
            return responses

        # Method 2: Try JSON array format
        print("JSONL failed, attempting JSON array parsing...")
        if batch_content.startswith('[') and batch_content.endswith(']'):
            data = json.loads(batch_content)
            if isinstance(data, list):
                print(f"Successfully parsed {len(data)} responses from JSON array")
                return data

        # Method 3: Try single JSON object
        print("Attempting single JSON object parsing...")
        data = json.loads(batch_content)
        if isinstance(data, dict):
            if 'responses' in data:
                return data['responses']
            elif 'data' in data:
                return data['data']
            else:
                return [data]

    except Exception as e:
        print(f"All parsing methods failed: {e}")

    return []


def extract_category_from_text(text):
    """
    Extract category from text, handling various formats
    """
    import re

    if not text or not isinstance(text, str):
        return None

    text = text.strip()
    if not text:
        return None

    # Handle numbered format: "1. category" or "2. system_leak"
    numbered_match = re.match(r'^\d+\.\s*(.+)', text)
    if numbered_match:
        category = numbered_match.group(1).strip()
    # Handle bullet formats: "- category" or "* category"
    elif text.startswith(('- ', '* ')):
        category = text[2:].strip()
    # Handle colon format: "Category: value"
    elif ':' in text:
        category = text.split(':', 1)[1].strip()
    else:
        category = text

    # Clean formatting
    category = re.sub(r'[*`"\'()[\]{}]', '', category).strip()

    # Remove common prefixes
    prefixes = ['category', 'type', 'classification', 'label', 'answer', 'result']
    category_lower = category.lower()
    for prefix in prefixes:
        if category_lower.startswith(prefix + ':'):
            category = category[len(prefix)+1:].strip()
            break
        elif category_lower.startswith(prefix + ' '):
            category = category[len(prefix)+1:].strip()
            break

    # Convert to lowercase
    category = category.lower().strip()

    return category if category else None


def is_valid_category(category):
    """
    Check if category is valid
    """
    if not category or len(category) < 3:
        return False

    valid_categories = {
        'jailbreak',
        'system_leak',
        'task_override',
        'encoding_manipulation',
        'prompt_confusion'
    }

    return category.lower() in valid_categories


def filter_failed_parsing_datasetdict(dataset_dict):
    """
    Filter out failed entries from DatasetDict
    """
    from datasets import DatasetDict

    filtered_dict = {}
    failure_types = {'failed', 'failed_parsing', 'missing'}

    for phase_name, dataset in dataset_dict.items():
        print(f"\nFiltering {phase_name}...")

        # Show before filtering
        before_count = len(dataset)
        category_counts = {}
        for example in dataset:
            cat = example['category']
            category_counts[cat] = category_counts.get(cat, 0) + 1

        print(f"Before filtering ({before_count} samples):")
        for cat, count in sorted(category_counts.items()):
            print(f"  {cat}: {count}")

        # Filter out failure types
        filtered_dataset = dataset.filter(lambda example: example['category'] not in failure_types)
        filtered_dict[phase_name] = filtered_dataset

        # Show after filtering
        after_count = len(filtered_dataset)
        print(f"After filtering: {before_count} -> {after_count} ({before_count - after_count} removed)")

    return DatasetDict(filtered_dict)

In [None]:
updated_dataset_part_one = process_batch_and_add_categories(
     original_dataset=dataset_injected_first,
     batch_content=file_response.text,
     batch_size=1
 )


In [None]:
updated_dataset_part_two1 = process_batch_and_add_categories(
     original_dataset=dataset_injected_second1,
     batch_content=file_response_second1.text,
     batch_size=1
 )


In [None]:
updated_dataset_part_two2 = process_batch_and_add_categories(
     original_dataset=dataset_injected_second2,
     batch_content=file_response_second2.text,
     batch_size=1
 )



In [None]:
from datasets import concatenate_datasets
phase1 = concatenate_datasets([updated_dataset_part_two1["Phase1"], updated_dataset_part_two2["Phase1"]])
phase2 = concatenate_datasets([updated_dataset_part_two1["Phase2"], updated_dataset_part_two2["Phase2"]])

updated_dataset_part_two = DatasetDict({"Phase1": phase1, "Phase2":phase2})
updated_dataset_part_two

In [None]:
updated_dataset_part_three = process_batch_and_add_categories(
     original_dataset=dataset_injected_third,
     batch_content=file_response_third.text,
     batch_size=1
 )

In [None]:
updated_dataset_part_four = process_batch_and_add_categories(
     original_dataset=dataset_injected_fourth,
     batch_content=file_response_fourth.text,
     batch_size=1
 )


In [None]:
from datasets import concatenate_datasets

updated_dataset_full_phase1 = concatenate_datasets([updated_dataset_part_one["Phase1"], updated_dataset_part_two["Phase1"], updated_dataset_part_three["Phase1"], updated_dataset_part_four["Phase1"]])
updated_dataset_full_phase2 = concatenate_datasets([updated_dataset_part_one["Phase2"], updated_dataset_part_two["Phase2"], updated_dataset_part_three["Phase2"], updated_dataset_part_four["Phase2"]])

updated_dataset_full = DatasetDict({"Phase1":updated_dataset_full_phase1, "Phase2":updated_dataset_full_phase2})


In [None]:
updated_dataset_full

In [None]:
from google.colab import drive

drive.mount('/content/drive')


In [None]:
import json
import os

output_dir = '/content/drive/MyDrive/Algoverse/'
os.makedirs(output_dir, exist_ok=True)


json_data = {}

for phase_name, dataset in updated_dataset_full.items():
    print(f"Processing {phase_name}...")


    phase_data = []
    for i in range(len(dataset)):
        row = {}
        for feature in dataset.features:
            row[feature] = dataset[i][feature]
        phase_data.append(row)

    json_data[phase_name] = {
        'features': list(dataset.features.keys()),
        'num_rows': len(dataset),
        'data': phase_data
    }


output_path = os.path.join(output_dir, 'dataset_with_categories.json')

print(f"Saving data to {output_path}...")
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(json_data, f, indent=2, ensure_ascii=False, default=str)

print(f"Successfully saved dataset to {output_path}")
print(f"File size: {os.path.getsize(output_path)} bytes")

In [None]:
import json
from datasets import Dataset, DatasetDict
json_file_path = '/content/drive/MyDrive/Algoverse/dataset_with_categories.json'

print(f"Loading data from {json_file_path}...")
with open(json_file_path, 'r', encoding='utf-8') as f:
    json_data = json.load(f)

dataset_dict = {}
for phase_name, phase_info in json_data.items():
    print(f"Processing {phase_name}...")

    dataset_dict[phase_name] = Dataset.from_list(phase_info['data'])


updated_dataset_full = DatasetDict(dataset_dict)

print("Successfully loaded DatasetDict!")

updated_dataset_full

In [None]:
print(calculate_avg_length(updated_dataset_full))

# Add pairs for the injected prompts

## Create JSONL File for Cleaning (OpenAI BatchAPI)

The code in this section defines a system for generating cleaned versions of prompt injection attacks by sending the original malicious prompts along with their classifications to OpenAI’s API to remove only the injected portions. The `create_json_file_clean` function creates batch requests that ask GPT to reconstruct safe versions by removing injection parts without adding new content, while the `process_batch_and_add_pairs` function processes the API responses and adds the cleaned text as a new “pair” column to create training pairs of malicious and cleaned prompts.


In [None]:
from datasets import Dataset, DatasetDict
import time
import json
import re
import tiktoken
counter = 0
def create_jsonl_file_clean(batch, field, output_file="batch_requests_clean.jsonl", model="gpt-3.5-turbo-0125"):
    global counter
    text = batch[field]
    category = batch["category"]
    encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-0125")
    text = text[0]
    tokens = encoding.encode(text, disallowed_special=())
    num_tokens = len(tokens)
    if not text:
        return "Error with Texts"

    try:
        counter += 1
        custom_id = f"batch-{counter}"

        request_data = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {

                "model": model,
                "messages": [
                    {"role": "system", "content": """You are a safe and helpful assistant, and not malicious at all.
                    You are quick at identifying malicious prompts and are skilled at changing them into clean prompts
                    ONLY by removing injected parts."""},

                    {"role": "user", "content": f"""You will receive an
                    injected prompts and a category that goes with the type of injection
                    it recieved. For this prompt, reconstruct
                    a safe and non malicious version
                    by removing the injected portions.
                    Here is the prompt: {text}
                    and here is the corresponding category: {category}.
                    Respond only with the rewritten prompt ONLY by REMOVING injection portions.
                    Make sure TO NOT ADD any words to the prompt.
                """}
                    ],
                "max_tokens": num_tokens,
                "temperature": 0.1
            }
        }

        with open(output_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(request_data) + '\n')


    except Exception as e:
        print(f"Error creating batch file: {e}")

def create_clean(dataset_injected, field, batch_size=1, output_file="batch_requests_clean.jsonl"):

    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("")

    for split_name, dataset in dataset_injected.items():
        print(f"Processing {split_name} split with {len(dataset)} samples...")

        dataset.map(
            lambda batch: create_jsonl_file_clean(batch, field, output_file),
            batched=True,
            batch_size=batch_size,
            desc=f"Creating batch requests for {split_name}"
        )

    print(f"All batch requests written to {output_file}")

In [None]:
def process_batch_and_add_pairs(original_dataset, batch_content, batch_size=1, filter_failed=True):
    """
    Process batch responses and add text pairs to dataset

    Args:
        original_dataset: HuggingFace DatasetDict
        batch_content: String containing batch responses (JSON/JSONL format)
        batch_size: Size of each batch (default=1)
        filter_failed: Whether to filter out failed entries (default=True)
    """
    from datasets import DatasetDict, Dataset

    print("Starting batch processing...")
    print(f"Batch size: {batch_size}")
    print(f"Filter failed: {filter_failed}")

    # Step 1: Parse the batch content
    print("Parsing batch responses...")
    batch_responses = {}

    # Parse the batch content
    batch_data = parse_batch_content(batch_content)

    if not batch_data:
        print("ERROR: No batch data could be parsed from the input!")
        print(f"Input content preview: {repr(batch_content[:500] if batch_content else 'None')}...")
        return original_dataset

    print(f"Successfully parsed {len(batch_data)} batch responses")

    # Process each response
    for response_data in batch_data:
        try:
            custom_id = response_data['custom_id']

            # Extract content from response
            content = response_data['response']['body']['choices'][0]['message']['content']

            # Extract and clean the text pairs
            text_pairs = []

            # For batch_size=1, we expect a single text response
            if batch_size == 1:
                cleaned_text = extract_and_clean_text(content)
                if cleaned_text:
                    text_pairs = [cleaned_text]
                else:
                    print(f"Empty or invalid text extracted from {custom_id}")
                    text_pairs = ['failed_extraction']
            else:
                # Handle multiple text pairs (split by lines or other delimiters)
                content_lines = content.split('\n')

                for line_content in content_lines:
                    cleaned_text = extract_and_clean_text(line_content)
                    if cleaned_text:
                        text_pairs.append(cleaned_text)

                # Validate we have the expected number of text pairs
                if len(text_pairs) != batch_size:
                    print(f"Expected {batch_size} text pairs for {custom_id}, but found {len(text_pairs)}")
                    print(f"Text pairs found: {len(text_pairs)}")
                    print(f"Raw content: {repr(content)}")

                    # Adjust text pairs list
                    if len(text_pairs) < batch_size:
                        missing = batch_size - len(text_pairs)
                        text_pairs.extend(['failed_extraction'] * missing)
                    else:
                        text_pairs = text_pairs[:batch_size]

            batch_responses[custom_id] = text_pairs
            print(f"Batch {custom_id}: Found {len(text_pairs)} text pairs")

        except Exception as e:
            print(f"Error parsing batch response: {e}")
            print(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}")
            if isinstance(response_data, dict):
                print(f"Custom ID: {response_data.get('custom_id', 'MISSING')}")
            continue

    print(f"Found {len(batch_responses)} successful batches")

    # Step 2: Process each split
    updated_splits = {}
    failed_batches = set()

    # First, let's see what custom_ids we actually have
    available_custom_ids = sorted(batch_responses.keys())
    print(f"Available custom_ids: {available_custom_ids[:10]}..." if len(available_custom_ids) > 10 else f"Available custom_ids: {available_custom_ids}")

    # Extract just the batch numbers to understand the sequence
    batch_numbers = []
    for custom_id in available_custom_ids:
        try:
            parts = custom_id.split('-')
            if len(parts) >= 2:
                batch_num = int(parts[1])
                batch_numbers.append(batch_num)
        except ValueError:
            continue

    if batch_numbers:
        print(f"Batch number range: {min(batch_numbers)} to {max(batch_numbers)} ({len(batch_numbers)} total)")

    for split_name, dataset in original_dataset.items():
        print(f"\nProcessing {split_name}...")

        split_pairs = []
        num_samples = len(dataset)

        print(f"Dataset has {num_samples} samples")

        # Calculate expected number of batches for this split
        expected_batches = (num_samples + batch_size - 1) // batch_size  # Ceiling division
        print(f"Expected {expected_batches} batches for {num_samples} samples with batch_size={batch_size}")

        # Extract batch numbers from available responses to understand the numbering scheme
        available_batch_nums = []
        for custom_id in batch_responses.keys():
            try:
                parts = custom_id.split('-')
                if len(parts) >= 2 and parts[1].isdigit():
                    batch_num = int(parts[1])
                    available_batch_nums.append(batch_num)
            except:
                continue

        available_batch_nums.sort()

        if available_batch_nums:
            print(f"Available batch numbers: {available_batch_nums[0]} to {available_batch_nums[-1]} ({len(available_batch_nums)} total)")
            start_batch_num = available_batch_nums[0]
            end_batch_num = available_batch_nums[-1]
            expected_end_batch = start_batch_num + expected_batches - 1
            print(f"Expected batch range for this dataset: {start_batch_num} to {expected_end_batch}")
        else:
            print("No valid batch numbers found in responses")
            split_pairs = ['api_failed'] * num_samples
            failed_batches.update([f"no-batch-nums"])
            continue

        # Process samples in order, looking for the corresponding batch numbers
        split_pairs = []

        for i in range(0, num_samples, batch_size):
            current_batch_size = min(batch_size, num_samples - i)
            dataset_batch_index = i // batch_size  # 0, 1, 2, 3, ...

            # Calculate expected batch number for this dataset position
            expected_batch_num = start_batch_num + dataset_batch_index
            expected_custom_id = f"batch-{expected_batch_num}"

            # Look for this specific batch
            if expected_custom_id in batch_responses:
                batch_pairs = batch_responses[expected_custom_id]

                # Ensure we have the right number of text pairs
                if len(batch_pairs) == current_batch_size:
                    split_pairs.extend(batch_pairs)
                    print(f"✓ Dataset position {dataset_batch_index} -> {expected_custom_id}: Added {len(batch_pairs)} text pairs")
                else:
                    print(f"⚠ Dataset position {dataset_batch_index} -> {expected_custom_id}: Expected {current_batch_size} text pairs, got {len(batch_pairs)}")
                    # Take what we have and fill the rest
                    split_pairs.extend(batch_pairs[:current_batch_size])
                    if len(batch_pairs) < current_batch_size:
                        missing = current_batch_size - len(batch_pairs)
                        split_pairs.extend(['failed_extraction'] * missing)
            else:
                # This specific batch is missing (failed at API level)
                print(f"✗ Dataset position {dataset_batch_index} -> {expected_custom_id}: Missing - marking as api_failed")
                split_pairs.extend(['api_failed'] * current_batch_size)
                failed_batches.add(expected_custom_id)

        # Validation
        print(f"Generated {len(split_pairs)} text pairs for {num_samples} samples")

        if len(split_pairs) != num_samples:
            print(f"ERROR: Text pair count mismatch!")
            print(f"Expected: {num_samples}, Got: {len(split_pairs)}")

            if len(split_pairs) < num_samples:
                missing = num_samples - len(split_pairs)
                print(f"Adding {missing} 'missing' entries")
                split_pairs.extend(['missing'] * missing)
            elif len(split_pairs) > num_samples:
                print(f"Truncating to {num_samples} entries")
                split_pairs = split_pairs[:num_samples]

        # Count text pair types
        pair_counts = {}
        for pair in split_pairs:
            if pair in ['failed', 'failed_extraction', 'missing', 'api_failed']:
                pair_type = pair
            else:
                pair_type = 'valid_text'
            pair_counts[pair_type] = pair_counts.get(pair_type, 0) + 1

        print(f"Text pair distribution for {split_name}:")
        for pair_type, count in sorted(pair_counts.items()):
            print(f"  {pair_type}: {count}")

        # Add pair column
        try:
            dataset_with_pairs = dataset.add_column('pair', split_pairs)
            updated_splits[split_name] = dataset_with_pairs
            print(f"✓ Successfully added text pairs to {split_name}")
        except Exception as e:
            print(f"✗ Error adding text pairs to {split_name}: {e}")
            raise

    # Create DatasetDict
    dataset_dict = DatasetDict(updated_splits)

    # Step 3: Apply filtering if requested
    if filter_failed:
        print("\nApplying filtering to remove failed entries...")
        dataset_dict = filter_failed_extraction_datasetdict(dataset_dict)

    # Final summary
    print(f"\nSUMMARY:")
    for split_name, split_dataset in dataset_dict.items():
        print(f"{split_name}: {len(split_dataset)} samples")

    if failed_batches:
        print(f"Failed batches: {sorted(failed_batches)}")

    return dataset_dict


def parse_batch_content(batch_content):
    """
    Robust parser for batch content (JSON/JSONL format)
    """
    import json

    if not batch_content:
        print("Empty batch content received")
        return []

    # Handle different input types
    if hasattr(batch_content, 'text'):
        batch_content = batch_content.text
    elif not isinstance(batch_content, str):
        batch_content = str(batch_content)

    batch_content = batch_content.strip()

    if not batch_content:
        print("Empty batch content after processing")
        return []

    print(f"Content length: {len(batch_content)} characters")
    print(f"Content starts with: {repr(batch_content[:100])}")
    print(f"Content ends with: {repr(batch_content[-100:])}")

    responses = []

    try:
        # Method 1: Try JSONL format (most common)
        print("Attempting JSONL parsing...")
        lines = batch_content.split('\n')
        print(f"Found {len(lines)} lines")

        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                continue

            try:
                response = json.loads(line)

                # Validate response structure
                if 'custom_id' not in response:
                    print(f"Line {i+1}: Missing custom_id")
                    continue

                if 'response' not in response:
                    print(f"Line {i+1}: Missing response field")
                    continue

                # Check for error field
                if response.get('error'):
                    print(f"Line {i+1}: Response has error: {response['error']}")
                    continue

                # Validate nested structure
                try:
                    content = response['response']['body']['choices'][0]['message']['content']
                    responses.append(response)

                    if i < 5:  # Show first few for debugging
                        print(f"✓ Line {i+1}: {response['custom_id']} -> '{content[:100]}...' ({len(content)} chars)")

                except (KeyError, IndexError, TypeError) as e:
                    print(f"Line {i+1}: Invalid response structure: {e}")
                    continue

            except json.JSONDecodeError as e:
                print(f"Line {i+1}: JSON decode error: {e}")
                if len(line) < 200:
                    print(f"  Full line: {repr(line)}")
                else:
                    print(f"  Line preview: {repr(line[:100])}...{repr(line[-100:])}")
                continue

        if responses:
            print(f"Successfully parsed {len(responses)} responses from JSONL")
            return responses

        # Method 2: Try JSON array format
        print("JSONL failed, attempting JSON array parsing...")
        if batch_content.startswith('[') and batch_content.endswith(']'):
            data = json.loads(batch_content)
            if isinstance(data, list):
                print(f"Successfully parsed {len(data)} responses from JSON array")
                return data

        # Method 3: Try single JSON object
        print("Attempting single JSON object parsing...")
        data = json.loads(batch_content)
        if isinstance(data, dict):
            if 'responses' in data:
                return data['responses']
            elif 'data' in data:
                return data['data']
            else:
                return [data]

    except Exception as e:
        print(f"All parsing methods failed: {e}")

    return []


def extract_and_clean_text(text):
    """
    Extract and clean text content, preserving the full text but removing extra whitespace
    """
    if not text or not isinstance(text, str):
        return None

    # Strip leading and trailing whitespace
    text = text.strip()

    if not text:
        return None

    # Replace multiple consecutive whitespace characters (including newlines) with single spaces
    import re
    cleaned_text = re.sub(r'\s+', ' ', text)

    # Final strip to ensure no leading/trailing spaces
    cleaned_text = cleaned_text.strip()

    return cleaned_text if cleaned_text else None


def filter_failed_extraction_datasetdict(dataset_dict):
    """
    Filter out failed entries from DatasetDict
    """
    from datasets import DatasetDict

    filtered_dict = {}
    failure_types = {'failed', 'failed_extraction', 'missing', 'api_failed'}

    for phase_name, dataset in dataset_dict.items():
        print(f"\nFiltering {phase_name}...")

        # Show before filtering
        before_count = len(dataset)
        pair_counts = {}
        for example in dataset:
            pair = example['pair']
            pair_type = 'valid_text' if pair not in failure_types else pair
            pair_counts[pair_type] = pair_counts.get(pair_type, 0) + 1

        print(f"Before filtering ({before_count} samples):")
        for pair_type, count in sorted(pair_counts.items()):
            print(f"  {pair_type}: {count}")

        # Filter out failure types
        filtered_dataset = dataset.filter(lambda example: example['pair'] not in failure_types)
        filtered_dict[phase_name] = filtered_dataset

        # Show after filtering
        after_count = len(filtered_dataset)
        print(f"After filtering: {before_count} -> {after_count} ({before_count - after_count} removed)")

    return DatasetDict(filtered_dict)

## Use OpenAI Batch API to create clean dataset

The code in this section uses OpenAI's Batch API to create training pairs by generating cleaned versions of prompt injection attacks. The workflow splits classified datasets into chunks, creates batch requests asking GPT to remove malicious injection parts while preserving legitimate content, processes all chunks through the batch API, then concatenates results back together. The final dataset contains both original malicious prompts and their cleaned counterparts as training pairs for models to learn prompt injection removal.

In [None]:
dataset_injected_first, dataset_injected_second, dataset_injected_third, dataset_injected_fourth = split_dataset_dict(updated_dataset_full)

In [None]:
dataset_injected_second1, dataset_injected_second2 = split_dataset_dict_half(dataset_injected_second)

In [None]:
create_clean(dataset_injected_first, "body", output_file="batch_requests_clean_first.jsonl")

In [None]:
batch_input_file = client.files.create(
    file=open("batch_requests_clean_first.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file)

In [None]:
batch_input_file_id = batch_input_file.id
batch_val = client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "create clean first"
    }
)
print(batch_val)

In [None]:
batch = client.batches.retrieve("batch_68881ad42820819097ab5632a4bea1dc")
print(batch)
batch_output_file_id = batch.output_file_id

In [None]:
# Retrieve the batch using the ID you already have
batch = client.batches.retrieve("batch_68881ad42820819097ab5632a4bea1dc")
print(f"Batch status: {batch.status}")

# Check if the batch is completed
if batch.status == "completed":
    batch_output_file_id = batch.output_file_id
    print(f"batch_output_file_id = {batch_output_file_id}")

    # Now you can get the file content
    file_response = client.files.content(batch_output_file_id)
    print(file_response.text)

elif batch.status == "failed":
    print("Batch failed!")
    print(f"Error details: {batch}")

else:
    print(f"Batch is still {batch.status}. Please wait and try again.")

In [None]:
file_response = client.files.content(batch_output_file_id)
print(file_response.text)

In [None]:
create_clean(dataset_injected_second1, "body", batch_size=1, output_file="batch_requests_clean_second1.jsonl")

In [None]:
batch_input_file_second1 = client.files.create(
    file=open("batch_requests_clean_second1.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_second1)

In [None]:
batch_input_file_id_second1 = batch_input_file_second1.id
batch_val_second1 = client.batches.create(
    input_file_id=batch_input_file_id_second1,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "create clean second part 1"
    }
)

In [None]:
batch_second1 = client.batches.retrieve("batch_6888055fdfa88190b73f315cff6bf4d3")
print(batch_second1)
batch_output_file_id_second1 = batch_second1.output_file_id

In [None]:
file_response_second1 = client.files.content(batch_output_file_id_second1)
print(file_response_second1.text)

In [None]:
create_clean(dataset_injected_second2, "body", output_file="batch_requests_clean_second2.jsonl")

In [None]:
batch_input_file_second2 = client.files.create(
    file=open("batch_requests_clean_second2.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_second2)

In [None]:
batch_input_file_id_second2 = batch_input_file_second2.id
batch_val_second2 = client.batches.create(
    input_file_id=batch_input_file_id_second2,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "create clean second part 2"
    }
)

In [None]:
batch_second2 = client.batches.retrieve("batch_6888058ed9a88190b564a1ccec6ec333")
print(batch_second2)
batch_output_file_id_second2 = batch_second2.output_file_id

In [None]:
file_response_second2 = client.files.content(batch_output_file_id_second2)
print(file_response_second2.text)

In [None]:
create_clean(dataset_injected_third, "body", output_file="batch_requests_clean_third.jsonl")

In [None]:
batch_input_file_third = client.files.create(
    file=open("batch_requests_clean_third.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_third)

In [None]:
batch_input_file_id_third = batch_input_file_third.id
batch_val_third = client.batches.create(
    input_file_id=batch_input_file_id_third,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "create clean part three"
    }
)

In [None]:
batch_third = client.batches.retrieve("batch_6888063b192081908f0cca7300a3e217")
print(batch_third)
batch_output_file_id_third = batch_third.output_file_id

In [None]:
file_response_third = client.files.content(batch_output_file_id_third)

In [None]:
create_clean(dataset_injected_fourth, "body", output_file="batch_requests_clean_fourth.jsonl")

In [None]:
batch_input_file_fourth = client.files.create(
    file=open("batch_requests_clean_fourth.jsonl", "rb"),
    purpose="batch"
)

print(batch_input_file_fourth)

In [None]:
batch_input_file_id_fourth = batch_input_file_fourth.id
batch_val_fourth = client.batches.create(
    input_file_id=batch_input_file_id_fourth,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={
        "description": "create clean part four"
    }
)

In [None]:
batch_fourth = client.batches.retrieve("batch_688806be4608819087417496e9253385")
print(batch_fourth)
batch_output_file_id_fourth = batch_fourth.output_file_id

In [None]:
file_response_fourth = client.files.content(batch_output_file_id_fourth)

In [None]:
updated_dataset_part_one_clean = process_batch_and_add_pairs(
     original_dataset=dataset_injected_first,
     batch_content=file_response.text,
     batch_size=1
 )


In [None]:
updated_dataset_part_two1_clean = process_batch_and_add_pairs(
     original_dataset=dataset_injected_second1,
     batch_content=file_response_second1.text,
     batch_size=1
 )


In [None]:
updated_dataset_part_two2_clean = process_batch_and_add_pairs(
     original_dataset=dataset_injected_second2,
     batch_content=file_response_second2.text,
     batch_size=1
 )



In [None]:
from datasets import concatenate_datasets
phase1 = concatenate_datasets([updated_dataset_part_two1_clean["Phase1"], updated_dataset_part_two2_clean["Phase1"]])
phase2 = concatenate_datasets([updated_dataset_part_two1_clean["Phase2"], updated_dataset_part_two2_clean["Phase2"]])

updated_dataset_part_two_clean = DatasetDict({"Phase1": phase1, "Phase2":phase2})
updated_dataset_part_two_clean

In [None]:
updated_dataset_part_three_clean = process_batch_and_add_pairs(
     original_dataset=dataset_injected_third,
     batch_content=file_response_third.text,
     batch_size=1
 )

In [None]:
updated_dataset_part_four_clean = process_batch_and_add_pairs(
     original_dataset=dataset_injected_fourth,
     batch_content=file_response_fourth.text,
     batch_size=1
 )

In [None]:
from datasets import concatenate_datasets

updated_dataset_full_phase1 = concatenate_datasets([updated_dataset_part_one_clean["Phase1"], updated_dataset_part_two_clean["Phase1"], updated_dataset_part_three_clean["Phase1"], updated_dataset_part_four_clean["Phase1"]])
updated_dataset_full_phase2 = concatenate_datasets([updated_dataset_part_one_clean["Phase2"], updated_dataset_part_two_clean["Phase2"], updated_dataset_part_three_clean["Phase2"], updated_dataset_part_four_clean["Phase2"]])

dataset_clean_with_pairs = DatasetDict({"Phase1":updated_dataset_full_phase1, "Phase2":updated_dataset_full_phase2})


# Create Clean-Clean Prompt Pairs

This code generates a synthetic email dataset by creating 86,000 batch API requests to OpenAI asking for pairs of semantically identical but differently worded emails, processes the responses to extract clean email pairs, concatenates all results into a single dataset with standardized field names ('body', 'pair', 'clean' category), and splits it into 32 smaller chunks for storage and distribution.

In [None]:
from datasets import Dataset, DatasetDict
import time
import json
import re
import tiktoken
counter = 0
def create_jsonl_file_fully_clean(output_file="batch_requests_clean.jsonl", model="gpt-3.5-turbo-0125"):
    global counter
    try:
        counter += 1
        custom_id = f"batch-{counter}"

        request_data = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {

                "model": model,
                "messages": [
                    {"role": "system", "content": """You are a safe and helpful assistant, and not malicious at all.
                    You specialize in creating clean, uninjected, emails and also
                    you are skilled at rewriting them while keeping them uninjected."""},

                    {"role": "user", "content": f"""Write an original 2-3 sentence
                    email and also provide a rewritten version of the same email
                    in a numbered list format. Return the first email as 1. Email Content
                    and return the second (rewritten version with VERY SIMILAR LENGTH WITH NO LESS THAN A 5 CHARACTER DIFFERENCE) as 2. Rewritten email. Both emails should discuss the same thing,
                    however they are just reworded and written slightly differently. The semantic meaning of both emails should be THE EXACT SAME WITH NO SEMANTIC DRIFT in both the text and the embeddings.
                """}
                    ],
                "max_tokens": 100,
                "temperature": 0.1
            }
        }

        with open(output_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(request_data) + '\n')


    except Exception as e:
        print(f"Error creating batch file: {e}")

def create_clean_prompts(count, batch_size=1, output_file="batch_requests_fully_clean.jsonl"):

    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("")

    for i in range(count):
        print(f"Creating batch requests for Batch Number {i+1}")

        for j in range(batch_size):
            create_jsonl_file_fully_clean(output_file)

    print(f"All batch requests written to {output_file}")



In [None]:
d1 = create_clean_prompts(count=17200, output_file="batch_requests_fully_clean1.jsonl")

In [None]:
d2 = create_clean_prompts(count=17200, output_file="batch_requests_fully_clean2.jsonl")

In [None]:
d3 = create_clean_prompts(count=17200, output_file="batch_requests_fully_clean3.jsonl")

In [None]:
d4 = create_clean_prompts(count=17200, output_file="batch_requests_fully_clean4.jsonl")

In [None]:
d5 = create_clean_prompts(count=17200, output_file="batch_requests_fully_clean5.jsonl")

In [None]:
def create_batch_job(client, file_path, description="batch job", endpoint="/v1/chat/completions", completion_window="24h"):
    try:
        batch_input_file = client.files.create(
            file=open(file_path, "rb"),
            purpose="batch"
        )
        print(f"Uploaded file: {batch_input_file}")

        batch_val = client.batches.create(
            input_file_id=batch_input_file.id,
            endpoint=endpoint,
            completion_window=completion_window,
            metadata={"description": description}
        )
        print(f"Created batch: {batch_val}")

        return batch_val.id

    except Exception as e:
        print(f"Error creating batch job: {e}")
        return None


def check_batch_status(client, batch_id):
    try:
        batch = client.batches.retrieve(batch_id)
        print(f"Batch status: {batch}")

        result = {
            "batch_id": batch_id,
            "status": batch.status,
            "created_at": batch.created_at,
            "completed_at": getattr(batch, 'completed_at', None),
            "failed_at": getattr(batch, 'failed_at', None),
            "output_file_id": getattr(batch, 'output_file_id', None),
            "error_file_id": getattr(batch, 'error_file_id', None),
            "request_counts": getattr(batch, 'request_counts', None)
        }

        return result

    except Exception as e:
        print(f"Error checking batch status: {e}")
        return None

def file_check(status):
  file_response=None
  if status["output_file_id"] is not None:
    print("Saving file response")
    file_response = client.files.content(status["output_file_id"])
  else:
    print("File response not available")
  return file_response

In [None]:
batch_id = create_batch_job(client, "batch_requests_fully_clean1.jsonl", "create full clean first")

In [None]:
status = check_batch_status(client, batch_id)

In [None]:
file_response1 = file_check(status)

In [None]:
print(file_response1.text)

In [None]:
batch_id2 = create_batch_job(client, "batch_requests_fully_clean2.jsonl", "create full clean second")

In [None]:
status2 = check_batch_status(client, batch_id2)

In [None]:
file_response2 = file_check(status2)

In [None]:
batch_id3 = create_batch_job(client, "batch_requests_fully_clean3.jsonl", "create full clean third")

In [None]:
status3 = check_batch_status(client, batch_id3)

In [None]:
file_response3 = file_check(status3)

In [None]:
batch_id4 = create_batch_job(client, "batch_requests_fully_clean4.jsonl", "create full clean third")

In [None]:
status4 = check_batch_status(client, batch_id4)

In [None]:
file_response4 = file_check(status4)

In [None]:
batch_id5 = create_batch_job(client, "batch_requests_fully_clean5.jsonl", "create full clean fifth")

In [None]:
status5 = check_batch_status(client, batch_id5)

In [None]:
file_response5 = file_check(status5)

In [None]:
import json
from datasets import Dataset, DatasetDict
from typing import Optional
import re

def parse_jsonl_to_dataset(jsonl_content: str, split_name: str = "train") -> DatasetDict:

    data_records = []

    lines = jsonl_content.strip().split('\n')
    for line_num, line in enumerate(lines):
        if not line.strip():
            continue
        try:
            obj = json.loads(line)
            try:
                record = {}

                # Extract basic identifiers
                record["id"] = obj.get("id", f"unknown_{line_num}")
                record["custom_id"] = obj.get("custom_id", "")

                # Skip if there's an error
                if obj.get("error"):
                    continue

                # Extract response data
                response = obj.get("response", {})
                if response.get("status_code") != 200:
                    continue

                body = response.get("body", {})
                choices = body.get("choices", [])

                if choices:
                    # Extract assistant response content
                    message = choices[0].get("message", {})
                    content = message.get("content", "")

                    original_email, rewritten_email = extract_email_pair(content)

                    # Only store records where both emails were successfully extracted
                    if original_email and rewritten_email:
                        record["original_email"] = original_email
                        record["rewritten_email"] = rewritten_email
                        record["model"] = body.get("model", "")
                        record["finish_reason"] = choices[0].get("finish_reason", "")
                        data_records.append(record)

            except Exception:
                continue
        except:
          continue
    dataset = Dataset.from_list(data_records)
    return DatasetDict({split_name: dataset})


def extract_email_pair(content: str) -> tuple[Optional[str], Optional[str]]:
    """
    Extract original and rewritten email content, ensuring no numbered prefixes are included.
    """
    try:
        # Primary pattern: Look for numbered sections with headers
        pattern1 = r"1\.\s*(?:Email Content|Original Email|Original):\s*\n(.*?)(?=2\.\s*(?:Rewritten email|Rewritten Email|Rewritten):|$)"
        pattern2 = r"2\.\s*(?:Rewritten email|Rewritten Email|Rewritten):\s*\n(.*?)$"

        match1 = re.search(pattern1, content, re.DOTALL | re.IGNORECASE)
        match2 = re.search(pattern2, content, re.DOTALL | re.IGNORECASE)

        if match1 and match2:
            original_email = clean_email_content(match1.group(1))
            rewritten_email = clean_email_content(match2.group(1))
            return original_email, rewritten_email

        # Fallback: Look for Subject: patterns
        subjects = re.findall(r'Subject:.*?(?=Subject:|$)', content, re.DOTALL | re.IGNORECASE)
        if len(subjects) >= 2:
            original_email = clean_email_content(subjects[0])
            rewritten_email = clean_email_content(subjects[1])
            return original_email, rewritten_email

        # Another fallback: Split by double newlines followed by Subject:
        sections = re.split(r'\n\s*\n(?=Subject:)', content, flags=re.IGNORECASE)
        if len(sections) >= 2:
            email_sections = [s.strip() for s in sections if 'subject:' in s.lower()]
            if len(email_sections) >= 2:
                original_email = clean_email_content(email_sections[0])
                rewritten_email = clean_email_content(email_sections[1])
                return original_email, rewritten_email

        return None, None

    except Exception as e:
        print(f"Error extracting email pair: {e}")
        return None, None


def clean_email_content(email_text: str) -> str:
    """
    Clean email content by removing numbered prefixes that appear at the start of the content.
    """
    if not email_text:
        return email_text

    cleaned = email_text.strip()

    # Remove section headers like "1. Email Content:" or "2. Rewritten email:" at the start
    cleaned = re.sub(r'^\s*[12]\.\s*(?:Email Content|Original Email|Original|Rewritten email|Rewritten Email|Rewritten):\s*\n?', '', cleaned, flags=re.IGNORECASE)

    # Remove any numbered prefix at the very start of the content (like "1. Subject:" -> "Subject:")
    cleaned = re.sub(r'^\s*\d+\.\s+', '', cleaned)

    # Clean up extra whitespace
    cleaned = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned)  # Multiple blank lines to double
    cleaned = cleaned.strip()

    return cleaned


def save_dataset(dataset_dict: DatasetDict, output_path: str):
    dataset_dict.save_to_disk(output_path)

In [None]:
dataset_dict1 = parse_jsonl_to_dataset(file_response1.text)

In [None]:
dataset_dict2 = parse_jsonl_to_dataset(file_response2.text)

In [None]:
dataset_dict3 = parse_jsonl_to_dataset(file_response3.text)

In [None]:
dataset_dict4 = parse_jsonl_to_dataset(file_response4.text)

In [None]:
dataset_dict5 = parse_jsonl_to_dataset(file_response5.text)

In [None]:
from datasets import concatenate_datasets, DatasetDict

dict1 = concatenate_datasets([dataset_dict1["train"], dataset_dict2["train"],dataset_dict3["train"],dataset_dict4["train"], dataset_dict5["train"]])
full_dataset_dict = DatasetDict({"train":dict1})

In [None]:
full_dataset_dict

In [None]:
from datasets import DatasetDict, Dataset

def transform_dataset(full_dataset_dict):

    def process_split(dataset):
        # Create new dataset with transformed structure
        new_data = {
            'id': dataset['id'],
            'custom_id': dataset['custom_id'],
            'body': dataset['original_email'],  # Rename original_email to body
            'pair': dataset['rewritten_email'],  # Rename rewritten_email to pair
            'category': ['clean'] * len(dataset)  # Add 'clean' category to every element
        }

        return Dataset.from_dict(new_data)

    # Transform each split in the dataset
    transformed_dict = {}
    for split_name, dataset in full_dataset_dict.items():
        transformed_dict[split_name] = process_split(dataset)

    return DatasetDict(transformed_dict)


full_dataset_dict = transform_dataset(full_dataset_dict)

In [None]:
full_dataset_dict

In [None]:
import json
import os
import torch
import numpy as np
from google.colab import drive
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import multiprocessing as mp
from functools import partial
import gc

def process_batch_on_gpu(batch_data, features, device):
    """Process a batch of data on GPU for faster operations"""
    try:
        batch_rows = []

        # Move batch to GPU if it's not already there
        if isinstance(batch_data, dict):
            gpu_batch = {}
            for feature in features:
                if feature in batch_data:
                    data = batch_data[feature]
                    if torch.is_tensor(data):
                        gpu_batch[feature] = data.to(device) if data.device != device else data
                    else:
                        # Convert to tensor and move to GPU if numeric
                        try:
                            if isinstance(data, (list, np.ndarray)):
                                gpu_batch[feature] = torch.tensor(data, device=device)
                            else:
                                gpu_batch[feature] = data
                        except:
                            gpu_batch[feature] = data
                else:
                    gpu_batch[feature] = None

            # Convert back to CPU for JSON serialization
            cpu_row = {}
            for feature in features:
                if torch.is_tensor(gpu_batch[feature]):
                    cpu_row[feature] = gpu_batch[feature].cpu().tolist()
                elif isinstance(gpu_batch[feature], np.ndarray):
                    cpu_row[feature] = gpu_batch[feature].tolist()
                else:
                    cpu_row[feature] = gpu_batch[feature]

            batch_rows.append(cpu_row)

        return batch_rows

    except Exception as e:
        print(f"GPU processing failed, falling back to CPU: {e}")
        # Fallback to CPU processing
        batch_rows = []
        for feature in features:
            row = {}
            for feat in features:
                if isinstance(batch_data, dict) and feat in batch_data:
                    data = batch_data[feat]
                    if torch.is_tensor(data):
                        row[feat] = data.cpu().tolist()
                    elif isinstance(data, np.ndarray):
                        row[feat] = data.tolist()
                    else:
                        row[feat] = data
                else:
                    row[feat] = None
            batch_rows.append(row)
        return batch_rows

def process_dataset_parallel(dataset, phase_name, batch_size=1000, use_gpu=True, num_workers=None):
    """Process dataset in parallel batches with optional GPU acceleration"""

    # Setup device
    if use_gpu and torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using GPU: {torch.cuda.get_device_name()}")
    else:
        device = torch.device('cpu')
        print("Using CPU for processing")

    # Set number of workers
    if num_workers is None:
        num_workers = min(4, mp.cpu_count())  # Conservative for memory

    features = list(dataset.features.keys())
    total_rows = len(dataset)

    print(f"Processing {phase_name} with {num_workers} workers in batches of {batch_size}...")

    all_data = []

    # Process in batches
    for start_idx in tqdm(range(0, total_rows, batch_size), desc=f"Processing {phase_name}"):
        end_idx = min(start_idx + batch_size, total_rows)

        # Get batch data
        batch_indices = list(range(start_idx, end_idx))
        batch_data = dataset.select(batch_indices)

        # Convert batch to format suitable for GPU processing
        if use_gpu and torch.cuda.is_available():
            try:
                # Try to use GPU format if available
                batch_data.set_format(type='torch', device=device)
            except:
                pass

        # Process batch (could be parallelized further if needed)
        batch_rows = []
        for i in range(len(batch_data)):
            row = {}
            for feature in features:
                data = batch_data[i][feature]
                if torch.is_tensor(data):
                    row[feature] = data.cpu().tolist() if data.is_cuda else data.tolist()
                elif isinstance(data, np.ndarray):
                    row[feature] = data.tolist()
                else:
                    row[feature] = data
            batch_rows.append(row)

        all_data.extend(batch_rows)

        # Clear GPU cache periodically
        if use_gpu and torch.cuda.is_available():
            torch.cuda.empty_cache()

    return {
        'features': features,
        'num_rows': len(all_data),
        'data': all_data
    }

def save_data(dataset_dict, filename='dataset.json', output_dir='/content/drive/MyDrive/Algoverse/',
              mount_drive=True, use_gpu=True, batch_size=1000, num_workers=None, use_parallel_phases=True):
    """
    Save dataset with GPU acceleration and parallel processing

    Args:
        dataset_dict: Dictionary of datasets to save
        filename: Output filename
        output_dir: Output directory
        mount_drive: Whether to mount Google Drive
        use_gpu: Whether to use GPU acceleration
        batch_size: Batch size for processing
        num_workers: Number of parallel workers
        use_parallel_phases: Whether to process phases in parallel
    """

    if mount_drive:
        drive.mount('/content/drive')

    os.makedirs(output_dir, exist_ok=True)

    # Check GPU availability
    if use_gpu and torch.cuda.is_available():
        print(f"GPU acceleration enabled: {torch.cuda.get_device_name()}")
        print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        # Clear cache at start
        torch.cuda.empty_cache()
    else:
        print("Using CPU processing")
        use_gpu = False

    # Set number of workers
    if num_workers is None:
        num_workers = min(4, mp.cpu_count())

    json_data = {}

    if use_parallel_phases and len(dataset_dict) > 1:
        # Process phases in parallel
        print(f"Processing {len(dataset_dict)} phases in parallel...")

        with ThreadPoolExecutor(max_workers=min(len(dataset_dict), num_workers)) as executor:
            # Submit all phase processing tasks
            future_to_phase = {
                executor.submit(process_dataset_parallel, dataset, phase_name, batch_size, use_gpu, 1): phase_name
                for phase_name, dataset in dataset_dict.items()
            }

            # Collect results
            for future in tqdm(future_to_phase, desc="Processing phases"):
                phase_name = future_to_phase[future]
                try:
                    result = future.result()
                    json_data[phase_name] = result
                    print(f"Completed {phase_name}: {result['num_rows']} rows")
                except Exception as e:
                    print(f"Error processing {phase_name}: {e}")
    else:
        # Process phases sequentially
        for phase_name, dataset in dataset_dict.items():
            result = process_dataset_parallel(dataset, phase_name, batch_size, use_gpu, num_workers)
            json_data[phase_name] = result
            print(f"Completed {phase_name}: {result['num_rows']} rows")

    # Save to file
    output_path = os.path.join(output_dir, filename)
    print(f"Saving data to {output_path}...")

    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, indent=2, ensure_ascii=False, default=str)

        print(f"Successfully saved dataset to {output_path}")
        print(f"File size: {os.path.getsize(output_path) / 1e6:.1f} MB")

    except Exception as e:
        print(f"Error saving file: {e}")
        # Try saving without indentation to save space
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, ensure_ascii=False, default=str)
        print(f"Saved without formatting due to memory constraints")

    # Clean up GPU memory
    if use_gpu and torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    return output_path

In [None]:
def split_dataset(dataset_dict, number=32):
    result = []
    counter = 0
    for phase_name, dataset in dataset_dict.items():
        res = {}
        total_rows = len(dataset)
        quarter = total_rows // number

        for i in range(number):
          if i==(number-1):
              split_val = dataset.select(range(quarter*i, len(dataset)))
          else:
              split_val = dataset.select(range(quarter*i, quarter*(i+1)))
          res[phase_name] = split_val

          result.append(DatasetDict(res))

    for dataset_dict in result:
        save_data(dataset_dict, filename="token_"+str(counter))
        counter+=1


In [None]:
batches = split_dataset(dataset_clean_with_pairs)
batches

# Tokenization and Embedding Generation

The code in this section includes data pipeline with GPU-accelerated loading, stratified sampling, and dataset combination functions. It reconstructs datasets from JSON files, maintains category distributions when creating balanced subsets from prompt injection and clean email data, and combines them into training/test datasets with appropriate similarity scores. The pipeline includes parallel processing for efficient data handling and concludes with Hugging Face authentication for uploading the processed datasets to the Huggingface Hub.


In [None]:
import json
from datasets import Dataset, DatasetDict
import os
from google.colab import drive
from tqdm import tqdm
def load_data(json_file_path, mount_drive=True):
  if mount_drive:
        drive.mount('/content/drive')
  print(f"Loading data from {json_file_path}...")
  with open(json_file_path, 'r', encoding='utf-8') as f:
      json_data = json.load(f)

  dataset_dict = {}
  for phase_name, phase_info in json_data.items():
      print(f"Processing {phase_name}...")

      dataset_dict[phase_name] = Dataset.from_list(phase_info['data'])

  tokenized_dataset_dict = DatasetDict(dataset_dict)

  print("Successfully loaded DatasetDict!")

  return tokenized_dataset_dict

In [None]:
full_dataset_dict["train"]=full_dataset_dict["train"].add_column("similarity", [1.0] * len(full_dataset_dict["train"]))

In [None]:
full_dataset_dict

In [None]:
from datasets import concatenate_datasets, DatasetDict
import torch
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp

def load_and_combine_datasets(use_gpu=True, num_workers=None):
    """Load datasets from token_0 to token_31 and combine them into one large dataset

    Args:
        use_gpu (bool): Whether to use GPU for operations when possible
        num_workers (int): Number of parallel workers for loading (None = auto-detect)
    """

    # Check GPU availability
    if use_gpu and torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"GPU detected: {torch.cuda.get_device_name()}")
        print(f"GPU memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        device = torch.device('cpu')
        print("Using CPU for operations")

    # Set number of workers for parallel loading
    if num_workers is None:
        num_workers = min(8, mp.cpu_count())  # Reasonable default

    all_datasets = []
    base_path = "/content/drive/MyDrive/Algoverse/token_"

    print(f"Loading datasets with {num_workers} parallel workers...")

    def load_single_dataset(i):
        """Helper function to load a single dataset"""
        dataset_path = f"{base_path}{i}"
        try:
            dataset = load_data(dataset_path)
            return i, dataset
        except Exception as e:
            print(f"Warning: Could not load dataset {i}: {e}")
            return i, None

    # Use ThreadPoolExecutor for parallel loading
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        # Submit all loading tasks
        futures = [executor.submit(load_single_dataset, i) for i in range(32)]

        # Collect results as they complete
        for future in futures:
            i, dataset = future.result()
            if dataset is not None:
                print(f"Loaded dataset {i}")

                # Handle different dataset structures
                if isinstance(dataset, DatasetDict):
                    for split_name, split_data in dataset.items():
                        print(f"  Split '{split_name}': {len(split_data)} rows")
                        # Set format for GPU if available and data supports it
                        if use_gpu and torch.cuda.is_available():
                            try:
                                split_data.set_format(type='torch', device=device)
                            except:
                                pass  # Some datasets might not support torch format
                        all_datasets.append(split_data)
                else:
                    print(f"  Dataset: {len(dataset)} rows")
                    # Set format for GPU if available
                    if use_gpu and torch.cuda.is_available():
                        try:
                            dataset.set_format(type='torch', device=device)
                        except:
                            pass
                    all_datasets.append(dataset)

    if not all_datasets:
        raise ValueError("No datasets were successfully loaded!")

    # Combine all datasets (this operation is CPU-bound)
    print(f"\nCombining {len(all_datasets)} dataset splits...")
    combined_dataset = concatenate_datasets(all_datasets)

    # Set format for GPU on final dataset if requested
    if use_gpu and torch.cuda.is_available():
        try:
            combined_dataset.set_format(type='torch', device=device)
            print(f"Dataset moved to GPU: {device}")
        except Exception as e:
            print(f"Could not move dataset to GPU: {e}")

    # Create a DatasetDict with train split
    final_dataset = DatasetDict({
        'train': combined_dataset
    })

    print(f"Successfully combined all datasets!")
    print(f"Final dataset shape: {len(combined_dataset)} rows")
    print(f"Features: {list(combined_dataset.features.keys())}")

    return final_dataset

combined_dataset = load_and_combine_datasets(use_gpu=True, num_workers=8)

print(f"\nFinal combined dataset:")
print(combined_dataset)

In [None]:
combined_dataset["train"] = combined_dataset["train"].add_column("similarity", [0.0] * len(combined_dataset["train"]))

In [None]:
combined_dataset["train"] = combined_dataset["train"].remove_columns(["body_embeddings", "pair_embeddings"])

In [None]:
from datasets import Dataset, DatasetDict

def filter_dataset_columns(dataset, keep_columns=None):
    """
    Filter dataset to keep only specified columns.
    Works with both Dataset and DatasetDict objects.

    Args:
        dataset: The input dataset (Dataset or DatasetDict)
        keep_columns: List of column names to keep (default: ['body', 'pair', 'body_embeddings', 'pair_embeddings'])

    Returns:
        New filtered dataset with only the specified columns (same type as input)
    """

    if keep_columns is None:
        keep_columns = ['body', 'pair', 'body_embeddings', 'pair_embeddings']

    # Handle DatasetDict
    if isinstance(dataset, DatasetDict):
        filtered_dict = {}

        for split_name, split_dataset in dataset.items():
            print(f"\nProcessing split: {split_name}")

            # Check which columns exist in this split
            available_columns = split_dataset.column_names
            columns_to_keep = [col for col in keep_columns if col in available_columns]

            print(f"Original columns: {available_columns}")
            print(f"Requested columns: {keep_columns}")
            print(f"Columns to keep (available): {columns_to_keep}")

            missing_columns = [col for col in keep_columns if col not in available_columns]
            if missing_columns:
                print(f"Warning: These requested columns don't exist: {missing_columns}")

            # Filter this split
            filtered_dict[split_name] = split_dataset.select_columns(columns_to_keep)

        return DatasetDict(filtered_dict)

    # Handle single Dataset
    else:
        # Check which columns exist in the dataset
        available_columns = dataset.column_names
        columns_to_keep = [col for col in keep_columns if col in available_columns]

        # Print info about what we're keeping vs what's missing
        print(f"Original columns: {available_columns}")
        print(f"Requested columns: {keep_columns}")
        print(f"Columns to keep (available): {columns_to_keep}")

        missing_columns = [col for col in keep_columns if col not in available_columns]
        if missing_columns:
            print(f"Warning: These requested columns don't exist: {missing_columns}")

        # Create and return a new filtered dataset
        filtered_dataset = dataset.select_columns(columns_to_keep)

        return filtered_dataset

# Example usage:
def process_dataset(input_dataset):
    """
    Process a dataset by filtering to keep only specific columns.
    Works with both Dataset and DatasetDict objects.

    Args:
        input_dataset: The dataset to process (Dataset or DatasetDict)

    Returns:
        New filtered dataset (same type as input)
    """
    return filter_dataset_columns(
        input_dataset,
        keep_columns=['body', 'pair', "category", "similarity"]
    )

In [None]:
filtered_dataset = process_dataset(combined_dataset)
filtered_dataset

In [None]:
filtered_clean = process_dataset(full_dataset_dict)
filtered_clean

In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict
import numpy as np

def sample_dataset_stratified(dataset_dict, sample_fraction=0.5, category_column='category', random_state=42):
    np.random.seed(random_state)
    sampled_dict = {}
    remaining_dict = {}

    for split_name, dataset in dataset_dict.items():
        df = dataset.to_pandas()
        original_counts = df[category_column].value_counts()
        target_size = int(len(df) * sample_fraction)

        print(f"{split_name}: {len(df)} → {target_size} rows sampled, {len(df) - target_size} rows remaining")

        sampled_dfs = []
        remaining_dfs = []

        for category in original_counts.index:
            category_df = df[df[category_column] == category]
            category_sample_size = max(1, int(len(category_df) * sample_fraction))

            sampled_category = category_df.sample(
                n=min(category_sample_size, len(category_df)),
                random_state=random_state
            )

            # Get remaining data by excluding sampled indices
            remaining_category = category_df.drop(sampled_category.index)

            sampled_dfs.append(sampled_category)
            remaining_dfs.append(remaining_category)

        # Combine and shuffle sampled data
        sampled_df = pd.concat(sampled_dfs, ignore_index=True)
        sampled_df = sampled_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
        sampled_dict[split_name] = Dataset.from_pandas(sampled_df)

        # Combine and shuffle remaining data
        remaining_df = pd.concat(remaining_dfs, ignore_index=True)
        remaining_df = remaining_df.sample(frac=1, random_state=random_state + 1).reset_index(drop=True)
        remaining_dict[split_name] = Dataset.from_pandas(remaining_df)

    return DatasetDict(sampled_dict), DatasetDict(remaining_dict)

def verify_distribution(original_dict, sampled_dict, remaining_dict=None, category_column='category'):
    print("\nDistribution Verification:")
    for split_name in original_dict.keys():
        original_df = original_dict[split_name].to_pandas()
        sampled_df = sampled_dict[split_name].to_pandas()

        original_dist = original_df[category_column].value_counts(normalize=True)
        sampled_dist = sampled_df[category_column].value_counts(normalize=True)

        print(f"\n{split_name}:")
        print(f"  Original: {len(original_df)} rows")
        print(f"  Sampled:  {len(sampled_df)} rows")

        if remaining_dict:
            remaining_df = remaining_dict[split_name].to_pandas()
            remaining_dist = remaining_df[category_column].value_counts(normalize=True)
            print(f"  Remaining: {len(remaining_df)} rows")

        print("  Category distributions:")
        for category in original_dist.index:
            orig_pct = original_dist[category] * 100
            samp_pct = sampled_dist.get(category, 0) * 100
            print(f"    {category}: Original {orig_pct:.1f}% → Sampled {samp_pct:.1f}%", end="")

            if remaining_dict:
                rem_pct = remaining_dist.get(category, 0) * 100
                print(f" | Remaining {rem_pct:.1f}%")
            else:
                print()

sampled_dataset, remaining_dataset = sample_dataset_stratified(filtered_dataset, sample_fraction=0.5)
verify_distribution(filtered_dataset, sampled_dataset, remaining_dataset)

In [None]:
s1, s2 = sample_dataset_stratified(sampled_dataset, sample_fraction=0.7)
t1, t2 = sample_dataset_stratified(filtered_clean, sample_fraction=0.7)

In [None]:
from datasets import concatenate_datasets, DatasetDict
train_dataset = DatasetDict({"train":concatenate_datasets([s1["train"], t1["train"]])})
test_dataset = DatasetDict({"train":concatenate_datasets([s2["train"], t2["train"]])})

In [None]:
from huggingface_hub import login
from google.colab import userdata
login(userdata.get("HUGGINGFACE_TOKEN"))

# Model Training

This section implements four fine-tuning pipelines for sentence similarity models using Sentence BERT All-MPNet v2, LLaMA-3-8B, Mistral-7B, and Qwen2-7B. The code sets up RunPod cloud environments with memory optimization, converts large language models to SentenceTransformer-compatible versions using LoRA and 4-bit quantization, and trains on sentence pairs with cosine similarity loss. All pipelines follow the same framework but differ in base model architecture and configuration parameters for attention projection and quantization strategies.

## MPNET-V2

In [None]:
import os
import torch
import shutil
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import BinaryClassificationEvaluator

# ------------------ Paths ------------------
BASE_PATH = "/workspace"
MODELS_PATH = os.path.join(BASE_PATH, "models")
CACHE_PATH = os.path.join(BASE_PATH, "cache")
CHECKPOINTS_PATH = os.path.join(BASE_PATH, "checkpoints")

os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(CACHE_PATH, exist_ok=True)
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)

os.environ['HF_HOME'] = CACHE_PATH
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_PATH, "transformers")
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_PATH, "datasets")
os.environ['TORCH_HOME'] = os.path.join(CACHE_PATH, "torch")
os.environ['TMPDIR'] = os.path.join(CACHE_PATH, "tmp")

# ------------------ Force FP32 ------------------
os.environ["ACCELERATE_DISABLE_FP16"] = "1"  # disables all FP16/BF16
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"

# ------------------ Cleanup ------------------
def reset_memory_disk():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    tmp_dirs = [os.environ['TMPDIR'], os.environ['TRANSFORMERS_CACHE']]
    for d in tmp_dirs:
        if os.path.exists(d):
            for f in os.listdir(d):
                f_path = os.path.join(d, f)
                try:
                    if os.path.isfile(f_path) or os.path.islink(f_path):
                        os.unlink(f_path)
                    elif os.path.isdir(f_path):
                        shutil.rmtree(f_path)
                except Exception as e:
                    print(f"Failed to delete {f_path}: {e}")
    print("Memory and disk reset done.")

# ------------------ Dataset Helpers ------------------
def prepare_dataset_for_trainer(dataset):
    prepared_data = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        prepared_data.append({'sentence1': body, 'sentence2': pair, 'label': sim})
    return Dataset.from_list(prepared_data)

def convert_to_input_examples(dataset):
    examples = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        examples.append(InputExample(texts=[body, pair], label=sim))
    return examples

# ------------------ Model Loader ------------------
def create_sentence_transformer_mpnet():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading all-mpnet-base-v2 on device: {device}")
    try:
        model = SentenceTransformer(
            "sentence-transformers/all-mpnet-base-v2",
            cache_folder=os.environ['TRANSFORMERS_CACHE'],
            device=device
        )
        print("Loaded all-mpnet-base-v2 model.")
    except Exception as e:
        print(f"Failed to load all-mpnet-base-v2: {e}, falling back to MiniLM")
        model = SentenceTransformer(
            "sentence-transformers/all-MiniLM-L6-v2",
            cache_folder=os.environ['TRANSFORMERS_CACHE'],
            device=device
        )
        print("Loaded fallback MiniLM model.")

    model.max_seq_length = 512
    return model

# ------------------ Training Function ------------------
def train_mpnet_similarity_model(train_dataset, test_dataset, output_dir):
    reset_memory_disk()
    os.makedirs(output_dir, exist_ok=True)

    model = create_sentence_transformer_mpnet()

    train_data = prepare_dataset_for_trainer(train_dataset)
    test_data = prepare_dataset_for_trainer(test_dataset)
    test_examples = convert_to_input_examples(test_dataset)

    train_loss = losses.CosineSimilarityLoss(model)

    evaluator = BinaryClassificationEvaluator(
        sentences1=[ex.texts[0] for ex in test_examples],
        sentences2=[ex.texts[1] for ex in test_examples],
        labels=[float(ex.label) for ex in test_examples],
        show_progress_bar=True
    )

    args = SentenceTransformerTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,  # reduced for MPNet memory
        gradient_accumulation_steps=8,  # simulate larger batch
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # FP32 only
        bf16=False,
        max_grad_norm=1.0,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=2,
        save_only_model=True,
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=500,
        dataloader_num_workers=0,
        remove_unused_columns=False,
        run_name="mpnet-similarity-finetuning"
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_data,
        loss=train_loss,
        evaluator=evaluator
    )

    print("Starting training...")
    trainer.train()

    final_model_path = os.path.join(output_dir, "final_model")
    model.save(final_model_path)
    print(f"Training complete. Model saved at: {final_model_path}")
    return model, final_model_path

# ------------------ Inference Helpers ------------------
def load_trained_model(model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SentenceTransformer(model_path, device=device)
    print(f"Model loaded from {model_path}")
    return model

def compute_similarity(model, text1, text2):
    embeddings = model.encode([text1, text2], convert_to_tensor=True)
    similarity = torch.nn.functional.cosine_similarity(embeddings[0:1], embeddings[1:2])
    return similarity.item()

# ------------------ Example Run ------------------
if __name__ == "__main__":
    output_dir = os.path.join(CHECKPOINTS_PATH, "mpnet_similarity_model")
    try:

        train_dataset_s, _ = sample_dataset_stratified(train_dataset, sample_fraction=0.1)
        test_dataset_s, _  = sample_dataset_stratified(test_dataset, sample_fraction=0.1)

        model, model_path = train_mpnet_similarity_model(
            train_dataset_s["train"],
            test_dataset_s["train"],
            output_dir
        )

        # Example inference
        embeddings = model.encode(["This is a test sentence."], convert_to_tensor=True)
        print("Example embedding:", embeddings)

        sim_score = compute_similarity(model, "This is a test sentence.", "Another test sentence.")
        print("Similarity score:", sim_score)

    except ImportError:
        print("Please define your train_dataset, test_dataset, and sample_dataset_stratified function before running this example.")


## LLAMA-3 (8 billion)

In [None]:
train_dataset

In [None]:
import os
import torch
import shutil
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.evaluation import BinaryClassificationEvaluator

# New imports
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# ------------------ RunPod Network Paths ------------------
BASE_PATH = "/workspace"
MODELS_PATH = os.path.join(BASE_PATH, "models")
CACHE_PATH = os.path.join(BASE_PATH, "cache")
CHECKPOINTS_PATH = os.path.join(BASE_PATH, "checkpoints")

os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(CACHE_PATH, exist_ok=True)
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)

# Force all caches to network volume
os.environ['HF_HOME'] = CACHE_PATH
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_PATH, "transformers")
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_PATH, "datasets")
os.environ['TORCH_HOME'] = os.path.join(CACHE_PATH, "torch")
os.environ['TMPDIR'] = os.path.join(CACHE_PATH, "tmp")

# ------------------ Memory / Disk Cleanup ------------------
def reset_memory_disk():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    tmp_dirs = [os.environ['TMPDIR'], os.environ['TRANSFORMERS_CACHE']]
    for d in tmp_dirs:
        if os.path.exists(d):
            for f in os.listdir(d):
                f_path = os.path.join(d, f)
                try:
                    if os.path.isfile(f_path) or os.path.islink(f_path):
                        os.unlink(f_path)
                    elif os.path.isdir(f_path):
                        shutil.rmtree(f_path)
                except Exception as e:
                    print(f"Failed to delete {f_path}: {e}")
    print("Memory and disk reset done.")

# ------------------ Dataset Preparation ------------------
def prepare_dataset_for_trainer(dataset):
    prepared_data = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        prepared_data.append({'sentence1': body, 'sentence2': pair, 'label': sim})
    return Dataset.from_list(prepared_data)

def convert_to_input_examples(dataset):
    examples = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        examples.append(InputExample(texts=[body, pair], label=sim))
    return examples

# ------------------ LLaMA-based SentenceTransformer ------------------
def create_sentence_transformer_from_llama(hf_model_id="meta-llama/Meta-Llama-3-8B-Instruct"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading LLaMA model on device: {device}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(hf_model_id, cache_dir=os.environ['TRANSFORMERS_CACHE'])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Try 4-bit first, fallback to BF16
    try:
        print("Attempting 4-bit quantized load...")
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            quantization_config=quant_config,
            device_map="auto"
        )
        print("Loaded model in 4-bit mode.")
    except Exception as e:
        print(f"⚠️  4-bit load failed: {e}")
        print("Falling back to BF16 full precision...")
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

    # Enable gradient checkpointing to save VRAM
    base_model.gradient_checkpointing_enable()

    # LoRA config for efficient fine-tuning
    lora_config = LoraConfig(r=32, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.1, bias="none", task_type="FEATURE_EXTRACTION")
    base_model = get_peft_model(base_model, lora_config)

    # Wrap in sentence-transformers
    word_embedding_model = Transformer(
        model_name_or_path=hf_model_id,
        max_seq_length=256,   # reduced for speed/memory
        cache_dir=os.environ['TRANSFORMERS_CACHE']
    )
    word_embedding_model.tokenizer = tokenizer
    word_embedding_model.auto_model = base_model

    pooling_model = Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False
    )

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
    print("✅ LLaMA-based SentenceTransformer initialized (LoRA + checkpointing).")
    return model

# ------------------ Training Function ------------------
def train_llama_similarity_model(train_dataset, test_dataset, output_dir):
    reset_memory_disk()
    os.makedirs(output_dir, exist_ok=True)

    model = create_sentence_transformer_from_llama()

    train_data = prepare_dataset_for_trainer(train_dataset)
    test_data = prepare_dataset_for_trainer(test_dataset)
    test_examples = convert_to_input_examples(test_dataset)

    train_loss = losses.CosineSimilarityLoss(model)

    evaluator = BinaryClassificationEvaluator(
        sentences1=[ex.texts[0] for ex in test_examples],
        sentences2=[ex.texts[1] for ex in test_examples],
        labels=[float(ex.label) for ex in test_examples],
        show_progress_bar=True
    )

    args = SentenceTransformerTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,

        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,                       # better on B200
        fp16=False,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=1,
        save_only_model=True,
        logging_steps=100,
        remove_unused_columns=False,
        run_name="llama-similarity-finetuning"
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_data,
        loss=train_loss,
        evaluator=evaluator
    )

    print("🚀 Starting training...")
    trainer.train()

    final_model_path = os.path.join(output_dir, "final_model")
    model.save(final_model_path)
    print(f"🎉 Training complete. Model saved at: {final_model_path}")
    return model, final_model_path

# ------------------ Example Run ------------------
if __name__ == "__main__":
    output_dir = os.path.join(CHECKPOINTS_PATH, "llama_similarity_model")
    train_dataset_s, _ = sample_dataset_stratified(train_dataset, sample_fraction=0.1)
    test_dataset_s, _  = sample_dataset_stratified(test_dataset, sample_fraction=0.1)
    model, model_path = train_llama_similarity_model(train_dataset_s["train"], test_dataset_s["train"], output_dir)

    embeddings = model.encode(["This is a test sentence."], convert_to_tensor=True)
    print("Example embedding:", embeddings)


## Mistral

In [None]:
import os
import torch
import shutil
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.evaluation import BinaryClassificationEvaluator

# New imports
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# ------------------ RunPod Network Paths ------------------
BASE_PATH = "/workspace"
MODELS_PATH = os.path.join(BASE_PATH, "models")
CACHE_PATH = os.path.join(BASE_PATH, "cache")
CHECKPOINTS_PATH = os.path.join(BASE_PATH, "checkpoints")

os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(CACHE_PATH, exist_ok=True)
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)

# Force all caches to network volume
os.environ['HF_HOME'] = CACHE_PATH
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_PATH, "transformers")
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_PATH, "datasets")
os.environ['TORCH_HOME'] = os.path.join(CACHE_PATH, "torch")
os.environ['TMPDIR'] = os.path.join(CACHE_PATH, "tmp")

# ------------------ Memory / Disk Cleanup ------------------
def reset_memory_disk():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    tmp_dirs = [os.environ['TMPDIR'], os.environ['TRANSFORMERS_CACHE']]
    for d in tmp_dirs:
        if os.path.exists(d):
            for f in os.listdir(d):
                f_path = os.path.join(d, f)
                try:
                    if os.path.isfile(f_path) or os.path.islink(f_path):
                        os.unlink(f_path)
                    elif os.path.isdir(f_path):
                        shutil.rmtree(f_path)
                except Exception as e:
                    print(f"Failed to delete {f_path}: {e}")
    print("Memory and disk reset done.")

# ------------------ Dataset Preparation ------------------
def prepare_dataset_for_trainer(dataset):
    prepared_data = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        prepared_data.append({'sentence1': body, 'sentence2': pair, 'label': sim})
    return Dataset.from_list(prepared_data)

def convert_to_input_examples(dataset):
    examples = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        examples.append(InputExample(texts=[body, pair], label=sim))
    return examples

# ------------------ Mistral-based SentenceTransformer ------------------
def create_sentence_transformer_from_mistral(hf_model_id="mistralai/Mistral-7B-Instruct-v0.2"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading Mistral model on device: {device}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(hf_model_id, cache_dir=os.environ['TRANSFORMERS_CACHE'])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Try 4-bit first, fallback to BF16
    try:
        print("Attempting 4-bit quantized load...")
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            quantization_config=quant_config,
            device_map="auto"
        )
        print("Loaded model in 4-bit mode.")
    except Exception as e:
        print(f"⚠️  4-bit load failed: {e}")
        print("Falling back to BF16 full precision...")
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

    # Enable gradient checkpointing
    base_model.gradient_checkpointing_enable()

    # LoRA config tuned for Mistral
    lora_config = LoraConfig(
        r=32,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="FEATURE_EXTRACTION"
    )
    base_model = get_peft_model(base_model, lora_config)

    # Wrap in sentence-transformers
    word_embedding_model = Transformer(
        model_name_or_path=hf_model_id,
        max_seq_length=256,   # reduced for speed/memory
        cache_dir=os.environ['TRANSFORMERS_CACHE']
    )
    word_embedding_model.tokenizer = tokenizer
    word_embedding_model.auto_model = base_model

    pooling_model = Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False
    )

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
    print("✅ Mistral-based SentenceTransformer initialized (LoRA + checkpointing).")
    return model

# ------------------ Training Function ------------------
def train_mistral_similarity_model(train_dataset, test_dataset, output_dir):
    reset_memory_disk()
    os.makedirs(output_dir, exist_ok=True)

    model = create_sentence_transformer_from_mistral()

    train_data = prepare_dataset_for_trainer(train_dataset)
    test_data = prepare_dataset_for_trainer(test_dataset)
    test_examples = convert_to_input_examples(test_dataset)

    train_loss = losses.CosineSimilarityLoss(model)

    evaluator = BinaryClassificationEvaluator(
        sentences1=[ex.texts[0] for ex in test_examples],
        sentences2=[ex.texts[1] for ex in test_examples],
        labels=[float(ex.label) for ex in test_examples],
        show_progress_bar=True
    )

    args = SentenceTransformerTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,   # good for B200
        fp16=False,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=1,
        save_only_model=True,
        logging_steps=100,
        remove_unused_columns=False,
        run_name="mistral-similarity-finetuning"
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_data,
        loss=train_loss,
        evaluator=evaluator
    )

    print("🚀 Starting training...")
    trainer.train()

    final_model_path = os.path.join(output_dir, "final_model")
    model.save(final_model_path)
    print(f"🎉 Training complete. Model saved at: {final_model_path}")
    return model, final_model_path

# ------------------ Example Run ------------------
if __name__ == "__main__":
    output_dir = os.path.join(CHECKPOINTS_PATH, "mistral_similarity_model")

    # Use your already-defined function
    train_dataset_s, _ = sample_dataset_stratified(train_dataset, sample_fraction=0.1)
    test_dataset_s, _  = sample_dataset_stratified(test_dataset, sample_fraction=0.1)

    model, model_path = train_mistral_similarity_model(
        train_dataset_s["train"],
        test_dataset_s["train"],
        output_dir
    )

    embeddings = model.encode(["This is a test sentence."], convert_to_tensor=True)
    print("Example embedding:", embeddings)

## QWEN

In [None]:
import os
import torch
import shutil
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.evaluation import BinaryClassificationEvaluator

# New imports
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

# ------------------ RunPod Network Paths ------------------
BASE_PATH = "/workspace"
MODELS_PATH = os.path.join(BASE_PATH, "models")
CACHE_PATH = os.path.join(BASE_PATH, "cache")
CHECKPOINTS_PATH = os.path.join(BASE_PATH, "checkpoints")

os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(CACHE_PATH, exist_ok=True)
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)

# Force all caches to network volume
os.environ['HF_HOME'] = CACHE_PATH
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_PATH, "transformers")
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_PATH, "datasets")
os.environ['TORCH_HOME'] = os.path.join(CACHE_PATH, "torch")
os.environ['TMPDIR'] = os.path.join(CACHE_PATH, "tmp")

# ------------------ Memory / Disk Cleanup ------------------
def reset_memory_disk():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    tmp_dirs = [os.environ['TMPDIR'], os.environ['TRANSFORMERS_CACHE']]
    for d in tmp_dirs:
        if os.path.exists(d):
            for f in os.listdir(d):
                f_path = os.path.join(d, f)
                try:
                    if os.path.isfile(f_path) or os.path.islink(f_path):
                        os.unlink(f_path)
                    elif os.path.isdir(f_path):
                        shutil.rmtree(f_path)
                except Exception as e:
                    print(f"Failed to delete {f_path}: {e}")
    print("Memory and disk reset done.")

# ------------------ Dataset Preparation ------------------
def prepare_dataset_for_trainer(dataset):
    prepared_data = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        prepared_data.append({'sentence1': body, 'sentence2': pair, 'label': sim})
    return Dataset.from_list(prepared_data)

def convert_to_input_examples(dataset):
    examples = []
    for item in dataset:
        body = item.get('body') if isinstance(item, dict) else item['body']
        pair = item.get('pair') if isinstance(item, dict) else item['pair']
        sim = float(item.get('similarity', 0.0) if isinstance(item, dict) else item['similarity'])
        examples.append(InputExample(texts=[body, pair], label=sim))
    return examples

# ------------------ Qwen-based SentenceTransformer ------------------
def create_sentence_transformer_from_qwen(hf_model_id="Qwen/Qwen2-7B-Instruct"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading Qwen model on device: {device}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(hf_model_id, cache_dir=os.environ['TRANSFORMERS_CACHE'])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Try 4-bit first, fallback to BF16
    try:
        print("Attempting 4-bit quantized load...")
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            quantization_config=quant_config,
            device_map="auto"
        )
        print("Loaded model in 4-bit mode.")
    except Exception as e:
        print(f"⚠️  4-bit load failed: {e}")
        print("Falling back to BF16 full precision...")
        base_model = AutoModel.from_pretrained(
            hf_model_id,
            cache_dir=os.environ['TRANSFORMERS_CACHE'],
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

    base_model.gradient_checkpointing_enable()

    # LoRA config (Qwen uses different attention proj names than LLaMA)
    lora_config = LoraConfig(
        r=32,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="FEATURE_EXTRACTION"
    )

    base_model = get_peft_model(base_model, lora_config)

    # Wrap in sentence-transformers
    word_embedding_model = Transformer(
        model_name_or_path=hf_model_id,
        max_seq_length=256,
        cache_dir=os.environ['TRANSFORMERS_CACHE']
    )
    word_embedding_model.tokenizer = tokenizer
    word_embedding_model.auto_model = base_model

    pooling_model = Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False
    )

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
    print("✅ Qwen-based SentenceTransformer initialized (LoRA + checkpointing).")
    return model

# ------------------ Training Function ------------------
def train_qwen_similarity_model(train_dataset, test_dataset, output_dir):
    reset_memory_disk()
    os.makedirs(output_dir, exist_ok=True)

    model = create_sentence_transformer_from_qwen()

    train_data = prepare_dataset_for_trainer(train_dataset)
    test_data = prepare_dataset_for_trainer(test_dataset)
    test_examples = convert_to_input_examples(test_dataset)

    train_loss = losses.CosineSimilarityLoss(model)

    evaluator = BinaryClassificationEvaluator(
        sentences1=[ex.texts[0] for ex in test_examples],
        sentences2=[ex.texts[1] for ex in test_examples],
        labels=[float(ex.label) for ex in test_examples],
        show_progress_bar=True
    )

    args = SentenceTransformerTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=1,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,

        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,
        fp16=False,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=1,
        save_only_model=True,
        logging_steps=100,
        remove_unused_columns=False,
        run_name="qwen-similarity-finetuning"
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_data,
        loss=train_loss,
        evaluator=evaluator
    )

    print("🚀 Starting training...")
    trainer.train()

    final_model_path = os.path.join(output_dir, "final_model")
    model.save(final_model_path)
    print(f"🎉 Training complete. Model saved at: {final_model_path}")
    return model, final_model_path

# ------------------ Example Run ------------------
if __name__ == "__main__":
    output_dir = os.path.join(CHECKPOINTS_PATH, "qwen_similarity_model")
    train_dataset_s, _ = sample_dataset_stratified(train_dataset, sample_fraction=0.1)
    test_dataset_s, _  = sample_dataset_stratified(test_dataset, sample_fraction=0.1)
    model, model_path = train_qwen_similarity_model(train_dataset_s["train"], test_dataset_s["train"], output_dir)

    embeddings = model.encode(["This is a test sentence."], convert_to_tensor=True)
    print("Example embedding:", embeddings)

# Embedding Drift Functions

The code in this section implements a data quality flagging system that uses semantic drift detection between pairs. It encodes text data (bodies and pairs) into embeddings, computes cosine drift scores (1 minus cosine similarity), and uses statistical methods (Gaussian Mixture Models with KDE fallback) to automatically determine an optimal threshold for flagging problematic samples. The system is calibrated to achieve a target flagging rate (50%) while keeping false positives on clean data below 3%.

In [None]:
import numpy as np
from tqdm import tqdm
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
from scipy.optimize import brentq
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

# ---------------------------
# Config
# ---------------------------
batch_size = 64
target_total_frac = 0.50       # desired overall flagged fraction (50%)
clean_fp_target = 0.03         # ≤3% clean false positive
kde_grid_points = 2000

# ---------------------------
# Helper - safe add column
# ---------------------------
def safe_add_column(split_ds, name, values):
    if name in split_ds.column_names:
        split_ds = split_ds.remove_columns(name)
    return split_ds.add_column(name, values)

# ---------------------------
# 1) Encode prompts
# ---------------------------
def encode_with_progress(texts, desc="Encoding", batch_size=32):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc=desc):
        batch = texts[i:i+batch_size]
        emb = model.encode(batch, show_progress_bar=False, convert_to_numpy=True, batch_size=batch_size)
        embeddings.append(emb)
    return np.vstack(embeddings)

print("🔄 Encoding dataset...")
bodies = test_dataset["train"]["body"]
pairs = test_dataset["train"]["pair"]

body_embs = encode_with_progress(bodies, desc="Bodies", batch_size=batch_size)
pair_embs = encode_with_progress(pairs, desc="Pairs", batch_size=batch_size)

# ---------------------------
# 2) Compute drift scores = 1 - cosine
# ---------------------------
print("🔄 Computing cosine drift (vectorized)...")
A = np.array(body_embs, dtype=np.float32)
B = np.array(pair_embs, dtype=np.float32)

A_norm = np.linalg.norm(A, axis=1, keepdims=True)
B_norm = np.linalg.norm(B, axis=1, keepdims=True)
A_norm = np.where(A_norm == 0, 1.0, A_norm)
B_norm = np.where(B_norm == 0, 1.0, B_norm)

A_unit = A / A_norm
B_unit = B / B_norm

cosine_scores = np.sum(A_unit * B_unit, axis=1)   # similarity in [-1, 1]
drift_scores = 1.0 - cosine_scores                # drift in [0, 2]

drift_col_name = "drift_score"
train_ds = test_dataset["train"]
train_ds = safe_add_column(train_ds, drift_col_name, drift_scores.tolist())
test_dataset["train"] = train_ds

print(f"✅ Computed and saved '{drift_col_name}' "
      f"(min={drift_scores.min():.6f}, max={drift_scores.max():.6f})")

# ---------------------------
# 3) Threshold determination (GMM primary, KDE fallback)
# ---------------------------
print("🔄 Calibrating thresholds (GMM + KDE fallback)...")
sims = np.asarray(drift_scores, dtype=float)   # now using drift
n = len(sims)
target_total_flagged = int(target_total_frac * n)

def try_gmm_thresholding(vals):
    try:
        gmm = GaussianMixture(n_components=2, random_state=0, covariance_type='full', n_init=5)
        gmm.fit(vals.reshape(-1,1))
        means = gmm.means_.flatten()
        covs = gmm.covariances_.flatten()
        weights = gmm.weights_.flatten()
        sigmas = np.sqrt(covs)

        clean_comp = int(np.argmin(means))  # lower-mean = clean (low drift)
        other_comp = 1 - clean_comp

        mu_c = float(means[clean_comp]); sigma_c = float(sigmas[clean_comp]); w_c = float(weights[clean_comp])
        mu_o = float(means[other_comp]); sigma_o = float(sigmas[other_comp]); w_o = float(weights[other_comp])

        def wpdf_diff(x):
            return w_c * norm.pdf(x, loc=mu_c, scale=sigma_c) - w_o * norm.pdf(x, loc=mu_o, scale=sigma_o)

        left = min(mu_c, mu_o) - 5 * max(sigma_c, sigma_o)
        right = max(mu_c, mu_o) + 5 * max(sigma_c, sigma_o)

        intersect = None
        a, b = left, right
        for _ in range(6):
            fa, fb = wpdf_diff(a), wpdf_diff(b)
            if fa == 0: intersect = a; break
            if fb == 0: intersect = b; break
            if fa * fb < 0:
                intersect = brentq(wpdf_diff, a, b); break
            a -= (b - a); b += (b - a)
        if intersect is None:
            intersect = float((mu_c + mu_o) / 2.0)

        return True, {
            "method":"gmm",
            "gmm": gmm,
            "mu_c": mu_c, "sigma_c": sigma_c, "w_c": w_c,
            "mu_o": mu_o, "sigma_o": sigma_o, "w_o": w_o,
            "intersection": float(intersect)
        }
    except Exception as e:
        return False, {"error": str(e)}

def kde_fallback_thresholding(vals, grid_points=2000):
    kde = gaussian_kde(vals)
    grid = np.linspace(vals.min(), vals.max(), grid_points)
    dens = kde(grid)
    peaks, _ = find_peaks(dens)
    if len(peaks) == 0:
        return True, {"method":"kde", "valley": float(np.median(vals))}
    peak_heights = dens[peaks]
    sorted_idx = np.argsort(peak_heights)[::-1]
    if len(sorted_idx) == 1:
        clean_peak = grid[peaks[sorted_idx[0]]]
        valley = float((clean_peak + vals.max())/2.0)
        return True, {"method":"kde", "valley": valley}
    top_two = peaks[sorted_idx[:2]]
    peak_positions = grid[top_two]
    clean_peak_pos = float(np.min(peak_positions))   # low drift peak
    other_peak_pos = float(np.max(peak_positions))
    left_idx = int(np.argmin(np.abs(grid - clean_peak_pos)))
    right_idx = int(np.argmin(np.abs(grid - other_peak_pos)))
    if left_idx > right_idx:
        left_idx, right_idx = right_idx, left_idx
    valley_region = dens[left_idx:right_idx+1]
    if len(valley_region) == 0:
        valley_x = float((clean_peak_pos + other_peak_pos) / 2.0)
    else:
        valley_rel_idx = np.argmin(valley_region)
        valley_x = float(grid[left_idx + valley_rel_idx])

    # approximate clean normal from samples near clean_peak_pos
    window = (vals.max() - vals.min()) * 0.05
    samples_near_clean = vals[(vals >= clean_peak_pos - window) & (vals <= clean_peak_pos + window)]
    if len(samples_near_clean) < 8:
        window = (vals.max() - vals.min()) * 0.10
        samples_near_clean = vals[(vals >= clean_peak_pos - window) & (vals <= clean_peak_pos + window)]
    if len(samples_near_clean) >= 8:
        mu_c_est = float(np.mean(samples_near_clean))
        sigma_c_est = float(np.std(samples_near_clean, ddof=1) + 1e-8)
    else:
        lower_q = np.percentile(vals, 25)
        samples_near_clean = vals[vals <= lower_q]
        mu_c_est = float(np.mean(samples_near_clean))
        sigma_c_est = float(np.std(samples_near_clean, ddof=1) + 1e-8)

    return True, {"method":"kde","valley": valley_x, "mu_c_est": mu_c_est, "sigma_c_est": sigma_c_est, "kde": None}

# Threshold search
ok, res = try_gmm_thresholding(sims)
if ok and res["method"] == "gmm":
    used_method = "gmm"
    t_intersect = float(res["intersection"])
    mu_c = res["mu_c"]; sigma_c = res["sigma_c"]

    # Cap: clean FP ≤ target
    try:
        cap_threshold = float(norm.ppf(1 - clean_fp_target, loc=mu_c, scale=sigma_c))
    except Exception:
        cap_threshold = t_intersect
    cap_threshold = min(max(cap_threshold, sims.min()), sims.max())

    if np.sum(sims > cap_threshold) >= target_total_flagged:
        low, high = float(cap_threshold), float(sims.max())
        best_t = low; best_diff = abs(np.sum(sims > low) - target_total_flagged)
        for _ in range(40):
            mid = (low + high) / 2.0
            cnt = np.sum(sims > mid)
            diff = cnt - target_total_flagged
            if abs(diff) < best_diff:
                best_diff = abs(diff); best_t = mid
            if cnt >= target_total_flagged: low = mid
            else: high = mid
        final_threshold = float(best_t)
    else:
        final_threshold = float(cap_threshold)
else:
    ok2, kres = kde_fallback_thresholding(sims, grid_points=kde_grid_points)
    if not ok2: raise RuntimeError("Thresholding failed (GMM and KDE).")
    used_method = "kde"
    valley = float(kres["valley"])
    mu_c = kres.get("mu_c_est", None); sigma_c = kres.get("sigma_c_est", None)
    if mu_c is not None and sigma_c is not None:
        try:
            cap_threshold = float(norm.ppf(1 - clean_fp_target, loc=mu_c, scale=sigma_c))
        except Exception:
            cap_threshold = valley
    else:
        cap_threshold = valley

    cap_threshold = min(max(cap_threshold, sims.min()), sims.max())
    if np.sum(sims > cap_threshold) >= target_total_flagged:
        low, high = float(cap_threshold), float(sims.max())
        best_t = low; best_diff = abs(np.sum(sims > low) - target_total_flagged)
        for _ in range(40):
            mid = (low + high) / 2.0
            cnt = np.sum(sims > mid)
            diff = cnt - target_total_flagged
            if abs(diff) < best_diff: best_diff = abs(diff); best_t = mid
            if cnt >= target_total_flagged: low = mid
            else: high = mid
        final_threshold = float(best_t)
    else:
        final_threshold = float(cap_threshold)

threshold = float(final_threshold)
flagged_mask = sims > threshold

# Save flagged column
train_ds = test_dataset["train"]
train_ds = safe_add_column(train_ds, "flagged", flagged_mask.tolist())
test_dataset["train"] = train_ds

# Reporting
total_flagged = int(np.sum(flagged_mask))
empirical_clean_fp = None
if "category" in test_dataset["train"].column_names:
    is_clean = np.array(test_dataset["train"]["category"]) == "clean"
    if np.sum(is_clean) > 0:
        empirical_clean_fp = np.sum(flagged_mask[is_clean]) / np.sum(is_clean)

print("\n===== Thresholding Summary (Drift) =====")
print(f"Method used: {used_method}")
print(f"Final threshold: {threshold:.6f}")
print(f"Total flagged: {total_flagged}/{n} ({total_flagged/n:.1%})")
if empirical_clean_fp is not None:
    print(f"Empirical clean FP (labels): {empirical_clean_fp:.2%}")
try:
    if used_method == "gmm":
        est_clean_fp = 1 - float(norm.cdf(threshold, loc=mu_c, scale=sigma_c))
        print(f"Estimated clean FP (GMM clean tail): {est_clean_fp:.2%}")
    elif used_method == "kde":
        if mu_c is not None and sigma_c is not None:
            est_clean_fp = 1 - float(norm.cdf(threshold, loc=mu_c, scale=sigma_c))
            print(f"Estimated clean FP (KDE-derived normal tail): {est_clean_fp:.2%}")
except Exception:
    pass
print("================================\n")

# Category distribution
def category_distribution(dataset):
    if "category" not in dataset["train"].column_names:
        print("No 'category' column for distribution reporting."); return
    categories = np.array(dataset["train"]["category"])
    flagged_arr = np.array(dataset["train"]["flagged"])
    unique_categories = np.unique(categories)
    print("\n📊 Category distribution of flagged samples:")
    for cat in unique_categories:
        mask = categories == cat
        total = np.sum(mask)
        flagged_count = np.sum(flagged_arr[mask])
        print(f"{cat}: {flagged_count}/{total} flagged ({flagged_count/total:.1%})")

category_distribution(test_dataset)