In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import random
import re
from vllm import LLM, SamplingParams

# Initialize vLLM model
llm = LLM(
    model="/path_to_your_model/Qwen/Qwen2.5-7B-Instruct",
    tensor_parallel_size=1,  # Adjust based on your GPU setup
    trust_remote_code=True,
    max_model_len=2048
)


In [None]:
# Configure sampling parameters
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.95,
    max_tokens=2048,
    repetition_penalty=1.2  
)

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import random
import re
from vllm import LLM, SamplingParams
import jsonlines
from tqdm import tqdm
from typing import List, Dict
import math
import warnings
import logging
from transformers import logging as transformers_logging
from pathlib import Path

# Silence the warnings
warnings.filterwarnings("ignore")
transformers_logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)



# Initialize BERT pipeline once
ner_model = pipeline("ner", 
                    model="dbmdz/bert-large-cased-finetuned-conll03-english",
                    aggregation_strategy="simple",
                    device=0 if torch.cuda.is_available() else -1)

def step1_bert_ner(text):
    """
    Step 1: Use BERT to identify and mask named entities, plus regex for additional PII
    """
    # Update regex patterns to catch numbers without requiring prefixes
    pii_patterns = {
        'SSN': r'\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b',  # Just the SSN format
        'VIN': r'\b[A-HJ-NPR-Z0-9]{17}\b',  # 17 alphanumeric chars
        'BTC': r'\b[A-Za-z0-9]{20,}(?=[^A-Za-z0-9]|$)',
        'DL': r'\b[A-Z]\d{7}\b',  # Letter followed by 7 digits
        'EMAIL': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
    }
    
    # Process regex patterns first
    masked_text = text
    mapping = {}
    entity_counts = {}
    
    # Handle regex patterns first
    for pii_type, pattern in pii_patterns.items():
        matches = list(re.finditer(pattern, masked_text))
        matches.reverse()  # Process from end to start
        
        for match in matches:
            entity_counts[pii_type] = entity_counts.get(pii_type, 0) + 1
            mask = f"[{pii_type}{entity_counts[pii_type]}]"
            
            masked_text = (
                masked_text[:match.start()] + 
                mask + 
                masked_text[match.end():]
            )
            
            mapping[mask] = {
                'original': match.group(),
                'type': pii_type
            }
    
    # Then process BERT entities
    entities = ner_model(masked_text)
    
    # Sort entities by start position in reverse order
    entities.sort(key=lambda x: x['start'], reverse=True)
    
    # Process BERT entities
    for entity in entities:
        entity_type = entity['entity_group']
        if entity_type not in ['PER', 'ORG', 'LOC']:
            continue
            
        entity_counts[entity_type] = entity_counts.get(entity_type, 0) + 1
        mask = f"[{entity_type}{entity_counts[entity_type]}]"
        
        masked_text = (
            masked_text[:entity['start']] + 
            mask + 
            masked_text[entity['end']:]
        )
        
        mapping[mask] = {
            'original': entity['word'],
            'type': entity_type
        }
    
    return masked_text, mapping

def get_random_pii(entity_type, *args):
    """
    Fallback function for random PII generation with more options
    """
    random_values = {
        'PER': [
            'James Wilson', 'Mary Johnson', 'Robert Brown', 'Sarah Davis',
            'Michael Chen', 'Emily Taylor', 'David Miller', 'Lisa Anderson',
            'Thomas Wright', 'Jennifer Lee'
        ],
        'ORG': [
            'Acme Corp', 'Global Tech', 'Summit Industries', 'Pioneer Systems',
            'Blue Ridge Solutions', 'Nexus Innovations', 'Quantum Dynamics',
            'Atlas Technologies', 'Stellar Enterprises', 'Horizon Group'
        ],
        'LOC': [
            'Chicago', 'Los Angeles', 'Boston', 'Seattle',
            'Austin', 'Denver', 'Portland', 'Miami',
            'Atlanta', 'San Francisco', 'Dallas', 'Phoenix'
        ],
        'SSN': [
            '123-45-6789', '987-65-4321', '456-78-9012', '789-01-2345'
        ],
        'DL': [
            'A1234567', 'B7654321', 'C9876543', 'D2468101'
        ],
        'VIN': [
            '1HGCM82633A123456', '2FMDK3JC8BB234567', '3VWFE21C04M345678'
        ],
        'BTC': [
            '5JWwqjxTLBcJig6SgfiksxY6C1XEwE9ZJGEeyRb8K8N1qP3Xm5n',
            '5KQNQKj2FEsNvnXXBF73J7kMJqg3hdY6yQwkRQ9hNPPuPKZNiBk'
        ],
        'EMAIL': [
            'user1@example.com',
            'contact@company.net',
            'info@business.org',
            'support@service.com'
        ]
    }
    return random.choice(random_values[entity_type])

