# Parallel De-identification with Ollama

This notebook provides a guide to de-identifying data in parallel using Ollama. It offers several approaches to manage concurrent requests to the Ollama API, ensuring robust and efficient processing.

**How to Select an Implementation Approach:**
You can choose the implementation approach by setting the `IMPLEMENTATION_APPROACH` variable in the **Configuration** section below. The default is `'semaphore'`, which is a reliable and performant option for most use cases.

**Troubleshooting Guidance:**
If you experience issues such as timeouts or connection errors, you can adjust the settings in the **Configuration** section. Here are some specific recommendations:

1.  **Reduce `MAX_CONCURRENT_REQUESTS`**: Start by lowering this value to `2` or `1`. This will significantly reduce the load on the Ollama server.
2.  **Try the `queue` approach**: If reducing concurrent requests doesn't solve the issue, change the `IMPLEMENTATION_APPROACH` to `'queue'`. This can provide more stability in some environments.
3.  **Increase `MAX_RETRIES`**: If you are still seeing occasional errors, you can increase the `MAX_RETRIES` value to `5` or `6` to give the script more chances to succeed.

In [30]:
import pandas as pd
import requests
import json
import math
import os
import time
import random
import queue
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import threading
from functools import partial
import logging

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

## 1. Configuration

Before running the notebook, please specify the required variables in the following cell. This includes the path to your input CSV file, the columns to de-identify, and other processing settings.

In [31]:
# --- REQUIRED: Specify the path to your CSV file ---
INPUT_CSV_PATH = "tester.csv" 

# --- REQUIRED: List the names of the columns you want to de-identify ---
COLUMNS_TO_CLEAN = ["patient_id", "first_name", "last_name", "dob", "phone_number","note_text"] 

# --- OPTIONAL: Advanced Settings ---
REPLACE_ORIGINAL_COLUMN = True      # True: Replaces original column(s). False: Adds new de-identified column(s).
OUTPUT_PREFIX = "deidentified_output_post_LLM" # The prefix for the output files
MAX_ROWS_PER_BATCH = 100         # The max number of rows to process in a single batch file.

# --- Concurrency Settings ---
MAX_WORKERS = 5                   # Reduced from 10 to avoid overwhelming Ollama
MAX_CONCURRENT_REQUESTS = 3       # Maximum number of concurrent requests to Ollama
USE_RATE_LIMITING = True          # Enable rate limiting for API calls
RATE_LIMIT_CALLS = 3              # Maximum calls per time period
RATE_LIMIT_PERIOD = 1             # Time period in seconds

# --- Processing Settings ---
MAX_RETRIES = 4                    # Number of times to retry processing a row before marking it as "unable to deidentify"
DEIDENTIFICATION_PASSES = 2        # Number of passes through the LLM to catch any PHI missed in the first pass
IMPLEMENTATION_APPROACH = "semaphore"  # Options: "semaphore", "rate_limit", "queue", "process_pool"

# --- Ollama Settings ---
OLLAMA_API_URL = "http://localhost:11434/api/generate"
MODEL_NAME = "gemma3:4b"             # Or any other model you have available
MAX_CHUNK_SIZE = 5000               # Max characters for a single note before it's split for processing.

## 2. Core Functions for Ollama API Interaction

These functions handle the communication with the Ollama API. They include different strategies for managing concurrency to prevent overwhelming the API with too many requests at once.

In [32]:
# Create a semaphore to limit concurrent API calls
api_semaphore = threading.Semaphore(MAX_CONCURRENT_REQUESTS)

# Create a rate limiter
class RateLimiter:
    def __init__(self, max_calls, period):
        self.max_calls = max_calls
        self.period = period
        self.calls = []
        self.lock = threading.Lock()
        
    def __enter__(self):
        with self.lock:
            now = time.time()
            # Remove calls older than the period
            self.calls = [t for t in self.calls if now - t < self.period]
            
            # If we've reached the maximum calls, wait until we can make another
            if len(self.calls) >= self.max_calls:
                sleep_time = self.period - (now - self.calls[0])
                if sleep_time > 0:
                    time.sleep(sleep_time)
                    
            # Add the current call time
            self.calls.append(time.time())
            
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

rate_limiter = RateLimiter(RATE_LIMIT_CALLS, RATE_LIMIT_PERIOD)

