In [2]:
import json
from openai import OpenAI
from collections import Counter
from hashlib import md5
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from tqdm import tqdm

SYSTEM_PROMPT = """
You are an expert entity classification system. Your task is to classify entities into predefined categories.

## INSTRUCTIONS:
1. You will receive a list of predefined labels at the beginning
2. For each entity provided, you must select exactly ONE label from the predefined list
3. Classifications must be from the exact predefined labels list
4. If an entity doesn't clearly fit any label, choose the most appropriate one

## PREDEFINED LABELS:
{LABELS}

## RESPONSE FORMAT:
For each entity, provide exactly ONE label from the predefined list followed by a confidence score in this exact format:
"LABEL_NAME|CONFIDENCE_SCORE"

Where CONFIDENCE_SCORE is a number between 0 and 1 (e.g., "Cemetery|0.95")

## CRITERIA FOR SELECTION:
- Choose only from the predefined labels above
- Select the most relevant label based on entity description
- If multiple labels seem relevant, choose the most specific/precise one
- If unsure, make your best educated guess from the available options
- The label must be an exact match to one of the predefined labels
- Confidence score should reflect how certain you are in your choice (0.0 = no knowledge at all, 0.5 = uncertain, 1.0 = very certain)

## IMPORTANT:
- ONLY return ONE label that exists in the predefined list
- Do not create new labels or variations
- Do not return anything other than exactly one valid label followed by a pipe and confidence score
- The label must be a complete string matching exactly one of the predefined labels
- The confidence score must be a decimal between 0 and 1
- Return format: "LABEL_NAME|CONFIDENCE_SCORE" - nothing else

## ENTITY TO CLASSIFY:
{ENTITY_DATA}
"""

def load_labels_from_jsonl(file_path):
    """Load labels from JSONL file"""
    labels = []
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            labels.append(data.get('label', data.get('name', '')))
    return labels

def create_system_prompt(labels_list, entity_data):
    """Create system prompt with labels"""
    labels_str = "\n".join([f"- {label}" for label in labels_list])
    return SYSTEM_PROMPT.format(
        LABELS=labels_str,
        ENTITY_DATA=entity_data
    )

def get_entity_hash(entity_data):
    """Create a hash of the entity data for deduplication"""
    entity_string = entity_data.get('entity', '').lower().strip()
    return md5(entity_string.encode()).hexdigest()

def deduplicate_entities(entities):
    """Remove duplicate entities based on the 'entity' field"""
    seen_hashes = set()
    unique_entities = []
    duplicates = []

    for entity in entities:
        entity_hash = get_entity_hash(entity)
        if entity_hash not in seen_hashes:
            seen_hashes.add(entity_hash)
            unique_entities.append(entity)
        else:
            duplicates.append(entity)

    print(f"Removed {len(duplicates)} duplicate entities")
    print(f"Kept {len(unique_entities)} unique entities")

    return unique_entities

def process_entity_with_pass5(client, entity_data, labels_list):
    """Generate 5 classifications for an entity"""
    simplified_data = {
        "entity": entity_data.get("entity", ""),
        "description": entity_data.get("description", "")
    }

    system_prompt = create_system_prompt(labels_list, json.dumps(simplified_data))

    responses = []
    for i in range(5):
        try:
            completion = client.chat.completions.create(
                model="google/gemma-3-12b-it",  # Replace with your actual model name
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": f"Classify this entity: {json.dumps(simplified_data)}"}
                ],
                temperature=0.7,
                max_tokens=150
            )
            response_text = completion.choices[0].message.content.strip()
            responses.append(response_text)
        except Exception as e:
            print(f"Error generating response {i+1}: {e}")
            responses.append(None)

    return responses

def parse_response_with_confidence(response_text):
    """Parse response to extract label and confidence score"""
    if not response_text or "|" not in response_text:
        return None, None

    try:
        parts = response_text.split("|", 1)
        if len(parts) >= 2:
            label = parts[0].strip().strip('"\'')
            confidence_str = parts[1].strip().strip('"\'')
            confidence = float(confidence_str)
            confidence = max(0.0, min(1.0, confidence))
            return label, confidence
    except Exception as e:
        print(f"Error parsing response '{response_text}': {e}")
        return None, None

    return None, None

