In [1]:
import os
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    pipeline
)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS

  from .autonotebook import tqdm as notebook_tqdm


# Coleta e Preparação de Texto

In [2]:
def load_documents(directory_path):
    files = os.listdir(directory_path)
    loaders = [PyPDFLoader(os.path.join(directory_path, f)) for f in files]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())

    return documents

def process_reports(documents):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len
    )
    return text_splitter.split_documents(documents)


def vector_store_in_memory(documents):
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-distilroberta-v1",
        model_kwargs={'device': 'cpu'},
        encode_kwargs={'normalize_embeddings': True}
    )
    
    vector_store = FAISS.from_documents(documents, embeddings)
    
    return vector_store

In [3]:
path_folder = "./data"
documents = load_documents(path_folder)
splitted_docs = process_reports(documents)
vector_store = vector_store_in_memory(splitted_docs)

# Criação de modelinho com QA

In [37]:
def setup_local_llm():
    model_id = "google/flan-t5-small"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    
    return HuggingFacePipeline(pipeline=pipe)

class EpidemiologicalAssistant:
    def __init__(self, vector_store=None):
        self.llm = setup_local_llm()
        self.vector_store = vector_store
        self.qa_chain = None
        
        if vector_store:
            self.setup_assistant(vector_store)
    
    def setup_assistant(self, vector_store):
        self.vector_store = vector_store
        
        custom_prompt = PromptTemplate(
            template="""
            As an EU epidemiological expert, analyze this data to answer the question precisely.

            DATA: {context}

            QUESTION: {input}
            """,
            input_variables=["context", "input"]
        )
        
        
        combine_docs_chain = create_stuff_documents_chain(
            self.llm, 
            custom_prompt
        )
        
        self.retrieval_chain = create_retrieval_chain(
            self.vector_store.as_retriever(),
            combine_docs_chain
        )
    
    def ask_question(self, question):
        return self.retrieval_chain.invoke({"input": question})


assistant = EpidemiologicalAssistant(vector_store)

Device set to use cuda:0


In [42]:
assistant.ask_question("For syphilis, how much cases were confirmed in 2021?")

{'input': 'For syphilis, how much cases were confirmed in 2021?',
 'context': [Document(id='a878b35d-6f95-423b-9170-321181d7309d', metadata={'producer': 'Adobe PDF Library 23.6.136', 'creator': 'Acrobat PDFMaker 23 for Word', 'creationdate': '2023-10-20T09:07:48+02:00', 'author': 'ECDC', 'classificationcontentmarkingheaderfontprops': '#000000,10,Calibri', 'classificationcontentmarkingheadershapeids': '27e81e29,323c0725,1,2,5', 'classificationcontentmarkingheadertext': 'ECDC NORMAL', 'comments': '', 'company': '', 'contenttypeid': '0x010100EE95EE7DB3A482488E68FA4A7091999F00E9E88A449575E24AAD017D04A46B9265', 'dms product': '', 'ecdc_dms_communication_document_type': '1241;#first edit|80850886-251b-4f02-9aa9-b2af2dccb954', 'ecdc_dms_country': '', 'ecdc_dms_general_administration_document_type': '8455;#Report|ecbd77ad-3760-4178-a0ff-1df458fa4e60', 'ecdc_dms_mis_activity_code': '', 'ecdc_dms_organigramme': '1159;#Surveillance|e245debe-c0a6-469f-b09a-56488ff7aacd', 'ecdc_dms_organization': '