# Gemini Email Anonymization Pipeline

A Python pipeline that uses the Google Gemini API to detect and anonymize
sensitive entities in email texts using 21 custom labels.


In [None]:
#!/usr/bin/env python3

import os
import json
import re
import argparse
import logging
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import google.generativeai as genai
from dataclasses import dataclass

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class EntitySpan:
    """Represents an entity span with start, end positions and label."""
    start: int
    end: int
    label: str

class GeminiAnonymizer:
    """Main class for handling Gemini API-based email anonymization."""

    # Target labels for anonymization
    TARGET_LABELS = [
        'NACHNAME', 'VORNAME', 'STRASSE', 'POSTLEITZAHL', 'WOHNORT',
        'HAUSNUMMER', 'VERTRAGSNUMMER', 'DATUM', 'ZÄHLERNUMMER',
        'TELEFONNUMMER', 'GESENDET_MIT', 'ZAHLUNG', 'FIRMA', 'TITEL',
        'EMAIL', 'ZÄHLERSTAND', 'LINK', 'IBAN', 'BANK', 'BIC', 'FAX'
    ]

    def __init__(self, api_key: str, model_name: str = "gemini-1.5-flash"):
        """Initialize the Gemini anonymizer with API key."""
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel(model_name)

    def _create_anonymization_prompt(self, text: str) -> str:
        """Create a detailed prompt for the Gemini API to anonymize text."""
        labels_str = "', '".join(self.TARGET_LABELS)

        prompt = f"""You are an expert text anonymization system. Your task is to identify and replace sensitive entities in German email text with anonymization placeholders.

IMPORTANT RULES:
1. Only replace entities that match these 21 labels: ['{labels_str}']
2. Replace each detected entity with <<LABEL>> (e.g., <<VORNAME>>, <<NACHNAME>>)
3. Preserve the original text structure, formatting, and whitespace exactly
4. Do not modify any text that doesn't contain sensitive entities
5. Return ONLY the anonymized text, no explanations or metadata

LABEL DEFINITIONS:
- NACHNAME: Last names/surnames
- VORNAME: First names/given names
- STRASSE: Street names
- POSTLEITZAHL: Postal codes
- WOHNORT: City/town names
- HAUSNUMMER: House numbers
- VERTRAGSNUMMER: Contract numbers, all other seemingly sensitive numbers that don't fall into another category (e.g., OR-0001239759)
- DATUM: Dates in any format, also only Months, or "10ter" ("tenth") of a month
- ZÄHLERNUMMER: Meter numbers
- TELEFONNUMMER: Phone numbers
- GESENDET_MIT: "Sent with" messages (e.g., "Diese Nachricht wurde von meinem Android...")
- ZAHLUNG: Payment information, Euros/ Cents etc.
- FIRMA: Company names
- TITEL: Titles (Mr., Mrs., Dr., Dipl., etc.)
- EMAIL: Email addresses
- ZÄHLERSTAND: Meter readings and differences in m3 or kwh etc.
- LINK: URLs and web links
- IBAN: International bank account numbers
- BANK: Bank names
- BIC: Bank identifier codes
- FAX: Fax numbers

EXAMPLE:
Input: "Kunde: Gertraude Wagenknecht-Speer"
Output: "Kunde: <<VORNAME>> <<NACHNAME>>"

TEXT TO ANONYMIZE:
{text}

ANONYMIZED TEXT:"""

        return prompt

    def call_gemini_api(self, text: str) -> str:
        """Call the Gemini API to anonymize the given text."""
        try:
            prompt = self._create_anonymization_prompt(text)
            response = self.model.generate_content(prompt)

            if response.text:
                return response.text.strip()
            else:
                logger.error("Empty response from Gemini API")
                return text

        except Exception as e:
            logger.error(f"Error calling Gemini API: {e}")
            return text

    def extract_spans(self, original_text: str, anonymized_text: str) -> List[Dict]:
        """Extract entity spans by comparing original and anonymized texts."""
        spans = []

        # Find all placeholder patterns in anonymized text
        placeholder_pattern = r'<<([A-Z_]+)>>'
        placeholders = list(re.finditer(placeholder_pattern, anonymized_text))

        if not placeholders:
            return spans

        # Create a mapping of positions
        original_pos = 0
        anonymized_pos = 0

        for placeholder_match in placeholders:
            placeholder_start = placeholder_match.start()
            placeholder_end = placeholder_match.end()
            label = placeholder_match.group(1)

            # Find the corresponding position in original text
            # Move to the placeholder position in anonymized text
            while anonymized_pos < placeholder_start:
                if (anonymized_pos < len(anonymized_text) and
                    original_pos < len(original_text) and
                    anonymized_text[anonymized_pos] == original_text[original_pos]):
                    anonymized_pos += 1
                    original_pos += 1
                else:
                    break

            # Find the end of the original entity
            entity_start = original_pos

            # Look ahead to find where the next non-placeholder text matches
            next_match_pos = self._find_next_match_position(
                original_text, anonymized_text, original_pos, placeholder_end
            )

            if next_match_pos > entity_start:
                spans.append({
                    "start": entity_start,
                    "end": next_match_pos,
                    "label": label
                })
                original_pos = next_match_pos

            anonymized_pos = placeholder_end

        return spans

    def _find_next_match_position(self, original_text: str, anonymized_text: str,
                                 original_pos: int, anonymized_pos: int) -> int:
        """Find the next position where original and anonymized text match."""
        # Look for the next sequence of characters that match
        window_size = 10  # Look ahead window

        for i in range(original_pos, len(original_text)):
            if anonymized_pos >= len(anonymized_text):
                break

            # Check if we have a match at this position
            match_found = True
            for j in range(min(window_size, len(original_text) - i, len(anonymized_text) - anonymized_pos)):
                if (i + j >= len(original_text) or
                    anonymized_pos + j >= len(anonymized_text) or
                    original_text[i + j] != anonymized_text[anonymized_pos + j]):
                    match_found = False
                    break

            if match_found:
                return i

        return len(original_text)

    def anonymize_text(self, text: str, labels: List[Dict]) -> str:
        """Anonymize text using provided labels (for manual override)."""
        # Sort labels by start position in reverse order to avoid position shifts
        sorted_labels = sorted(labels, key=lambda x: x['start'], reverse=True)

        anonymized = text
        for label_info in sorted_labels:
            start = label_info['start']
            end = label_info['end']
            label = label_info['label']

            anonymized = (anonymized[:start] +
                         f"<<{label}>>" +
                         anonymized[end:])

        return anonymized

    def write_json(self, file_name: str, original_text: str, anonymized_text: str,
                   labels: List[Dict], output_dir: str = "output"):
        """Write results to JSON file."""
        os.makedirs(output_dir, exist_ok=True)

        output_data = {
            "file": file_name,
            "text": original_text,
            "anonymized_text": anonymized_text,
            "labels": labels
        }

        output_file = os.path.join(output_dir, f"{Path(file_name).stem}.json")
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)

        logger.info(f"JSON output written to {output_file}")

    def process_file(self, file_path: str, output_format: str = "json") -> Dict:
        """Process a single text file."""
        logger.info(f"Processing file: {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                original_text = f.read()

            # Call Gemini API for anonymization
            anonymized_text = self.call_gemini_api(original_text)

            # Extract spans
            labels = self.extract_spans(original_text, anonymized_text)

            file_name = os.path.basename(file_path)

            if output_format == "json":
                self.write_json(file_name, original_text, anonymized_text, labels)
            elif output_format == "txt":
                self.write_txt(file_name, anonymized_text)

            return {
                "file": file_name,
                "original_text": original_text,
                "anonymized_text": anonymized_text,
                "labels": labels,
                "success": True
            }

        except Exception as e:
            logger.error(f"Error processing file {file_path}: {e}")
            return {
                "file": os.path.basename(file_path),
                "success": False,
                "error": str(e)
            }

    def write_txt(self, file_name: str, anonymized_text: str, output_dir: str = "output"):
        """Write anonymized text to a .txt file."""
        os.makedirs(output_dir, exist_ok=True)

        output_file = os.path.join(output_dir, file_name)
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(anonymized_text)

        logger.info(f"Anonymized text written to {output_file}")

    def process_directory(self, input_dir: str, output_format: str = "json") -> List[Dict]:
        """Process all .txt files in a directory."""
        input_path = Path(input_dir)

        if not input_path.exists():
            raise FileNotFoundError(f"Input directory '{input_dir}' does not exist")

        txt_files = list(input_path.glob("*.txt"))

        if not txt_files:
            logger.warning(f"No .txt files found in {input_dir}")
            return []

        results = []
        for file_path in txt_files:
            result = self.process_file(str(file_path), output_format)
            results.append(result)

        # Create summary JSON if processing multiple files
        if len(results) > 1 and output_format == "json":
            self.write_batch_summary(results)

        return results

    def write_batch_summary(self, results: List[Dict], output_dir: str = "output"):
        """Write a summary of batch processing results."""
        os.makedirs(output_dir, exist_ok=True)

        summary = {
            "total_files": len(results),
            "successful": sum(1 for r in results if r.get("success", False)),
            "failed": sum(1 for r in results if not r.get("success", False)),
            "files": {}
        }

        for result in results:
            if result.get("success", False):
                summary["files"][result["file"]] = result["anonymized_text"]

        summary_file = os.path.join(output_dir, "batch_summary.json")
        with open(summary_file, 'w', encoding='utf-8') as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)

        logger.info(f"Batch summary written to {summary_file}")

