In [0]:
# /src/data_migration.py
import os
import logging
import sys
from datetime import datetime
from pyspark.sql import SparkSession


# Configuration
def get_config():
    """
    Fetch configuration from environment variables.
    """
    config = {
       "SOURCE_TABLE" : os.getenv("SOURCE_TABLE", "de_dev.bds_feature_store.bds_sms_data_structured"),
       "TARGET_TABLE" : os.getenv("TARGET_TABLE", "ds_dev.sms_parsing.sms_filtered_data_previous_day"),
        "LOG_FILE": os.getenv("LOG_FILE", "data_migration.log"),
        "LOG_LEVEL": os.getenv("LOG_LEVEL", "INFO").upper(),
    }
    return config


class DataFilteration:
    def __init__(self, source_table, target_table, spark_session):
        self.source_table = source_table
        self.target_table = target_table
        self.spark = spark_session
        self.setup_logging()

    def setup_logging(self):
        """
        Setup logging configuration.
        """
        config = get_config()
        logging.basicConfig(
            filename=config["LOG_FILE"],
            level=config["LOG_LEVEL"],
            format="%(asctime)s - %(levelname)s - %(message)s",
        )
        logging.info("Logging setup complete.")

    @staticmethod
    def validate_date(date_str):
        """
        Validate if the date string is in the correct YYYY-MM-DD format.
        """
        try:
            datetime.strptime(date_str, "%Y-%m-%d")
            return True
        except ValueError:
            logging.error(f"Invalid date format: {date_str}. Use YYYY-MM-DD format.")
            return False

    def fetch_data(self, start_date, end_date):
        """
        Fetch data from the source table.
        """
        print(f"Fetching data from the source table: {self.source_table}")
        try:
            query = f"""
                SELECT * FROM {self.source_table}
                WHERE fetch_date BETWEEN '{start_date}' AND '{end_date}'
                LIMIT 50
            """
            logging.info(f"Executing query: {query}")
            df = self.spark.sql(query)
            logging.info(f"Fetched {df.count()} records from the source table: {self.source_table}")
            return df
        except Exception as e:
            logging.error(f"Error fetching data: {e}")
            raise RuntimeError(f"Error fetching data: {e}")

    def write_data(self, df):
        """
        Write data to the target table.
        """
        print(f"Writing data to the target table: {self.target_table}")
        try:
            if df.count() > 0:
                df.createOrReplaceTempView("temp_view")
                self.spark.sql(f"""
                    INSERT OVERWRITE TABLE {self.target_table}
                    SELECT * FROM temp_view
                """)
                logging.info(f"Data written to target table: {self.target_table} successfully.")
            else:
                logging.warning("No data to write; DataFrame is empty.")
        except Exception as e:
            logging.error(f"Error writing data to the target table: {e}")
            raise RuntimeError(f"Error writing data to the target table: {e}")

    def run(self, start_date, end_date):
        """
        Run the data filteration process.
        """
        print(f"running for date {start_date} to {end_date}")
        try:
            # Validate dates
            if not (self.validate_date(start_date) and self.validate_date(end_date)):
                logging.error("Invalid date(s) provided. Exiting.")
                return

            logging.info("Data filteration process started.")

            # Fetch data
            data = self.fetch_data(start_date, end_date)

            # Write data
            self.write_data(data)

            logging.info("Data filteration  process completed successfully.")
            print(f"Done first step: {start_date} to {end_date}")
        except Exception as e:
            logging.error(f"Data filteration  process failed: {e}")
            raise


import faiss
import numpy as np
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import logging

# Set up logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

