In [None]:
# Download arXiv dataset using Kaggle API

# Install required package
%pip install kaggle

# Set up Kaggle credentials and permissions
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download and extract dataset
!kaggle datasets download -d Cornell-University/arxiv
!unzip arxiv.zip

# Verify file exists
import os

filename = "arxiv-metadata-oai-snapshot.json"
if os.path.exists(filename):
    print(f"Successfully downloaded {filename}")
    !ls -lh {filename}
else:
    print(f"Error: {filename} not found")


In [None]:
!pip install pandas requests backoff python-dotenv openai supabase tqdm pydantic aiohttp nest_asyncio habanero pydoi -q

In [None]:
import os
import json
import time
import logging
import asyncio
import aiohttp
import backoff
import nest_asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
import itertools
import re
import requests
import random
from functools import partial

import pandas as pd
import numpy as np

from pydantic import BaseModel
from dotenv import load_dotenv

import openai
from openai import OpenAI

from supabase import create_client
from tqdm import tqdm

# ------------------- LOGGING ------------------- #
logging.basicConfig(
    level=logging.WARNING,  # Set to WARNING to suppress INFO logs
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)

# ------------------- ENV & GLOBALS ------------------- #
# load_dotenv()  # Load .env if present

# It's recommended to use environment variables for sensitive information
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")



nest_asyncio.apply()

openai.api_key = OPENAI_API_KEY
supabase = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)

# How many lines per chunk in the JSONL file
CHUNK_SIZE = 500
# We'll wait 1s between each chunk
SLEEP_BETWEEN_CHUNKS = 1
# Progress file to track the last completed chunk index (optional for multi-machine scenario)
PROGRESS_FILE = "progress.txt"

# Default short timeout for each fallback method
FALLBACK_TIMEOUT = 1.0

# ------------------- JOURNAL FALLBACK METHODS (SHORT TIMEOUTS) ------------------- #
try:
    from habanero import Crossref
except ImportError:
    logger.warning("habanero not installed. `pip install habanero` if needed.")
    Crossref = None


from pydoi import resolve as pydoi_resolve



def get_journal_via_crossref(doi: str, timeout: float = 3.0) -> Optional[str]:
    """Attempt CrossRef (requests) call, with a short timeout."""
    url = f"https://api.crossref.org/works/{doi}"
    resp = requests.get(url, timeout=timeout)
    if resp.status_code == 200:
        data = resp.json()
        jrnl = data["message"].get("container-title", [])
        if jrnl:
            return jrnl[0]
    return None


def get_journal_via_habanero(doi: str, timeout: float = 3.0) -> Optional[str]:
    """
    Using habanero library, within a short overall timeframe.
    We'll do a direct requests fallback if library is missing or fails.
    """
    if Crossref is None:
        return None
    try:
        cr = Crossref()
        result = cr.works(ids=doi)
        jrnl = result["message"].get("container-title", [])
        if jrnl:
            return jrnl[0]
    except Exception:
        pass
    return None


def get_journal_via_pydoi(doi: str, timeout: float = 3.0) -> Optional[str]:
    """Using pydoi library, with short fallback."""
    if pydoi_resolve is None:
        return None
    try:
        metadata = pydoi_resolve(doi)
        ctitle = metadata.get("container-title")
        if isinstance(ctitle, list) and ctitle:
            return ctitle[0]
        elif isinstance(ctitle, str):
            return ctitle
    except Exception:
        pass
    return None


def get_journal_via_doi2bib(doi: str, timeout: float = 3.0) -> Optional[str]:
    """Scrape from doi2bib.org."""
    url = f"https://doi2bib.org/bib/{doi}"
    try:
        response = requests.get(url, timeout=timeout)
        if response.status_code == 200:
            bibtex = response.text
            match = re.search(r'journal\s*=\s*{([^}]+)}', bibtex, re.IGNORECASE)
            if match:
                return match.group(1).strip()
    except Exception:
        pass
    return None


emails = [
   "Add emails here"
]
email_cycle = itertools.cycle(emails)


def get_journal_via_unpaywall(doi: str, timeout: float = 3.0) -> Optional[str]:
    """Using Unpaywall with rotating emails, short timeout."""
    email = next(email_cycle)
    url = f"https://api.unpaywall.org/v2/{doi}"
    params = {"email": email}
    try:
        resp = requests.get(url, params=params, timeout=timeout)
        if resp.status_code == 200:
            data = resp.json()
            return data.get("journal_name")  # or None
    except Exception:
        pass
    return None


ALL_JOURNAL_METHODS = [
    get_journal_via_crossref,
    get_journal_via_habanero,
    get_journal_via_pydoi,
    get_journal_via_doi2bib,
    get_journal_via_unpaywall,
]


