# Improved Parallel De-identification with Ollama

This notebook provides several alternative approaches to handle parallel processing with Ollama while avoiding concurrent errors. It builds on the original de-identification notebook but adds more robust concurrency control mechanisms.

## Key Improvements:

1. **Rate Limiting**: Controls how many requests are sent to Ollama at once
2. **Semaphore-based Concurrency Control**: Limits the number of concurrent API calls
3. **Backoff Strategy**: Implements exponential backoff for retries
4. **Queue-based Processing**: Option for a more controlled processing flow
5. **Multiple Implementation Options**: Choose the approach that works best for your system

Select one of the implementation approaches below based on your needs.

In [None]:
import pandas as pd
import requests
import json
import math
import os
import time
import random
import queue
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import threading
from functools import partial
import logging

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

## 1. Configuration (User Input Required)

**Please fill in the variables in the cell below before running the notebook.**

In [None]:
# --- REQUIRED: Specify the path to your CSV file ---
INPUT_CSV_PATH = "/Users/dli989/Documents/RECOVER-local/liebovitz-llm_deidentified_notes.csv" 

# --- REQUIRED: List the names of the columns you want to de-identify ---
COLUMNS_TO_CLEAN = ["note_text"] 

# --- OPTIONAL: Advanced Settings ---
REPLACE_ORIGINAL_COLUMN = True      # True: Replaces original column(s). False: Adds new de-identified column(s).
OUTPUT_PREFIX = "deidentified_output_post_LLM" # The prefix for the output files
MAX_ROWS_PER_BATCH = 100         # The max number of rows to process in a single batch file.

# --- Concurrency Settings ---
MAX_WORKERS = 5                   # Reduced from 10 to avoid overwhelming Ollama
MAX_CONCURRENT_REQUESTS = 3       # Maximum number of concurrent requests to Ollama
USE_RATE_LIMITING = True          # Enable rate limiting for API calls
RATE_LIMIT_CALLS = 3              # Maximum calls per time period
RATE_LIMIT_PERIOD = 1             # Time period in seconds

# --- Processing Settings ---
MAX_RETRIES = 4                    # Number of times to retry processing a row before marking it as "unable to deidentify"
DEIDENTIFICATION_PASSES = 2        # Number of passes through the LLM to catch any PHI missed in the first pass
IMPLEMENTATION_APPROACH = "semaphore"  # Options: "semaphore", "rate_limit", "queue", "process_pool"

# --- Ollama Settings ---
OLLAMA_API_URL = "http://localhost:11434/api/generate"
MODEL_NAME = "gemma3:4b"             # Or any other model you have available
MAX_CHUNK_SIZE = 5000               # Max characters for a single note before it's split for processing.

## 2. Core Functions with Improved Concurrency Control

These functions handle the communication with the Ollama API with better concurrency control.

In [None]:
# Create a semaphore to limit concurrent API calls
api_semaphore = threading.Semaphore(MAX_CONCURRENT_REQUESTS)

# Create a rate limiter
class RateLimiter:
    def __init__(self, max_calls, period):
        self.max_calls = max_calls
        self.period = period
        self.calls = []
        self.lock = threading.Lock()
        
    def __enter__(self):
        with self.lock:
            now = time.time()
            # Remove calls older than the period
            self.calls = [t for t in self.calls if now - t < self.period]
            
            # If we've reached the maximum calls, wait until we can make another
            if len(self.calls) >= self.max_calls:
                sleep_time = self.period - (now - self.calls[0])
                if sleep_time > 0:
                    time.sleep(sleep_time)
                    
            # Add the current call time
            self.calls.append(time.time())
            
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

rate_limiter = RateLimiter(RATE_LIMIT_CALLS, RATE_LIMIT_PERIOD)

# Queue for controlled processing
request_queue = queue.Queue()
response_queue = queue.Queue()

def call_ollama_api_with_semaphore(text_chunk, row_index=None, retry_count=0):
    """
    Sends a text chunk to the Ollama API with semaphore-based concurrency control.
    """
    if retry_count >= MAX_RETRIES:
        return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"
    
    prompt = create_prompt(text_chunk)
    payload = {
        "model": MODEL_NAME,
        "prompt": prompt,
        "stream": False
    }
    
    # Use the semaphore to limit concurrent requests
    with api_semaphore:
        try:
            # Add jitter to avoid thundering herd problem
            time.sleep(random.uniform(0.1, 0.5))
            response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
            response.raise_for_status()
            response_data = response.json()
            return response_data.get('response', '').strip()
        except requests.exceptions.RequestException as e:
            retry_count += 1
            if row_index is not None:
                print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
            
            # Exponential backoff with jitter
            if retry_count < MAX_RETRIES:
                backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                time.sleep(backoff_time)
                return call_ollama_api_with_semaphore(text_chunk, row_index, retry_count)
            else:
                return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"