def clean_mistral_output(text, entity_type):
    """
    Clean up Mistral output and validate based on entity type
    """
    # Remove instruction tokens and anything after them
    if '[/INST]' in text:
        text = text.split('[/INST]')[0]
    
    # Remove any special tokens or prefixes
    text = text.replace('[INST]', '').strip()
    text = re.sub(r'^(Location:|Locality:|Name:|Organization:|Company:|You are)\s*', '', text)
    
    # Remove any text after common separators
    separators = ['/STATE', '/', ',', ' - ', ' You ']
    for sep in separators:
        if sep in text:
            text = text.split(sep)[0]
    
    text = text.strip()
    
    # Validate based on entity type
    if entity_type == 'PER':
        # Take only first and last name if more than two words
        words = text.split()
        if len(words) > 2:
            text = f"{words[0]} {words[-1]}"
        # Ensure it looks like a person name (2 words, no numbers)
        if not re.match(r'^[A-Za-z]+\s+[A-Za-z]+$', text):
            return ''
            
    elif entity_type == 'ORG':
        # Ensure it's not too long and doesn't contain special characters
        if len(text) > 30 or re.search(r'[^\w\s&\'-]', text):
            return ''
            
    elif entity_type == 'LOC':
        # Ensure it's a simple location name
        if len(text) > 20 or re.search(r'[^\w\s\'-]', text):
            return ''
        # Ensure it's not just "New" or generic terms
        if text.lower() in ['new', 'location', 'city', 'town']:
            return ''
    
    # Add validation for new PII types
    elif entity_type == 'SSN':
        # Ensure it matches SSN format
        if not re.match(r'^\d{3}-?\d{2}-?\d{4}$', text):
            return ''
    elif entity_type == 'DL':
        # Ensure it matches driver's license format
        if not re.match(r'^[A-Z]\d{7}$', text):
            return ''
    elif entity_type == 'VIN':
        # Ensure it matches VIN format
        if not re.match(r'^[A-HJ-NPR-Z0-9]{17}$', text):
            return ''
    elif entity_type == 'BTC':
        # Ensure it matches Bitcoin address format
        if not re.match(r'^[1-9A-HJ-NP-Za-km-z]{51,52}$', text):
            return ''
    elif entity_type == 'EMAIL':
        if not re.match(r'^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}$', text):
            return ''
    
    return text

def step2_mistral_replace(masked_text, mapping):
    """
    Step 2: Use Mistral to generate replacements
    """
    for mask, info in mapping.items():
        try:
            replacement = get_context_aware_replacement(
                info['type'], 
                mask, 
                masked_text
            )
            
            # Clean up the replacement text
            replacement = clean_mistral_output(replacement, info['type'])
            
            if replacement and len(replacement.strip()) > 0:
                mapping[mask]['replacement'] = replacement
                masked_text = masked_text.replace(mask, replacement)
            else:
                # Fallback to random PII if Mistral output is invalid
                fallback = get_random_pii(info['type'])
                mapping[mask]['replacement'] = fallback
                masked_text = masked_text.replace(mask, fallback)
                
        except Exception as e:
            print(f"Error replacing {mask}, using fallback: {str(e)}")
            fallback = get_random_pii(info['type'])
            mapping[mask]['replacement'] = fallback
            masked_text = masked_text.replace(mask, fallback)
    
    return masked_text, mapping

def step3_save_mapping(mapping, filename="entity_mapping.json"):
    """
    Step 3: Save the mapping to a file
    """
    with open(filename, 'w') as f:
        json.dump(mapping, f, indent=2)
    return filename

def step4_restore_text(text, mapping):
    """
    Step 4: Restore original text using mapping
    """
    restored_text = text
    for mask, info in mapping.items():
        restored_text = restored_text.replace(info['replacement'], info['original'])
    return restored_text


