In [None]:
!pip install trafilatura sentence-transformers torch pandas pyarrow duckdb scipy -q
!pip install fireducks

In [4]:
#
# --- Step 2: Import Libraries ---
#
import gradio as gr
import duckdb
# import pandas as pd
import fireducks.pandas as pd

import numpy as np
import os
import re
import io
import logging
import time
import trafilatura
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from abc import ABC, abstractmethod
from dataclasses import dataclass
import warnings

# Suppress a common warning from the sentence-transformers library
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download")


#
# --- Step 3: Configuration & Core Interfaces ---
#

@dataclass
class EmbeddingConfig:
    """Holds all configuration settings for the pipeline."""
    input_path: str = "/content/drive/My Drive/master_july_2025/data/crawled_data_parquet/"
    output_path: str = "/content/drive/My Drive/master_july_2025/data/url_embeddings/"
    model_name: str = 'all-MiniLM-L6-v2' # should be changed to other/multilingual content if the content is not in English
    batch_size: int = 10

class ILogger(ABC):
    """Interface for logging messages."""
    @abstractmethod
    def info(self, message: str): pass
    @abstractmethod
    def error(self, message: str): pass
    @abstractmethod
    def exception(self, message: str): pass

class ConsoleAndGradioLogger(ILogger):
    """Logs messages to the console and a Gradio UI component."""
    def __init__(self, log_output_stream: io.StringIO, level=logging.INFO):
        self._logger = logging.getLogger("EmbeddingLogger")
        self._logger.setLevel(level)
        if self._logger.hasHandlers():
            self._logger.handlers.clear()

        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        self._logger.addHandler(console_handler)

        # Gradio handler
        gradio_handler = logging.StreamHandler(log_output_stream)
        gradio_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        self._logger.addHandler(gradio_handler)

    def info(self, message: str): self._logger.info(message)
    def error(self, message: str): self._logger.error(message)
    def exception(self, message: str): self._logger.exception(message)

#
# --- Step 4: Component Classes (Single Responsibility Principle) ---
#

class EmbeddingStateManager:
    """Manages the state of the embedding process, enabling resumes."""
    def __init__(self, output_path: str, logger: ILogger):
        self.output_path = output_path
        self.logger = logger

    def get_processed_urls(self) -> set:
        """Scans the output directory to find URLs that have already been embedded."""
        processed_urls = set()
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)
            self.logger.info("Output directory created.")
            return processed_urls

        try:
            output_glob_path = os.path.join(self.output_path, '*.parquet')
            # Use DuckDB for efficient scanning of existing results
            processed_df = duckdb.query(f"SELECT DISTINCT URL FROM read_parquet('{output_glob_path}')").to_df()
            processed_urls = set(processed_df['URL'])
            if processed_urls:
                self.logger.info(f"Found {len(processed_urls)} URLs that have already been processed. They will be skipped.")
        except Exception:
            self.logger.info("No previously processed embeddings found. Starting fresh.")
        return processed_urls

class DataLoader:
    """Responsible for loading unprocessed data in batches."""
    def __init__(self, input_path: str, logger: ILogger):
        self.input_path = input_path
        self.logger = logger
        self.con = duckdb.connect()

    def stream_unprocessed_data(self, processed_urls: set, batch_size: int):
        """A generator that yields batches of new data to be processed."""
        input_glob_path = os.path.join(self.input_path, '**', '*.parquet')
        base_query = f"SELECT URL, Content FROM read_parquet('{input_glob_path}') WHERE Status_Code >= 200 AND Status_Code < 300 AND Content IS NOT NULL AND Content != ''"

        if processed_urls:
            processed_urls_df = pd.DataFrame(list(processed_urls), columns=['URL'])

            # --- THIS IS THE FIX ---
            # We replace the non-standard "LEFT ANTI JOIN" with a standard
            # "LEFT JOIN" and a "WHERE ... IS NULL" check. This achieves the same goal.
            final_query = f"""
                SELECT t1.URL, t1.Content
                FROM ({base_query}) AS t1
                LEFT JOIN processed_urls_df AS t2 ON t1.URL = t2.URL
                WHERE t2.URL IS NULL
            """
            # --- END OF FIX ---
        else:
            final_query = base_query

        self.logger.info("Querying for new pages to process...")
        try:
            # Use fetch_record_batch for memory-efficient iteration
            for batch in self.con.execute(final_query).fetch_record_batch(batch_size):
                yield batch.to_pandas()
        except Exception as e:
            self.logger.error(f"Could not query Parquet files. Please check the input path: {e}")
            return

