In [None]:
# Import necessary libraries
import json
from pathlib import Path
from langchain import hub
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain.schema import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_core.runnables import RunnablePassthrough
from langchain.docstore.document import Document
from trulens_eval import TruChain, Tru
from trulens_eval.feedback.provider import OpenAI
from trulens_eval import Feedback
import numpy as np

# Initialize TruLens
tru = Tru()
tru.reset_database()

# Function to read JSON files from PROJECT_DATA folder
def read_json_files(folder_path):
    json_files = Path(folder_path).glob('*.json')
    data = []
    for json_file in json_files:
        with open(json_file, 'r') as file:
            data.append(json.load(file))
    return data

# Read JSON files from PROJECT_DATA folder
project_data_folder = 'PROJECT_DATA'
data = read_json_files(project_data_folder)

# Display the loaded data
print(f"Loaded {len(data)} JSON files from {project_data_folder}")

# Initialize LangChain components
chat_openai = ChatOpenAI()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter()
vector_store = Chroma()
output_parser = StrOutputParser()

# Function to perform augmented queries
def augmented_query(query):
    # Split the query into chunks
    chunks = text_splitter.split_text(query)
    
    # Embed the chunks
    embedded_chunks = [embeddings.embed(chunk) for chunk in chunks]
    
    # Store the embeddings in the vector store
    for i, chunk in enumerate(chunks):
        vector_store.add_document(Document(content=chunk, embedding=embedded_chunks[i]))
    
    # Perform the query
    response = chat_openai.query(query)
    
    # Parse the response
    parsed_response = output_parser.parse(response)
    
    return parsed_response

# Example usage of augmented_query function
query = "What are the most common species in the permian interval?"
response = augmented_query(query)
print("Query Response:", response)