Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved env variables to the pydantic settings file #223

Merged
merged 9 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env-template
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
API_KEY=<LLM api key (for example, open ai key)>
EMBEDDINGS_KEY=<LLM embeddings api key (for example, open ai key)>
59 changes: 26 additions & 33 deletions application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,12 @@

from error import bad_request
from worker import ingest_worker
from core.settings import settings
import celeryconfig

# os.environ["LANGCHAIN_HANDLER"] = "langchain"

if os.getenv("LLM_NAME") is not None:
llm_choice = os.getenv("LLM_NAME")
else:
llm_choice = "openai_chat"

if os.getenv("EMBEDDINGS_NAME") is not None:
embeddings_choice = os.getenv("EMBEDDINGS_NAME")
else:
embeddings_choice = "openai_text-embedding-ada-002"

if llm_choice == "manifest":
if settings.LLM_NAME == "manifest":
from manifest import Manifest
from langchain.llms.manifest import ManifestWrapper

Expand Down Expand Up @@ -79,20 +70,20 @@
with open("prompts/chat_reduce_prompt.txt", "r") as f:
chat_reduce_template = f.read()

if os.getenv("API_KEY") is not None:
if settings.API_KEY is not None:
api_key_set = True
else:
api_key_set = False
if os.getenv("EMBEDDINGS_KEY") is not None:
if settings.EMBEDDINGS_KEY is not None:
embeddings_key_set = True
else:
embeddings_key_set = False

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER = "inputs"
app.config['CELERY_BROKER_URL'] = os.getenv("CELERY_BROKER_URL")
app.config['CELERY_RESULT_BACKEND'] = os.getenv("CELERY_RESULT_BACKEND")
app.config['MONGO_URI'] = os.getenv("MONGO_URI")
app.config['CELERY_BROKER_URL'] = settings.CELERY_BROKER_URL
app.config['CELERY_RESULT_BACKEND'] = settings.CELERY_RESULT_BACKEND
app.config['MONGO_URI'] = settings.MONGO_URI
celery = Celery()
celery.config_from_object('celeryconfig')
mongo = MongoClient(app.config['MONGO_URI'])
Expand Down Expand Up @@ -122,8 +113,8 @@ def ingest(self, directory, formats, name_job, filename, user):

@app.route("/")
def home():
return render_template("index.html", api_key_set=api_key_set, llm_choice=llm_choice,
embeddings_choice=embeddings_choice)
return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME,
embeddings_choice=settings.EMBEDDINGS_NAME)


@app.route("/api/answer", methods=["POST"])
Expand All @@ -135,11 +126,11 @@ def api_answer():
if not api_key_set:
api_key = data["api_key"]
else:
api_key = os.getenv("API_KEY")
api_key = settings.API_KEY
if not embeddings_key_set:
embeddings_key = data["embeddings_key"]
else:
embeddings_key = os.getenv("EMBEDDINGS_KEY")
embeddings_key = settings.EMBEDDINGS_KEY

# use try and except to check for exception
try:
Expand All @@ -160,13 +151,13 @@ def api_answer():
# vectorstore = "outputs/inputs/"
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
if embeddings_choice == "openai_text-embedding-ada-002":
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
elif embeddings_choice == "huggingface_sentence-transformers/all-mpnet-base-v2":
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
elif embeddings_choice == "huggingface_hkunlp/instructor-large":
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
elif embeddings_choice == "cohere_medium":
elif settings.EMBEDDINGS_NAME == "cohere_medium":
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))

# create a prompt template
Expand All @@ -182,7 +173,7 @@ def api_answer():

q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
if llm_choice == "openai_chat":
if settings.LLM_NAME == "openai_chat":
# llm = ChatOpenAI(openai_api_key=api_key, model_name="gpt-4")
llm = ChatOpenAI(openai_api_key=api_key)
messages_combine = [
Expand All @@ -195,16 +186,18 @@ def api_answer():
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
elif llm_choice == "openai":
elif settings.LLM_NAME == "openai":
llm = OpenAI(openai_api_key=api_key, temperature=0)
elif llm_choice == "manifest":
elif settings.LLM_NAME == "manifest":
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
elif llm_choice == "huggingface":
elif settings.LLM_NAME == "huggingface":
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
elif llm_choice == "cohere":
elif settings.LLM_NAME == "cohere":
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
else:
raise ValueError("unknown LLM model")

if llm_choice == "openai_chat":
if settings.LLM_NAME == "openai_chat":
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine)
chain = ConversationalRetrievalChain(
Expand Down Expand Up @@ -316,7 +309,7 @@ def combined_json():
"fullName": 'default',
"date": 'default',
"docLink": 'default',
"model": embeddings_choice,
"model": settings.EMBEDDINGS_NAME,
"location": "local"
}]
# structure: name, language, version, description, fullName, date, docLink
Expand All @@ -330,7 +323,7 @@ def combined_json():
"fullName": index['name'],
"date": index['date'],
"docLink": index['location'],
"model": embeddings_choice,
"model": settings.EMBEDDINGS_NAME,
"location": "local"
})

