In [1]:
import gzip
import json
import argparse
from dataclasses import dataclass
from pathlib import Path
from tqdm import tqdm

from ftlangdetect import detect
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType

from success_prediction.rag_components.embeddings import EmbeddingHandler
from success_prediction.rag_components.cleanup import MarkdownCleaner

from success_prediction.config import DATA_DIR, RAW_DATA_DIR

[32m2025-05-17 15:29:10.980[0m | [1mINFO    [0m | [36msuccess_prediction.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /Users/manuelbolz/Documents/git/for_work/company_success_prediction[0m
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@dataclass
class Clients:
    md_cleaner: MarkdownCleaner
    embedding_creator: EmbeddingHandler
    db_client: MilvusClient


def load_raw_file(file_path: Path):
    """
    """
    with gzip.open(file_path, 'r') as f:
        return json.load(f)


def store_links(file_path: Path, data: dict):
    """
    """
    with open(file_path, 'w', encoding='utf-8') as f:
        return json.dump(data, f, ensure_ascii=False, indent=4)


def structure_links(
    ehraid: int,
    links: list[dict],
    email_addresses: set,
    social_media: dict
) -> dict:
    """
    """
    for link in links:
        base_domain = link.get('base_domain')
        if '@' in link.get('text'):
            email_addresses[ehraid]['emails'].add(link['text'])
        elif base_domain == "linkedin.com":
            social_media[ehraid]['linkedin'].add(link['href'])
        elif base_domain == "instagram.com":
            social_media[ehraid]['instagram'].add(link['href'])
        elif base_domain == "facebook.com":
            social_media[ehraid]['facebook'].add(link['href'])
        elif base_domain == "tiktok.com":
            social_media[ehraid]['tiktok'].add(link['href'])
        elif base_domain == "youtube.com":
            social_media[ehraid]['youtube'].add(link['href'])
        elif base_domain == "x.com" or base_domain == "twitter.com":
            social_media[ehraid]['x'].add(link['href'])
    return email_addresses, social_media


def run_pipeline(clients: Clients, idx: int, file_path: Path, **kwargs):
    """
    """
    raw_json = load_raw_file(file_path)
    processed_files = []
    email_addresses, social_media = {}, {}

    for ehraid, urls2attributes in tqdm(raw_json.items()):
        email_addresses[ehraid] = {'emails': set()}
        social_media[ehraid] = {k: set() for k in ['linkedin', 'instagram', 'facebook', 'tiktok', 'youtube', 'x']}

        for url, attributes in urls2attributes.items():
            markdown = attributes.get('markdown')
            if not markdown:
                continue

            date = attributes['date']
            internal_links = [link['href'] for link in attributes['links']['internal']]
            external_links = [link['href'] for link in attributes['links']['external']]

            email_addresses, social_media = structure_links(
                ehraid, attributes['links']['external'], email_addresses, social_media)

            markdown_clean = clients.md_cleaner.clean(markdown, internal_links, external_links)
            if len(markdown_clean) <= 300:
                continue
            
            # Detect language using the text without bracket content, since it includes
            # English tokens such as INTERNAL_LINKS that might confuse the model
            language = detect(text=clients.md_cleaner.remove_nested_brackets(markdown_clean).replace('\n', ' '))

            # Split the text into smaller chunks to fit into the model context + normalize whitespace per chunk
            markdown_chunks = clients.embedding_creator.chunk(markdown_clean)
            markdown_chunks_clean = [
                clients.md_cleaner.normalize_whitespace(doc.page_content)
                for doc in markdown_chunks
            ]

            query_embeddings = clients.embedding_creator.embed(
                markdown_chunks_clean, prefix='query:')
            
            processed_files.extend([
                {
                    'ehraid': int(ehraid),
                    'url': str(url),
                    'date': date,
                    'language': language.get('lang'),
                    'text': md,
                    'embedding': q_emb
                }
                for md, q_emb in zip(markdown_chunks_clean, query_embeddings)
            ])

        email_addresses[ehraid] = {k: list(v) for k, v in email_addresses[ehraid].items()}
        social_media[ehraid] = {k: list(v) for k, v in social_media[ehraid].items()}

    clients.db_client.insert(collection_name=kwargs.get('collection_name'), data=processed_files)

    store_links(RAW_DATA_DIR / f'emails_{idx}.json', email_addresses)
    store_links(RAW_DATA_DIR / f'social_media_{idx}.json', social_media)


def setup_database(client: MilvusClient, collection_name: str, schema: CollectionSchema, replace: bool):
    if replace and client.has_collection(collection_name):
        client.drop_collection(collection_name)

    if not client.has_collection(collection_name):
        client.create_collection(
            collection_name=collection_name,
            schema=schema)
    else:
        print(f"{collection_name} already exists!")


def main(args: argparse.Namespace):

    clients = Clients(
        md_cleaner=MarkdownCleaner(),
        embedding_creator=EmbeddingHandler(model_name='intfloat/multilingual-e5-base'),
        db_client=MilvusClient(uri=DATA_DIR / 'database' / 'websites.db')
    )

    website_schema = CollectionSchema(fields=[
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="ehraid", dtype=DataType.INT64),
        FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=512),
        FieldSchema(name="date", dtype=DataType.VARCHAR, max_length=10),
        FieldSchema(name="language", dtype=DataType.VARCHAR, max_length=5),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=64_000),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768),
    ])
    setup_database(clients.db_client, collection_name=args.collection_name, schema=website_schema, replace=args.replace or False)

    raw_files = [file for file in Path(RAW_DATA_DIR / 'company_websites' / 'current').iterdir() if str(file).endswith('.json.gz')]
    # raw_files = [RAW_DATA_DIR / 'company_websites' / 'current' / '0_websites.json.gz']

    for i, file in enumerate(raw_files):
        run_pipeline(clients, idx=i, file_path=file, collection_name=args.collection_name)

In [3]:
def main(collection_name, replace=False):

    clients = Clients(
        md_cleaner=MarkdownCleaner(),
        embedding_creator=EmbeddingHandler(),
        db_client=MilvusClient(uri=str(DATA_DIR / 'database' / 'websites.db'))
    )

    website_schema = CollectionSchema(fields=[
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="ehraid", dtype=DataType.INT64),
        FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=512),
        FieldSchema(name="date", dtype=DataType.VARCHAR, max_length=10),
        FieldSchema(name="language", dtype=DataType.VARCHAR, max_length=5),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=64_000),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768),
    ])
    # setup_database(clients.db_client, collection_name=collection_name, schema=website_schema, replace=replace)

    raw_files = [file for file in Path(RAW_DATA_DIR / 'company_websites' / 'current').iterdir() if str(file).endswith('.json.gz')]
    # raw_files = [RAW_DATA_DIR / 'company_websites' / 'current' / '0_websites.json.gz']

    for i, file in enumerate(raw_files[136:]):
        run_pipeline(clients, idx=i+136, file_path=file, collection_name= collection_name)