def get_random_different_pii(entity_type, original_value):
    """
    Get a random PII value that's different from the original
    """
    random_values = {
        'PER': [
            'James Wilson', 'Mary Johnson', 'Robert Brown', 'Sarah Davis',
            'Michael Chen', 'Emily Taylor', 'David Miller', 'Lisa Anderson'
        ],
        'ORG': [
            'Acme Corp', 'Global Tech', 'Summit Industries', 'Pioneer Systems',
            'Blue Ridge Solutions', 'Nexus Innovations', 'Quantum Dynamics'
        ],
        'LOC': [
            'Chicago', 'Los Angeles', 'Boston', 'Seattle',
            'Austin', 'Denver', 'Portland', 'Miami'
        ]
    }
    
    options = [v for v in random_values[entity_type] if v.lower() != original_value.lower()]
    return random.choice(options)

class PIITracker:
    def __init__(self):
        self.pii_sets = {
            'PER': {},  # {original_text: set_id}
            'ORG': {},
            'LOC': {}
        }
        self.current_sets = {
            'PER': 0,
            'ORG': 0,
            'LOC': 0
        }
        self.replacements = {}  # {set_id: replacement_value}
    
    def identify_pii(self, original_text, entity_type):
        """
        Identify if PII belongs to existing set or create new set
        """
        pii_dict = self.pii_sets[entity_type]
        
        # Check if this exact text was seen before
        if original_text in pii_dict:
            return pii_dict[original_text]
        
        # Check if this text might be a variant of existing PII
        for known_text, set_id in pii_dict.items():
            if self._are_similar(original_text, known_text):
                pii_dict[original_text] = set_id
                return set_id
        
        # Create new set if not found
        self.current_sets[entity_type] += 1
        set_id = f"{entity_type}_{self.current_sets[entity_type]}"
        pii_dict[original_text] = set_id
        return set_id
    
    def _are_similar(self, text1, text2):
        """
        Check if two pieces of text might be variants of the same PII
        """
        # Convert to lowercase and remove common prefixes/suffixes
        t1 = text1.lower().strip()
        t2 = text2.lower().strip()
        
        # Direct match
        if t1 == t2:
            return True
        
        # One is contained in the other
        if t1 in t2 or t2 in t1:
            return True
        
        # TODO: Could add more sophisticated matching here
        # (e.g., fuzzy matching, nickname matching for persons)
        
        return False
    
    def add_replacement(self, set_id, replacement):
        """
        Add or get consistent replacement for a PII set
        """
        if set_id not in self.replacements:
            self.replacements[set_id] = replacement
        return self.replacements[set_id]
    
    def get_pii_report(self):
        """
        Generate report of PII sets and their replacements
        """
        report = []
        for entity_type in self.pii_sets:
            for original, set_id in self.pii_sets[entity_type].items():
                replacement = self.replacements.get(set_id, "NOT_REPLACED")
                report.append({
                    'set_id': set_id,
                    'type': entity_type,
                    'original': original,
                    'replacement': replacement
                })
        return report

def group_entities(entities):
    """
    Group entities based on their position in text.
    """
    grouped = []
    current_group = None
    last_end = -1
    
    for entity in entities:
        if not entity:  # Skip if entity is None
            continue
            
        entity_type = entity['entity'].replace('I-', '').replace('B-', '')  # Remove I- and B- prefixes
        
        # Start new group if not adjacent to previous entity or different type
        if (current_group is None or 
            entity['start'] > last_end + 1 or 
            entity_type != current_group['type']):
            
            if current_group:
                grouped.append(current_group)
            current_group = {
                'word': entity['word'],
                'type': entity_type,
                'start': entity['start'],
                'end': entity['end']
            }
        else:
            # Extend current group
            current_group['word'] += ' ' + entity['word']
            current_group['end'] = entity['end']
        
        last_end = entity['end']
    
    # Add last group
    if current_group:
        grouped.append(current_group)
    
    return grouped