class TextExtractor:
    """Extracts clean text from raw HTML."""
    def extract(self, html_content: str) -> str:
        if not html_content or not isinstance(html_content, str):
            return ""
        text = trafilatura.extract(html_content, include_comments=False, include_tables=False, deduplicate=True)
        if text:
            text = re.sub(r'\n\s*\n', '\n\n', text)
            return text.strip()
        return ""

class EmbeddingGenerator:
    """Generates embeddings for a list of texts."""
    def __init__(self, model_name: str, logger: ILogger):
        self.logger = logger
        self.logger.info(f"Loading embedding model: {model_name}...")
        self.model = SentenceTransformer(model_name)
        self.logger.info("Model loaded successfully.")

    def generate(self, texts: list[str]) -> np.ndarray:
        self.logger.info(f"Generating embeddings for {len(texts)} texts...")
        return self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)

class DataSaver:
    """Saves a batch of embeddings to a Parquet file."""
    def __init__(self, output_path: str, logger: ILogger):
        self.output_path = output_path
        self.logger = logger

    def save_batch(self, df_batch: pd.DataFrame, batch_num: int):
        """Saves a DataFrame of URLs and embeddings to a uniquely named file."""
        batch_filename = f"embeddings_batch_{int(time.time())}_{batch_num}.parquet"
        batch_output_path = os.path.join(self.output_path, batch_filename)
        df_batch.to_parquet(batch_output_path, index=False)
        self.logger.info(f"✅ Saved batch {batch_num} to {batch_filename}")

#
# --- Step 5: The Main Pipeline Orchestrator ---
#

class EmbeddingPipeline:
    """Orchestrates the entire embedding generation process."""
    def __init__(self, config: EmbeddingConfig, logger: ILogger, state_manager: EmbeddingStateManager,
                 data_loader: DataLoader, text_extractor: TextExtractor,
                 embedding_generator: EmbeddingGenerator, data_saver: DataSaver):
        self.config = config
        self.logger = logger
        self.state_manager = state_manager
        self.data_loader = data_loader
        self.text_extractor = text_extractor
        self.embedding_generator = embedding_generator
        self.data_saver = data_saver

    def run(self):
        """A generator that executes the pipeline and yields status updates."""
        try:
            yield "Initializing..."
            processed_urls = self.state_manager.get_processed_urls()

            yield "Loading model and querying data..."
            data_stream = self.data_loader.stream_unprocessed_data(processed_urls, self.config.batch_size)

            batch_num = 1
            processed_in_this_session = False
            for df_batch in data_stream:
                processed_in_this_session = True
                status_msg = f"Processing Batch {batch_num} ({len(df_batch)} pages)..."
                self.logger.info(status_msg)
                yield status_msg

                # Extract Text
                df_batch['clean_text'] = [self.text_extractor.extract(html) for html in tqdm(df_batch['Content'], desc="Extracting Text")]
                df_batch = df_batch[df_batch['clean_text'].str.len() > 100]

                if df_batch.empty:
                    self.logger.info("Batch had no pages with sufficient text after cleaning.")
                    continue

                # Generate Embeddings
                embeddings = self.embedding_generator.generate(df_batch['clean_text'].tolist())

                # Save Batch
                output_df = pd.DataFrame({'URL': df_batch['URL'], 'Embedding': [e.tolist() for e in embeddings]})
                self.data_saver.save_batch(output_df, batch_num)
                batch_num += 1

            if not processed_in_this_session:
                self.logger.info("No new pages to process. The dataset is already up to date.")
                yield "Already up to date."
            else:
                self.logger.info("All new batches processed successfully.")
                yield "Finished"

        except Exception as e:
            self.logger.exception(f"A critical pipeline error occurred: {e}")
            yield f"Error: {e}"


