In [None]:
%pip install more-itertools azure-identity azure-search-documents==11.6.0b4 openai

In [None]:
import os
import requests
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AzureOpenAI
# Service Principal should have Cosmos DB Operation role on Cosmos DB account

os.environ["AZURE_TENANT_ID"] = "<YOUR_TENANT_ID>"
os.environ["AZURE_CLIENT_ID"] = "<YOUR_CLIENT_ID>"
sp_client_secret = notebookutils.credentials.getSecret('https://<YOUR_KEYVAULT>.vault.azure.net/', '<YOUR_SECRET_NAME>')
os.environ["AZURE_CLIENT_SECRET"] = sp_client_secret


credential=DefaultAzureCredential()
token = credential.get_token("https://cosmos.azure.com/.default").token

index_name = "books_index"
azure_search_endpoint = "https://<YOUR_SEARCH_SERVICE_NAME>.search.windows.net"
search_client = SearchClient(endpoint=azure_search_endpoint, index_name=index_name, credential=credential)

search_token = credential.get_token("https://search.azure.com/.default").token

AZURE_OPENAI_ENDPOINT="https://<AZURE_OPENAI_RESOURCE_NAME>.openai.azure.com/"
token_provider = get_bearer_token_provider(
    DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)

AZURE_COSMOS_ENDPOINT = "https://<YOUR_COSMOS_DB_ACCOUNT>.documents.azure.com:443/"

client = AzureOpenAI(
   api_version="2024-02-15-preview",
   azure_endpoint=AZURE_OPENAI_ENDPOINT,
   azure_ad_token_provider=token_provider,   
)

# Set configuration settings
config = {
  "spark.cosmos.accountEndpoint": AZURE_COSMOS_ENDPOINT,
  "spark.cosmos.auth.type": "ServicePrincipal",
  "spark.cosmos.account.subscriptionId": "<YOUR_SUBSCRIPTION_ID>",
  "spark.cosmos.account.resourceGroupName": "<YOUR_RESOURCE_GROUP_NAME>",
  "spark.cosmos.account.tenantId": os.environ["AZURE_TENANT_ID"],
  "spark.cosmos.auth.aad.clientId": os.environ["AZURE_CLIENT_ID"],
  "spark.cosmos.auth.aad.clientSecret": sp_client_secret,
  "spark.cosmos.database": "YOUR_DATABASE_NAME",
  "spark.cosmos.container": "YOUR_CONTAINER_NAME",        
}

# Configure Catalog Api
spark.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", config["spark.cosmos.accountEndpoint"])
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.auth.type", "ServicePrincipal")
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.account.subscriptionId", config["spark.cosmos.account.subscriptionId"])
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.account.resourceGroupName", config["spark.cosmos.account.resourceGroupName"])
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.account.tenantId", config["spark.cosmos.account.tenantId"])
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.auth.aad.clientId", config["spark.cosmos.auth.aad.clientId"])
spark.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.auth.aad.clientSecret", sp_client_secret)
df = spark.read.format("cosmos.oltp").options(**config).load()

StatementMeta(, 36a557d1-dd5b-4df8-a6bd-5fc3a7896df6, 5, Finished, Available, Finished)

In [12]:
# Create a database using the Catalog API
#spark.sql("CREATE DATABASE IF NOT EXISTS cosmosCatalog.{};".format(config["spark.cosmos.database"]))

StatementMeta(, eed4bfe5-69e5-4c1d-9055-f72fed064e39, 16, Finished, Available, Finished)

DataFrame[]

In [None]:
# Create container

#spark.sql("CREATE TABLE IF NOT EXISTS cosmosCatalog.{}.{} USING cosmos.oltp TBLPROPERTIES(partitionKeyPath = '{}', manualThroughput = '{}')".format(config["spark.cosmos.database"], config["spark.cosmos.container"], "/session_id", "400"))


In [None]:
import requests

