In [5]:
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

// Keep all the existing schema definitions and supporting functions from the original code
// (cdeContractSchema, objectiveContractSchema, keyResultContractSchema, getSchemaForContainer, etc.)

val cdeContractSchema: StructType = StructType(Seq(
  StructField("payload", StructType(Seq(
    StructField("before", StructType(Seq(
      StructField("name", StringType),
      StructField("dataType", StringType),
      StructField("status", StringType),
      StructField("contacts", StructType(Seq(
        StructField("owner", ArrayType(StructType(Seq(
          StructField("id", StringType)
        ))))
      ))),
      StructField("description", StringType),
      StructField("domain", StringType),
      StructField("id", StringType),
      StructField("systemData", StructType(Seq(
        StructField("lastModifiedAt", StringType),
        StructField("lastModifiedBy", StringType),
        StructField("createdAt", StringType),
        StructField("createdBy", StringType)
      )))
    )), nullable = true),  // Set before field as nullable
    StructField("after", StructType(Seq(
      StructField("name", StringType),
      StructField("dataType", StringType),
      StructField("status", StringType),
      StructField("contacts", StructType(Seq(
        StructField("owner", ArrayType(StructType(Seq(
          StructField("id", StringType)
        ))))
      ))),
      StructField("description", StringType),
      StructField("domain", StringType),
      StructField("id", StringType),
      StructField("systemData", StructType(Seq(
        StructField("lastModifiedAt", StringType),
        StructField("lastModifiedBy", StringType),
        StructField("createdAt", StringType),
        StructField("createdBy", StringType)
      )))
    )), nullable = true),  // Set after field as nullable
    StructField("related", StructType(Seq(
      StructField("dataAssetId", ArrayType(StringType))
    )), nullable = true)  // Also making related nullable for consistency
  ))),
  StructField("eventSource", StringType),
  StructField("payloadKind", StringType),
  StructField("operationType", StringType),
  StructField("preciseTimestamp", StringType),
  StructField("tenantId", StringType),
  StructField("accountId", StringType),
  StructField("changedBy", StringType),
  StructField("eventId", StringType),
  StructField("correlationId", StringType),
  StructField("EventProcessedUtcTime", StringType),
  StructField("PartitionId", IntegerType),
  StructField("EventEnqueuedUtcTime", StringType),
  StructField("id", StringType),
  StructField("_rid", StringType),
  StructField("_etag", StringType),
  StructField("_ts", LongType)
))


val objectiveContractSchema : StructType = StructType(Seq(
    StructField("payload", StructType(Seq(
      StructField("before", StructType(Seq(
        StructField("id", StringType),
        StructField("definition", StringType),
        StructField("domain", StringType),
        StructField("targetDate", StringType),
        StructField("contacts", StructType(Seq(
          StructField("owner", ArrayType(StructType(Seq(
            StructField("id", StringType)
          )))),
        ))),
        StructField("status", StringType),
        StructField("systemData", StructType(Seq(
          StructField("lastModifiedAt", StringType),
          StructField("lastModifiedBy", StringType),
          StructField("createdAt", StringType),
          StructField("createdBy", StringType)
        )))
      )), nullable = true),
      StructField("after", StructType(Seq(
        StructField("id", StringType),
        StructField("definition", StringType),
        StructField("domain", StringType),
        StructField("targetDate", StringType),
        StructField("contacts", StructType(Seq(
          StructField("owner", ArrayType(StructType(Seq(
            StructField("id", StringType)
          )))),
        ))),
        StructField("status", StringType),
        StructField("systemData", StructType(Seq(
          StructField("lastModifiedAt", StringType),
          StructField("lastModifiedBy", StringType),
          StructField("createdAt", StringType),
          StructField("createdBy", StringType)
        )))
      )), nullable = true),
      StructField("related", StructType(Seq(
      StructField("dataAssetId", ArrayType(StringType))
    )), nullable = true)
      ))),
    StructField("eventSource", StringType),
    StructField("payloadKind", StringType),
    StructField("operationType", StringType),
    StructField("preciseTimestamp", StringType),
    StructField("tenantId", StringType),
    StructField("accountId", StringType),
    StructField("changedBy", StringType),
    StructField("eventId", StringType),
    StructField("correlationId", StringType),
    StructField("EventProcessedUtcTime", StringType),
    StructField("PartitionId", IntegerType),
    StructField("EventEnqueuedUtcTime", StringType),
    StructField("id", StringType),
    StructField("_rid", StringType),
    StructField("_etag", StringType),
    StructField("_ts", LongType)
  ))