Expand Down Expand Up @@ -421,7 +414,7 @@ def upload_index_files():
"language": job_name,
"location": save_dir,
"date": datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
"model": embeddings_choice,
"model": settings.EMBEDDINGS_NAME,
"type": "local"
})
return {"status": 'ok'}
Expand Down
Empty file added application/core/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions application/core/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseSettings
from pathlib import Path


class Settings(BaseSettings):
LLM_NAME: str = "openai_chat"
EMBEDDINGS_NAME: str = "openai_text-embedding-ada-002"
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"

API_URL: str = "http://localhost:5001" # backend url for celery worker

API_KEY: str = None # LLM api key
EMBEDDINGS_KEY: str = None # api key for embeddings (if using openai, just copy API_KEY


path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
38 changes: 16 additions & 22 deletions application/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from parser.schema.base import Document
from parser.open_ai_func import call_openai_api
from parser.token_func import group_split
from celery import current_task
from urllib.parse import urljoin
from core.settings import settings


import string
Expand All @@ -18,11 +19,12 @@
nltk.download('averaged_perceptron_tagger', quiet=True)
except FileExistsError:
pass


def generate_random_string(length):
return ''.join([string.ascii_letters[i % 52] for i in range(length)])



def ingest_worker(self, directory, formats, name_job, filename, user):
# directory = 'inputs' or 'temp'
# formats = [".rst", ".md"]
Expand All @@ -39,12 +41,8 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
max_tokens = 1250
full_path = directory + '/' + user + '/' + name_job
# check if API_URL env variable is set
if not os.environ.get('API_URL'):
url = 'http://localhost:5001/api/download'
else:
url = os.environ.get('API_URL') + '/api/download'
file_data = {'name': name_job, 'file': filename, 'user': user}
response = requests.get(url, params=file_data)
response = requests.get(urljoin(settings.API_URL, "/api/download"), params=file_data)
file = response.content

if not os.path.exists(full_path):
Expand All @@ -58,8 +56,6 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
zip_ref.extractall(full_path)
os.remove(full_path + '/' + filename)


import time
self.update_state(state='PROGRESS', meta={'current': 1})

raw_docs = SimpleDirectoryReader(input_dir=full_path, input_files=input_files, recursive=recursive,
Expand All @@ -78,22 +74,20 @@ def ingest_worker(self, directory, formats, name_job, filename, user):

# get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl
# and send them to the server (provide user and name in form)
if not os.environ.get('API_URL'):
url = 'http://localhost:5001/api/upload_index'
else:
url = os.environ.get('API_URL') + '/api/upload_index'
file_data = {'name': name_job, 'user': user}
files = {'file_faiss': open(full_path + '/index.faiss', 'rb'),
'file_pkl': open(full_path + '/index.pkl', 'rb')}
response = requests.post(url, files=files, data=file_data)

#deletes remote
if not os.environ.get('API_URL'):
url = 'http://localhost:5001/api/delete_old?path=' + 'inputs/' + user + '/' + name_job
else:
url = os.environ.get('API_URL') + '/api/delete_old?path=' + 'inputs/' + user + '/' + name_job
response = requests.get(url)
response = requests.post(urljoin(settings.API_URL, "/api/upload_index"), files=files, data=file_data)

response = requests.get(urljoin(settings.API_URL, "/api/delete_old?path="))
# delete local
shutil.rmtree(full_path)

return {'directory': directory, 'formats': formats, 'name_job': name_job, 'filename': filename, 'user': user, 'limited': False}
return {
'directory': directory,
'formats': formats,
'name_job': name_job,
'filename': filename,
'user': user,
'limited': False
}