def apply_fallback_replacements(text, pii_tracker):
    """
    Apply fallback replacements using BERT NER with grouped entities
    """
    fallback_values = {
        'PER': ['Emily Thompson', 'Michael Chen', 'Sarah Davis', 'Robert Wilson'],
        'ORG': ['TechCorp', 'Quantum Systems', 'Pioneer Solutions', 'Global Dynamics'],
        'LOC': ['Boston', 'Seattle', 'Chicago', 'Austin']
    }
    
    # Use BERT to identify entities
    ner_model = pipeline("ner", 
                        model="dbmdz/bert-large-cased-finetuned-conll03-english",
                        aggregation_strategy="simple")  # Use simple aggregation
    
    raw_entities = ner_model(text)
    
    # Sort entities by start position in reverse order
    raw_entities.sort(key=lambda x: x['start'], reverse=True)
    
    result = text
    for entity in raw_entities:
        entity_type = entity['entity_group']  # Using entity_group from simple aggregation
        if entity_type in ['PER', 'ORG', 'LOC']:
            original = entity['word']
            set_id = pii_tracker.identify_pii(original, entity_type)
            replacement = pii_tracker.replacements.get(set_id)
            
            if not replacement:
                replacement = random.choice(fallback_values[entity_type])
                pii_tracker.add_replacement(set_id, replacement)
            
            # Replace in text
            result = (
                result[:entity['start']] + 
                replacement + 
                result[entity['end']:]
            )
    
    return result

def clean_output(text):
    """Clean the model output"""
    if not text:
        return text
    
    # Remove any lines starting with ###
    text = re.sub(r'\n###.*$', '', text, flags=re.MULTILINE)
    
    # Remove any "Replacements:" section
    text = re.sub(r'\n?Replacements:.*$', '', text, flags=re.DOTALL)
    
    # Remove any "Note how" explanatory text
    text = re.sub(r'\n?Note how.*$', '', text, flags=re.DOTALL)
    
    # Remove any trailing response markers
    text = re.sub(r'\n?### Response:.*$', '', text, flags=re.DOTALL)
    
    # Clean up any extra spaces in email addresses
    text = re.sub(r'(\[[\w.]+)\s*\.\s*([\w]+@[\w.]+\])', r'\1.\2', text)
    
    # Remove any explanatory notes at the end
    text = re.sub(r'\n?---.*$', '', text, flags=re.DOTALL)
    
    return text.strip()
def replace_entities_with_mistral(text, bracketed_text, pii_tracker=None):
    """
    Direct entity replacement using vLLM with Alpaca-style prompt
    """
    if pii_tracker is None:
        pii_tracker = PIITracker()
        
    messages = [
        {"role": "system", "content": "You are a specialized entity replacement system. Replace ALL marked entities in [brackets] with new consistent values."},
        {"role": "user", "content": f"""
Replace ALL entities marked with [brackets] in this text. Each entity MUST be replaced.

### Input:
{bracketed_text}
### Examples:
Input: [John Smith] works at [Microsoft]. [John] enjoys [Seattle]. At [Microsoft], [Smith] is happy.
Output: [Michael Brown] works at [TechCorp]. [Michael] enjoys [Boston]. At [TechCorp], [Brown] is happy.

### Rules:
- Replace each entity consistently throughout the text
- Replace person names with different names keeping same word count
- Replace company names with different company names
- Replace city/location names with different cities
- Replace driver license to another driver license number
- replace Bitcoin key to another bitcoin key
- replace SSN number to another SSN number
- replace password to a different password
- Keep email addresses in proper format without spaces
- Keep alphanumeric keys in exact format length
- Maintain original sentence or paragraph structure and format
- return the replaced entities also in bracket
- DO not rewrite sentences where is no replacement happened, return them as it is!
- Do not add any explanations or comments

### Response:"""}
    ]

    try:
        # Format prompt for vLLM
        prompt = messages[0]["content"] + "\n\n" + messages[1]["content"]
        
        # Generate using vLLM
        outputs = llm.generate([prompt], sampling_params)
        replaced_text = outputs[0].outputs[0].text
        
        # Clean up the output
        replaced_text = clean_output(replaced_text)
        
        # If model didn't make proper replacements, use fallback
        if replaced_text == '' or replaced_text == text:
            print('fallback!\n')
            return apply_fallback_replacements(text, pii_tracker)
            
        return replaced_text
        
    except Exception as e:
        print(f"Error in replacement: {str(e)}")
        return apply_fallback_replacements(text, pii_tracker)

