In [51]:
import os
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    pipeline
)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
import re
from langchain.text_splitter import TokenTextSplitter

In [52]:
def do_preprocessing(text: str):
    def _strip_multiple_whitespaces(s):
        re_whitespace = re.compile(r"(\s)+", re.UNICODE)
        return re_whitespace.sub(" ", s)


    pipeline = [
            _strip_multiple_whitespaces,
            str.strip,
        ]
    for filter in pipeline:
        text = filter(text)

    return text

def load_documents(directory_path):
    files = os.listdir(directory_path)
    loaders = [PyPDFLoader(os.path.join(directory_path, f)) for f in files]
    documents = []
    for loader in loaders:
        loaded_docs = loader.load()
        for doc in loaded_docs:
            doc.page_content = do_preprocessing(doc.page_content)
        documents.extend(loaded_docs)

    return documents


def process_reports(documents):
    text_splitter = TokenTextSplitter(
        chunk_size=400, 
        chunk_overlap=100,
        encoding_name="cl100k_base" 
    )
    return text_splitter.split_documents(documents)

def vector_store_in_memory(documents):
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-distilroberta-v1",
        model_kwargs={'device': 'cpu'},
        encode_kwargs={'normalize_embeddings': True}
    )
    
    vector_store = FAISS.from_documents(documents, embeddings)
    
    return vector_store

def setup_local_llm():
    model_id = "google/flan-t5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    
    return HuggingFacePipeline(pipeline=pipe)

class EpidemiologicalAssistant:
    def __init__(self, vector_store=None):
        self.llm = setup_local_llm()
        self.vector_store = vector_store
        self.qa_chain = None
        
        if vector_store:
            self.setup_assistant(vector_store)
    
    def setup_assistant(self, vector_store):
        self.vector_store = vector_store
        
        custom_prompt = PromptTemplate(
            template="""
            As an EU epidemiological expert, analyze this data to answer the question precisely.

            DATA: {context}

            QUESTION: {input}
            """,
            input_variables=["context", "input"]
        )
        
        
        combine_docs_chain = create_stuff_documents_chain(
            self.llm, 
            custom_prompt
        )
        
        self.retrieval_chain = create_retrieval_chain(
            self.vector_store.as_retriever(search_type="mmr", 
        search_kwargs={"k": 5, "fetch_k": 10}),
            combine_docs_chain
        )
    
    def ask_question(self, question):
        question = do_preprocessing(question)
        return self.retrieval_chain.invoke({"input": question}).get("answer")


In [53]:
path_folder = "./data"
documents = load_documents(path_folder)
splitted_docs = process_reports(documents)
vector_store = vector_store_in_memory(splitted_docs)
assistant = EpidemiologicalAssistant(vector_store)

Device set to use cuda:0


In [54]:
assistant.ask_question("Which epidemiology was used to metrify the cases of zika virus in the report?")

Token indices sequence length is longer than the specified maximum sequence length for this model (1855 > 512). Running this sequence through the model will result in indexing errors


'EU/EEA countries reported seven cases of Zika virus disease, five (71%) of which were confirmed, while 22 countries reported no cases. The cases were reported by Spain (n=4), Germany (n=2) and Luxembourg (n= 1) (Table 1, Figure 1).'

In [56]:
assistant.ask_question("What is syphillis? What is the key facts about it")

'Syphilis is a sexually transmitted infection (STI) caused by the bacterium Treponema pallidum [1]. Infection can be acquired during sexual activity by direct contact with treponema-rich, open lesions and contaminated secretions from an infected partner. It can also be transmitted from a mother to a baby during pregnancy (congenital syphilis).'

In [57]:
assistant.ask_question("Give me the epidemology information about syphillis")

'Syphilis is a sexually transmitted infection (STI) caused by the bacterium Treponema pallidum [1]. Infection can be acquired during sexual activity by direct contact with treponema-rich, open lesions and contaminated secretions from an infected partner. It can also be transmitted from a mother to a baby during pregnancy (congenital syphilis). After an average incubation period of three weeks (range: 10–90 days), a lesion (usually painless) called a ‘chancre’ occurs at the site of infection (primary syphilis). This is followed by a series of eruptions on mucous membranes and skin (secondary syphilis). Untreated infections can become latent. Although latent infections are non-infectious and cannot be transmitted sexually, they may still be passed on to a foetus.'