class UniqueTemplateExtractor:
    def __init__(self, threshold=0.90):
        """Initialize the extractor with the specified threshold and embedding model."""
        try:
            self.model = SentenceTransformer('BAAI/bge-base-en-v1.5')  # Using a pre-trained model for embeddings
            self.threshold = threshold

            # Dictionaries to store sender data
            self.sender_indices = {}
            self.sender_templates = {}
            self.sender_template_ids = {}
            self.sender_embeddings = {}

            logging.info("Initialized UniqueTemplateExtractor successfully.")
        except Exception as e:
            logging.error(f"Error initializing UniqueTemplateExtractor: {e}")
            raise

    def get_embedding(self, text):
        """Compute embedding of a given text message."""
        try:
            embedding = self.model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
            return embedding.cpu().numpy()
        except Exception as e:
            logging.error(f"Error generating embedding for text '{text}': {e}")
            raise

    def add_template(self, sender_id, embedding, text):
        """Add a new template embedding to the FAISS index for the given sender."""
        try:
            if sender_id not in self.sender_indices:
                # Initialize FAISS index for new sender
                dimension = self.model.get_sentence_embedding_dimension()
                self.sender_indices[sender_id] = faiss.IndexFlatIP(dimension)
                self.sender_templates[sender_id] = []
                self.sender_template_ids[sender_id] = []
                self.sender_embeddings[sender_id] = []

            # Add to FAISS index and metadata for this sender
            self.sender_indices[sender_id].add(embedding)
            self.sender_templates[sender_id].append(text)
            self.sender_template_ids[sender_id].append(len(self.sender_template_ids[sender_id]) + 1)
            self.sender_embeddings[sender_id].append(embedding)

            logging.info(f"Template added for sender '{sender_id}': {text}")
        except Exception as e:
            logging.error(f"Error adding template for sender '{sender_id}': {e}")
            raise

    def is_unique_template(self, sender_id, embedding):
        """Check if the embedding matches any existing template within the same sender."""
        try:
            if sender_id not in self.sender_indices or len(self.sender_templates[sender_id]) == 0:
                return True  # If no templates for this sender yet, it's unique

            # Query the FAISS index for the specific sender
            D, I = self.sender_indices[sender_id].search(embedding, k=1)

            # D is the similarity score, I is the index of the nearest neighbor
            if D[0][0] >= self.threshold:
                return False  # Not unique, matches an existing template for this sender
            return True  # Unique template within the sender
        except Exception as e:
            logging.error(f"Error checking uniqueness for sender '{sender_id}': {e}")
            raise

    def extract_unique_templates(self, messages):
        """Process messages and extract unique templates, checking within the same sender."""
        unique_templates = []
        try:
            for sender_id, msg in messages:
                embedding = self.get_embedding(msg).reshape(1, -1)

                # Check and add template within the sender's context
                if self.is_unique_template(sender_id, embedding):
                    self.add_template(sender_id, embedding, msg)

            # Collect all unique templates across all senders
            for sender_id in self.sender_templates:
                templates = list(
                    zip(
                        self.sender_template_ids[sender_id],
                        self.sender_templates[sender_id],
                        self.sender_embeddings[sender_id],
                    )
                )
                unique_templates.extend(
                    [
                        (sender_id, template_id, template_text, embedding.tolist())
                        for template_id, template_text, embedding in templates
                    ]
                )
            logging.info(f"Extracted {len(unique_templates)} unique templates.")
            return unique_templates
        except Exception as e:
            logging.error(f"Error extracting unique templates: {e}")
            raise



import logging
import pytz
from datetime import datetime, timedelta
import pandas as pd
import csv
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, current_date, current_timestamp, udf, max as max_, row_number
from pyspark.sql.types import StringType, BooleanType
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import to_date