def call_ollama_api_with_rate_limit(text_chunk, row_index=None, retry_count=0):
    """
    Sends a text chunk to the Ollama API with rate limiting.
    """
    if retry_count >= MAX_RETRIES:
        return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"
    
    prompt = create_prompt(text_chunk)
    payload = {
        "model": MODEL_NAME,
        "prompt": prompt,
        "stream": False
    }
    
    # Use rate limiting
    with rate_limiter:
        try:
            response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
            response.raise_for_status()
            response_data = response.json()
            return response_data.get('response', '').strip()
        except requests.exceptions.RequestException as e:
            retry_count += 1
            if row_index is not None:
                print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
            
            # Exponential backoff with jitter
            if retry_count < MAX_RETRIES:
                backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                time.sleep(backoff_time)
                return call_ollama_api_with_rate_limit(text_chunk, row_index, retry_count)
            else:
                return "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"

def queue_worker():
    """
    Worker function for queue-based processing.
    """
    while True:
        try:
            # Get a task from the queue
            task = request_queue.get(timeout=1)
            if task is None:  # Sentinel value to stop the worker
                request_queue.task_done()
                break
                
            text_chunk, row_index, retry_count = task
            
            # Process the task
            try:
                prompt = create_prompt(text_chunk)
                payload = {
                    "model": MODEL_NAME,
                    "prompt": prompt,
                    "stream": False
                }
                
                # Add jitter to avoid thundering herd problem
                time.sleep(random.uniform(0.1, 0.5))
                response = requests.post(OLLAMA_API_URL, json=payload, timeout=120)
                response.raise_for_status()
                response_data = response.json()
                result = response_data.get('response', '').strip()
                
                # Put the result in the response queue
                response_queue.put((row_index, result))
                
            except requests.exceptions.RequestException as e:
                retry_count += 1
                if row_index is not None:
                    print(f"\nError processing row {row_index}: {e} (Attempt {retry_count}/{MAX_RETRIES})")
                
                # Retry if we haven't reached the maximum retries
                if retry_count < MAX_RETRIES:
                    # Exponential backoff with jitter
                    backoff_time = min(2 ** retry_count + random.uniform(0, 1), 10)
                    time.sleep(backoff_time)
                    request_queue.put((text_chunk, row_index, retry_count))
                else:
                    response_queue.put((row_index, "[UNABLE TO DEIDENTIFY: Maximum retry attempts reached]"))
            
            # Mark the task as done
            request_queue.task_done()
                
        except queue.Empty:
            # If the queue is empty, wait a bit and try again
            time.sleep(0.1)
        except Exception as e:
            logging.error(f"Error in queue worker: {e}")
            request_queue.task_done()

def create_prompt(text_chunk):
    """
    Creates the prompt for the Ollama API.
    """
    return f"""[SYSTEM]
    You are an automated de-identification system. Your sole function is to process the text provided and return a clean version with all Protected Health Information (PHI) replaced by the appropriate category label. You are precise and do not deviate from your instructions.

    **Instructions:**

    1.  **Identify and Replace:** Your task is to find all instances of the following PHI categories in the user-provided text and replace them with their corresponding labels:
        * Names of individuals: `[PERSON]`
        * All dates (including full dates, partial dates, and days of the week): `[DATE]`
        * Geographical locations (cities, states, addresses, etc.): `[LOCATION]`
        * Phone and fax numbers: `[PHONE]`
        * Email addresses: `[EMAIL]`
        * Any identification numbers (e.g., Social Security Numbers, Medical Record Numbers, account numbers): `[ID_NUMBER]`

    2.  **Output Format:**
        * The output MUST be only the modified text.
        * Do NOT include any introductory phrases, explanations, or apologies.
        * The response should not contain any of the original PHI.

    **Examples:**

    * **Input:** "John Smith visited the clinic on January 5, 2024. His MRN is 12345."
    * **Output:** "[PERSON] visited the clinic on [DATE]. His MRN is [ID_NUMBER]."

    * **Input:** "Patient can be reached at (555) 123-4567 or jane.doe@email.com."
    * **Output:** "Patient can be reached at [PHONE] or [EMAIL]."

    **Text for De-identification:**

    ---
    {text_chunk}
    ---
    """