# Queue for controlled processing
request_queue = queue.Queue()
response_queue = queue.Queue()

def call_ollama_api_with_semaphore(text_chunk, row_index=None, retry_count=0):
    """
    Sends a text chunk to the Ollama API with semaphore-based concurrency control.
    """
    if retry_count >= MAX_RETRIES:
        return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"
    
    prompt = create_prompt(text_chunk)
    payload = {
        "model": MODEL_NAME,
        "prompt": prompt,
        "stream": False
    }
    
    # Use the semaphore to limit concurrent requests
    with api_semaphore:
        try:
            # Add jitter to avoid thundering herd problem
            time.sleep(random.uniform(0.1, 0.5))
            response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
            response.raise_for_status()
            response_data = response.json()
            return response_data.get('response', '').strip()
        except requests.exceptions.RequestException as e:
            retry_count += 1
            if row_index is not None:
                print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
            
            # Exponential backoff with jitter
            if retry_count < MAX_RETRIES:
                backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                time.sleep(backoff_time)
                return call_ollama_api_with_semaphore(text_chunk, row_index, retry_count)
            else:
                return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"

def call_ollama_api_with_rate_limit(text_chunk, row_index=None, retry_count=0):
    """
    Sends a text chunk to the Ollama API with rate limiting.
    """
    if retry_count >= MAX_RETRIES:
        return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"
    
    prompt = create_prompt(text_chunk)
    payload = {
        "model": MODEL_NAME,
        "prompt": prompt,
        "stream": False
    }
    
    # Use rate limiting
    with rate_limiter:
        try:
            response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
            response.raise_for_status()
            response_data = response.json()
            return response_data.get('response', '').strip()
        except requests.exceptions.RequestException as e:
            retry_count += 1
            if row_index is not None:
                print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
            
            # Exponential backoff with jitter
            if retry_count < MAX_RETRIES:
                backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                time.sleep(backoff_time)
                return call_ollama_api_with_rate_limit(text_chunk, row_index, retry_count)
            else:
                return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"

def queue_worker():
    """
    Worker function for queue-based processing.
    """
    while True:
        try:
            # Get a task from the queue
            task = request_queue.get(timeout=1)
            if task is None:  # Sentinel value to stop the worker
                request_queue.task_done()
                break
                
            text_chunk, row_index, retry_count = task
            
            # Process the task
            try:
                prompt = create_prompt(text_chunk)
                payload = {
                    "model": MODEL_NAME,
                    "prompt": prompt,
                    "stream": False
                }
                
                # Add jitter to avoid thundering herd problem
                time.sleep(random.uniform(0.1, 0.5))
                response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
                response.raise_for_status()
                response_data = response.json()
                result = response_data.get('response', '').strip()
                
                # Put the result in the response queue
                response_queue.put((row_index, result))
                
            except requests.exceptions.RequestException as e:
                retry_count += 1
                if row_index is not None:
                    print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
                
                # Retry if we haven't reached the maximum retries
                if retry_count < MAX_RETRIES:
                    # Exponential backoff with jitter
                    backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                    time.sleep(backoff_time)
                    request_queue.put((text_chunk, row_index, retry_count))
                else:
                    response_queue.put((row_index, "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"))
            
            # Mark the task as done
            request_queue.task_done()
                
        except queue.Empty:
            # If the queue is empty, wait a bit and try again
            time.sleep(0.1)
        except Exception as e:
            logging.error(f"Error in queue worker: {e}")
            request_queue.task_done()