In [4]:
if __name__ == '__main__':
    """
    parser = argparse.ArgumentParser(
        prog='RAGPipeline',
        description='Processes the markdown and handles retrieval from the Milvus DB',
    )
    parser.add_argument('--collection_name', default='current_websites')
    parser.add_argument('--replace', action='store_true')
    args = parser.parse_args()
    main(args)
    """
    main('current_websites', False)

[EmbeddingHandler] Using model on `mps`.


100%|██████████| 500/500 [05:06<00:00,  1.63it/s]
100%|██████████| 496/496 [04:57<00:00,  1.67it/s]
100%|██████████| 500/500 [04:22<00:00,  1.90it/s]
100%|██████████| 499/499 [03:18<00:00,  2.52it/s]
100%|██████████| 499/499 [03:38<00:00,  2.28it/s]
100%|██████████| 498/498 [04:12<00:00,  1.98it/s]
100%|██████████| 449/449 [03:28<00:00,  2.16it/s]
100%|██████████| 500/500 [02:57<00:00,  2.82it/s]
100%|██████████| 448/448 [03:12<00:00,  2.32it/s]
100%|██████████| 450/450 [03:07<00:00,  2.41it/s]
100%|██████████| 499/499 [03:38<00:00,  2.28it/s]
100%|██████████| 499/499 [04:02<00:00,  2.06it/s]
100%|██████████| 500/500 [03:44<00:00,  2.23it/s]
100%|██████████| 499/499 [03:15<00:00,  2.55it/s]
100%|██████████| 499/499 [03:19<00:00,  2.50it/s]
100%|██████████| 499/499 [03:52<00:00,  2.15it/s]
100%|██████████| 498/498 [03:52<00:00,  2.14it/s]
100%|██████████| 500/500 [03:55<00:00,  2.12it/s]
100%|██████████| 499/499 [02:57<00:00,  2.81it/s]
100%|██████████| 450/450 [02:46<00:00,  2.70it/s]