val keyResultContractSchema : StructType = StructType(Seq(
    StructField("payload", StructType(Seq(
      StructField("before", StructType(Seq(
        StructField("id", StringType),
        StructField("definition", StringType),
        StructField("domainId", StringType),
        StructField("progress", IntegerType),
        StructField("goal", IntegerType),
        StructField("max", IntegerType),
        StructField("status", StringType),
        StructField("systemData", StructType(Seq(
          StructField("lastModifiedAt", StringType),
          StructField("lastModifiedBy", StringType),
          StructField("createdAt", StringType),
          StructField("createdBy", StringType),
        )))
      )), nullable = true),
      StructField("after", StructType(Seq(
        StructField("id", StringType),
        StructField("definition", StringType),
        StructField("domainId", StringType),
        StructField("progress", IntegerType),
        StructField("goal", IntegerType),
        StructField("max", IntegerType),
        StructField("status", StringType),
        StructField("systemData", StructType(Seq(
          StructField("lastModifiedAt", StringType),
          StructField("lastModifiedBy", StringType),
          StructField("createdAt", StringType),
          StructField("createdBy", StringType),
        )))
      )), nullable = true),
      StructField("related", StructType(Seq(
        StructField("dataAssetId", ArrayType(StringType))
      )), nullable = true)))),
    StructField("eventSource", StringType),
    StructField("payloadKind", StringType),
    StructField("operationType", StringType),
    StructField("preciseTimestamp", StringType),
    StructField("tenantId", StringType),
    StructField("accountId", StringType),
    StructField("changedBy", StringType),
    StructField("eventId", StringType),
    StructField("correlationId", StringType),
    StructField("EventProcessedUtcTime", StringType),
    StructField("PartitionId", IntegerType),
    StructField("EventEnqueuedUtcTime", StringType),
    StructField("id", StringType),
    StructField("_rid", StringType),
    StructField("_etag", StringType),
    StructField("_ts", LongType)
  ))

def getSchemaForContainer(containerName: String): StructType = {
  containerName match {
    case "cde" => cdeContractSchema
    case "okr" => objectiveContractSchema
    case "keyresult" => keyResultContractSchema
    case _ => throw new IllegalArgumentException(s"Unknown container name: $containerName")
  }
}

def loadAndParseJson(databaseName: String, containerName: String, cosmosEndPoint: String, cosmosKey: String, accountId: String): DataFrame = {
  val schema = getSchemaForContainer(containerName)
  
  // Load the data from the database
  val sourceDF = spark.read.format("cosmos.olap")
    .schema(schema)
    .option("spark.cosmos.accountEndpoint", cosmosEndPoint)
    .option("spark.cosmos.database", databaseName)
    .option("spark.cosmos.container", containerName)
    .option("spark.cosmos.accountKey", cosmosKey)
    .load()
 
  // Filter by account ID
  val filteredDF = sourceDF.filter(col("accountId") === accountId)
  
  filteredDF
}
// Create a case class to store validation results
case class ValidationResult(
  accountId: String,
  containerName: String,
  sourceCount: Long,
  targetCount: Long,
  inSourceNotInTarget: Long,
  inTargetNotInSource: Long,
  deletedInTarget: Long,
  distinctInTargetNotInSourceNotDeleted: Long
)

