Skip to content

Commit

Permalink
generative_qa module
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed May 10, 2023
1 parent 70b94de commit 75e17e3
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 5 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.37.0 (TBD)

### new:
- Support for **Generative Question-Answering** powered by OpenAI models. Ask questions to any set of documents and get back answers with citations to where the answer was found in your corpus.

### changed
- N/A

### fixed:
- N/A


## 0.36.1 (2023-05-09)

### new:
Expand Down
3 changes: 2 additions & 1 deletion ktrain/text/qa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .core import QA, AnswerExtractor, SimpleQA
from .extractive_qa import ExtractiveQABase, AnswerExtractor, SimpleQA
from .generative_qa import GenerativeQA
6 changes: 3 additions & 3 deletions ktrain/text/qa/core.py → ktrain/text/qa/extractive_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def process_question(
# return TU.tokenize(question, join_tokens=True)


class QA(ABC, TorchBase):
class ExtractiveQABase(ABC, TorchBase):
"""
Base class for QA
"""
Expand Down Expand Up @@ -603,7 +603,7 @@ def display_answers(self, answers):
return display_answers(answers)


class SimpleQA(QA):
class SimpleQA(ExtractiveQABase):
"""
SimpleQA: Question-Answering on a list of texts
"""
Expand Down Expand Up @@ -887,7 +887,7 @@ def search(self, query, limit=10):
return output


class _QAExtractor(QA):
class _QAExtractor(ExtractiveQABase):
def __init__(
self,
model_name=DEFAULT_MODEL,
Expand Down
203 changes: 203 additions & 0 deletions ktrain/text/qa/generative_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import os
import pickle
from typing import Optional
from pathlib import Path

try:
from paperqa import Docs

PAPERQA_INSTALLED = True
except ImportError:
PAPERQA_INSTALLED = False

DOCS = "docs_obj.pkl"


class GenerativeQA:
"""
Question-answering using OpenAI or open-source GPT or GPT-like generative LLM models
"""

def __init__(self):
"""
```
GenerativeQA constructor
```
"""
if not PAPERQA_INSTALLED:
raise Exception(
"GenerativeQA in ktrain requires the paper-qa package by Andrew White: pip install paper-qa"
)
self.docs = Docs()

def load(self, path: str):
"""
```
load previously-saved document vector database from folder specified by path
Args:
path(str): folder path
```
"""
with open(os.path.join(path, DOCS), "rb") as f:
self.docs = pickle.load(f)

def save(self, path: str):
"""
```
Save current document vector database to folder represented by path
Save the current vector database to disk
Args:
path(str): folder path
```
"""
if not os.path.exists(path):
os.makedirs(path)
self.docs.index_path = Path(path)
with open(os.path.join(path, DOCS), "wb") as f:
pickle.dump(self.docs, f)

def clear_index(self):
"""
This will delete the entire index.
"""
if input("are you sure you want to delete the vector index? (y/n)") != "y":
print("ok - aborting")
return
index_path = self.docs.index_path.as_posix()
self.docs.clear()
self.save(index_path)

def add_doc(
self,
path: Optional[str] = None,
text: Optional[str] = None,
citation: Optional[str] = None,
key: Optional[str] = None,
disable_check: bool = True,
chunk_chars: Optional[int] = 3000,
):
"""
```
Add documents to the data store
Args:
path(str): Path to the document. Mutually-exclusive with text parameter.
text(str): text of document. Mutually-exclusive with path parameter.
citation(str): The citation for document that will appear in references below answer.
If omitted, the LLM will be used to infer the correct citation from the document text.
key(str): The key for the document that will appear within the body of the answer when referenced.
If omitted, the LLM will be used to infer the correct citaiton from the document text.
disable_check(bool): A check of the text of the document.
chunk_chars(int): This is how many characters documents are split into.
Returns:
None
```
"""
if (path is not None and text is not None) or (path is None and text is None):
raise ValueError(
"The path and text parameters are mutually-exclusive and exactly one must be supplied."
)
if (
path is not None
and not path.lower().endswith(".pdf")
and not path.lower().endswith(".txt")
):
raise ValueError(
"Currently, the path parameter only accepts files that end with either a .pdf or .txt extension."
)

if text is not None:
import os
import tempfile

fd, fpath = tempfile.mkstemp()
os.rename(fpath, fpath + ".txt")
fpath = fpath + ".txt"
try:
with os.fdopen(fd, "w") as tmp:
# do stuff with temp file
tmp.write(text)
key, citation = self.default_key_and_citation(
fpath, key=key, citation=citation
)
self.add_doc(
fpath,
citation=citation,
key=key,
disable_check=disable_check,
chunk_chars=chunk_chars,
)
finally:
pass
return
key, citation = self.default_key_and_citation(path, key=key, citation=citation)
self.docs.add(
path=path,
citation=citation,
key=key,
disable_check=disable_check,
chunk_chars=chunk_chars,
)
return

def query(
self,
query: str,
k: int = 10,
max_sources: int = 5,
length_prompt: str = "about 100 words",
marginal_relevance: bool = True,
answer=None,
key_filter: Optional[bool] = None,
# get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x: [],
):
"""
```
Query for cited answers
```
"""
try:
result = self.docs.query(
query=query,
k=k,
max_sources=max_sources,
length_prompt=length_prompt,
marginal_relevance=marginal_relevance,
answer=answer,
key_filter=key_filter,
)
return result
except RuntimeError:
raise Exception(
"There was a RuntimeError - try addding the following to the top of your notebook:\nimport nest_asyncio\nnest_asyncio.apply()"
)

def default_key_and_citation(
self, path: str, key: Optional[str] = None, citation: Optional[str] = None
):
"""
```
Get default key and citation
```
"""
if path.endswith(".pdf"):
return (key, citation)
default_key = self.compute_key(path)
if key is None:
key = default_key
if citation is None:
citation = f"Document {default_key}"
return (key, citation)

def compute_key(self, path: str):
"""
```
compute MD5 hash
```
"""
from paperqa.utils import md5sum

return f"md5:{md5sum(path)}"
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ["__version__"]
__version__ = "0.36.1"
__version__ = "0.37.0"

0 comments on commit 75e17e3

Please sign in to comment.