def process_pii_pipeline(text):
    """
    Process text with PII tracking and grouped entities
    """
    # First use BERT to get masked text
    masked_text, initial_mapping = step1_bert_ner(text)
    
    # Create bracketed version
    bracketed_text = text
    processed_spans = set()
    
    # Sort entities by length and create a list of tuples for easier processing
    sorted_entities = []
    for mask, info in initial_mapping.items():
        original = info['original']
        cleaned = re.sub(r'\s*([\'"])\s*', r'\1', original)
        sorted_entities.append({
            'original': cleaned,
            'type': info['type']
        })
    
    # Sort by length (longest first) to handle overlapping entities properly
    sorted_entities.sort(key=lambda x: -len(x['original']))
    
    # Create a list of all positions to replace
    replacements = []
    for entity in sorted_entities:
        original = entity['original']
        start = 0
        text_len = len(bracketed_text)
        
        # Find all occurrences efficiently
        while start < text_len:
            pos = bracketed_text.find(original, start)
            if pos == -1:
                break
                
            end = pos + len(original)
            # Check if this span overlaps with any processed spans
            if not any(p[0] <= pos < p[1] or p[0] < end <= p[1] for p in processed_spans):
                replacements.append((pos, end, original))
                processed_spans.add((pos, end))
            start = pos + 1
    
    # Sort replacements in reverse order to maintain correct positions
    replacements.sort(key=lambda x: -x[0])
    
    # Apply all replacements at once
    for start, end, original in replacements:
        bracketed_text = (
            bracketed_text[:start] + 
            f"[{original}]" + 
            bracketed_text[end:]
        )
    
    return {
        'original': text,
        'bracketed': bracketed_text,
        'masked': masked_text,
        'mapping': initial_mapping
    }
def extract_email_parts(text):
    """
    Extract the email content and category from the instruction-based format
    """
    try:
        # Extract the input section (email content)
        input_match = re.search(r'### Input:\n(.*?)(?=\n### Response:)', text, re.DOTALL)
        
        email_content = input_match.group(1).strip() if input_match else ""
        
        # Extract the response section (category)
        response_match = re.search(r'### Response:\n?(.*?)(?:\n|$)', text, re.DOTALL)
        category = response_match.group(1).strip() if response_match else "Uncategorized"
        
        return email_content, category
    except Exception as e:
        print(f"Extraction error: {str(e)}")
        return text, "Uncategorized"

def prepare_replacement_prompt(text: str) -> str:
    """Prepare prompt for entity replacement using the original template"""
    system_msg = "You are a specialized entity replacement system. Replace ALL marked entities in [brackets] with new consistent values."
    user_msg = f"""Replace ALL entities marked with [brackets] in this text. Each entity MUST be replaced.

### Input:
{text}
### Examples:
Input: [John Smith] works at [Microsoft]. [John] enjoys [Seattle]. At [Microsoft], [Smith] is happy.
Output: [Michael Brown] works at [TechCorp]. [Michael] enjoys [Boston]. At [TechCorp], [Brown] is happy.

### Rules:
- Replace each entity consistently throughout the text
- Replace person names with different names keeping same word count
- Replace company names with different company names
- Replace city/location names with different cities
- Replace driver license to another driver license number
- replace Bitcoin key to another bitcoin key
- replace SSN number to another SSN number
- replace password to a different password
- Keep email addresses in proper format without spaces
- Keep alphanumeric keys in exact format length
- Maintain original sentence or paragraph structure and format
- Only return the replaced text, no explanations or lists
- DO NOT include any "Replacements:" section
- DO NOT add any additional comments or sections

### Response:"""

    return f"{system_msg}\n\n{user_msg}"

from tqdm import tqdm

def save_prompts(prompts, output_dir):
    """Save prompts to a file"""
    prompts_file = output_dir / "prompts.jsonl"
    with jsonlines.open(prompts_file, mode='w') as writer:
        for prompt in prompts:
            writer.write({"prompt": prompt})
    return prompts_file

def load_processed_indices(output_dir):
    """Load indices of already processed prompts"""
    final_output_file = output_dir / "3_final_output.jsonl"
    processed_indices = set()
    if final_output_file.exists():
        with jsonlines.open(final_output_file) as reader:
            for idx, _ in enumerate(reader):
                processed_indices.add(idx)
    return processed_indices

