<a href="https://colab.research.google.com/github/Koks-creator/SimpleRAG/blob/main/SimpleRAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install unidecode langchain==0.2.12 langchain_community==0.2.11 openai==1.39.0 pypdf sentence-transformers chromadb unstructured python-pptx python-magic nltk==3.9.1



**If you wanna use ppt/pptx files - run 2nd and 3rd cells - if not, you can skip them since it takes few minuts to install**

In [2]:
!apt install libreoffice

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libreoffice is already the newest version (1:7.3.7-0ubuntu0.22.04.6).
0 upgraded, 0 newly installed, 0 to remove and 45 not upgraded.


In [3]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [38]:
import os
import re
from typing import Optional
from dataclasses import dataclass, field
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import UnstructuredPowerPointLoader
from openai import OpenAI
from unidecode import unidecode

**1. Create Vector Database**

In [39]:
CHUNK_SIZE = 20000,
CHUNK_OVERLAP = 20
DB_PATH = "chroma_db_ncnn"
OVERWRITE_DB = True
K=5

In [40]:
@dataclass
class DataLoader():
    _loaders: dict = field(init=False, repr=False)

    def __post_init__(self) -> None:
        self._loaders = {
            "pdf": PyPDFLoader,
            "txt": TextLoader,
            "md": TextLoader,
            "ppt": UnstructuredPowerPointLoader
        }

    @staticmethod
    def remove_html_tags(text: str) -> str:
        clean = re.compile('<.*?>')
        return re.sub(clean, '', text)

    def clean_text(self, text: str) -> str:
        return self.remove_html_tags(unidecode(text))

    def load_file(self, file_path: str) -> list:
        file_name, file_extension = os.path.splitext(file_path)
        try:
            file_extension = file_extension.replace(".", "")

            loader = self._loaders[file_extension](file_path)
            data = loader.load()
            for d in data:
                d.page_content = self.clean_text(text=d.page_content)

            return data
        except KeyError:
            return []

    def load_files(self, data_folder: str) -> list:
        docs = []
        files = os.listdir(data_folder)
        for file in files:
              docs.extend(self.load_file(file_path=f"{data_folder}/{file}"))
        return docs


@dataclass
class VectorDatabase:
    db_path: str
    chunk_size: int
    chunk_overlap: int
    separators: list = field(default_factory=lambda: ["\n\n", "\n", ".", "?", "!", ",", " "])
    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    device: str = "cpu"

    def __post_init__(self) -> None:
        self.embedding_func = HuggingFaceEmbeddings(model_name=self.model_name, model_kwargs={"device": self.device})

    def create_db(self, data: list, overwrite: bool = True) -> Chroma:
        if overwrite:
            self.delete_collections()
        vectorstore = Chroma.from_documents(data, self.embedding_func, persist_directory=self.db_path)

        return vectorstore

    def get_db(self) -> Chroma:
        return Chroma(persist_directory=self.db_path, embedding_function=self.embedding_func)

    def get_db_data(self) -> dict:
        return self.get_db().get()

    def query_db(self, query: str, k: int = 5) -> list:
        vector_db = self.get_db()
        results = vector_db.similarity_search(query, k=k)

        return results

    def delete_collections(self) -> None:
        ids = self.get_db_data()["ids"]
        if ids:
            self.get_db().delete(ids)
        else:
            print("Nothing to delete")


In [41]:
dl = DataLoader()

res = dl.load_files("/content/data")

In [42]:
vector_db = VectorDatabase(
    db_path=DB_PATH,
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    )

vector_db.create_db(data=res, overwrite=OVERWRITE_DB)

<langchain_community.vectorstores.chroma.Chroma at 0x7cf7578b9b40>

In [43]:
vector_db.get_db_data()