In [None]:
db_client = MilvusClient(uri=DATA_DIR / 'database' / 'websites.db')
embedding_creator = EmbeddingHandler()

dim2query = {
    "Value Proposition & Innovation":
        "query: What products or services does the company offer, what makes them unique or innovative, and what value do they offer to their customers?",
    
    "Purpose & Responsibility":
        "query: What is the company's mission or long-term vision, and what social or environmental initiatives does it pursue?",
    
    "Leadership & People":
        "query: Who founded or currently leads the company, and what are their professional backgrounds?"
}

# Convert to list
dim2embedding = {dimension: embedding_creator.embed([query], prefix='query:') for dimension, query in dim2query.items()}
top_n = 10

search_results = db_client.search(
    collection_name='current_websites',
    data=[dim2embedding["Value Proposition & Innovation"]],
    filter = f"ehraid == {ehraid}",
    limit = top_n,
    output_fields = ["ehraid", "text", "embedding", "url", "date", "language"],
    search_params = {"metric_type": "IP", "params": {"nprobe": 10}},
)



In [None]:
db_client = MilvusClient(uri=DATA_DIR / 'database' / 'websites.db')
company_data = db_client.query(collection_name='current_websites', filter="ehraid == 1252082")

In [5]:
company_data

data: ['{\'id\': 457887796930547137, \'date\': \'2025-05-06\', \'ehraid\': 1252082, \'language\': \'fr\', \'text\': "[SITE_PAGE: 0 ]\\n[SITE_PAGE: Passer au contenu ]\\n[ [IMAGE:Ecole Suisse de Sertissage] ](<INTERNAL_PAGE>)\\n[EXTERNAL_SITE: ] [EXTERNAL_SITE: ]\\n[SITE_PAGE: CONTACT ]\\nOuvrir le menu Fermer le menu\\n[EXTERNAL_SITE: ] [EXTERNAL_SITE: ]\\n[SITE_PAGE: CONTACT ]\\n[ [IMAGE:Ecole Suisse de Sertissage] ](<INTERNAL_PAGE>)\\nOuvrir le menu Fermer le menu\\n00:00\\n00:00\\nRéactiver le sonCouper le son\\nParamètres\\nVitesseNormal\\nVitesseRevenir au menu précédent\\n0.5×0.75×Normal1.25×1.5×1.75×2×\\nQuitter le mode plein écranActiver le mode plein écran\\n% buffered00:00\\nLecture\\nLa vidéo n\'est pas disponible ou le format n\'est pas pris en charge. Essayez un autre navigateur.\\n[SITE_PAGE: RETOUR ]", \'url\': \'https://www.esds.ch/clip\'}', '{\'id\': 457887796930547138, \'date\': \'2025-05-06\', \'ehraid\': 1252082, \'language\': \'fr\', \'text\': "[SITE_PAGE: 0 ]\\n[S