In [None]:
! pip install -r requirements.txt --quiet


In [2]:
from dotenv import load_dotenv
from azure.identity.aio import DefaultAzureCredential
from azure.core.credentials import AzureKeyCredential
import os

load_dotenv(override=True) # take environment variables from .env.

# Variables not used here do not need to be updated in your .env file
source_endpoint = os.environ["AZURE_SEARCH_SOURCE_SERVICE_ENDPOINT"]
# Using a key is optional. See https://learn.microsoft.com/en-us/azure/search/keyless-connections
source_credential = AzureKeyCredential(os.getenv("AZURE_SEARCH_SOURCE_ADMIN_KEY")) if os.getenv("AZURE_SEARCH_SOURCE_ADMIN_KEY") else DefaultAzureCredential()
destination_endpoint = os.environ["AZURE_SEARCH_DESTINATION_SERVICE_ENDPOINT"]
destination_credential = AzureKeyCredential(os.getenv("AZURE_SEARCH_DESTINATION_ADMIN_KEY")) if os.getenv("AZURE_SEARCH_DESTINATION_ADMIN_KEY") else DefaultAzureCredential()
index_name = os.environ["AZURE_SEARCH_INDEX"]
timestamp_field_name = os.environ["AZURE_SEARCH_TIMESTAMP_FIELD"]

In [3]:
from azure.search.documents.indexes.aio import SearchIndexClient

async def copy_index_definition(source_index_client: SearchIndexClient, destination_index_client: SearchIndexClient, index_name: str):
    index = await source_index_client.get_index(index_name)
    # Check for any synonym maps
    synonym_map_names = []
    for field in index.fields:
        if field.synonym_map_names:
            synonym_map_names.extend(field.synonym_map_names)
    
    # Copy over synonym maps if they exist
    for synonym_map_name in synonym_map_names:
        synonym_map = await source_index_client.get_synonym_map(synonym_map_name)
        await destination_index_client.create_or_update_synonym_map(synonym_map)
    
    # Copy over the index
    await destination_index_client.create_or_update_index(index)

In [4]:
source_index_client = SearchIndexClient(endpoint=source_endpoint, credential=source_credential)
destination_index_client = SearchIndexClient(endpoint=destination_endpoint, credential=destination_credential)

await copy_index_definition(source_index_client, destination_index_client, index_name)

In [5]:
from azure.search.documents.indexes.aio import SearchIndexClient
from azure.search.documents.indexes.models import SearchFieldDataType
from typing import List

async def validate_resume_backup_and_restore(index_client: SearchIndexClient, index_name: str, timestamp_field_name: str) -> bool:
    index = await index_client.get_index(index_name)

    found_field = False
    for field in index.fields:
        if field.name == timestamp_field_name:
            found_field = True
            if field.type != SearchFieldDataType.DateTimeOffset:
                # Field must be a timestamp
                return False
            if not field.filterable:
                # Field must be filterable
                return False
            if not field.sortable:
                # Field must be sortable
                return False
            break
    
    # Field must exist on the index
    return found_field

async def validate_fields_backup_and_restore(index_client: SearchIndexClient, index_name: str) -> List[str]:
    missing_fields = []
    index = await index_client.get_index(index_name)
    for field in index.fields:
        message = ""
        if not field.stored:
            message += f"Field {field.name} cannot be backed up because it's not marked as stored\n"
        elif field.hidden: 
            message += f"Field {field.name} cannot be backed up because it's not marked as retrievable\n"
        
        if message:
            missing_fields.append(message)
    
    return missing_fields


In [6]:
can_resume_backup_and_restore = await validate_resume_backup_and_restore(source_index_client, index_name, timestamp_field_name)
if can_resume_backup_and_restore:
    print("Index has a valid timestamp field and can use resumable backup and restore")
else:
    print("Index does not have a valid timestamp field and cannot use resumable backup and restore")

Index has a valid timestamp field and can use resumable backup and restore


In [87]:
from azure.search.documents.aio import SearchClient
from typing import Optional, AsyncGenerator, List, Callable, Tuple
from tqdm.notebook import tqdm
import ipywidgets as widgets
from uuid import uuid4
import asyncio
from datetime import datetime, timedelta