{'ids': ['0092db99-53b4-472d-9eec-01d60b2ec1a7',
  '02fed92b-c377-442b-9338-d8cb29f6d0b6',
  '04b67caf-439d-4b46-a77a-8995f5115f7f',
  '06210e41-b995-4e66-98c3-a7c2d489e437',
  '06d33b10-1d27-4553-8ff8-06464e62f19d',
  '07a4459f-f2d7-49b9-b700-d8d4629d02c1',
  '0c2a2433-22f4-4d36-90d6-7dba959ca06a',
  '114eb477-7666-41bf-8d83-868b338e30ab',
  '143fbf55-a151-4458-888f-f45507e2dddb',
  '1c8f4345-0a40-42a6-8406-33ee3996e988',
  '22552521-8c85-495f-a446-faa9128f8657',
  '2445de51-72ca-44ec-967e-19586288241b',
  '27684619-b824-4c48-99fc-08a485b521ac',
  '2787dd6e-99fc-44c0-a9ee-8ce91a753eb8',
  '2d89f62a-4bb3-430b-b711-cd2afe67e6d5',
  '3aa82eb5-ce48-49ee-bd8d-aebb65aad588',
  '3c7567d0-0792-484c-bc59-18faee0cdda9',
  '3fc7db6a-0ce8-49e0-9255-09c7c2dc70fc',
  '45a3e06c-cb4d-4048-8295-be2c8a3c784a',
  '46cc2877-f784-4d28-81dd-87da5e5281f5',
  '49ba51e5-d6cb-4b23-9657-848bdb1a3cdf',
  '4dc83a0c-65ce-4f2b-8bcc-3342abfe4404',
  '4eb83cae-4003-4624-b01a-3168818c3060',
  '4f2fe6b7-ddcb-4ecd-9c1d-

**2. Quering**

In [49]:
START_MSG = """You are a helpful assistant. Your task is to answer questions based only on the context provided.
If certain parts of the context are not relevant to the question, you can ignore them.
If you cannot find the answer within the context, please respond by saying that you can't find the answer in the given context.
Always aim to provide clear and understandable answers. When answering please tell on what page you found information and from which document.
"""
API_KEY = ''

In [102]:
@dataclass(kw_only=True)  # since python 3.10 kw_only=True
class TurboRAG(VectorDatabase):
   # openai_api_key: Optional[str] = None  # for older python
    openai_api_key: int
    model: str = "gpt-4o-mini"
    def __post_init__(self) -> None:
        super().__post_init__()

        self.chatgpt_client = OpenAI(api_key=self.openai_api_key)
        self.vector_db = self.get_db()

    def get_context(self, query: str, k: int = 5) -> str:
        context = ""
        raw_context = self.query_db(query=query, k=k)

        for res in raw_context:
            page = res.metadata.get("page", None)
            page = page + 1 if page else None

            context += f"PAGE: {page} DOCUMENT: {res.metadata['source'].split('/')[-1]} - {res.page_content}\n"
        return context

    def prepare_prompt(self, query: str, context: str, start_msg: str = START_MSG) -> str:
        return f"{start_msg} \nQuestion: {query} \nContext: {context}"

    def get_answer_from_context(self, prompt: str) -> str:
        chat_completion = self.chatgpt_client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        model=self.model)

        return chat_completion.choices[0].message.content

    def get_answer(self, question: str, start_msg: str = START_MSG) -> str:
        context = self.get_context(query=question, k=K)
        prompt = self.prepare_prompt(query=question, context=context, start_msg=start_msg)
        answer = self.get_answer_from_context(prompt=prompt)

        return answer, context, prompt

    def chat(self) -> None:
        messages = []
        messages.append({
              "role": "user",
              "content": START_MSG
        })
        chat_completion = self.chatgpt_client.chat.completions.create(
                messages=messages,
                model=self.model
        )
        assistant_response = chat_completion.choices[0].message.content
        messages.append({
            "role": "assistant",
            "content": assistant_response
        })

        print(f"Assistant: {assistant_response}")

        while True:
            user_input = input("You: ")
            if user_input.lower() in ["exit", "quit", "bye"]:
              print("Ending the chat. Goodbye!")
              break

            user_context = self.get_context(query=user_input, k=K)
            user_prompt = self.prepare_prompt(query=user_input, context=user_context, start_msg="")

            messages.append({
                "role": "user",
                "content": user_prompt
            })

            chat_completion = self.chatgpt_client.chat.completions.create(
                messages=messages,
                model=self.model
            )

            assistant_response = chat_completion.choices[0].message.content
            messages.append({
                "role": "assistant",
                "content": assistant_response
            })

            print(f"Assistant: {assistant_response}")

In [103]:
turbo_rag = TurboRAG(
    openai_api_key=API_KEY,
    db_path=DB_PATH,
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP
)

In [99]:
question = "What is cnn?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'CNN stands for Convolutional Neural Network. It is a type of deep learning model that excels in processing data that has a spatial or grid-like organization, such as images, videos, or time series. CNNs achieve this by utilizing convolution operations and apply a series of transformations to learn features from the input data while reducing the number of trainable parameters through weight sharing properties. A basic CNN architecture typically involves layers that perform convolution, non-linear activation, pooling operations, and fully connected layers for classification or regression tasks.\n\nThis information can be found on PAGE 24 of the document "CNN.pdf".'

In [100]:
question = "What are requirements for New Egg GPUs?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'The requirements for New Egg GPUs as mentioned in the context are:\n\n1. **Python 3.9**\n2. **Packages from requirements.txt file**\n\nThis information can be found on page 1 of the document titled "README.md - New Egg GPUs."'

In [89]:
question = "What is rnn?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

"An RNN, or Recurrent Neural Network, is defined as a non-linear dynamical system that uses non-linear functions for its operations. In a shallow RNN, the system's output depends on its previous states, which allows it to model time series or sequences. This characteristic distinguishes it from other types of neural networks. \n\nThis information can be found on PAGE: 16 of the document CI-RNN-LSTM.pdf."

In [90]:
question = "what is Deep Feedforward Networks?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'Deep Feedforward Networks are a class of parametric, non-linear, and hierarchical representation models optimized using stochastic gradient descent. They consist of multiple layers where the output of one layer serves as the input for the next in a hierarchical fashion. The network is "deep" because it contains multiple layers, and "feedforward" indicates that connections between the nodes do not form cycles, meaning that the data moves in one direction, from input to output.\n\nA key feature of these networks is the use of non-linear activation functions that allow them to learn complex patterns. Training involves adjusting the model parameters (weights and biases) based on the data provided, enabling the network to learn from its experiences.\n\nThis information is found on page 4 of the document "CNN.pdf - Deep Learning: Basics and CNN".'

In [59]:
question = "Give examples of non-linear functions"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'Examples of non-linear functions include:\n\n1. Hyperbolic Tangent Function (tanh)\n2. Sigmoid Function\n3. Rectified Linear Unit (ReLU)\n\nThese functions are significant components of deep neural networks as they help convert linear input signals into non-linear outputs. The information can be found on PAGE: 10 of DOCUMENT: CNN.pdf - Deep Learning: Basics and CNN.'

In [61]:
question = "what is lstm?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'LSTM stands for Long Short-Term Memory, which is a type of Recurrent Neural Network (RNN) designed to handle long-term temporal dependencies. It features a mechanism that enables the network to "remember" relevant information over extended periods. LSTMs consist of four interacting layers and have the capability to learn long-term dependencies. They utilize a cell state that can be manipulated through what are called gates, which control the flow of information. There are three main gates in an LSTM: the input gate, the output gate, and the forget gate, each serving a specific function in managing the information stored in the cell state.\n\nThis information can be found on PAGE 43 and PAGE 36 of the document "CI-RNN-LSTM.pdf."'

In [63]:
question = "whos is author of CI-RNN-LSTM.pdf?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'The author of the document "CI-RNN-LSTM.pdf" is Adrian Horzyk, as mentioned on page 47 of the document.'

In [64]:
question = "explain BPTT"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'BPTT, or Backpropagation Through Time, is an adaptation of the traditional backpropagation algorithm that is designed to work with sequential patterns often found in temporal data. It allows for the computation of gradients over sequences by unfolding the recurrent network through time and applying the standard backpropagation technique across these unfolded layers. \n\nThis method enables the network to learn dependencies that span across time steps, making it particularly relevant for tasks involving sequences, such as time series forecasting or natural language processing.\n\nYou can find this information on pages 22, 23, 24, and 27 of the document titled "CI-RNN-LSTM.pdf."'

In [68]:
question = "when do i have my birthday?"
answer, context, prompt = turbo_rag.get_answer(question=question)
answer

'Your birthday is on the 6th of September. This information is found on the section titled "birthdays" in the document randomData.txt.'

**Chat**

In [104]:
turbo_rag.chat()

Assistant: Understood! Please provide the context or document with information that I should refer to in order to answer your questions.
You: what is rnn
Assistant: An RNN, or Recurrent Neural Network, is defined as a non-linear dynamical system that involves functions, where these functions can include non-linear functions such as tanh. The document also mentions the use of additional architectural features like shortcut connections and higher-order states. 

This information is found on PAGE 16 of the document titled "CI-RNN-LSTM.pdf".
You: desribe web app in new gpu's api
Assistant: The web app in the new GPU's API is constructed using Dash and is described as a tool for visualizing data. It operates on port 8050 and serves the purpose of displaying the collected and processed information regarding GPUs scraped from the Newegg website. The main functionality of the web app is to provide a user interface for the data that has been scraped, transformed, and stored in a database, enabl

KeyboardInterrupt: Interrupted by user