<a href="https://colab.research.google.com/github/SM-Learning/advanced-rag-techniques/blob/main/NER_Spark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install -q transformers torch inflect teradatasql pyspark GLiNER

In [2]:
!pip install --upgrade GLiNER




In [2]:
from gliner import GLiNER

In [3]:
import datetime
import logging
import re
from itertools import chain
import torch
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import *
from inflect import engine
import pandas as pd
from gliner import GLiNER
#from pyspark.sql.window import Window

# Initialize inflect engine
p = engine()

# Initialize logging
# Configure logging with datetime
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)


In [4]:
# Initialize Spark session
#spark = SparkSession.builder.appName("NER_Processing").getOrCreate()

spark = SparkSession.builder \
       .appName("NER_Processing") \
       .config("spark.driver.memory", "8g") \
       .config("spark.executor.memory", "8g") \
       .getOrCreate()

In [5]:
# Generate more complete sample data
sample_data = [
    {
        "company_id": "1",
        "es_ecid": "EC001",
        "description": "Google is an American multinational technology company focusing on AI, cloud computing, and online advertising.",
        "industry_primary": "Technology",
        "industries": "Software, Internet",
        "sub_industries": "Search Engine, Cloud Computing",
        "top3_industries": "Technology, Advertising, AI"
    },
    {
        "company_id": "2",
        "es_ecid": "EC002",
        "description": "Microsoft Corporation develops software and provides cloud services.",
        "industry_primary": "Technology",
        "industries": "Software",
        "sub_industries": "Operating Systems, Cloud Computing",
        "top3_industries": "Software, Cloud, Enterprise"
    }
]

In [6]:
# Create DataFrame with all required columns
spark_df = spark.createDataFrame(sample_data)

In [7]:
# Initialize logging
#logging.basicConfig(level=logging.INFO)

# Batch sizes
batch_size = 50
#batch_size_db = 100
table_name = "xyz"

# GLiNER model initialization
try:
    model_name = "knowledgator/gliner-multitask-large-v0.5"
    model = GLiNER.from_pretrained(model_name)
    model_name_2 = "EmergentMethods/gliner_large_news-v2.1"
    model_2 = GLiNER.from_pretrained(model_name_2)
except Exception as e:
    logger.error(f"Error loading models: {e}")
    raise

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]



In [8]:
# Labels for NER
labels = {
    "year started": None,
    "product": None,
    "brand": None,
    "business categories": None,
    "business sub categories": None,
    "industry": None,
    "business services": None,
    "offer": None,
}

# Schema definition
ner_schema = StructType([
    StructField("product", StringType(), True),
    StructField("brand", StringType(), True),
    StructField("business_categories", StringType(), True),
    StructField("business_sub_categories", StringType(), True),
    StructField("year_started", StringType(), True),
    StructField("industry", StringType(), True),
    StructField("business_services", StringType(), True),
    StructField("offer", StringType(), True)
])

In [9]:
def clean_entity(entity):
    try:
        if entity['label'] == 'year started':
            match = re.search(r'\b(18|19|20)\d{2}\b', entity['text'])
            entity['text'] = match.group(0) if match else ''
        elif entity['label'] not in ['brand']:
            entity['text'] = re.sub(r'[^a-zA-Z\s]', '', entity['text'])
            entity['text'] = p.singular_noun(entity['text']) or entity['text']
    except Exception as e:
        logger.error(f"Error cleaning entity: {e}")
    return entity

In [12]:
def extract_ner_features(text):
    if not text or len(str(text)) < 50:
        logger.warning(f"Skipping text with length < 50 characters")
        return {}

    information = {}
    entity_sets = {}

    try:
        with torch.no_grad():
            entities = model.predict_entities(text, labels, threshold=0.54)
            entities_2 = model_2.predict_entities(text, labels, threshold=0.7)
            entities = sorted(
                (clean_entity(entity) for entity in chain(entities, entities_2)),
                key=lambda k: (k['label'], -k['score'])
            )

            for entity in entities:
                label = entity["label"]
                text = entity["text"]

                if label == 'brand':
                    text = text.replace('brands', '').replace('brand', '').strip()
                elif label == 'industry':
                    text = text.lower().replace('industry', '').replace('industries', '').strip()
                elif label == 'product':
                    text = text.lower().replace('products', '').replace('product', '').strip()

                if label not in entity_sets:
                    entity_sets[label] = set()
                entity_sets[label].add(text)

                if label == "year started" and "year started" not in information:
                    information["year started"] = text

                # Combine entities with delimiter
                if label not in entity_sets:
                    entity_sets[label] = set()
                entity_sets[label].add(text)

            for field in labels.keys():
                if field in entity_sets:
                    information[field.replace(" ", "_")] = " | ".join(entity_sets[field])

    except Exception as e:
        logger.error(f"Error in NER processing: {e}")

    return information

