<a href="https://colab.research.google.com/github/SM-Learning/advanced-rag-techniques/blob/main/NER_GLiNER_GPU_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import asyncio
import concurrent.futures
import pandas as pd
import numpy as np
import logging
import GPUtil
import teradatasql
import inflect
import re
from datetime import datetime
from typing import List, Dict, Any
from itertools import chain
from queue import Queue
from threading import Event, Thread


In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Initialize inflect engine
p = inflect.engine()

# Teradata connection parameters
TERADATA_CONFIG = {
    'host': 'hostname',
    'user': 'user_id',
    'password': 'password',
    'logmech': 'LDAP'
}

In [None]:

# Table names
SOURCE_TABLE = "source_table"
TARGET_TABLE = "xyz"
ERROR_LOG_TABLE = "error_log"


In [None]:
class GPUMonitor:
    @staticmethod
    def get_gpu_memory_usage(gpu_id: int) -> tuple:
        gpu = GPUtil.getGPUs()[gpu_id]
        return gpu.memoryUsed, gpu.memoryTotal

    @staticmethod
    def calculate_optimal_batch_size(gpu_id: int, current_batch_size: int) -> int:
        used, total = GPUMonitor.get_gpu_memory_usage(gpu_id)
        memory_utilization = used / total

        if memory_utilization > 0.85:
            return max(32, current_batch_size // 2)
        elif memory_utilization < 0.50:
            return min(256, current_batch_size * 2)
        return current_batch_size


In [None]:
class BatchMetrics:
    def __init__(self):
        self.successful_records = 0
        self.failed_records = 0
        self.current_batch = 0
        self.total_batches = 0

    def log_batch_metrics(self, batch_id: int, success_count: int, fail_count: int):
        logging.info(f"""
            Batch {batch_id} Metrics:
            - Successful records: {success_count}
            - Failed records: {fail_count}
            - GPU Memory Usage: {GPUMonitor.get_gpu_memory_usage(0)}
        """)



In [None]:
class DataLoader:
    def __init__(self, batch_size: int = 50000):
        self.batch_size = batch_size
        self.queue = Queue(maxsize=2)
        self.stop_event = Event()

    async def fetch_batch_from_db(self, offset: int) -> pd.DataFrame:
        query = f"""
            SELECT TOP {self.batch_size}
                zi_c_company_id, zi_es_ecid, zi_c_description, zi_industry_primary,
                industries, sub_industries, top3_industries
            FROM {SOURCE_TABLE}
            WHERE row_number > {offset}
            QUALIFY ROW_NUMBER() OVER (ORDER BY zi_c_company_id) > {offset}
        """
        try:
            with teradatasql.connect(**TERADATA_CONFIG) as conn:
                return pd.read_sql(query, conn)
        except Exception as e:
            logging.error(f"Fatal error in database fetch: {e}")
            raise


In [None]:
class NERProcessor:
    def __init__(self, gpu_ids: List[int]):
        self.gpu_ids = gpu_ids
        self.models = {}
        self.batch_size = 128  # Initial batch size
        self.setup_models()

    def setup_models(self):
        model_configs = [
            ("knowledgator/gliner-multitask-large-v0.5", 0.54),
            ("EmergentMethods/gliner_large_news-v2.1", 0.7)
        ]

        for gpu_id in self.gpu_ids:
            self.models[gpu_id] = []
            with torch.cuda.device(gpu_id):
                for model_name, threshold in model_configs:
                    model = GLiNER.from_pretrained(model_name).to(f"cuda:{gpu_id}")
                    model.eval()
                    self.models[gpu_id].append((model, threshold))

    def clean_entity(self, entity: Dict) -> Dict:
        try:
            if entity['label'] == 'year started':
                match = re.search(r'\b(18|19|20)\d{2}\b', entity['text'])
                entity['text'] = match.group(0) if match else ''
            elif entity['label'] not in ['brand']:
                entity['text'] = re.sub(r'[^a-zA-Z\s]', '', entity['text'])
                entity['text'] = p.singular_noun(entity['text']) or entity['text']
        except Exception as e:
            logging.error(f"Error in clean_entity: {e}")
        return entity

    @torch.no_grad()
    async def process_text(self, text: str, gpu_id: int) -> Dict:
        if len(text) < 50:  # Check minimum text length
            return {}

        try:
            entities = []
            for model, threshold in self.models[gpu_id]:
                with torch.cuda.device(gpu_id):
                    batch_entities = model.predict_entities(text, labels, threshold=threshold)
                    entities.extend(batch_entities)

            entities = sorted(
                (self.clean_entity(entity) for entity in entities),
                key=lambda k: (k['label'], -k['score'])
            )

            return self.extract_ner_features(entities)
        except Exception as e:
            logging.error(f"Error processing text on GPU {gpu_id}: {e}")
            # Retry on different GPU if available
            other_gpu = [g for g in self.gpu_ids if g != gpu_id][0]
            return await self.process_text(text, other_gpu)

    def extract_ner_features(self, entities: List[Dict]) -> Dict:
        information = {}
        entity_sets = {}

        for entity in entities:
            label = entity["label"]
            text = entity["text"]

            # Process specific entity types
            if label == "brand":
                text = text.replace('brands', '').replace('brand', '').strip()
            elif label == "industry":
                text = text.lower().replace('industry', '').strip()
            elif label == "product":
                text = text.lower().replace('products', '').replace('product', '').strip()

            if label not in entity_sets:
                entity_sets[label] = set()
            entity_sets[label].add(text)

        # Build information dictionary
        fields = [
            "year started", "product", "brand", "business_categories",
            "business_sub_categories", "industry", "business_services", "offer"
        ]

        for field in fields:
            if field in entity_sets:
                information[field] = " | ".join(entity_sets[field])

        return information

class TeradataManager:
    def __init__(self, config: Dict, batch_size: int = 5000):
        self.config = config
        self.batch_size = batch_size

    async def upload_batch(self, df: pd.DataFrame) -> bool:
        try:
            with teradatasql.connect(**self.config) as conn:
                cursor = conn.cursor()

                data = [tuple(x) for x in df.values]

                for i in range(0, len(data), self.batch_size):
                    batch = data[i:i + self.batch_size]
                    cursor.executemany(self.get_insert_sql(), batch)

                return True
        except Exception as e:
            logging.error(f"Upload error: {e}")
            await self.log_error(df.iloc[0]['zi_c_company_id'], str(e))
            return False

    def get_insert_sql(self) -> str:
        return f"""
        INSERT INTO {TARGET_TABLE} (
            zi_c_company_id, zi_es_ecid, zi_c_description, zi_industry_primary,
            industries, sub_industries, top3_industries, year_started, product,
            brand, industry, business_categories, business_sub_categories,
            business_services, offer
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """

    async def log_error(self, company_id: str, error_msg: str):
        # Error logging table insert logic (commented out as requested)
        pass

async def main():
    metrics = BatchMetrics()
    gpu_processor = NERProcessor(gpu_ids=[0, 1])
    data_loader = DataLoader()
    db_manager = TeradataManager(TERADATA_CONFIG)

    offset = 0
    total_processed = 0

    try:
        while total_processed < 5_000_000:  # 5 million records
            batch_df = await data_loader.fetch_batch_from_db(offset)
            if batch_df.empty:
                break

            # Process batch
            for gpu_id in gpu_processor.gpu_ids:
                gpu_processor.batch_size = GPUMonitor.calculate_optimal_batch_size(
                    gpu_id, gpu_processor.batch_size
                )

            processed_records = []
            for _, row in batch_df.iterrows():
                if len(row['zi_c_description']) >= 50:
                    result = await gpu_processor.process_text(
                        row['zi_c_description'],
                        gpu_id=total_processed % 2  # Alternate between GPUs
                    )
                    processed_records.append({**row.to_dict(), **result})

            # Prepare for upload
            output_df = pd.DataFrame(processed_records)
            success = await db_manager.upload_batch(output_df)

            # Update metrics
            metrics.log_batch_metrics(
                offset // data_loader.batch_size,
                len(processed_records),
                len(batch_df) - len(processed_records)
            )

            offset += data_loader.batch_size
            total_processed += len(batch_df)

    except Exception as e:
        logging.error(f"Fatal error: {e}")
        raise
    finally:
        logging.info(f"Total records processed: {total_processed}")


In [None]:

if __name__ == "__main__":
    asyncio.run(main())

#  Update 3_8

In [None]:
def upload_partition_to_teradata(partition_iter: Iterator[tuple]) -> Iterator[dict]:
    """Process and upload a partition of data to Teradata."""
    from pyspark.taskcontext import TaskContext
    import teradatasql  # Ensure teradatasql is imported here

    partition_id = TaskContext.get().partitionId()
    partition_rows_processed = 0
    partition_rows_failed = 0

    # Convert iterator to list for length calculation and multiple passes
    rows = list(partition_iter)

    if not rows:
        logging.warning(f"Partition {partition_id}: Empty partition")
        return iter([])

    logging.info(f"Partition {partition_id}: Starting upload of {len(rows)} rows")

    try:
        # Establish connection to Teradata
        with teradatasql.connect(
            host=hostname,
            user=user_id,
            password=password,
            logmech=logmech
        ) as con:
            cursor = con.cursor()

            # Define the SQL insert statement
            insert_sql = f"""
            INSERT INTO {table_name} (
                zi_c_company_id, zi_es_ecid, zi_c_description, zi_industry_primary,
                industries, sub_industries, top3_industries, year_started, product,
                brand, industry, business_categories, business_sub_categories,
                business_services, offer
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """

            # Process rows in batches
            total_batches = (len(rows) - 1) // batch_size + 1

            for batch_num, i in enumerate(range(0, len(rows), batch_size), 1):
                batch = rows[i:i + batch_size]
                try:
                    # Convert batch elements to tuples before executing
                    batch_tuples = [tuple(row) for row in batch]
                    cursor.executemany(insert_sql, batch_tuples)
                    partition_rows_processed += len(batch)
                    logging.info(f"Partition {partition_id}: Completed batch {batch_num}/{total_batches} "
                                 f"({len(batch)} rows)")
                except Exception as e:
                    partition_rows_failed += len(batch)
                    logging.error(f"Partition {partition_id}: Batch {batch_num} failed: {str(e)}")
                    continue

    except Exception as e:
        logging.error(f"Partition {partition_id}: Connection error: {str(e)}")
        partition_rows_failed += len(rows)

    # Return statistics for this partition
    return iter([{
        'partition_id': partition_id,
        'rows_processed': partition_rows_processed,
        'rows_failed': partition_rows_failed
    }])

In [None]:
# Create a single sample record
sample_data = [(
    '12345', 'ECID456', 'Sample Corp', 'Technology',
    'Tech,Software', 'Cloud,Data', 'Tech,AI,Cloud', '2020',
    'Software', 'SampleBrand', 'Technology', 'Enterprise',
    'B2B', 'Consulting', 'Cloud Services'
)]

# Define schema explicitly
schema = StructType([
    StructField("zi_c_company_id", StringType(), True),
    StructField("zi_es_ecid", StringType(), True),
    StructField("zi_c_description", StringType(), True),
    StructField("zi_industry_primary", StringType(), True),
    StructField("industries", StringType(), True),
    StructField("sub_industries", StringType(), True),
    StructField("top3_industries", StringType(), True),
    StructField("year_started", StringType(), True),
    StructField("product", StringType(), True),
    StructField("brand", StringType(), True),
    StructField("industry", StringType(), True),
    StructField("business_categories", StringType(), True),
    StructField("business_sub_categories", StringType(), True),
    StructField("business_services", StringType(), True),
    StructField("offer", StringType(), True)
])

# Create DataFrame
sample_df = spark.createDataFrame(sample_data, schema=schema)

# Convert to RDD
sample_rdd = sample_df.rdd.map(tuple)

# Test the upload function
results = sample_rdd.mapPartitions(upload_partition_to_teradata).collect()

# Print results
print("Results:", results)