In [None]:
import os
from functools import partial
from pathlib import Path
import ast

from langchain.document_loaders import ReadTheDocsLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
# from langchain.embeddings import OpenAIEmbeddings
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
import numpy as np
import psycopg2
from pgvector.psycopg2 import register_vector
# from sentence_transformers import SentenceTransformer
from langchain_huggingface import HuggingFaceEmbeddings
import ray
from ray.data import ActorPoolStrategy

In [None]:
ray.shutdown()
ray.init(logging_level="error")

### Reading raw data in doc/ folder

In [None]:
BASE_PATH = "/Users/guohaorui/Projects/cs230/"
DOCS_DIR = Path('/Users/guohaorui/Projects/cs230/pandas/doc')
files_path = [str(file) for file in DOCS_DIR.rglob("*.rst") if "_static" not in str(file)]

In [None]:
@ray.remote
def read_file(file_path):
    with open(file_path, "r") as f:
        text = f.read()
    return {'source': os.path.relpath(file_path, BASE_PATH), 'text': text}

This is not very efficient rn

In [None]:
ds1 = ray.data.from_items(ray.get([read_file.remote(file_path) for file_path in files_path]))

In [None]:
ds1.count()

### Reading inline API documentation data

In [None]:
DOCS_DIR = Path('/Users/guohaorui/Projects/cs230/pandas/pandas')
files_paths = [str(file) for file in DOCS_DIR.rglob("*.py") if "_testing" not in str(file)]

In [None]:
@ray.remote
def extract_docstrings_from_file(file_path):
    """Extract all docstrings from a Python file."""
    with open(file_path, "r", encoding="utf-8") as f:
        source_code = f.read()

    tree = ast.parse(source_code)
    docstrings = []

    for node in ast.walk(tree):
        combined = None
        if isinstance(node, ast.FunctionDef):
            docstring = ast.get_docstring(node)
            if docstring: 
                signature = f"def {node.name}({', '.join([arg.arg for arg in node.args.args])})"
                combined = f"{signature}: {docstring}"
        elif isinstance(node, ast.ClassDef):
            docstring = ast.get_docstring(node)
            if docstring:
                signature = f"class {node.name}()"
                combined = f"{signature}: {docstring}"
        if combined is not None:
                docstrings.append({
                    'source': os.path.relpath(file_path, BASE_PATH),
                    'text': combined
                })
    
    return docstrings

docstring_tasks = [extract_docstrings_from_file.remote(file_path) for file_path in files_paths]
docstring_results = ray.get(docstring_tasks)
flattened_results = [item for sublist in docstring_results for item in sublist]
ds2 = ray.data.from_items(flattened_results)

In [None]:
ds = ds2.union(ds1)

### chunking

#### 1st approach

In [None]:
def chunk_section_1st_approach(section):
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", " ", ""],
        chunk_size=300,
        chunk_overlap=50,
        length_function=len,
    )
    chunks = text_splitter.create_documents(
        texts=[section["text"]], 
        metadatas=[{"source": section["source"]}])
    return [{"text": chunk.page_content, "source": chunk.metadata["source"]} for chunk in chunks]

In [None]:
chunks_ds = ds.flat_map(chunk_section_1st_approach)

#### 2nd approach

In [None]:
def chunk_section_2nd_approach(section):
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["---", "~~~", "###", "***", "+++", "```", "\n\n", "\n", " ", ""],
        chunk_size=1000,
        chunk_overlap=50,
        length_function=len,
    )
    chunks = text_splitter.create_documents(
        texts=[section["text"]], 
        metadatas=[{"source": section["source"]}])
    return [{"text": chunk.page_content, "source": chunk.metadata["source"]} for chunk in chunks]

In [None]:
section_identifiers = ["=", "-", "~", "#", "*", "+", "`", "~", ":"]
def is_section_identifier(text):
    stripped_text = text.strip()
    return all(char in section_identifiers for char in stripped_text)

filtered_ds = ds.filter(lambda row: not is_section_identifier(row["text"]))

In [None]:
chunks_ds = filtered_ds.flat_map(chunk_section_2nd_approach)

#### 3rd approach

In [None]:
# embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
embedding_model = HuggingFaceEmbeddings(model_name="paraphrase-MiniLM-L6-v2")

In [None]:
def chunk_section_3rd_approach(section):
    text_splitter = SemanticChunker(embedding_model)
    chunks = text_splitter.create_documents(
        texts=[section["text"]], 
        metadatas=[{"source": section["source"]}])
    return [{"text": chunk.page_content, "source": chunk.metadata["source"]} for chunk in chunks]

In [None]:
chunks_ds = ds.flat_map(chunk_section_3rd_approach)

### embeddings

In [None]:
embedded_chunks = chunks_ds.flat_map(
    lambda row: [{"text": row["text"], "source": row["source"], "embeddings": embedding_model.embed_query(row["text"])}])

### DB storage

In [None]:
db_user = 'postgres'
db_password = 'CS230password'
db_host = 'database-1.cdi4gywsaigf.us-east-2.rds.amazonaws.com'
db_port = 5432
db_name = 'postgres'

db_connection_string = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"

In [None]:
def insert_batch_into_db(batch):
    with psycopg2.connect(db_connection_string) as conn:
        register_vector(conn)
        with conn.cursor() as cur:
            for text, source, embedding in zip(batch["text"], batch["source"], batch["embeddings"]):
                cur.execute(
                    "INSERT INTO document_semantic_split (text, source, embedding) VALUES (%s, %s, %s)",
                    (text, source, embedding),
                )
        conn.commit()
    return {}

In [None]:
embedded_chunks.map_batches(
    insert_batch_into_db,
    batch_size=64,
    num_cpus=8,
    concurrency=8
).count()

### Retrieval example

In [None]:
query = "What are the arguments for fillna()?"

embedding = np.array(embedding_model.embed_query(query))
len(embedding)

In [None]:
num_chunks = 5
with psycopg2.connect(db_connection_string) as conn:
    register_vector(conn)
    with conn.cursor() as cur:
        cur.execute("SELECT * FROM document_recursive_split_default_300_50 ORDER BY embedding <-> %s LIMIT %s", (embedding, num_chunks))
        rows = cur.fetchall()
        context = [{"text": row[1]} for row in rows]
        sources = [row[2] for row in rows]


In [None]:
sources[1]