def get_journal_multi(doi: str, fallback_timeout: float = 1.0) -> str:
    """
    Try each fallback method in **random order** so we don't always
    wait for a slow method first. Each has a short timeout (~1s).
    If all fail or return nothing, raise ValueError.
    """
    normalized = doi.replace("https://doi.org/", "").strip()
    if not normalized:
        raise ValueError("DOI is empty or invalid.")

    methods = ALL_JOURNAL_METHODS[:]
    random.shuffle(methods)  # randomize fallback order

    for fn in methods:
        try:
            jrnl = fn(normalized, timeout=fallback_timeout)
            if jrnl and jrnl.lower() not in ("unknown journal", "journal not found"):
                return jrnl
        except Exception:
            # We ignore the error and move on
            pass

    raise ValueError("All fallback methods failed or timed out.")


# ------------------- Pydantic Models ------------------- #
class Metadata(BaseModel):
    date: Optional[str] = None
    journal_ref: Optional[str] = None
    journal_title: Optional[str] = None
    source: str = 'arxiv'
    authors: Optional[List[List[str]]] = None
    categories: Optional[List[str]] = None


# ------------------- EMBEDDINGS MANAGER ------------------- #
class EmbeddingsManager:
    def __init__(self, openai_api_key: Optional[str] = None, default_model: str = "text-embedding-3-small"):
        if not openai_api_key:
            openai_api_key = "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"
        openai.api_key = openai_api_key
        self.client = OpenAI(api_key=openai_api_key)
        self.default_model = default_model

    def get_embeddings(self, text: str, model: Optional[str] = None) -> List[float]:
        if model is None:
            model = self.default_model
        text = text.replace("\n", " ").strip()
        response = self.client.embeddings.create(model=model, input=text)
        return response.data[0].embedding


embeddings_manager = EmbeddingsManager()


def _get_embedding_for_text(text: str) -> List[float]:
    return embeddings_manager.get_embeddings(text)


# ------------------- SUPABASE INSERT ------------------- #
def batch_insert_documents(df: pd.DataFrame, batch_size: int = 100) -> None:
    """
    Insert the valid records in sub-batches of size batch_size.
    """
    total_batches = len(df) // batch_size + (1 if len(df) % batch_size else 0)
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i : i + batch_size]
        batch_data = []
        for _, row in batch_df.iterrows():
            batch_data.append({
                "title": row.get("title", "Untitled"),
                "content": row.get("content", ""),
                "embedding": row.get("embedding", []),
                "url": row.get("url", None),
                "metadata": row.get("metadata", {}) if isinstance(row.get("metadata"), dict) else {}
            })
        try:
            response = supabase.table("documents").insert(batch_data).execute()
            if not response.data:
                logger.warning(f"No data returned for sub-batch {i//batch_size + 1}")
        except Exception as e:
            logger.error(f"Error inserting sub-batch {i//batch_size + 1}: {str(e)}")


# ------------------- PROGRESS HELPER ------------------- #
def load_progress() -> int:
    """
    Reads the last completed chunk index from PROGRESS_FILE.
    Returns 0 if file not found or empty.
    """
    if not os.path.exists(PROGRESS_FILE):
        return 0
    try:
        with open(PROGRESS_FILE, "r") as f:
            line = f.read().strip()
            return int(line)
    except Exception:
        return 0


def save_progress(chunk_index: int):
    """
    Saves the last completed chunk index to PROGRESS_FILE.
    Overwrites previous content.
    """
    with open(PROGRESS_FILE, "w") as f:
        f.write(str(chunk_index))


# ------------------- ASYNC WRAPPER FOR JOURNAL LOOKUPS ------------------- #
@backoff.on_exception(backoff.expo, (aiohttp.ClientError, asyncio.TimeoutError), max_tries=3)
async def _fetch_journal(session: aiohttp.ClientSession, doi: str) -> str:
    """
    Runs get_journal_multi() in a thread executor (to avoid blocking the event loop).
    If it fails => raise exception => we'll store "Error: ...".
    """
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, partial(get_journal_multi, doi, FALLBACK_TIMEOUT))


async def fetch_journals_in_bulk(dois: List[str]) -> Dict[str, str]:
    """
    Concurrency for each chunk of DOIs.
    Returns { original_doi: journal_or_error_msg }.
    If an error occurs or all methods fail, we store "Error: ..." so we can skip it.
    """
    results = {}
    async with aiohttp.ClientSession() as session:
        tasks = []
        for doi in dois:
            tasks.append(asyncio.ensure_future(_fetch_journal(session, doi)))

        fetch_results = await asyncio.gather(*tasks, return_exceptions=True)

    # Count successes / errors
    success_count = 0
    error_count = 0
    for doi, result in zip(dois, fetch_results):
        if isinstance(result, Exception):
            results[doi] = f"Error: {result}"
            error_count += 1
        else:
            results[doi] = result
            success_count += 1

    return results


# ------------------- READ JSONL IN CHUNKS ------------------- #
def read_jsonl_in_chunks(file_path: Path | str, chunk_size: int = 100):
    """
    Generator that yields chunks of data (list of dicts),
    each chunk up to 'chunk_size' lines from the JSONL file.
    """
    batch = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue
            batch.append(record)
            if len(batch) >= chunk_size:
                yield batch
                batch = []
    if batch:
        yield batch


