In [227]:
import gzip
import json
import argparse
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from tqdm import tqdm

import numpy as np
import torch
from ftlangdetect import detect
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from pymilvus.client.types import ExtraList
from success_prediction.rag_components.embeddings import EmbeddingHandler
from success_prediction.rag_components.cleanup import MarkdownCleaner

from success_prediction.config import DATA_DIR, RAW_DATA_DIR

In [None]:
@dataclass
class Clients:
    md_cleaner: MarkdownCleaner
    embedding_creator: EmbeddingHandler
    db_client: MilvusClient


def load_raw_file(file_path: Path) -> dict:
    """
    Loads a gzipped JSON file and returns its content as a dictionary.

    Args:
        file_path (Path): Path to the gzipped JSON file.

    Returns:
        dict: Parsed JSON data.
    """
    with gzip.open(file_path, 'r') as f:
        return json.load(f)


def store_links(file_path: Path, data: dict) -> None:
    """
    Stores a dictionary as a formatted JSON file.

    Args:
        file_path (Path): Destination file path.
        data (dict): Dictionary to save.
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        return json.dump(data, f, ensure_ascii=False, indent=4)


def structure_links(
    ehraid: int,
    links: list[dict],
    email_addresses: set,
    social_media: dict
) -> tuple[dict, dict]:
    """
    Organizes links by identifying emails and social media handles and storing them per company ID.

    Args:
        ehraid (int): Unique company identifier.
        links (List[dict]): List of extracted link dictionaries.
        email_addresses (Dict[int, Dict[str, Set[str]]]): Storage for emails.
        social_media (Dict[int, Dict[str, Set[str]]]): Storage for social links.

    Returns:
        Tuple containing updated email_addresses and social_media.
    """
    for link in links:
        base_domain = link.get('base_domain')
        if '@' in link.get('text'):
            email_addresses[ehraid]['emails'].add(link['text'])
        elif base_domain == "linkedin.com":
            social_media[ehraid]['linkedin'].add(link['href'])
        elif base_domain == "instagram.com":
            social_media[ehraid]['instagram'].add(link['href'])
        elif base_domain == "facebook.com":
            social_media[ehraid]['facebook'].add(link['href'])
        elif base_domain == "tiktok.com":
            social_media[ehraid]['tiktok'].add(link['href'])
        elif base_domain == "youtube.com":
            social_media[ehraid]['youtube'].add(link['href'])
        elif base_domain == "x.com" or base_domain == "twitter.com":
            social_media[ehraid]['x'].add(link['href'])
    return email_addresses, social_media


def run_pipeline(clients: Clients, idx: int, file_path: Path, **kwargs) -> None:
    """
    Processes raw company website data:
    - Cleans and chunks content
    - Embeds it
    - Extracts contact and social media links
    - Stores results in a Milvus database and contact info files

    Args:
        clients (Clients): Wrapper containing the database, embedding, and cleaning tools.
        idx (int): Index of the file being processed.
        file_path (Path): Path to the raw JSON file.
        **kwargs: Additional options, expects 'collection_name'.
    """
    raw_json = load_raw_file(file_path)
    processed_files = []
    email_addresses, social_media = {}, {}

    for ehraid, urls2attributes in tqdm(raw_json.items()):
        email_addresses[ehraid] = {'emails': set()}
        social_media[ehraid] = {k: set() for k in ['linkedin', 'instagram', 'facebook', 'tiktok', 'youtube', 'x']}

        for url, attributes in urls2attributes.items():
            markdown = attributes.get('markdown')
            if not markdown:
                continue

            date = attributes['date']
            internal_links = [link['href'] for link in attributes['links']['internal']]
            external_links = [link['href'] for link in attributes['links']['external']]

            email_addresses, social_media = structure_links(
                ehraid, attributes['links']['external'], email_addresses, social_media)

            markdown_clean = clients.md_cleaner.clean(markdown, internal_links, external_links)
            markdown_no_links = clients.md_cleaner.remove_nested_brackets(markdown_clean).replace('\n', ' ')
            if len(markdown_no_links) <= 300:
                continue

            # Detect language using the text without bracket content, since it includes
            # English tokens such as INTERNAL_LINKS that might confuse the model
            language = detect(text=markdown_no_links)

            # Split the text into smaller chunks to fit into the model context + normalize whitespace per chunk
            markdown_chunks = clients.embedding_creator.chunk(markdown_no_links)
            markdown_chunks_clean = [
                clients.md_cleaner.normalize_whitespace(doc.page_content)
                for doc in markdown_chunks
            ]

            passage_embeddings = clients.embedding_creator.embed(
                markdown_chunks_clean, prefix='passage:')

            query_embeddings = clients.embedding_creator.embed(
                markdown_chunks_clean, prefix='query:')

            processed_files.extend([
                {
                    'ehraid': int(ehraid),
                    'url': str(url),
                    'date': date,
                    'language': language.get('lang'),
                    'text': md,
                    'embedding_passage': p_emb,
                    'embedding_query': q_emb
                }
                for md, p_emb, q_emb in zip(markdown_chunks_clean, passage_embeddings, query_embeddings)
            ])

        email_addresses[ehraid] = {k: list(v) for k, v in email_addresses[ehraid].items()}
        social_media[ehraid] = {k: list(v) for k, v in social_media[ehraid].items()}

    clients.db_client.insert(collection_name=kwargs.get('collection_name'), data=processed_files)

    store_links(RAW_DATA_DIR / 'company_websites' / 'current' / 'contact_info' / f'emails_{idx}.json', email_addresses)
    store_links(RAW_DATA_DIR / 'company_websites' / 'current' / 'contact_info' /  f'social_media_{idx}.json', social_media)


def setup_database(client: MilvusClient, collection_name: str, schema: CollectionSchema, replace: bool) -> None:
    """
    Sets up a Milvus collection for storing embedded documents.

    Args:
        client (MilvusClient): Initialized Milvus client.
        collection_name (str): Name of the collection to use/create.
        schema (CollectionSchema): Schema of the collection.
        replace (bool): Whether to drop and recreate the collection if it exists.
    """
    if replace and client.has_collection(collection_name):
        client.drop_collection(collection_name)

    if not client.has_collection(collection_name):
        client.create_collection(
            collection_name=collection_name,
            schema=schema)
    else:
        print(f"{collection_name} already exists!")


def main(args: argparse.Namespace):

    clients = Clients(
        md_cleaner=MarkdownCleaner(),
        embedding_creator=EmbeddingHandler(model_name='intfloat/multilingual-e5-base'),
        db_client=MilvusClient(uri=DATA_DIR / 'database' / 'websites.db')
    )

    website_schema = CollectionSchema(fields=[
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="ehraid", dtype=DataType.INT64),
        FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=512),
        FieldSchema(name="date", dtype=DataType.VARCHAR, max_length=10),
        FieldSchema(name="language", dtype=DataType.VARCHAR, max_length=5),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=64_000),
        FieldSchema(name="embedding_passage", dtype=DataType.FLOAT_VECTOR, dim=768),
        FieldSchema(name="embedding_query", dtype=DataType.FLOAT_VECTOR, dim=768),
    ])
    setup_database(clients.db_client, collection_name=args.collection_name, schema=website_schema, replace=args.replace or False)

    raw_files = [file for file in Path(RAW_DATA_DIR / 'company_websites' / 'current').iterdir() if str(file).endswith('.json.gz')]
    # raw_files = [RAW_DATA_DIR / 'company_websites' / 'current' / '0_websites.json.gz']

    for i, file in enumerate(raw_files):
        run_pipeline(clients, idx=i, file_path=file, collection_name=args.collection_name)

In [4]:
if __name__ == '__main__':
    """
    parser = argparse.ArgumentParser(
        prog='RAGPipeline',
        description='Processes the markdown and handles retrieval from the Milvus DB',
    )
    parser.add_argument('--collection_name', default='current_websites')
    parser.add_argument('--replace', action='store_true')
    args = parser.parse_args()
    main(args)
    """
    main('current_websites', False)

[EmbeddingHandler] Using model on `mps`.


100%|██████████| 500/500 [05:06<00:00,  1.63it/s]
100%|██████████| 496/496 [04:57<00:00,  1.67it/s]
100%|██████████| 500/500 [04:22<00:00,  1.90it/s]
100%|██████████| 499/499 [03:18<00:00,  2.52it/s]
100%|██████████| 499/499 [03:38<00:00,  2.28it/s]
100%|██████████| 498/498 [04:12<00:00,  1.98it/s]
100%|██████████| 449/449 [03:28<00:00,  2.16it/s]
100%|██████████| 500/500 [02:57<00:00,  2.82it/s]
100%|██████████| 448/448 [03:12<00:00,  2.32it/s]
100%|██████████| 450/450 [03:07<00:00,  2.41it/s]
100%|██████████| 499/499 [03:38<00:00,  2.28it/s]
100%|██████████| 499/499 [04:02<00:00,  2.06it/s]
100%|██████████| 500/500 [03:44<00:00,  2.23it/s]
100%|██████████| 499/499 [03:15<00:00,  2.55it/s]
100%|██████████| 499/499 [03:19<00:00,  2.50it/s]
100%|██████████| 499/499 [03:52<00:00,  2.15it/s]
100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 500/500 [03:55<00:00,  2.12it/s]
100%|██████████| 499/499 [02:57<00:00,  2.81it/s]
100%|██████████| 450/450 [02:46<00:00,  2.70it/s]


In [None]:
from scraper.crawler_config import value_keywords, esg_keywords, team_keywords, product_keywords

VALUE_PROPOSITION_KEYWORDS = list({k for k in value_keywords + esg_keywords})
LEADERSHIP_KEYWORDS = list({k for k in value_keywords + esg_keywords})
RESPONSIBILITY_KEYWORDS = list({k for k in value_keywords + esg_keywords})

In [228]:
db_client = MilvusClient(uri=str(DATA_DIR / 'database' / 'websites.db'))
embedding_creator = EmbeddingHandler()

dim2query = {
    "Value Proposition & Innovation": [
        "What solutions, services, or products does the company provide to customers?",
        "What products and services does the company advertise on its website?",
        "Which innovative features or technologies are highlighted in the company's offerings?",
        "What benefits or outcomes does the company promise or deliver through its solutions, services, products, or platforms?",
        "How does the company differentiate its products or services from competitors, and what specific customer needs are addressed?"
    ],
    "Purpose & Responsibility": [
        "What is the stated mission, purpose, or long-term vision of the company?",
        "Which ethical, social, or environmental commitments does the company emphasize?",
        "Does the company value sustainability, diversity, inclusion, or in general ESG-related goals?",
        "What values or principles guide the company's operations and decisions?",
        "Does the company participate in any charitable initiatives, community outreach, or global impact programs?"
    ],
    "Leadership & People": [
        "Who are the founders or key leaders of the company, and what roles do they hold?",
        "What are the professional backgrounds or credentials of the company's executive team?",
        "Who makes up the leadership team, and how is the company structured in terms of people and roles?",
        "What experience or expertise does the management bring to the company?",
        "Are there biographies or personal stories of team members or executives available on the website?"
    ]
}

# Convert to list
dim2embedding = {dimension: [embedding_creator.embed([q], prefix='query:') for q in queries] for dimension, queries in dim2query.items()}

[EmbeddingHandler] Using model on `mps`.


In [229]:
index_params = [{
    "field_name": "embedding",
    "metric_type": "IP",  # Use inner product because E5 embeddings are normalized (||v|| = 1)
    "index_type": "FLAT",
}]

db_client.create_index(
    collection_name="current_websites",
    index_params=index_params
)

db_client.load_collection(collection_name="current_websites")

In [None]:
def ensemble_top_passages(company_data, query_embeddings, top_k_per_query=15, final_top_k=15):
    """
    Returns the passages that appear most frequently in the top-k across query ensemble.

    Parameters:
        company_data: List of dicts with fields 'id', 'embedding', 'text', 'url'
        query_embeddings: List of embedded queries for the dimension
        top_k_per_query: Number of top results to keep for each query
        final_top_k: Final number of consensus passages to return

    Returns:
        List of dicts: top passages by consensus frequency
    """
    passage_scores = defaultdict(list)

    for query_vec in query_embeddings:
        query_vec = np.array(query_vec[0])
        scored_entries = []

        for entry in company_data:
            score = np.dot(query_vec, entry['embedding'])
            scored_entries.append({**entry, 'score': score})

        top_k = sorted(scored_entries, key=lambda x: x['score'], reverse=True)[:top_k_per_query]

        for passage in top_k:
            passage_scores[passage['id']].append((passage['score'], passage))

    frequency_counter = Counter({pid: len(scores) for pid, scores in passage_scores.items()})

    # Sort by frequency (how often in top 15), then by best score (descending)
    sorted_passages = sorted(
        passage_scores.items(),
        key=lambda x: (frequency_counter[x[0]], max(s[0] for s in x[1])),
        reverse=True
    )

    final_passages = [x[1][0][1] for x in sorted_passages[:final_top_k]]
    return final_passages


def ensemble_rerank(top_n_entries, query_texts):
    """
    """
    for entry in top_n_entries:
        pairs = [(query, entry['text']) for query in query_texts]
        relevancy_score = embedding_creator.calculate_relevancy_scores(sentence_pairs=pairs).median()
        entry.update({'attention_score': relevancy_score})

    sorted_entries = sorted(
        top_n_entries,
        key=lambda entry: (float(entry['attention_score'])),
        reverse=True
    )
    scores = np.array([entry['attention_score'] for entry in sorted_entries])
    z_scores = (scores - np.mean(scores)) / np.std(scores)
    return [entry for z_score, entry in zip(z_scores, sorted_entries) if z_score >= 0]

def get_dimension_vec(dimension: str, company_data: ExtraList, dim2embedding: dict, dim2query: dict):
    """
    """
    # Get top 15 based on cosine / IP similarity
    top_15 = ensemble_top_passages(
        company_data=company_data,
        query_embeddings=dim2embedding[dimension]
    )

    # Get the most relevant by reranking them via cross encoder
    most_relevant = ensemble_rerank(
        top_n_entries=top_15,
        query_texts=dim2query[dimension]
    )
    
    # combine the remaining into one vector by using the quasi attention score from the ensemble rerank
    dim_vec = embedding_creator.waggregate_embeddings([torch.tensor(entry['embedding']) for entry in most_relevant], [entry['attention_score'] for entry in most_relevant])
    return most_relevant, dim_vec

In [None]:
vec_results = []
for ehraid in [1251382, 1433629]:
    company_data = db_client.query(collection_name='current_websites', filter=f"ehraid == {ehraid}")
    dim_vectors = {}, {}
    for dim in dim2query.keys():
        most_relevant, dim_vec = get_dimension_vec(dim, company_data, dim2embedding, dim2query)
        dim_vectors[dim]['entries'] = most_relevant
        dim_vectors[dim]['vectors'] = dim_vec
    vec_results.append({ehraid: dim_vectors})