In [None]:
! pip install -r requirements.txt --quiet


In [2]:
from dotenv import load_dotenv
from azure.identity.aio import DefaultAzureCredential
from azure.core.credentials import AzureKeyCredential
import os

load_dotenv(override=True) # take environment variables from .env.

# Variables not used here do not need to be updated in your .env file
source_endpoint = os.environ["AZURE_SEARCH_SOURCE_SERVICE_ENDPOINT"]
# Using a key is optional. See https://learn.microsoft.com/en-us/azure/search/keyless-connections
source_credential = AzureKeyCredential(os.getenv("AZURE_SEARCH_SOURCE_ADMIN_KEY")) if os.getenv("AZURE_SEARCH_SOURCE_ADMIN_KEY") else DefaultAzureCredential()
destination_endpoint = os.environ["AZURE_SEARCH_DESTINATION_SERVICE_ENDPOINT"]
destination_credential = AzureKeyCredential(os.getenv("AZURE_SEARCH_DESTINATION_ADMIN_KEY")) if os.getenv("AZURE_SEARCH_DESTINATION_ADMIN_KEY") else DefaultAzureCredential()
index_name = os.environ["AZURE_SEARCH_INDEX"]
timestamp_field_name = os.environ["AZURE_SEARCH_TIMESTAMP_FIELD"]

In [3]:
from azure.search.documents.indexes.aio import SearchIndexClient

async def copy_index_definition(source_index_client: SearchIndexClient, destination_index_client: SearchIndexClient, index_name: str):
    index = await source_index_client.get_index(index_name)
    # Check for any synonym maps
    synonym_map_names = []
    for field in index.fields:
        if field.synonym_map_names:
            synonym_map_names.extend(field.synonym_map_names)
    
    # Copy over synonym maps if they exist
    for synonym_map_name in synonym_map_names:
        synonym_map = await source_index_client.get_synonym_map(synonym_map_name)
        await destination_index_client.create_or_update_synonym_map(synonym_map)
    
    # Copy over the index
    await destination_index_client.create_or_update_index(index)

In [4]:
source_index_client = SearchIndexClient(endpoint=source_endpoint, credential=source_credential)
destination_index_client = SearchIndexClient(endpoint=destination_endpoint, credential=destination_credential)

await copy_index_definition(source_index_client, destination_index_client, index_name)

In [5]:
from azure.search.documents.indexes.aio import SearchIndexClient
from azure.search.documents.indexes.models import SearchFieldDataType
from typing import List

async def validate_resume_backup_and_restore(index_client: SearchIndexClient, index_name: str, timestamp_field_name: str) -> bool:
    index = await index_client.get_index(index_name)

    found_field = False
    for field in index.fields:
        if field.name == timestamp_field_name:
            found_field = True
            if field.type != SearchFieldDataType.DateTimeOffset:
                # Field must be a timestamp
                return False
            if not field.filterable:
                # Field must be filterable
                return False
            if not field.sortable:
                # Field must be sortable
                return False
            break
    
    # Field must exist on the index
    return found_field

async def validate_fields_backup_and_restore(index_client: SearchIndexClient, index_name: str) -> List[str]:
    missing_fields = []
    index = await index_client.get_index(index_name)
    for field in index.fields:
        message = ""
        if not field.stored:
            message += f"Field {field.name} cannot be backed up because it's not marked as stored\n"
        elif field.hidden: 
            message += f"Field {field.name} cannot be backed up because it's not marked as retrievable\n"
        
        if message:
            missing_fields.append(message)
    
    return missing_fields


In [None]:
can_resume_backup_and_restore = await validate_resume_backup_and_restore(source_index_client, index_name, timestamp_field_name)
if can_resume_backup_and_restore:
    print("Index has a valid timestamp field and can use resumable backup and restore")
else:
    print("Index does not have a valid timestamp field and cannot use resumable backup and restore")

In [9]:
from azure.search.documents.aio import SearchClient
from typing import Optional, AsyncGenerator, List, Callable
from tqdm.notebook import tqdm
import ipywidgets as widgets
from uuid import uuid4
import asyncio

async def get_most_recent_timestamp(client: SearchClient, timestamp_field_name: str) -> Optional[str]:
    result = await client.search(
        search_text="*",
        order_by=f"{timestamp_field_name} desc",
        top=1,
        select=[timestamp_field_name]
    )
    if len(result) == 0:
        return None
    return result[0][timestamp_field_name]

