In [0]:
%run ./configs

In [0]:
%run ./model_endpoint

In [0]:
%run ./vectorSearch

In [0]:
class Setup:
  def __init__(self):
      from mlflow.deployments import get_deploy_client
      from databricks.vector_search.client import VectorSearchClient
      self.vs_client = VectorSearchClient(disable_notice=True)
      self.dp_client = get_deploy_client('databricks')
  
  def setup_landing_zone(self):
      print(f"Setting up landing zone directories....", end="")
      profile_landing_zone = f"/Volumes/{ats_configs.catalog}/{ats_configs.db}/{ats_configs.profile_landing_zone}"
      dbutils.fs.mkdirs(profile_landing_zone)

      js_landing_zone = f"/Volumes/{ats_configs.catalog}/{ats_configs.db}/{ats_configs.jd_landing_zone}"
      dbutils.fs.mkdirs(js_landing_zone)

      print('done')
      
  def setup_tables(self):
      print(f"Setting up tables....", end="")
      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.profile_bronze_table_name}(
          id BIGINT GENERATED ALWAYS AS IDENTITY,
          path string,
          modificationTime timestamp,
          length BIGINT,
          content binary
      ) TBLPROPERTIES (delta.enableChangeDataFeed = true)
                """)
      print(f"{ats_configs.profile_bronze_table_name} created...", end="")

      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.profile_silver_table_name}_stg(
          source string,
          text_content string
      )
          """)
      
      print(f"{ats_configs.profile_silver_table_name}_stg created...", end="")

      spark.sql(f"""
                CREATE TABLE IF NOT EXISTS {ats_configs.profile_silver_table_name}(
                id BIGINT GENERATED ALWAYS AS IDENTITY,
                source string,
                text_content string,
                json_context string
                ) TBLPROPERTIES (delta.enableChangeDataFeed = true)
                """)
      print(f"{ats_configs.profile_silver_table_name} created...", end="")

      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.jd_bronze_table_name}(
          id BIGINT GENERATED ALWAYS AS IDENTITY,
          path string,
          modificationTime timestamp,
          length BIGINT,
          content binary
      ) TBLPROPERTIES (delta.enableChangeDataFeed = true)
                """)
      print(f"{ats_configs.jd_bronze_table_name} created...", end="")

      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.jd_silver_table_name}_stg(
          source string,
          text_content string
      )
          """)
      
      print(f"{ats_configs.jd_silver_table_name}_stg created...", end="")

      spark.sql(f"""
                CREATE TABLE IF NOT EXISTS {ats_configs.jd_silver_table_name}(
                id BIGINT GENERATED ALWAYS AS IDENTITY,
                source string,
                text_content string,
                json_context string
                ) TBLPROPERTIES (delta.enableChangeDataFeed = true)
                """)
      print(f"{ats_configs.jd_silver_table_name} created...", end="")

      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.jd_profile_table_name}_stg(
          jd_id bigint,
          profile_id bigint
      )
          """)
      
      print(f"{ats_configs.jd_profile_table_name}_stg created...", end="")

      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.jd_profile_table_name}(
          jd_id bigint,
          jd_source string,
          jd_extract string,
          profile_id bigint,
          profile_source string,
          profile_extract string,
          summary string,
          generated_date timestamp
      ) TBLPROPERTIES (delta.enableChangeDataFeed = true)
          """)
      
      print(f"{ats_configs.jd_profile_table_name} created...", end="")

  def setup_jobs_metadata(self):
      spark.sql(f"""CREATE TABLE IF NOT EXISTS {ats_configs.jobs_metadate_table_name}(
          job_name string,
          last_load_date date,
          execution_time timestamp,
          description string
      )
          """)
      
      print(f"{ats_configs.jobs_metadate_table_name} created...", end="")

      print(f"loading initial config insert for{ats_configs.jobs_metadate_table_name}...", end="")

      spark.sql(f"""
                INSERT INTO {ats_configs.jobs_metadate_table_name}
                VALUES ('{ats_configs.profile_ingestion_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.jd_ingestion_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.profile_bronze_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.profile_silver_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.jd_bronze_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.jd_silver_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.jd_profile_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}'),
                ('{ats_configs.index_sync_job_name}',try_cast(current_timestamp() as date),current_timestamp(),'{ats_configs.jobs_Initial_description}')
                """)
      
      print('initial config insert for{ats_configs.jobs_metadate_table_name} loaded...', end="")

  def setup_endpoint(self):
      endpoint = Endpoints()
      endpoint.create(name=ats_configs.chat_model_endpoint_name, model=ats_configs.chat_model_name)
      print(f"Chat endpoint {ats_configs.chat_model_endpoint_name} created")

      endpoint.create(name=ats_configs.embedding_model_endpoint_name, model=ats_configs.embedding_model_name, task ='embeddings')
      print(f"Embeding endpoint {ats_configs.embedding_model_endpoint_name} created")

  def setup_vector_index(self):
      vs = VectorSearch()
      vs.create_vector_endpoint(endpoint_name=ats_configs.vector_store_endpoint_name)
      print(f"Vector store endpoint {ats_configs.vector_store_endpoint_name} created")

      vs.create_vector_index(
          index_name=ats_configs.vector_index_name, 
          source_table_name=f'{ats_configs.catalog}.{ats_configs.db}.{ats_configs.profile_silver_table_name}',
          primary_key="id",
          embedding_source_column = 'json_context',
          embedding_model_endpoint_name=ats_configs.embedding_model_endpoint_name,
          vector_endpoint_name=ats_configs.vector_store_endpoint_name
      )

  def run(self):
      if not spark.catalog.tableExists(f"{ats_configs.catalog}.{ats_configs.db}.{ats_configs.jobs_metadate_table_name}"):
          self.setup_jobs_metadata()
      elif(spark.sql(f"SELECT COUNT(*) as cnt FROM {ats_configs.jobs_metadate_table_name}").collect()[0]['cnt'] < 8):
          self.setup_jobs_metadata()
      self.setup_landing_zone()
      self.setup_tables()
      self.setup_endpoint()
      self.setup_vector_index()

  def assert_table(self, table_name):
      assert spark.catalog.tableExists(f"{ats_configs.catalog}.{ats_configs.db}.{table_name}"), f"Table {table_name} does not exist"
      print(f"Table {table_name} exists")
  def assert_count(self, table_name, expected_count):
      actual_count = spark.read.table(f"{ats_configs.catalog}.{ats_configs.db}.{table_name}").count()
      assert actual_count == expected_count, f"Table {table_name} has {actual_count} rows, expected {expected_count}"
      print(f"Table {table_name} has {actual_count} rows")
  
  def assert_endpoint(self, endpoint_name):
      endpoint_list = self.dp_client.list_endpoints()
      endpoint_names = [ep["name"] for ep in endpoint_list]
      assert endpoint_name in endpoint_names, f"Endpoint {endpoint_name} does not exist"
      print(f"Endpoint {endpoint_name} exists")

  def assert_vector_index(self, endpoint_name, index_name):
      ep_list = self.vs_client.list_endpoints()
      endpoint_names = [ep['name'] for ep in ep_list['endpoints']]
      print(endpoint_names)
      assert endpoint_name in endpoint_names, f"Vector endpoint: {endpoint_name} does not exist"
      print(f"Vector endpoint: {endpoint_name} exists")

      index_list = self.vs_client.list_indexes(endpoint_name)
      index_names = [index["name"] for index in index_list["vector_indexes"]]
      print(index_names)
      assert index_name in index_names, f"Vector index {index_name} does not exist"
      print(f"Vector index {index_name} exists")
  def assert_dir(self, dir_name):
      dirs = dbutils.fs.ls(f"/Volumes/{ats_configs.catalog}/{ats_configs.db}/{ats_configs.landing_volume}")
      dir_list = [dir.path for dir in dirs]
      dir_path = f"dbfs:/Volumes/{ats_configs.catalog}/{ats_configs.db}/{ats_configs.landing_volume}/{dir_name}/"
      assert dir_path in dir_list, f"Directory {dir_path} does not exist"
      print(f"Directory {dir_path} exists")

  def validate(self):
      import time
      start_time = int(time.time())
      print(f"Starting validation at {start_time}")

      assert spark.catalog.databaseExists(f"{ats_configs.db}")
      self.assert_table(ats_configs.profile_bronze_table_name)
      self.assert_table(ats_configs.profile_silver_table_name)
      self.assert_table(ats_configs.jd_profile_table_name)
      self.assert_table(ats_configs.jd_bronze_table_name)
      self.assert_table(ats_configs.jd_silver_table_name)
      self.assert_table(ats_configs.jobs_metadate_table_name)

      self.assert_dir(ats_configs.profile_source)
      self.assert_dir(ats_configs.jd_source)

      self.assert_count(ats_configs.jobs_metadate_table_name, 8)

      self.assert_endpoint(ats_configs.chat_model_endpoint_name)
      self.assert_endpoint(ats_configs.embedding_model_endpoint_name)
      self.assert_vector_index(ats_configs.vector_store_endpoint_name, ats_configs.vector_index_name)

      print(f"Completed validation in {int(time.time()) - start_time} seconds.")



In [0]:
setup = Setup()
#setup.run()
setup.validate()