def calculate_final_confidence_and_label(responses, labels_list):
    """Calculate final label and confidence from 5 responses"""
    parsed_responses = []

    for resp in responses:
        if resp is None:
            continue
        label, confidence = parse_response_with_confidence(resp)
        if label and confidence is not None and label in labels_list:
            parsed_responses.append((label, confidence))

    if not parsed_responses:
        return None, 0.0

    labels = [item[0] for item in parsed_responses]
    confidences = [item[1] for item in parsed_responses]

    # Only accept if all responses have confidence >= 0.95
    if all(conf >= 0.95 for conf in confidences):
        counter = Counter(labels)
        most_common_label, count = counter.most_common(1)[0]
        
        try:
            final_confidence = sum(confidences) / len(confidences) if confidences else 0.0
        except Exception as e:
            print(f"Error calculating confidence: {e}")
            final_confidence = sum(confidences) / len(confidences) if confidences else 0.0

        final_confidence = max(0.0, min(1.0, final_confidence))
        return most_common_label, final_confidence
    else:
        # If any response has confidence < 0.95, reject the classification
        return None, 0.0

def classify_single_entity(client, entity, labels_list):
    """Process a single entity and return result."""
    responses = process_entity_with_pass5(client, entity, labels_list)
    final_label, final_confidence = calculate_final_confidence_and_label(responses, labels_list)

    result = entity.copy()
    result['assigned_label'] = final_label
    result['confidence_score'] = final_confidence
    result['raw_responses'] = responses

    return result

def classify_entities(client, entities_file, labels_file, max_workers=4):
    """Main function to classify all entities with deduplication and confidence scores using concurrency."""

    # Load entities line by line
    entities = []
    with open(entities_file, 'r') as f:
        for line in f:
            if line.strip():
                entities.append(json.loads(line.strip()))

    print(f"Loaded {len(entities)} entities from {entities_file}")

    # Deduplicate entities
    unique_entities = deduplicate_entities(entities)

    # Load labels
    labels = load_labels_from_jsonl(labels_file)
    print(f"Loaded {len(labels)} labels from {labels_file}")

    # Use ThreadPoolExecutor for concurrent processing with tqdm progress bar
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_entity = {
            executor.submit(classify_single_entity, client, entity, labels): entity
            for entity in unique_entities
        }

        # Process with progress bar
        for future in tqdm(as_completed(future_to_entity), total=len(future_to_entity), desc="Classifying entities"):
            try:
                result = future.result(timeout=500)  # Optional timeout
                results.append(result)
            except Exception as exc:
                print(f"Generated an exception: {exc}")

    return results

In [3]:

# Initialize client
client = OpenAI(
    base_url="http://localhost:8124/v1",
    api_key="dummy"
)

# Run classification with concurrency
start_time = time.time()
results = classify_entities(
    client,
    '/home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/_combined.jsonl',
    '/home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/_entity_labels.jsonl',
    max_workers=32  # Adjust based on performance
)
end_time = time.time()

print(f"\nTotal time taken: {end_time - start_time:.2f} seconds")

# Filter results to only include those with confidence >= 0.95
filtered_results = [result for result in results if result.get('confidence_score', 0) >= 0.95]

print(f"\nFiltered results: {len(filtered_results)} out of {len(results)} passed confidence threshold")

# Save filtered results to file
with open('/home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/classified_entities_filtered.jsonl', 'w') as f:
    for result in filtered_results:
        f.write(json.dumps(result) + '\n')

# Also save all results (including rejected ones) for reference
with open('/home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/classified_entities_all.jsonl', 'w') as f:
    for result in results:
        f.write(json.dumps(result) + '\n')

Loaded 16995 entities from /home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/_combined.jsonl
Removed 831 duplicate entities
Kept 16164 unique entities
Loaded 60 labels from /home/leeeefun681/volume/eefun/webscraping/scraping/vlm_webscrape/app/schema/_entity_labels.jsonl


Classifying entities: 100%|██████████| 16164/16164 [13:11<00:00, 20.42it/s]



Total time taken: 791.88 seconds

Filtered results: 15436 out of 16164 passed confidence threshold