def create_prompt(text_chunk):
    """
    Creates the prompt for the Ollama API with enhanced PHI detection.
    """
    return f"""[SYSTEM]
    You are an automated de-identification system. Your sole function is to process the text provided and return a clean version with all Protected Health Information (PHI) replaced by the appropriate category label. You are precise and do not deviate from your instructions.

    **Instructions:**

    1.  **Identify and Replace:** Your task is to find all instances of the following PHI categories in the user-provided text and replace them with their corresponding labels:
        * Names of individuals: `[PERSON]`
        * All dates (including full dates, partial dates, and days of the week): `[DATE]`
        * Geographical locations (cities, states, addresses, etc.): `[LOCATION]`
        * Phone and fax numbers: `[PHONE]`
        * Email addresses: `[EMAIL]`
        * Any identification numbers including patient IDs, medical record numbers, account numbers, social security numbers, etc.: `[ID_NUMBER]`
        * Other categories such as medical conditions, medications, and health-related information should remain unchanged unless they contain PHI.

    2.  **Output Format:**
        * The output MUST be only the modified text.
        * Do NOT include any introductory phrases, explanations, or apologies.
        * The response should not contain any of the original PHI.
        * If the input is ONLY a number that could be an ID, replace it with [ID_NUMBER].

    **Examples:**

    * **Input:** "John Smith visited the clinic on January 5, 2024. His MRN is 12345. He has HTN and takes lisinopril."
    * **Output:** "[PERSON] visited the clinic on [DATE]. His MRN is [ID_NUMBER]. He has HTN and takes lisinopril."

    * **Input:** "101"
    * **Output:** "[ID_NUMBER]"

    * **Input:** "Patient ID: 12345"
    * **Output:** "Patient ID: [ID_NUMBER]"

    * **Input:** "Patient can be reached at (555) 123-4567 or jane.doe@email.com."
    * **Output:** "Patient can be reached at [PHONE] or [EMAIL]."

    **Text for De-identification:**

    ---
    {text_chunk}
    ---
    """

def deidentify_text(full_text, row_index=None, implementation=IMPLEMENTATION_APPROACH):
    """
    Manages the de-identification of a full text string, handling chunking if necessary.
    Performs multiple passes through the LLM based on DEIDENTIFICATION_PASSES setting.
    Uses the specified implementation approach for API calls.
    """
    if not isinstance(full_text, str) or not full_text.strip():
        return ""
    
    # Select the appropriate API call function based on the implementation approach
    if implementation == "semaphore":
        api_call_func = call_ollama_api_with_semaphore
    elif implementation == "rate_limit":
        api_call_func = call_ollama_api_with_rate_limit
    else:
        # Default to semaphore-based approach
        api_call_func = call_ollama_api_with_semaphore
    
    # Process the text through multiple passes to catch any missed PHI
    processed_text = full_text
    
    for pass_num in range(DEIDENTIFICATION_PASSES):
        if pass_num > 0 and row_index is not None:
            print(f"\r    - Row {row_index}: Pass {pass_num+1}/{DEIDENTIFICATION_PASSES}", end="")
            
        if len(processed_text) <= MAX_CHUNK_SIZE:
            processed_text = api_call_func(processed_text, row_index)
        else:
            chunks = [processed_text[i:i + MAX_CHUNK_SIZE] for i in range(0, len(processed_text), MAX_CHUNK_SIZE)]
            processed_chunks = []
            
            for i, chunk in enumerate(chunks):
                processed_chunk = api_call_func(chunk, row_index)
                processed_chunks.append(processed_chunk)
                
            processed_text = "".join(processed_chunks)
    
    return processed_text

def process_row_parallel(args):
    """
    Helper function for parallel processing of individual rows.
    Returns tuple of (original_index, processed_text) to maintain order.
    """
    row_index, original_text = args
    processed_text = deidentify_text(original_text, row_index, IMPLEMENTATION_APPROACH)
    return (row_index, processed_text)

## 3. Processing Implementations

This section contains different implementations for processing the data. You can choose the one that best fits your needs by setting the `IMPLEMENTATION_APPROACH` variable in the configuration section.

### Semaphore-based Approach
A semaphore is a synchronization primitive that controls access to a shared resource. In this context, it limits the number of concurrent API calls to Ollama. This is a simple and effective way to prevent overloading the API.

### Rate Limiting Approach
Rate limiting controls the number of requests sent to the API within a specific time period. This approach is useful for APIs that have a strict rate limit policy.

### Queue-based Approach
A queue-based approach provides a more controlled way of processing the data. Requests are added to a queue and processed by a pool of workers. This ensures that the number of concurrent requests never exceeds the specified limit and provides a buffer for incoming requests.

### Process Pool Approach
This approach uses a pool of processes to execute the de-identification tasks in parallel. It can be more efficient for CPU-bound tasks, as it leverages multiple CPU cores.