def create_simple_index(index_name: str, search_token: str, azure_search_endpoint: str):
    index_schema = {
        "name": index_name,
        "fields": [
            {
                "name": "id",
                "type": "Edm.String",
                "key": True,
                "sortable": True,
                "filterable": True,
                "facetable": True
            },
            {
                "name": "content",
                "type": "Edm.String",
                "searchable": True
            },
            {
                "name": "fileName",
                "type": "Edm.String",
                "searchable": True
            },
            {
                "name": "contentVector",
                "type": "Collection(Edm.Single)",
                "searchable": True,
                "dimensions": 1536,
                "vectorSearchProfile": "amlHnswProfile"
            }
        ],
        "scoringProfiles": [],
        "suggesters": [],
        "vectorSearch": {
            "algorithms": [
                {
                    "name": "amlHnsw",
                    "kind": "hnsw",
                    "hnswParameters": {
                        "m": 4,
                        "metric": "cosine"
                    }
                }
            ],
            "profiles": [
                {
                    "name": "amlHnswProfile",
                    "algorithm": "amlHnsw"
                }
            ],
            "vectorizers": []
        },
        "semantic": {
            "configurations": [
                {
                    "name": "aml-semantic-config",
                    "prioritizedFields": {
                        "titleField": {"fieldName": "content"},
                        "prioritizedKeywordsFields": [{"fieldName": "content"}],
                        "prioritizedContentFields": [{"fieldName": "content"}]
                    }
                }
            ]
        }
    }

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {search_token}"
    }

    url = f"{azure_search_endpoint}/indexes/{index_name}?api-version=2024-07-01"

    response = requests.get(url, headers=headers)
    
    if response.status_code == 404:
        create_response = requests.put(url, headers=headers, json=index_schema)
        if create_response.status_code in [200, 201]:
            print("✅ Index created successfully.")
        else:
            print("❌ Failed to create index:", create_response.text)
    elif response.status_code == 200:
        print("ℹ️ Index already exists.")
    else:
        print("❌ Unexpected error while checking index:", response.text)

# Example usage:
search_token = credential.get_token("https://search.azure.com/.default").token
index_name = "YOUR_INDEX_NAME"
create_simple_index(index_name, search_token, azure_search_endpoint)


StatementMeta(, 36a557d1-dd5b-4df8-a6bd-5fc3a7896df6, 23, Finished, Available, Finished)

✅ Index created successfully.


In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from pyspark.sql import Row
from more_itertools import chunked
import json
import os
import time
from random import uniform
from pyspark.sql.types import StructType, StructField, StringType
import re


# Constants
#MAX_RETRIES = 50000
INITIAL_BACKOFF = 2
EMBEDDING_BATCH_SIZE = 16
SEARCH_UPLOAD_BATCH_SIZE = 16
NUM_THREADS = 4  # Number of threads per worker node

sp_client_secret = notebookutils.credentials.getSecret(
    'https://kv-anildwaa684447902659.vault.azure.net/',
    'fabric-sp-client-secret'
)
sp_client_secret_bc = spark.sparkContext.broadcast(sp_client_secret)

# Broadcast Azure OpenAI endpoint and index name


#AZURE_OPENAI_ENDPOINT_bc = spark.sparkContext.broadcast(AZURE_OPENAI_ENDPOINT)
#index_name_bc = spark.sparkContext.broadcast(index_name)

# Load checkpoint from ADLS Gen2
checkpoint_location = "abfss://<YOUR_STORAGE_CONTAINER>@<YOUR_STORAGE_ACCOUNT>.dfs.core.windows.net/search_ingestion_checkpoint/"
checkpoint_file_path = checkpoint_location + "completed_ids.json"

try:
    checkpoint_df = spark.read.json(checkpoint_file_path)
    completed_ids = set(row["id"] for row in checkpoint_df.select("id").collect())
except Exception:
    # Define empty DataFrame with the expected schema
    checkpoint_schema = StructType([StructField("id", StringType(), True)])
    checkpoint_df = spark.createDataFrame([], checkpoint_schema)
    completed_ids = set()

# Filter documents to process
#df_pending = df.filter(~df.id.isin(list(completed_ids)))
df_pending = df.join(checkpoint_df, on="id", how="left_anti")




# Clean text utility
def clean_text(text):
    if not text:
        return ""
    text = str(text)
    text = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", text)
    text = text.replace('\n', ' ').replace('\r', ' ')
    text = re.sub(r":", " -", text)
    return text.strip()