async def get_total_documents_remaining(client: SearchClient, timestamp_field_name: str, min_timestamp: Optional[str] = None, max_timestamp: Optional[str] = None) -> int:
    filter = None
    if min_timestamp and not max_timestamp:
        filter = f"{timestamp_field_name} ge {min_timestamp}"
    elif min_timestamp and max_timestamp:
        filter = f"{timestamp_field_name} ge {min_timestamp} and {timestamp_field_name} le {max_timestamp}"
    results = await client.search(
        search_text="*",
        include_total_count=True,
        filter=filter,
        top=0
    )
    return await results.get_count()

async def get_timestamp_bound(client: SearchClient, timestamp_field_name: str, max: bool) -> Optional[str]:
    result = await client.search(
        search_text="*",
        order_by=f"{timestamp_field_name} {'desc' if max else 'asc'}",
        top=1,
        select=[timestamp_field_name]
    )
    result = [item async for item in result]
    if len(result) == 0:
        return None
    return result[0][timestamp_field_name]

def timestamp_to_datetime(timestamp: str) -> datetime:
    return datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
def datetime_to_timestamp(date: datetime) -> str:
    # Trim microseconds to milliseconds
    return date.strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-3] + "Z"
    
async def get_partition_bounds_backup_results(client: SearchClient, timestamp_field_name: str, desired_partitions: int = 2, partition_size_threshold: float = 0.05, min_timestamp: Optional[str] = None, max_timestamp: Optional[str] = None) -> List[datetime]:
    if max_timestamp == None:
        max_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=True)
        if max_timestamp == None:
            return []
    if min_timestamp == None:
        min_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=False)

    if min_timestamp == max_timestamp or desired_partitions == 1:
        return []

    partition_splits = []
    low = timestamp_to_datetime(min_timestamp)
    for partition in range(desired_partitions - 1):
        high = timestamp_to_datetime(max_timestamp)
        remaining_partitions = desired_partitions - partition
        target_partition_size = await get_total_documents_remaining(client, timestamp_field_name, min_timestamp=datetime_to_timestamp(low)) // remaining_partitions
        partition_threshold = target_partition_size * partition_size_threshold
        best_split = None
        partition_sizes = []
        mid = low + (high - low) / 2
        while low <= high:
            current_partition_size = await get_total_documents_remaining(client, timestamp_field_name, datetime_to_timestamp(low), datetime_to_timestamp(mid))
            partition_sizes.append((mid, current_partition_size))
            if current_partition_size < target_partition_size + partition_threshold and current_partition_size > target_partition_size - partition_threshold:
                best_split = mid
                break
            elif current_partition_size < target_partition_size:
                mid = mid + (high - mid) / 2
            else:
                high = mid
                mid = mid - (mid - low) / 2
        
        if best_split is None:
            print("Could not find best split....", partition_sizes)
            min_difference = -1
            for split, partition_size in partition_sizes:
                difference = abs(target_partition_size - partition_size)
                if min_difference == -1 or difference < min_difference:
                    best_split = split
                    min_difference = difference
                    print("set fallback ", partition_size)

        partition_splits.append(best_split)
        low = best_split + timedelta(milliseconds=1)

    return partition_splits

async def get_partitions(client: SearchClient, timestamp_field_name: str, partition_splits: List[datetime]) -> List[Tuple[str, str]]:
    max_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=True)
    if max_timestamp == None:
        return []
    min_timestamp = await get_timestamp_bound(client, timestamp_field_name, max=False)
    prev_partition_end = timestamp_to_datetime(min_timestamp)
    partitions = []
    for partition_split in partition_splits:
        partitions.append((datetime_to_timestamp(prev_partition_end), datetime_to_timestamp(partition_split)))
        prev_partition_end = partition_split + timedelta(milliseconds=1)
    partitions.append((datetime_to_timestamp(prev_partition_end), max_timestamp))
    return partitions

async def resume_backup_results(client: SearchClient, timestamp_field_name: str, timestamp: Optional[str], select=None) -> AsyncGenerator[List[dict], None]:
    session_id = str(uuid4())
    max_results_size = 100000
    get_next_results = True
    while get_next_results:
        total_results_size = 0
        results = await client.search(
            search_text="*",
            order_by=f"{timestamp_field_name} asc",
            top=max_results_size,
            filter=f"{timestamp_field_name} ge {timestamp}" if timestamp else None,
            session_id=session_id,
            select=select
        )
        results_by_page = results.by_page()

        async for page in results_by_page:
            next_page = [item async for item in page]
            total_results_size += len(next_page)
            yield next_page
            timestamp = next_page[-1][timestamp_field_name]
        
        get_next_results = total_results_size == max_results_size