# ------------------- MAIN PIPELINE ------------------- #
def main(
    file_path: str = "arxiv-metadata-oai-snapshot.json",
    start_chunk: int = 1,
    end_chunk: Optional[int] = None
):
    """
    Reads lines in chunks. For each chunk:
      - Only processes if chunk_counter >= start_chunk
      - If end_chunk is provided, stop if chunk_counter > end_chunk
      - Fetch journal info concurrently
      - Filter out records with invalid/no journal
      - Generate embeddings concurrently
      - Batch-insert to Supabase
      - Wait 1 second between chunks

    If you want to fully ignore the progress file logic, just remove or comment it out.
    If you still want the script to skip already processed chunks on each machine,
    you could mix `load_progress()` logic with the [start_chunk, end_chunk] range.
    """

    # Load progress
    current_progress = load_progress()

    chunk_counter = 0
    for batch_data in read_jsonl_in_chunks(file_path, chunk_size=CHUNK_SIZE):
        chunk_counter += 1

        # Skip already processed chunks
        if chunk_counter <= current_progress:
            continue

        # Skip if below start_chunk
        if chunk_counter < start_chunk:
            continue
        # Break if we've passed end_chunk (when end_chunk is not None)
        if end_chunk is not None and chunk_counter > end_chunk:
            break

        # Print current chunk being processed
        print(f"Processing chunk #{chunk_counter}")

        chunk_start = time.time()

        # ---------------- Step 1: Convert to DF & Basic Cleanup ---------------- #
        print("Processing Step: Convert + Cleanup")
        df = pd.DataFrame(batch_data)
        if "doi" not in df.columns:
            df["doi"] = ""

        df["doi"] = df["doi"].fillna("").astype(str).str.strip()
        valid_dois = df.loc[df["doi"] != "", "doi"].tolist()

        # ---------------- Step 2: Fetch Journals Concurrently ---------------- #
        doi_to_journal = {}
        if valid_dois:
            loop = asyncio.get_event_loop()
            doi_to_journal = loop.run_until_complete(fetch_journals_in_bulk(valid_dois))

        # ---------------- Step 3: Filter + Build Processed Data ---------------- #
        print("Processing Step: Filter + Build Processed Data")
        processed_data = []
        for _, row in df.iterrows():
            doi_str = row["doi"]
            if not doi_str:
                continue  # skip if no doi

            jrnl = doi_to_journal.get(doi_str, None)
            if not jrnl or jrnl.startswith("Error:"):
                continue

            jrnl_clean = jrnl.strip()
            if not jrnl_clean or jrnl_clean.lower() in ("unknown journal", "journal not found"):
                continue

            url_str = doi_str if doi_str.startswith("https://doi.org/") else f"https://doi.org/{doi_str}"
            processed_data.append({
                "title": row.get("title", ""),
                "content": row.get("abstract", ""),
                "url": url_str,
                "metadata": Metadata(
                    date=row.get("update_date"),
                    journal_ref=row.get("journal-ref"),
                    journal_title=jrnl_clean,
                    source="arxiv",
                    authors=row.get("authors_parsed"),
                    categories=row["categories"].split() if row.get("categories") else None
                ).model_dump(),
            })

        successful_docs = len(processed_data)
        print(f"Number of successful docs processed: {successful_docs}")

        if not processed_data:
            # Update progress and continue to next chunk
            save_progress(chunk_counter)
            print("-----")
            time.sleep(SLEEP_BETWEEN_CHUNKS)
            continue

        processed_df = pd.DataFrame(processed_data)

        # ---------------- Step 4: Generate Embeddings Concurrency ---------------- #
        print("Processing Step: Generate Embeddings")
        contents = processed_df["content"].fillna("").tolist()
        with ThreadPoolExecutor(max_workers=5) as executor:
            embeddings_list = list(
                tqdm(
                    executor.map(_get_embedding_for_text, contents),
                    total=len(contents),
                    desc=f"Embedding chunk {chunk_counter}",
                    disable=True  # Disable tqdm progress bar
                )
            )
        processed_df["embedding"] = embeddings_list

        # ---------------- Step 5: Insert into Supabase ---------------- #
        batch_insert_documents(processed_df, batch_size=100)

        # Update progress after successful processing
        save_progress(chunk_counter)

        chunk_end = time.time()
        elapsed = chunk_end - chunk_start

        # Add separator after processing the chunk
        print("-----")

        time.sleep(SLEEP_BETWEEN_CHUNKS)

    # Final message after all chunks are processed
    print(f"All done! Processed up to chunk #{chunk_counter}.")


if __name__ == "__main__":
    start = 1
    end = 5000
    main("arxiv-metadata-oai-snapshot.json", start_chunk=start, end_chunk=end)