async def resume_backup_results(client: SearchClient, timestamp_field_name: str, timestamp: Optional[str], select=None) -> AsyncGenerator[List[dict], None]:
    session_id = str(uuid4())
    max_results_size = 100000
    get_next_results = True
    while get_next_results:
        total_results_size = 0
        results = await client.search(
            search_text="*",
            order_by=f"{timestamp_field_name} asc",
            top=max_results_size,
            filter=f"{timestamp_field_name} ge {timestamp}" if timestamp else None,
            session_id=session_id,
            select=select
        )
        results_by_page = results.by_page()

        async for page in results_by_page:
            next_page = [item async for item in page]
            total_results_size += len(next_page)
            yield next_page
            timestamp = next_page[-1][timestamp_field_name]
        
        get_next_results = total_results_size == max_results_size

async def get_key_field(index_client: SearchIndexClient, index_name: str) -> str:
    index = await index_client.get_index(index_name)
    for field in index.fields:
        if field.key:
            return field.name
    
    raise Exception("No key field in the index")

async def get_total_documents_remaining(source_client: SearchClient, timestamp_field_name: str, timestamp: Optional[str]) -> int:
    results = await source_client.search(
        search_text="*",
        include_total_count=True,
        filter=f"{timestamp_field_name} ge {timestamp}" if timestamp else None,
        top=0
    )
    return await results.get_count()

## FIXME
def find_missing_backup_keys(source_client: SearchClient, destination_client: SearchClient, timestamp_field_name: str, key_field_name: str) -> List[str]:
    source_results = resume_backup_results(source_client, timestamp_field_name, select=[key_field_name])
    destination_results = resume_backup_results(destination_client, timestamp_field_name, select=[key_field_name])
    missing_destination_keys = []
    for source_page, destination_page in zip(source_results, destination_results):
        missing_destination_page_keys = set(source_page) - set(destination_page)
        missing_destination_keys.extend(missing_destination_page_keys)
    
    return missing_destination_keys

async def backup_index_with_resume(source_client: SearchClient, destination_client: SearchClient, timestamp_field_name: str, total_documents: int, resume_timestamp: Optional[str], backup_tasks:int = 2, on_backup_page: Optional[Callable[[str], None]] = None) -> None:
    progress_bar = tqdm(total=total_documents, desc="Backing up documents...", unit="docs", unit_scale=False)
    pages_label = widgets.Label(value="Queued Result Pages: 0")
    display(pages_label)
    
    async def get_results(output_queue: asyncio.Queue):
        results = resume_backup_results(source_client, timestamp_field_name, timestamp=resume_timestamp)
        async for result_page in results:
            pages_label.value=f"Queued Result Pages: {output_queue.qsize()}"
            await output_queue.put(result_page)
        await output_queue.put(None)
    
    async def backup_results(results_queue: asyncio.Queue, timestamp_queue: asyncio.Queue):
        processed_count = 0
        while True:
            result_page = await results_queue.get()
            if result_page is None:
                break
            saved_timestamp = result_page[-1][timestamp_field_name]
            await destination_client.upload_documents(result_page)
            await timestamp_queue.put(saved_timestamp)
            processed_count += len(result_page)
            progress_bar.update(len(result_page))
    
    async def checkpoint_results(timestamp_queue: asyncio.Queue):
        latest_timestamp = None
        while True:
            next_timestamp = await timestamp_queue.get()
            if next_timestamp is None:  # Stop signal received
                break
            if latest_timestamp is None or next_timestamp > latest_timestamp:
                latest_timestamp = next_timestamp
                on_backup_page(latest_timestamp)

    results_queue = asyncio.Queue()
    timestamp_queue = asyncio.Queue()

    # Run producer and consumer concurrently
    producer_task = asyncio.create_task(get_results(results_queue))
    
    consumer_tasks = [asyncio.create_task(backup_results(results_queue, timestamp_queue)) for i in range(backup_tasks)]

    checkpoint_task = asyncio.create_task(checkpoint_results(timestamp_queue))

    # Wait for the producer to complete
    await producer_task

    # Wait for all tasks to complete
    await producer_task
    await asyncio.gather(*consumer_tasks)
    await checkpoint_task


In [None]:
import os

source_client = SearchClient(source_endpoint, index_name, source_credential)
destination_client = SearchClient(destination_endpoint, index_name, destination_credential)

backup_file = "backup-timestamp.txt"
def on_backup_page(last_timestamp: str) -> None:
    with open(backup_file, "w") as f:
        f.write(last_timestamp)

restored_timestamp = None
if os.path.exists(backup_file):
    with open(backup_file, "r") as f:
        restored_timestamp = f.read()

total_documents = await get_total_documents_remaining(
    source_client,
    timestamp_field_name,
    timestamp=restored_timestamp
)

await backup_index_with_resume(
    source_client,
    destination_client,
    timestamp_field_name,
    resume_timestamp=restored_timestamp,
    total_documents=total_documents,
    on_backup_page=on_backup_page,
    backup_tasks=3
)