def main():
    """Main CLI function."""
    parser = argparse.ArgumentParser(description="Gemini Email Anonymization Pipeline")
    parser.add_argument("../../../data/original/ground_truth_split/test_norm.json", help="Input file or directory containing .txt files")
    parser.add_argument("--api-key", required=True, help="Google Gemini API key")
    parser.add_argument("--model", default="gemini-1.5-flash", help="Gemini model to use")
    parser.add_argument("--output-format", choices=["json", "txt"], default="json",
                       help="Output format (json or txt)")
    parser.add_argument("../../../data/testing/gemini_results", default="output", help="Output directory")

    args = parser.parse_args()

    # Initialize anonymizer
    anonymizer = GeminiAnonymizer(args.api_key, args.model)

    input_path = Path(args.input)

    try:
        if input_path.is_file():
            # Process single file
            result = anonymizer.process_file(str(input_path), args.output_format)
            if result["success"]:
                logger.info("File processed successfully")
            else:
                logger.error(f"Failed to process file: {result.get('error', 'Unknown error')}")

        elif input_path.is_dir():
            # Process directory
            results = anonymizer.process_directory(str(input_path), args.output_format)
            successful = sum(1 for r in results if r.get("success", False))
            logger.info(f"Processed {successful}/{len(results)} files successfully")

        else:
            logger.error(f"Input path '{args.input}' is neither a file nor a directory")
            return 1

    except Exception as e:
        logger.error(f"Error: {e}")
        return 1

    return 0

if __name__ == "__main__":
    exit(main())