In [33]:
def process_with_queue(row_data, total_in_batch):
    """
    Process rows using a queue-based approach for better control over concurrency.
    """
    # Clear the queues
    while not request_queue.empty():
        try:
            request_queue.get_nowait()
            request_queue.task_done()
        except queue.Empty:
            break
            
    while not response_queue.empty():
        try:
            response_queue.get_nowait()
        except queue.Empty:
            break
    
    # Start worker threads
    workers = []
    for _ in range(MAX_CONCURRENT_REQUESTS):
        worker = threading.Thread(target=queue_worker)
        worker.daemon = True
        worker.start()
        workers.append(worker)
    
    # Add tasks to the queue
    for row_index, original_text in row_data:
        request_queue.put((original_text, row_index, 0))
    
    # Wait for all tasks to be processed
    processed_results = {}
    completed_count = 0
    
    while completed_count < len(row_data):
        try:
            row_index, processed_text = response_queue.get(timeout=1)
            processed_results[row_index] = processed_text
            completed_count += 1
            
            # Print progress
            print(f"\r    - Row {completed_count}/{total_in_batch}", end="")
            
        except queue.Empty:
            # If the queue is empty, wait a bit and try again
            time.sleep(0.1)
        except Exception as e:
            logging.error(f"Error in queue worker: {e}")
            request_queue.task_done()

    return processed_results

## 4. Main Processing Function

This is the main function that orchestrates the de-identification process. It reads the input CSV, splits it into batches, and processes each batch using the selected implementation approach.

In [34]:
def process_with_process_pool(row_data, total_in_batch):
    """
    Process rows using ProcessPoolExecutor for potentially better performance on CPU-bound tasks.
    """
    # Initialize results dictionary to maintain order
    processed_results = {}
    completed_count = 0
    
    # Define a worker function that can be pickled
    def worker_func(row_index, text):
        return row_index, deidentify_text(text, row_index, "semaphore")
    
    # Use ProcessPoolExecutor for parallel processing
    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit all tasks
        future_to_index = {executor.submit(worker_func, idx, text): idx for idx, text in row_data}
        
        # Process completed tasks as they finish
        for future in as_completed(future_to_index):
            original_index, processed_text = future.result()
            processed_results[original_index] = processed_text
            completed_count += 1
            
            # Print progress
            print(f"\r    - Row {completed_count}/{total_in_batch}", end="")
    
    return processed_results

## 5. Run the Process

Execute the following cell to start the de-identification process. The script will use the settings you provided in the Configuration section.

