In [8]:
import os
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from google.cloud.sql.connector import Connector
import sqlalchemy


model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")



  from .autonotebook import tqdm as notebook_tqdm


In [9]:
def get_files(dir_path: list) -> dict:
    EXCLUDE_FILES = set([".gitkeep"])
    files = [file for file in os.listdir(dir_path) if file not in EXCLUDE_FILES]
    file_paths = [dir_path + file for file in files]
    return pd.DataFrame({"file_name": files, "file_path": file_paths})


def read_file(file_path):
    with open(file_path, "r") as file:
        return file.read()


def get_text(db: pd.DataFrame) -> pd.DataFrame:
    db["text"] = db["file_path"].apply(read_file)
    return db


# def get_context(db: pd.DataFrame, overlap: int = 300) -> pd.DataFrame:
#     context = []
#     for row in db.iterrows():
#         i = row["chunk"].metadata["start_index"]
#         if i - overlap < 0 or i + overlap > len(row["text"]):
#             continue

#         context.append(row["text"][i - overlap : i + overlap])

#     db["context"] = context
#     return db


def get_context_for_row(row, chunk_size, overlap):
    start_index = row["chunk"].metadata["start_index"]
    text_length = len(row["text"])

    if start_index - overlap < 0:
        return row["text"][start_index : start_index + chunk_size + overlap]
    if start_index + chunk_size + overlap > text_length:
        return row["text"][start_index - overlap : text_length]

    return row["text"][start_index - overlap : start_index + chunk_size + overlap]


def get_context(db: pd.DataFrame, chunk_size, overlap: int = 100) -> pd.DataFrame:
    # Apply the function to each row of the DataFrame
    db["context"] = db.apply(
        get_context_for_row, chunk_size=chunk_size, overlap=overlap, axis=1
    )
    return db


def get_chunks(db: pd.DataFrame, text_splitter) -> pd.DataFrame:
    db["chunk"] = db["text"].apply(lambda s: text_splitter.create_documents([s]))
    return db.explode("chunk")


def get_embeddings(db: pd.DataFrame, model) -> pd.DataFrame:
    tqdm.pandas()
    db["embedding"] = db["chunk"].progress_apply(lambda s: model.encode(s.page_content))
    return db


def separate_tables(db):
    return db[["file_name", "file_path", "text"]].drop_duplicates(
        subset="file_name"
    ), db.drop(columns=['file_path', 'text'])


def to_postgres():
    NotImplemented


CHUNK_SIZE = 512

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=20,
    length_function=len,
    add_start_index=True,
)
db = get_files(os.path.join(os.getcwd(), "../backend/data/legislations/"))
db = get_text(db)
db = get_chunks(db, text_splitter)
db = get_context(db, CHUNK_SIZE)
db = get_embeddings(db, model)
text_db, vector_db = separate_tables(db)
#to_postgres(vector_db)

100%|██████████| 22265/22265 [05:19<00:00, 69.62it/s]


In [14]:
# i.e demo-project:us-central1:demo-instance
INSTANCE_CONNECTION_NAME = %env INSTANCE_CONNECTION_NAME  
DB_USER = %env DB_USER
DB_PASS = %env DB_PASS
DB_NAME = %env DB_NAME

In [60]:
# initialize Connector object
connector = Connector()

# function to return the database connection object
def getconn():
    conn = connector.connect(
        INSTANCE_CONNECTION_NAME, "pg8000", user=DB_USER, password=DB_PASS, db=DB_NAME
    )
    return conn


# create connection pool with 'creator' argument to our connection object function
pool = sqlalchemy.create_engine(
    "postgresql+pg8000://",
    creator=getconn,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [65]:
with pool.connect() as db_conn:
    # db_conn.execute(sqlalchemy.text("CREATE TABLE * FROM legislation_vector_db_001"))
    results = db_conn.execute(sqlalchemy.text("SELECT count(*) FROM legislation_vector_db_003")).fetchall()
    print(results)

[(5829,)]


In [64]:
TABLE_NAME = "legislation_vector_db_003"
from tqdm import tqdm

def create_insert_list():
  insert_list = []
  for row in vector_db.iterrows():
    insert_list.append(
      {"file_name": f"{row[1]['file_name']}",
      "chunk": f"""{row[1]['chunk'].page_content.replace("'", "''")}""",
      "context": f"""{row[1]['context'].replace("'", "''")}""", 
      "embedding": f"{str(list(row[1]['embedding']))}"}
    )
  return insert_list

# connect to connection pool
with pool.connect() as db_conn:
  print("creating table")
  db_conn.execute(
    sqlalchemy.text(
      f"""CREATE TABLE IF NOT EXISTS {TABLE_NAME}
      (file_name VARCHAR NOT NULL,
      chunk VARCHAR NOT NULL,
      context VARCHAR NOT NULL,
      embedding VECTOR NOT NULL);"""
    )
  )

  # # commit transaction (SQLAlchemy v2.X.X is commit as you go)
  # db_conn.commit()

  # insert data into our ratings table
  insert_stmt = sqlalchemy.text(
      f"INSERT INTO {TABLE_NAME} (file_name, chunk, context, embedding) VALUES (:file_name, :chunk, :context, :embedding);"
  )

  insert_list = create_insert_list()

  # insert entries into table
  print("inserting entries")
  for batch in tqdm(range(0, len(insert_list), 100)):
    db_conn.execute(insert_stmt, insert_list[batch:batch+100])
  
  # commit transactions
  print("committing entries")
  db_conn.commit()

  # query and fetch ratings table
  results = db_conn.execute(sqlalchemy.text(f"SELECT count(*) FROM {TABLE_NAME}")).fetchall()

  # show results
  print(len(results))
  print(results[0])

inserting entries...


100%|██████████| 56/56 [10:46<00:00, 11.54s/it]


KeyboardInterrupt: 