In [None]:
import asyncio
import json
from pathlib import Path
from tqdm import tqdm
from ftlangdetect import detect

from success_prediction.rag_components.embeddings import EmbeddingCreator
from success_prediction.rag_components.cleanup import MarkdownCleaner
from success_prediction.vector_db.utils import DatabaseClient

from success_prediction.config import RAW_DATA_DIR, PROCESSED_DATA_DIR


In [None]:
def create_batches(list_object: list, batch_size: int) -> list:
    length = len(list_object)
    for idx in range(0, length, batch_size):
        yield list_object[idx:min(idx + batch_size, length)]


def load_raw_file(file_path: Path):
    with open(file_path, 'r') as f:
        return json.load(f)


def store_links(file_path: Path, data: dict):
    with open(file_path, 'w', encoding='utf-8') as f:
        return json.dump(data, f, ensure_ascii=False, indent=4)


def structure_links(
    ehraid: int,
    links: list[dict],
    email_addresses: set(),
    social_media: dict
) -> dict:
    for link in links:
        base_domain = link.get('base_domain')
        if '@' in link.get('text'):
            email_addresses[ehraid]['emails'].add(link['text'])
        elif base_domain == "linkedin.com":
            social_media[ehraid]['linkedin'].add(link['href'])
        elif base_domain == "instagram.com":
            social_media[ehraid]['instagram'].add(link['href'])
        elif base_domain == "facebook.com":
            social_media[ehraid]['facebook'].add(link['href'])
        elif base_domain == "tiktok.com":
            social_media[ehraid]['tiktok'].add(link['href'])
        elif base_domain == "youtube.com":
            social_media[ehraid]['youtube'].add(link['href'])
        elif base_domain == "x.com":
            social_media[ehraid]['x'].add(link['href'])
    return email_addresses, social_media
            

async def run_pipeline(idx: int, file_path: Path):

    raw_json = load_raw_file(file_path)
    processed_files = []
    email_addresses = {}
    social_media = {}
    for ehraid, urls in tqdm(raw_json.items()):
        email_addresses[ehraid] = {'emails': set()}
        social_media[ehraid] = {
            'linkedin': set(),
            'instagram': set(),
            'facebook': set(),
            'tiktok': set(),
            'youtube': set(),
            'x': set(),
        }
        for url, attributes in urls.items():
            markdown = attributes.get('markdown')

            if not markdown:
                continue

            date = attributes['date']
            external_links = attributes['links']['external']
            markdown_clean = cleaner.clean(markdown)
            markdown_chunks = embedding_creator.chunk(markdown_clean)

            language = detect(text=markdown_clean)

            embeddings = embedding_creator.embed(markdown_chunks)

            processed_files.append(
                {
                    'ehraid': [int(ehraid)] * len(embeddings),
                    'url': [str(url)] * len(embeddings),
                    'date': [date] * len(embeddings),
                    'text': markdown_chunks,
                    'language': [language] * len(embeddings),
                    'embedding': embeddings,
                }
            )
            email_addresses, social_media = structure_links(ehraid, external_links, email_addresses, social_media)

    db_client.insert_data(data=processed_files)

    store_links(PROCESSED_DATA_DIR / f'emails_{idx}.json', email_addresses)
    store_links(PROCESSED_DATA_DIR / f'social_media_{idx}.json', social_media)


async def main(raw_files: list[Path], batch_size: int = 32_000):
    for batch in create_batches(raw_files, batch_size):
        tasks = []
        for i, file in enumerate(batch):
            tasks.append(run_pipeline(i, file))
        await asyncio.gather(*tasks)
    
    
if __name__ == '__main__':

    cleaner = MarkdownCleaner()
    embedding_creator = EmbeddingCreator(model_name='intfloat/multilingual-e5-large-instruct')

    init_args = {'dim': 1024}
    db_client = DatabaseClient(**init_args)
    db_client.setup_database()
    
    raw_files = [file for file in Path(RAW_DATA_DIR).iterdir() if file.endswith('.json')]
    
    asyncio.run(main())
    await main()