'''
            # Combine entities with delimiter
            if "product" in entity_sets:
                information["product"] = " | ".join(entity_sets["product"])
            if "brand" in entity_sets:
                information["brand"] = " | ".join(entity_sets["brand"])
            if "business categories" in entity_sets:
                information["business_categories"] = " | ".join(entity_sets["business categories"])
            if "business sub categories" in entity_sets:
                information["business_sub_categories"] = " | ".join(entity_sets["business sub categories"])
            if "industry" in entity_sets:
                information["industry"] = " | ".join(entity_sets["industry"])
            if "business services" in entity_sets:
                information["business_services"] = " | ".join(entity_sets["business services"])
            if "offer" in entity_sets:
                information["offer"] = " | ".join(entity_sets["offer"])
'''

In [13]:
def process_data():
    # Create DataFrame
    spark_df = spark.createDataFrame(sample_data)

    # Filter descriptions
    spark_df = spark_df.filter(F.length("description") >= 50)

    # Register UDF
    extract_ner_features_udf = F.udf(extract_ner_features, ner_schema)

    # Process data
    processed_df = spark_df.withColumn(
        "ner_features",
        extract_ner_features_udf("description")
    ).select("*", "ner_features.*").drop("ner_features")

    # Add missing columns and handle null values
    required_columns = [
        'company_id', 'es_ecid', 'description', 'industry_primary', 'industries',
        'sub_industries', 'top3_industries', 'year_started', 'product', 'brand',
        'industry', 'business_categories', 'business_sub_categories', 'business_services', 'offer'
    ]

    for column in required_columns:
        if column not in processed_df.columns:
            processed_df = processed_df.withColumn(column, F.lit(""))

    # Cleanse the NER responses before loading/printing
    final_df = (processed_df
        .select(required_columns)
        .na.fill("")
        .replace({"nan": "", "None": ""})
    )

    return final_df

In [None]:
# Main execution
if __name__ == "__main__":
    logger.info(f"Starting NER processing at: {datetime.datetime.now()}")

    try:
        result_df = process_data()

        # Convert to Pandas for display
        pandas_df = result_df.toPandas()

        logger.info("\nFinal DataFrame:")
        print(pandas_df.to_string())

    except Exception as e:
        logger.error(f"Error in main execution: {e}")

    logger.info(f"Processing completed at: {datetime.datetime.now()}")

In [21]:
# Debug output
print("\nProcessed DataFrame Schema:")
final_df.printSchema()


Processed DataFrame Schema:
root
 |-- company_id: string (nullable = false)
 |-- es_ecid: string (nullable = false)
 |-- description: string (nullable = false)
 |-- industry_primary: string (nullable = false)
 |-- industries: string (nullable = false)
 |-- sub_industries: string (nullable = false)
 |-- top3_industries: string (nullable = false)
 |-- year_started: string (nullable = false)
 |-- product: string (nullable = false)
 |-- brand: string (nullable = false)
 |-- industry: string (nullable = false)
 |-- business_categories: string (nullable = false)
 |-- business_sub_categories: string (nullable = false)
 |-- business_services: string (nullable = false)
 |-- offer: string (nullable = false)



In [22]:
print("\nProcessed DataFrame Content:")
final_df.show(truncate=False)


Processed DataFrame Content:
+----------+-------+---------------------------------------------------------------------------------------------------------------+----------------+------------------+----------------------------------+---------------------------+------------+-------------------------------+---------+--------+-------------------+-----------------------+-----------------+-----+
|company_id|es_ecid|description                                                                                                    |industry_primary|industries        |sub_industries                    |top3_industries            |year_started|product                        |brand    |industry|business_categories|business_sub_categories|business_services|offer|
+----------+-------+---------------------------------------------------------------------------------------------------------------+----------------+------------------+----------------------------------+---------------------------+-----------

In [12]:
print("End time: ", datetime.datetime.now())

End time:  2025-02-17 05:11:55.737103