def process_jsonl_with_intermediate_files(input_file, output_dir, batch_size=32, overwrite=True):
    """
    Process JSONL file with intermediate files for each step
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Define intermediate file paths
    original_file = output_dir / "1_original.jsonl"
    bracketed_file = output_dir / "2_bracketed.jsonl"
    final_output_file = output_dir / "3_final_output.jsonl"
    prompts_file = output_dir / "prompts.jsonl"
    
    # Step 1 & 2: Extract and create bracketed versions if needed
    if overwrite or not bracketed_file.exists():
        print("Step 1 & 2: Processing original text and creating bracketed versions...")
        contents_and_categories = []
        prompts = []
        
        with jsonlines.open(input_file) as reader, \
             jsonlines.open(original_file, mode='w') as original_writer, \
             jsonlines.open(bracketed_file, mode='w') as bracketed_writer:
            
            for item in tqdm(reader, desc="Processing entities"):
                text = item.get('text', '')
                if text:
                    content, category = extract_email_parts(text)
                    if content:
                        # Save original
                        original_item = {
                            'category': category,
                            'text': content
                        }
                        original_writer.write(original_item)
                        contents_and_categories.append(original_item)
                        
                        # Process and save bracketed version
                        try:
                            results = process_pii_pipeline(content)
                            bracketed_item = {
                                'category': category,
                                'text': results['bracketed']
                            }
                            bracketed_writer.write(bracketed_item)
                            
                            # Prepare prompt
                            prompt = prepare_replacement_prompt(results['bracketed'])
                            prompts.append(prompt)
                            
                        except Exception as e:
                            print(f"Error in bracketing: {str(e)}")
                            bracketed_writer.write(original_item)
                            prompts.append(prepare_replacement_prompt(content))
        
        # Save prompts
        save_prompts(prompts, output_dir)
    
    # Step 3: Generate final replacements using vLLM in batches
    print("\nStep 3: Generating final replacements...")
    
    # Load prompts
    all_prompts = []
    with jsonlines.open(prompts_file) as reader:
        for item in reader:
            all_prompts.append(item['prompt'])
    
    # Get already processed indices
    processed_indices = load_processed_indices(output_dir)
    
    # Process in batches
    total_prompts = len(all_prompts)
    for batch_start in range(0, total_prompts, batch_size):
        batch_end = min(batch_start + batch_size, total_prompts)
        
        # Skip if all items in this batch are already processed
        if all(i in processed_indices for i in range(batch_start, batch_end)):
            continue
        
        print(f"\nProcessing batch {batch_start//batch_size + 1}/{(total_prompts + batch_size - 1)//batch_size}")
        
        # Get batch prompts
        batch_prompts = all_prompts[batch_start:batch_end]
        
        try:
            # Generate using vLLM
            outputs = llm.generate(batch_prompts, sampling_params)
            
            # Save results
            with jsonlines.open(final_output_file, mode='a') as writer:
                for idx, output in enumerate(outputs):
                    global_idx = batch_start + idx
                    if global_idx not in processed_indices:
                        try:
                            replaced_text = output.outputs[0].text.strip()
                            replaced_text = clean_output(replaced_text)
                            
                            # Load corresponding original item
                            with jsonlines.open(bracketed_file) as reader:
                                for i, item in enumerate(reader):
                                    if i == global_idx:
                                        output_item = {
                                            'category': item['category'],
                                            'text': replaced_text if replaced_text else item['text']
                                        }
                                        writer.write(output_item)
                                        break
                            
                        except Exception as e:
                            print(f"Error in final output for index {global_idx}: {str(e)}")
            
        except Exception as e:
            print(f"Error processing batch {batch_start//batch_size + 1}: {str(e)}")
    
    print(f"\nProcessing complete. Results saved in {output_dir}")
    print(f"1. Original content: {original_file}")
    print(f"2. Bracketed versions: {bracketed_file}")
    print(f"3. Final output: {final_output_file}")
    print(f"4. Prompts: {prompts_file}")



In [None]:
# Example usage
if __name__ == "__main__":
    input_jsonl = "data/costco_corporate_synthetic_emails.jsonl"
    #input_jsonl = 'test.jsonl'
    output_jsonl = "data"
    process_jsonl_with_intermediate_files(input_jsonl, output_jsonl,
                                          batch_size=600,
        overwrite=True)