// Function to run validation for a single accountId and container
def validateSingleAccountContainer(
  accountId: String,
  containerName: String,
  sourceDatabaseName: String,
  targetDatabaseName: String,
  cosmosEndPoint: String,
  cosmosKey: String
): ValidationResult = {
  
  println(s"Running validation for accountId: $accountId, container: $containerName")
  
  try {
    // Load source and target data
    val sourceViewName = s"source_${containerName}_${accountId.replaceAll("-", "_")}"
    val targetViewName = s"target_${containerName}_${accountId.replaceAll("-", "_")}"
    
    // Load the data frames
    val sourceDf = loadAndParseJson(sourceDatabaseName, containerName, cosmosEndPoint, cosmosKey, accountId)
    val targetDf = loadAndParseJson(targetDatabaseName, containerName, cosmosEndPoint, cosmosKey, accountId)
    
    // Create temp views
    sourceDf.createOrReplaceTempView(sourceViewName)
    targetDf.createOrReplaceTempView(targetViewName)
    
    // Get counts
    val sourceCount = sourceDf.count()
    val targetCount = targetDf.count()
    
    // Find records in source but not in target
    val allInSourceNotInTarget = spark.sql(s"""
      SELECT s.* 
      FROM ${sourceViewName} s
      WHERE NOT EXISTS (
        SELECT 1 
        FROM ${targetViewName} t
        WHERE s.payload.after.id = t.payload.after.id
      )
    """)
    val inSourceNotInTargetCount = allInSourceNotInTarget.count()
    
    // Find records in target but not in source
    val allInTargetNotInSource = spark.sql(s"""
      SELECT s.* 
      FROM ${targetViewName} s
      WHERE NOT EXISTS (
        SELECT 1 
        FROM ${sourceViewName} t
        WHERE s.payload.after.id = t.payload.after.id
      )
    """)
    val inTargetNotInSourceCount = allInTargetNotInSource.count()
    
    // Find deleted records in target
    val allInTargetThatAreDelete = spark.sql(s"""
      SELECT s.* 
      FROM ${targetViewName} s
      WHERE s.operationType = "Delete"
    """)
    val deletedInTargetCount = allInTargetThatAreDelete.count()
    
    // Find records in target not in source and not deleted
    allInTargetNotInSource.createOrReplaceTempView("AllInTargetNotInSource")
    allInTargetThatAreDelete.createOrReplaceTempView("AllInTargetThatAreDelete")
    
    val allInTargetNotInSourceButNotDeleted = spark.sql(s"""
      SELECT s.payload.after.id 
      FROM AllInTargetNotInSource s
      LEFT JOIN AllInTargetThatAreDelete t
      ON s.payload.before.id = t.payload.before.id
      OR s.payload.after.id = t.payload.before.id
      WHERE t.id is NULL
    """)
    val distinctInTargetNotInSourceNotDeletedCount = allInTargetNotInSourceButNotDeleted.distinct().count()
    
    // Clean up temporary views to avoid conflicts in subsequent runs
    spark.catalog.dropTempView(sourceViewName)
    spark.catalog.dropTempView(targetViewName)
    spark.catalog.dropTempView("AllInTargetNotInSource")
    spark.catalog.dropTempView("AllInTargetThatAreDelete")
    
    // Return validation result
    ValidationResult(
      accountId,
      containerName,
      sourceCount,
      targetCount,
      inSourceNotInTargetCount,
      inTargetNotInSourceCount,
      deletedInTargetCount,
      distinctInTargetNotInSourceNotDeletedCount
    )
  } catch {
    case e: Exception => 
      println(s"Error validating account $accountId, container $containerName: ${e.getMessage}")
      e.printStackTrace()
      ValidationResult(
        accountId,
        containerName,
        -1, // Use -1 to indicate error in counts
        -1,
        -1,
        -1,
        -1,
        -1
      )
  }
}