In [35]:
def process_large_csv_complete(input_path, output_prefix, columns_to_clean):
    """
    Reads a potentially large CSV, splits it into batches, de-identifies the 
    specified columns, and saves new numbered CSV files for each batch.
    Uses the specified implementation approach for concurrency control.
    """
    if not os.path.exists(input_path) or input_path == "/path/to/your/file.csv":
        print("ERROR: Input file not found or path not set.")
        print(f"Please update the 'INPUT_CSV_PATH' variable in the Configuration section.")
        return
    
    # Get the directory of the input file to save output files in the same location
    output_dir = os.path.dirname(os.path.abspath(input_path))
    
    try:
        print(f"Reading and preparing CSV from '{input_path}'...")
        df_iterator = pd.read_csv(input_path, chunksize=MAX_ROWS_PER_BATCH, on_bad_lines='warn')
        with open(input_path, 'r', encoding='utf-8', errors='ignore') as f:
            total_rows = sum(1 for row in f) - 1 # -1 for header
        num_batches = math.ceil(total_rows / MAX_ROWS_PER_BATCH)

    except Exception as e:
        print(f"Error reading input file: {e}")
        return

    print(f"Total rows: {total_rows}. This will be processed in {num_batches} batch(es).")
    print(f"Output files will be saved in: {output_dir}")
    print(f"Each text will undergo {DEIDENTIFICATION_PASSES} pass(es) through the LLM.")
    print(f"Using implementation approach: {IMPLEMENTATION_APPROACH}")
    print(f"Maximum concurrent requests to Ollama: {MAX_CONCURRENT_REQUESTS}")
    
    # Track processed batches for summary
    processed_batches = []
    skipped_batches = []

    for i, batch_df in enumerate(df_iterator):
        batch_num = i + 1
        output_path = os.path.join(output_dir, f"{output_prefix}_part_{batch_num}.csv")
        
        if os.path.exists(output_path):
            print(f"\nOutput file '{output_path}' already exists. Skipping Batch {batch_num}.")
            skipped_batches.append(batch_num)
            continue

        print(f"\n--- Processing Batch {batch_num}/{num_batches} ---")
        
        for column_name in columns_to_clean:
            if column_name not in batch_df.columns:
                print(f"  - WARNING: Column '{column_name}' not found in this batch. Skipping.")
                continue
            
            print(f"  - De-identifying column: '{column_name}' using {IMPLEMENTATION_APPROACH} approach")
            
            total_in_batch = len(batch_df)
            
            # Prepare data for processing: (row_index, text_content)
            row_data = [(idx, str(row[column_name])) for idx, row in batch_df.iterrows()]
            
            # Initialize results dictionary to maintain order
            processed_results = {}
            completed_count = 0
            
            # Use ThreadPoolExecutor for parallel processing
            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                # Submit all tasks
                future_to_index = {executor.submit(process_row_parallel, data): data[0] for data in row_data}
                
                # Process completed tasks as they finish
                for future in as_completed(future_to_index):
                    original_index, processed_text = future.result()
                    processed_results[original_index] = processed_text
                    completed_count += 1
                    
                    # Print progress
                    print(f"\r    - Row {completed_count}/{total_in_batch}", end="")
            
            print()  # Newline after the progress bar for a column is complete
            
            # Reconstruct the processed data in the original order
            processed_data = [processed_results[idx] for idx, _ in batch_df.iterrows()]
            
            # Update the dataframe with processed data
            if REPLACE_ORIGINAL_COLUMN:
                batch_df[column_name] = processed_data
            else:
                batch_df[f"{column_name}_deidentified"] = processed_data
        
        # Save the processed batch to CSV
        try:
            batch_df.to_csv(output_path, index=False)
            print(f"  - Saved batch to: '{output_path}'")
            processed_batches.append(batch_num)
        except Exception as e:
            print(f"  - ERROR: Could not save batch {batch_num}: {e}")
    
    print(f"\n--- Processing Complete ---")
    print(f"All batches have been processed and saved with prefix '{output_prefix}'.")
    
    # Print summary of processed and skipped batches
    if processed_batches:
        print(f"\nProcessed batches: {', '.join(map(str, processed_batches))}")
    if skipped_batches:
        print(f"Skipped batches (already existed): {', '.join(map(str, skipped_batches))}")
    
    # Print the location of the output files
    print(f"\nOutput files are located in: {output_dir}")
    print(f"File naming pattern: {output_prefix}_part_X.csv where X is the batch number")
    
    # List the output files that exist
    existing_output_files = [f for f in os.listdir(output_dir) if f.startswith(output_prefix) and f.endswith('.csv')]
    if existing_output_files:
        print(f"\nFound {len(existing_output_files)} output files")
        for file in sorted(existing_output_files):
            file_path = os.path.join(output_dir, file)
            file_size = os.path.getsize(file_path) / (1024 * 1024)  # Convert to MB
            print(f"  - {file} ({file_size:.2f} MB)")
    else:
        print(f"\nNo output files found with prefix '{output_prefix}' in {output_dir}")

## 4. Run the Process

Execute the main function. This will start the process using the file and columns you specified in the Configuration section.

In [36]:
# This cell runs the main function with the current settings.
process_large_csv_complete(
    input_path=INPUT_CSV_PATH,
    output_prefix=OUTPUT_PREFIX, 
    columns_to_clean=COLUMNS_TO_CLEAN
)

Reading and preparing CSV from 'tester.csv'...
Total rows: 8. This will be processed in 1 batch(es).
Output files will be saved in: /Users/david/Documents/Documents - David’s iMac/GitHub/deid-ollama
Each text will undergo 2 pass(es) through the LLM.
Using implementation approach: semaphore
Maximum concurrent requests to Ollama: 3

--- Processing Batch 1/1 ---
  - De-identifying column: 'patient_id' using semaphore approach
    - Row 8/8Pass 2/2
  - De-identifying column: 'first_name' using semaphore approach
    - Row 8/8Pass 2/2
  - De-identifying column: 'last_name' using semaphore approach
    - Row 8/8Pass 2/2
  - De-identifying column: 'dob' using semaphore approach
    - Row 8/8Pass 2/2
  - De-identifying column: 'phone_number' using semaphore approach
    - Row 8/8Pass 2/2
  - De-identifying column: 'note_text' using semaphore approach
    - Row 8/8Pass 2/2
  - Saved batch to: '/Users/david/Documents/Documents - David’s iMac/GitHub/deid-ollama/deidentified_output_post_LLM_part_1