#
# --- Step 6: Gradio UI and Main Execution Logic ---
#

def run_gradio_interface(input_path: str, output_path: str, batch_size: int):
    """Wires up all components and runs the pipeline, yielding UI updates."""
    log_stream = io.StringIO()
    logger = ConsoleAndGradioLogger(log_stream)

    config = EmbeddingConfig(input_path=input_path, output_path=output_path, batch_size=batch_size)

    # Instantiate all our components
    state_manager = EmbeddingStateManager(config.output_path, logger)
    data_loader = DataLoader(config.input_path, logger)
    text_extractor = TextExtractor()
    embedding_generator = EmbeddingGenerator(config.model_name, logger)
    data_saver = DataSaver(config.output_path, logger)

    pipeline = EmbeddingPipeline(config, logger, state_manager, data_loader, text_extractor, embedding_generator, data_saver)

    final_status = "Initializing..."
    for status in pipeline.run():
        final_status = status
        # Yield the current status and the full log content
        yield status, log_stream.getvalue(), ""

    # Generate final summary after the pipeline finishes
    try:
        output_glob_path = os.path.join(output_path, '*.parquet')
        total_embeddings = duckdb.query(f"SELECT COUNT(URL) FROM read_parquet('{output_glob_path}')").fetchone()[0]
        summary_md = f"### ✅ Pipeline Finished\n\n- **Final Status:** {final_status}\n- **Total embeddings saved:** {total_embeddings}\n- **Output location:** `{output_path}`"
    except Exception as e:
        summary_md = f"### Pipeline Finished\n\n- Could not generate summary. Error: {e}"

    yield final_status, log_stream.getvalue(), summary_md


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🤖 Resumable Embedding Pipeline")
    gr.Markdown("This tool reads HTML from Parquet files, cleans it, generates embeddings, and saves the results in batches. It can be stopped and resumed at any time.")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 1. Configuration")
            input_path_box = gr.Textbox(
                label="Input Parquet Folder Path",
                value=EmbeddingConfig.input_path
            )
            output_path_box = gr.Textbox(
                label="Output Embeddings Directory Path",
                value=EmbeddingConfig.output_path
            )
            batch_size_slider = gr.Slider(
                minimum=10, maximum=50, value=EmbeddingConfig.batch_size, step=10,
                label="Batch Size",
                info="How many pages to process in memory at a time."
            )
            start_button = gr.Button("🚀 Start/Resume Embedding Generation", variant="primary")

        with gr.Column(scale=2):
            gr.Markdown("## 2. Status & Results")
            status_output = gr.Textbox(label="Current Status", interactive=False)
            log_output = gr.Textbox(label="Detailed Logs", interactive=False, lines=10, max_lines=20)
            summary_output = gr.Markdown("---")

    start_button.click(
        fn=run_gradio_interface,
        inputs=[input_path_box, output_path_box, batch_size_slider],
        outputs=[status_output, log_output, summary_output]
    )

#
# --- Launch the Application ---
#
if __name__ == '__main__':
    try:
        from google.colab import drive
        drive.mount('/content/drive/', force_remount=True)
        demo.launch(debug=True, share=True)
    except Exception as e:
        print("Could not launch Gradio demo in this environment.")
        print(e)

Keyboard interruption in main thread... closing server.


KeyboardInterrupt: 