import asyncio
from openai import AsyncAzureOpenAI
from azure.search.documents.aio import SearchClient as AsyncSearchClient
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from azure.core.credentials import AzureKeyCredential
import backoff
from openai import RateLimitError

@backoff.on_exception(
    backoff.constant,
    RateLimitError,
    interval=60,
    max_tries=MAX_RETRIES,
    jitter=None
)
async def get_embeddings_with_retry(aoai_client, texts):
    return await aoai_client.embeddings.create(
        input=texts,
        model="text-embedding-ada-002"
    )

# Process a chunk of rows
async def process_row_chunk_async(row_chunk, aoai_client, search_client):
    try:
        texts = [row["text"] for row in row_chunk]
        #embeddings = await aoai_client.embeddings.create(
        #    input=texts,
        #    model="text-embedding-ada-002"
        #)
        embeddings = await get_embeddings_with_retry(aoai_client, texts)
        vectors = [e.embedding for e in embeddings.data]
        documents_to_upload = [
            {
                "id": row["id"],
                "fileName": row["fileName"],
                "content": clean_text(row["text"]),
                "contentVector": vector
            }
            for row, vector in zip(row_chunk, vectors)
        ]
        await search_client.upload_documents(documents=documents_to_upload)
        return [row["id"] for row in row_chunk], []
    except Exception as e:
        print(f"Async batch failed: {e}")
        return [], [
            Row(id=row["id"], fileName=row["fileName"], text=row["text"], error=str(e))
            for row in row_chunk
        ]



# Worker partition function
def process_partition(rows):
    import nest_asyncio

    
    nest_asyncio.apply()  # Required in notebooks
    os.environ["AZURE_TENANT_ID"] = "<YOUR_TENANT_ID>"
    os.environ["AZURE_CLIENT_ID"] = "<YOUR_CLIENT_ID>"
    os.environ["AZURE_CLIENT_SECRET"] = sp_client_secret_bc.value

    async def run_partition():
        credential = DefaultAzureCredential()
        token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
        aoai_client = AsyncAzureOpenAI(
            api_version="2024-02-15-preview",
            azure_endpoint="https://<YOUR_AZURE_OPENAN_RESOURCE_NAME>.openai.azure.com",
            azure_ad_token_provider=token_provider
        )
        search_client = AsyncSearchClient(
            endpoint="https://<YOUR_SEARCH_SERVICE_NAME>.search.windows.net",
            index_name="<YOUR_INDEX_NAME>",
            credential=credential
        )

        checkpoint_updates = []
        failed_rows = []

        tasks = []
        for chunk in chunked(list(rows), SEARCH_UPLOAD_BATCH_SIZE):
            tasks.append(
                process_row_chunk_async(chunk, aoai_client, search_client)
            )

        results = await asyncio.gather(*tasks)
        for success_ids, failures in results:
            checkpoint_updates.extend(success_ids)
            failed_rows.extend(failures)

        return [("success", id) for id in checkpoint_updates] + [("failed", r) for r in failed_rows]

    return iter(asyncio.run(run_partition()))


results = df_pending.rdd.mapPartitions(process_partition).collect()

success_ids = [r[1] for r in results if r[0] == "success"]
failed_docs = [r[1] for r in results if r[0] == "failed"]

# Append successful IDs to checkpoint
if success_ids:
    df_checkpoint = spark.createDataFrame([Row(id=id) for id in success_ids])
    df_checkpoint.write.mode("append").json(checkpoint_file_path)

failure_log_path = "abfss://<YOUR_STORAGE_CONTAINER>@<YOUR_STORAGE_ACCOUNT>.dfs.core.windows.net/search_ingestion_failures/"
# Save failed docs
if failed_docs:
    schema = StructType([
        StructField("id", StringType(), True),
        StructField("fileName", StringType(), True),
        StructField("text", StringType(), True),
        StructField("error", StringType(), True),
    ])
    df_failures = spark.createDataFrame(failed_docs, schema=schema)
    df_failures.write.mode("append").json(failure_log_path)


StatementMeta(, 36a557d1-dd5b-4df8-a6bd-5fc3a7896df6, 24, Finished, Available, Finished)