// Main function to run validation across multiple accounts and containers
def runBatchValidation(
  accountIds: List[String],
  containers: List[String],
  sourceDatabaseName: String,
  targetDatabaseName: String,
  cosmosEndPoint: String,
  cosmosKey: String
): DataFrame = {
  
  import spark.implicits._
  
  println(s"Starting batch validation for ${accountIds.size} accounts across ${containers.size} containers")
  println(s"Source DB: $sourceDatabaseName, Target DB: $targetDatabaseName")
  
  // Create a list to store all validation results
  var allResults = List[ValidationResult]()
  
  // Process each account and container
  for {
    accountId <- accountIds
    container <- containers
  } {
    val result = validateSingleAccountContainer(
      accountId,
      container, 
      sourceDatabaseName,
      targetDatabaseName,
      cosmosEndPoint,
      cosmosKey
    )
    
    allResults = allResults :+ result
  }
  
  // Convert results to DataFrame
  val resultsDF = allResults.toDF()
  
  // Display results
  println("Validation complete. Results summary:")
  resultsDF.show(false)
  
  // Return the DataFrame for further processing if needed
  resultsDF
}

// Example usage - replace with your actual values
val sourceDatabaseName = "dgh-Backfill"
val targetDatabaseName = "dgh-DataEstateHealth"
val cosmosEndPoint = "https://pdgprodcosmosdocdxb.documents.azure.com:443/"
val cosmosKey = mssparkutils.credentials.getSecret("dgh-prod-dxb-kv", "cosmosDBWritekey")

// Define list of accountIds and containers to validate
val accountIds = List(
  "0c89ab95-b9d6-47d8-9101-913ded037ba6",
"83b87287-22e3-4d36-990a-865e123ece55"
)

val containers = List(
  "cde",
  "okr",
  "keyresult"
)

// Run the validation
val validationResults = runBatchValidation(
  accountIds,
  containers,
  sourceDatabaseName,
  targetDatabaseName,
  cosmosEndPoint,
  cosmosKey
)

display(validationResults)


StatementMeta(testCatalog, 14, 6, Finished, Available, Finished)

Starting batch validation for 2 accounts across 3 containers
Source DB: dgh-Backfill, Target DB: dgh-DataEstateHealth
Running validation for accountId: 0c89ab95-b9d6-47d8-9101-913ded037ba6, container: cde
Running validation for accountId: 0c89ab95-b9d6-47d8-9101-913ded037ba6, container: okr
Running validation for accountId: 0c89ab95-b9d6-47d8-9101-913ded037ba6, container: keyresult
Running validation for accountId: 83b87287-22e3-4d36-990a-865e123ece55, container: cde
Running validation for accountId: 83b87287-22e3-4d36-990a-865e123ece55, container: okr
Running validation for accountId: 83b87287-22e3-4d36-990a-865e123ece55, container: keyresult
[ERROR] [03/26/2025 03:35:48.845] [default-akka.actor.default-dispatcher-7] [akka.actor.ActorSystemImpl(default)] Outgoing request stream error (akka.stream.AbruptTerminationException: Processor actor [Actor[akka://default/system/StreamSupervisor-93/flow-0-0-PoolFlow#-885755594]] terminated abruptly)
[ERROR] [03/26/2025 03:35:48.847] [default-akk

SynapseWidget(Synapse.DataFrame, 88fd0902-9a30-4f5e-ab56-9a2528526eb8)


import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
cdeContractSchema: org.apache.spark.sql.types.StructType = StructType(StructField(payload,StructType(StructField(before,StructType(StructField(name,StringType,true),StructField(dataType,StringType,true),StructField(status,StringType,true),StructField(contacts,StructType(StructField(owner,ArrayType(StructType(StructField(id,StringType,true)),true),true)),true),StructField(description,StringType,true),StructField(domain,StringType,true),StructField(id,StringType,true),StructField(systemData,StructType(StructField(lastModifiedAt,StringType,true),StructField(lastModifiedBy,StringType,true),StructField(createdAt,StringType,true),StructField(createdBy,StringType,true)),true)),true),StructField(after,StructType(StructField(name,StringType,true),StructField(dataType,StringType,true),StructField...
objectiveContractSchema: org.apache.spark.sql.types.StructType = Str