def deidentify_text(full_text, row_index=None, implementation=IMPLEMENTATION_APPROACH):
    """
    Manages the de-identification of a full text string, handling chunking if necessary.
    Performs multiple passes through the LLM based on DEIDENTIFICATION_PASSES setting.
    Uses the specified implementation approach for API calls.
    """
    if not isinstance(full_text, str) or not full_text.strip():
        return ""
    
    # Select the appropriate API call function based on the implementation approach
    if implementation == "semaphore":
        api_call_func = call_ollama_api_with_semaphore
    elif implementation == "rate_limit":
        api_call_func = call_ollama_api_with_rate_limit
    else:
        # Default to semaphore-based approach
        api_call_func = call_ollama_api_with_semaphore
    
    # Process the text through multiple passes to catch any missed PHI
    processed_text = full_text
    
    for pass_num in range(DEIDENTIFICATION_PASSES):
        if pass_num > 0 and row_index is not None:
            print(f"\r    - Row {row_index}: Pass {pass_num+1}/{DEIDENTIFICATION_PASSES}", end="")
            
        if len(processed_text) <= MAX_CHUNK_SIZE:
            processed_text = api_call_func(processed_text, row_index)
        else:
            chunks = [processed_text[i:i + MAX_CHUNK_SIZE] for i in range(0, len(processed_text), MAX_CHUNK_SIZE)]
            processed_chunks = []
            
            for i, chunk in enumerate(chunks):
                processed_chunk = api_call_func(chunk, row_index)
                processed_chunks.append(processed_chunk)
                
            processed_text = "".join(processed_chunks)
    
    return processed_text

def process_row_parallel(args):
    """
    Helper function for parallel processing of individual rows.
    Returns tuple of (original_index, processed_text) to maintain order.
    """
    row_index, original_text = args
    processed_text = deidentify_text(original_text, row_index, IMPLEMENTATION_APPROACH)
    return (row_index, processed_text)

## 3. Queue-based Processing Implementation

This is an alternative implementation using a queue-based approach for more controlled processing.

In [None]:
def process_with_queue(row_data, total_in_batch):
    """
    Process rows using a queue-based approach for better control over concurrency.
    """
    # Clear the queues
    while not request_queue.empty():
        try:
            request_queue.get_nowait()
            request_queue.task_done()
        except queue.Empty:
            break
            
    while not response_queue.empty():
        try:
            response_queue.get_nowait()
        except queue.Empty:
            break
    
    # Start worker threads
    workers = []
    for _ in range(MAX_CONCURRENT_REQUESTS):
        worker = threading.Thread(target=queue_worker)
        worker.daemon = True
        worker.start()
        workers.append(worker)
    
    # Add tasks to the queue
    for row_index, original_text in row_data:
        request_queue.put((original_text, row_index, 0))
    
    # Wait for all tasks to be processed
    processed_results = {}
    completed_count = 0
    
    while completed_count < len(row_data):
        try:
            row_index, processed_text = response_queue.get(timeout=1)
            processed_results[row_index] = processed_text
            completed_count += 1
            
            # Print progress
            print(f"\r    - Row {completed_count}/{total_in_batch}", end="")
            
        except queue.Empty:
            # If the queue is empty, wait a bit and try again
            time.sleep(0.1)
    
    # Stop the workers
    for _ in range(len(workers)):
        request_queue.put(None)  # Sentinel value to stop the worker
    
    # Wait for all workers to finish
    for worker in workers:
        worker.join()
    
    return processed_results

## 4. Process Pool Implementation

This implementation uses ProcessPoolExecutor instead of ThreadPoolExecutor for potentially better performance on CPU-bound tasks.

In [None]:
def process_with_process_pool(row_data, total_in_batch):
    """
    Process rows using ProcessPoolExecutor for potentially better performance on CPU-bound tasks.
    """
    # Initialize results dictionary to maintain order
    processed_results = {}
    completed_count = 0
    
    # Define a worker function that can be pickled
    def worker_func(row_index, text):
        return row_index, deidentify_text(text, row_index, "semaphore")
    
    # Use ProcessPoolExecutor for parallel processing
    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit all tasks
        future_to_index = {executor.submit(worker_func, idx, text): idx for idx, text in row_data}
        
        # Process completed tasks as they finish
        for future in as_completed(future_to_index):
            original_index, processed_text = future.result()
            processed_results[original_index] = processed_text
            completed_count += 1
            
            # Print progress
            print(f"\r    - Row {completed_count}/{total_in_batch}", end="")
    
    return processed_results

## 5. Main Processing Function Stub

This is a stub for the main processing function. The complete implementation is in the second notebook (`improved_parallel_deid_part2.ipynb`).

In [None]:
def process_large_csv(input_path, output_prefix, columns_to_clean):
    """
    Stub for the main processing function.
    The complete implementation is in improved_parallel_deid_part2.ipynb.
    """
    print("This is a stub. Please run the complete implementation in improved_parallel_deid_part2.ipynb")