# Configure logging
log_path = f"/Workspace/Users/nareshkumar.y@angelbroking.com/SMS parsing/hackathon/log_{datetime.now().strftime('%Y%m%d')}.log"
logging.basicConfig(
    filename=log_path,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Define the UDF outside the class
def id_sender_concept_reqired(sender):
    sender = str(sender)
    return not sender.isdigit()

id_sender_concept_reqired_udf = udf(id_sender_concept_reqired, BooleanType())

class SenderDataProcessor:
    def __init__(self):
        self.DB_TABLE = "ds_dev.sms_parsing.sms_filtered_data_previous_day"
        self.TARGET_DB_TABLE = "ds_dev.sms_parsing.unique_sender_schema"

        self.FLAG = True

        self.START_DATE = "2024-01-01"
        self.END_DATE = "2024-01-02"
        self.FIRST_DATE_FETCH = "2024-01-01"

        self.NO_OF_DAYS_UPDATE = 15

        # Initiate the spark session for script
        self.spark = SparkSession \
            .builder \
            .appName("PySpark Session for sender job") \
            .config("spark.some.config.options", "some-value") \
            .getOrCreate()

    def get_unique_new_senders(self, date_from='2023-01-01', min_count=5, start_date='2023-01-01', end_date='2023-01-01'):
        query = f"""
        WITH first_occurrences AS (
            SELECT
                sender,
                MIN(fetch_date) AS first_fetch_date
            FROM
                {self.DB_TABLE}
            GROUP BY
                sender
    ),
    filtered_data AS (
        SELECT
            fetch_date,
            sender,
            COUNT(*) as count 
        FROM
            {self.DB_TABLE}
        GROUP BY
            fetch_date,
            sender
    )
    SELECT
        DATE(fd.fetch_date) AS fetch_day,
        fd.sender,
        fd.count,
        fo.first_fetch_date,
        (CASE WHEN DATE(fd.fetch_date) == DATE(fo.first_fetch_date) THEN TRUE ELSE FALSE END) AS new_sender
    FROM
        filtered_data fd
    JOIN
        first_occurrences fo ON fd.sender = fo.sender
    WHERE
        fd.count > {min_count}
        AND DATE(fd.fetch_date) = DATE(fo.first_fetch_date)
        """

        spd = self.spark.sql(query)
        return spd

    def get_todays_date(self):
        ist = pytz.timezone('Asia/Kolkata')
        now_ist = datetime.now(ist)
        today_ist = now_ist.strftime('%Y-%m-%d')
        return today_ist

    def get_n_days_back_date(self, days=0):
        ist = pytz.timezone('Asia/Kolkata')
        now_ist = datetime.now(ist)
        date_n_days_before = now_ist - timedelta(days=days)
        date_n_days_before_str = date_n_days_before.strftime('%Y-%m-%d')
        return date_n_days_before_str

    def get_target_table_df(self, target_table_name=None):
        if target_table_name is None:
            target_table_name = self.TARGET_DB_TABLE
        target_table_df = self.spark.table(target_table_name)
        return target_table_df

    def get_max_id_count(self, target_table_df=None):
        if target_table_df is None:
            target_table_df = self.spark.table(self.TARGET_DB_TABLE)

        max_existing_id = target_table_df.agg(F.max("sender_id")).collect()[0][0]
        if max_existing_id is None:
            max_existing_id = 0

        return max_existing_id

    def add_new_processed_columns(self, source_data_df):
        updated_data_df = source_data_df.withColumn("date_created", current_date()) \
            .withColumn("time_created", current_timestamp()) \
            .withColumn("concept_required", id_sender_concept_reqired_udf(F.col("sender")))
        return updated_data_df

    def add_unique_sender_ids(self, updated_data_df, max_existing_id=99):
        window_spec = Window.orderBy(F.monotonically_increasing_id())  # Create an ordered window specification
        updated_data_df_with_id = updated_data_df.withColumn("sender_id", F.row_number().over(window_spec) + max_existing_id)
        return updated_data_df_with_id

    def filtering_new_ids(self, updated_data_df=None, target_table_df=None):
        if updated_data_df is None:
            pass
        if target_table_df is None:
            target_table_df = self.spark.table(self.TARGET_DB_TABLE)
        filtered_new_data_df = updated_data_df.join(target_table_df.select("sender"), on="sender", how="left_anti")
        return filtered_new_data_df

    def fill_remaining_cols_with_empty_string(self, updated_data_df):
        updated_data_df = updated_data_df.withColumn("description", F.lit(""))
        updated_data_df = updated_data_df.withColumn("concept", F.lit(""))
        return updated_data_df

    def fill_concept_and_description_cols_with_values(self, updated_data_df, concept_df):
        updated_data_df = updated_data_df.alias("udf")
        concept_df = concept_df.alias("cdf")
    
        # Joining on `sender` and `SenderId` with a left join
        joined_df = updated_data_df.join(
            concept_df,
            updated_data_df["sender"] == concept_df["SenderId"],
            "left"
        )
    
        # Adding `Concept` and `description` columns from `concept_df`
        result_df = joined_df.select(
            "udf.*",  # Keep all original columns from updated_data_df
            F.coalesce(F.col("cdf.description"), F.lit("")).alias("description"),  # Add description with fallback
            F.coalesce(F.col("cdf.concept"), F.lit("")).alias("concept")  # Add Concept with fallback
        )
        return result_df

    def get_type_casted_df(self, updated_data_df):
        updated_data_df = updated_data_df.withColumn("sender_id", updated_data_df["sender_id"].cast("bigint"))
        updated_data_df = updated_data_df.withColumn("date_created", to_date(col("date_created"), "yyyy-MM-dd"))
        updated_data_df = updated_data_df.withColumn("first_fetch_date", to_date(col("first_fetch_date"), "yyyy-MM-dd"))
        print(updated_data_df.dtypes)
        return updated_data_df

    def filter_and_append_df_to_target_table(self, updated_data_df, target_table_name=None):
        if target_table_name is None:
            target_table_name = self.TARGET_DB_TABLE
        updated_data_df = updated_data_df.select(
            "sender", "sender_id", "date_created", "first_fetch_date", "time_created", "description", "concept_required", "concept")
        updated_data_df.write.mode('append').saveAsTable(target_table_name)
        return True

    def fetch_concepts_for_senders(self, new_senders_df):
        print("***Fetching concepts***")
        SENDERID_CONCEPTS_TABLE = "ds_dev.dev_naman.senderid_themes"
        try:
            # Get list of all sender IDs
            sender_ids = [row['sender'] for row in new_senders_df.collect()]
            if sender_ids:
                # Fetch concepts for sender IDs and filter out any null concepts in the SQL query
                concepts_df = self.spark.sql(f"""
                    SELECT SenderId, Concept, `TRAI - Principal Entity Name` as description
                    FROM {SENDERID_CONCEPTS_TABLE}
                    WHERE SenderId IN ({', '.join([f"'{id}'" for id in sender_ids])}) AND Concept IS NOT NULL
                """)
                # log_info(f"Fetched concepts for {concepts_df.count()} sender IDs (non-null concepts only).")
          
                # Identify and log sender IDs without matching concepts or with null concepts
                matched_sender_ids = [row['SenderId'] for row in concepts_df.collect()]
                missing_concepts = set(sender_ids) - set(matched_sender_ids)
                if missing_concepts:
                    #log_info(f"Missing concepts (either not found or null) for sender IDs: {missing_concepts}")
                    #log_info(f"Missing concepts count: {len(missing_concepts)}")
                    print(f"Missing concepts (either not found or null) for sender IDs: {missing_concepts}")
                    print(f"Missing concepts count: {len(missing_concepts)}")
                display(concepts_df)
                return concepts_df
            else:
                #log_info("No new sender IDs to process.")
                print("No new sender IDs to process.")
                return self.spark.createDataFrame([], schema="SenderId STRING, Concept STRING")
        except Exception as e:
            #log_error(f"Failed to fetch concepts for sender IDs: {e}")
            print(f"Failed to fetch concepts for sender IDs: {e}")
            return self.spark.createDataFrame([], schema="SenderId STRING, Concept STRING")

    def process_data(self):
        logging.info("Starting the job")
        res = False
        source_df = None
        if self.FLAG:
            todays_date = self.get_todays_date()
            start_date = self.FIRST_DATE_FETCH
            from_date = self.FIRST_DATE_FETCH
            min_c = 1
            source_df = self.get_unique_new_senders(date_from=from_date, min_count=min_c, start_date=start_date, end_date=todays_date)
        else:
            todays_date = self.get_todays_date()
            start_date = "2024-01-01"
            from_date = self.get_n_days_back_date(days=self.NO_OF_DAYS_UPDATE)
            min_c = 1
            source_df = self.get_unique_new_senders(date_from=from_date, min_count=min_c, start_date=start_date, end_date=todays_date)

        target_table_df = self.get_target_table_df(target_table_name=self.TARGET_DB_TABLE)
        max_unique_id = self.get_max_id_count(target_table_df=target_table_df)
        display(source_df)
        print(f"Count after fetching unique new senders: {source_df.count()}")
        updated_data_df = self.add_new_processed_columns(source_df)
        print(f"Count after adding new processed columns: {updated_data_df.count()}")
        updated_data_df_1 = self.add_unique_sender_ids(updated_data_df, max_existing_id=max_unique_id)
        print(f"Count after adding unique sender IDs: {updated_data_df_1.count()}")
        updated_data_df_2 = self.filtering_new_ids(updated_data_df_1, target_table_df=target_table_df)
        print(f"Count after filtering new IDs: {updated_data_df_2.count()}")
        new_senders_df = updated_data_df_2.select("sender").distinct()
        senderid_concept_df = self.fetch_concepts_for_senders(new_senders_df)
        updated_data_df_3 = self.fill_concept_and_description_cols_with_values(updated_data_df_2, senderid_concept_df)
        print(f"Count after filling concept and description columns: {updated_data_df_3.count()}")
        display(updated_data_df_3)
        updated_data_df_4 = self.get_type_casted_df(updated_data_df_3)
        print(f"Count after type casting: {updated_data_df_4.count()}")
        res = self.filter_and_append_df_to_target_table(updated_data_df=updated_data_df_4, target_table_name=self.TARGET_DB_TABLE)

        if res:
            logging.info("Job completed successfully and updated target table")
        else:
            logging.info("Job failed to update target table")

if __name__ == "__main__":

    config = get_config()
    spark = SparkSession.builder.appName("DataMigrationApp").getOrCreate()

    if len(sys.argv) != 3:
        print("Usage: python data_migration.py <start_date> <end_date>")
        logging.error("Incorrect number of arguments provided.")
        sys.exit(1)

    start_date = sys.argv[1]
    end_date = sys.argv[2]

    filteration = DataFilteration(
        source_table=config["SOURCE_TABLE"],
        target_table=config["TARGET_TABLE"],
        spark_session=spark,
    )

    try:
        #filteration.run(start_date, end_date)
        filteration.run('2024-01-01','2024-01-01')
    except Exception as e:
        logging.error(f"Unhandled exception occurred: {e}")
        sys.exit(1)

    processor = SenderDataProcessor()
    processor.process_data()

    query = f"""
                SELECT * FROM ds_dev.sms_parsing.sms_filtered_data_previous_day
            """
    unique_sender_table="ds_dev.sms_parsing.unique_sender_schema"
    unique_pattern_table="ds_dev.sms_parsing.unique_sms_pattern_schema"

    logging.info(f"Executing query: {query}")
    spark_df = spark .sql(query)
    df = spark_df.toPandas()
    print(df.head())
    df=df[['sender','sms_text']]

    messages = list(df.itertuples(index=False, name=None))
    logging.info(f"Loaded {len(messages)} messages from input  file.")

    # Initialize extractor
    extractor = UniqueTemplateExtractor(threshold=0.90)
    unique_templates = extractor.extract_unique_templates(messages)

    # Save unique templates with embeddings to Excel
    columns = ['sender', 'template_id', 'template_message', 'embedding']
    unique_templates_df = pd.DataFrame(unique_templates, columns=columns)
    unique_templates_df['template_attribute']='"company:XXX","date:XXX"'
    unique_templates_without_embedding_df=unique_templates_df[['sender', 'template_id', 'template_message', 'template_attribute']]
    unique_templates_without_embedding_spark_df = spark.createDataFrame(unique_templates_without_embedding_df)

    unique_templates_without_embedding_spark_df.createOrReplaceTempView("temp_df_view")


    query = """
    INSERT INTO ds_dev.sms_parsing.unique_sms_pattern_schema 
        SELECT *
    FROM ds_dev.sms_parsing.unique_sender_schema st 
    INNER JOIN temp_df_view tt
    on st.sender = tt.sender  
    INNER JOIN ds_dev.sms_parsing.unique_sms_pattern_schema pt
    ON st.sender_id = pt.sender_id
    """
    #spark.sql(query)
    spark.createDataFrame(unique_templates_df).createOrReplaceTempView("temp_df_view1")
    spark.sql(f"""
                INSERT INTO ds_dev.sms_parsing.unique_template_embedding
                SELECT * FROM temp_df_view1
            """)
    
    unique_templates_df.to_excel('unique_template_data_with_embeddings_final.xlsx', index=False)

    # Display unique templates
    logging.info("Extracted Unique Templates:")
    for sender_id, template_id, template_text, embedding in unique_templates:
        logging.info(f"Sender: {sender_id}, Template ID: {template_id}, Text: {template_text}, Embedding: {embedding}")
    except Exception as e:
        logging.error(f"An error occurred during processing: {e}")