async def get_key_field(index_client: SearchIndexClient, index_name: str) -> str:
    index = await index_client.get_index(index_name)
    for field in index.fields:
        if field.key:
            return field.name
    
    raise Exception("No key field in the index")

async def backup_index_with_resume(source_client: SearchClient, destination_client: SearchClient, timestamp_field_name: str, total_documents: int, resume_timestamp: Optional[str], backup_tasks:int = 2, on_backup_page: Optional[Callable[[str], None]] = None) -> None:
    progress_bar = tqdm(total=total_documents, desc="Backing up documents...", unit="docs", unit_scale=False)
    pages_label = widgets.Label(value="Queued Result Pages: 0")
    display(pages_label)
    
    async def get_results(output_queue: asyncio.Queue):
        results = resume_backup_results(source_client, timestamp_field_name, timestamp=resume_timestamp)
        async for result_page in results:
            pages_label.value=f"Queued Result Pages: {output_queue.qsize()}"
            await output_queue.put(result_page)
        await output_queue.put(None)
    
    async def backup_results(results_queue: asyncio.Queue, timestamp_queue: asyncio.Queue):
        processed_count = 0
        while True:
            result_page = await results_queue.get()
            if result_page is None:
                break
            saved_timestamp = result_page[-1][timestamp_field_name]
            await destination_client.upload_documents(result_page)
            await timestamp_queue.put(saved_timestamp)
            processed_count += len(result_page)
            progress_bar.update(len(result_page))
    
    async def checkpoint_results(timestamp_queue: asyncio.Queue):
        latest_timestamp = None
        while True:
            next_timestamp = await timestamp_queue.get()
            if next_timestamp is None:  # Stop signal received
                break
            if latest_timestamp is None or next_timestamp > latest_timestamp:
                latest_timestamp = next_timestamp
                on_backup_page(latest_timestamp)

    results_queue = asyncio.Queue()
    timestamp_queue = asyncio.Queue()

    # Run producer and consumer concurrently
    producer_task = asyncio.create_task(get_results(results_queue))
    
    consumer_tasks = [asyncio.create_task(backup_results(results_queue, timestamp_queue)) for i in range(backup_tasks)]

    checkpoint_task = asyncio.create_task(checkpoint_results(timestamp_queue))

    # Wait for the producer to complete
    await producer_task

    # Wait for all tasks to complete
    await producer_task
    await asyncio.gather(*consumer_tasks)
    await checkpoint_task


In [None]:
source_client = SearchClient(source_endpoint, index_name, source_credential)
destination_client = SearchClient(destination_endpoint, index_name, destination_credential)

In [89]:
partition_splits = await get_partition_bounds_backup_results(source_client, timestamp_field_name, desired_partitions=4)
total = 0
for low, high in await get_partitions(source_client, timestamp_field_name, partition_splits):
    partition_size = await get_total_documents_remaining(source_client, timestamp_field_name, low, high)
    print("Partition", low, high, partition_size)
    total += partition_size
print("Total", total)
    

Partition 2024-11-10T12:48:37.4980Z 2024-11-10T15:28:48.3827Z 1288000
Partition 2024-11-10T15:28:48.3837Z 2024-11-10T19:29:04.7103Z 1194000
Partition 2024-11-10T19:29:04.7113Z 2024-11-10T21:29:12.8741Z 1282000
Partition 2024-11-10T21:29:12.8751Z 2024-11-10T23:29:21.037Z 1236000
Total 5000000


In [None]:
import os

source_client = SearchClient(source_endpoint, index_name, source_credential)
destination_client = SearchClient(destination_endpoint, index_name, destination_credential)

backup_file = "backup-timestamp.txt"
def on_backup_page(last_timestamp: str) -> None:
    with open(backup_file, "w") as f:
        f.write(last_timestamp)

restored_timestamp = None
if os.path.exists(backup_file):
    with open(backup_file, "r") as f:
        restored_timestamp = f.read()

total_documents = await get_total_documents_remaining(
    source_client,
    timestamp_field_name,
    timestamp=restored_timestamp
)

await backup_index_with_resume(
    source_client,
    destination_client,
    timestamp_field_name,
    resume_timestamp=restored_timestamp,
    total_documents=total_documents,
    on_backup_page=on_backup_page,
    backup_tasks=3
)