### 📚 Sectral (Mistral-7B_instruct.v0.1) 🚀
**Retrieval augmented generation (RAG)** is a natural language processing (NLP) technique that combines the strengths of both retrieval- and generative-based artificial intelligence (AI) models. Based on the task assigned, the easiest and best suited way to extract insightful data without huge amounts of data analysis, in a short period of time, while being resistant to events as, changes in data formats, layouts, etc, was to implement a RAG based on a relatively lightweight model (`Mistral 7B`  in our case).

The following notebook contains the code blocks required for the functioning of the CLI program that initializes a RAG at the backend and provides a chat based interactive interface to the user. The user has the ability to **load** data corresponding to a particular stock and then ask **questions** related to it. The user also has the ability to create and visualize important financial metrics via **plots** and get **insights** from the model itself.

Even after my best efforts to create the system completely locally so that it could be easily deployed to a server without the requirement of a GPU, it turned out that even after quantization of the model (reduction in weight precisions), the model still required a GPU to function. Hence the decision was later taken to shift to an interactive kaggle notebook and provide a chat interface that is as intuitive and easy to use as possible, while making sure that the required results were being generated.

**Dependency Installation**: The following code block consists of the various libraries and frameworks required for the SectralCLI UI and Interface to function correctly.

In [1]:
!pip install -qU langchain datasets sec-edgar-downloader matplotlib sentence-transformers unstructured faiss-cpu faiss-gpu
!pip install -qU transformers accelerate sentence_transformers faiss-gpu langchain bitsandbytes loralib datasets edgartools torchinfo

!pip install -i https://pypi.org/simple/ bitsandbytes
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 23.8.0 requires cubinlinker, which is not installed.
cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 23.8.0 requires ptxcompiler, which is not installed.
cuml 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
keras-cv 0.8.2 requires keras-core, which is not installed.
keras-nlp 0.9.3 requires keras-core, which is not installed.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 1.26.4 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 15.0.2 which is incompatible.

### Required Configuration
Load a model without downloading it - We leverage the Mistral model provided by Kaggle. In my case, my current setup, does not include a GPU which made it almost impossible even to train a quantized version of the Mistral 7B model. Hence, the required settings for this Kaggle Notebook to function correctly are,

- Internet: `On` (Toggle)
- Accelerator: `GPU P100`
- Model: mistral 7b-instruct-v0.1-hf/1 (by kaggle)
- Persistence: Files Only

### Components
The entire notebook is divided into the following major sections:
- **SecConfig** : Contains the configurations and other parameters required by the various components
  - Internally uses the [edgartools](https://github.com/dgunning/edgartools) library to gather facts and information in the form of pandas dataframes
- **SecPlot**: Contains functions required to plot graphs and financial data related to a particular stock
  - Internally uses the `matplotlib` and `pandas` libraries to create, plot and store the figures / graphs locally for reference
- **SecDB**: Contains functions required to create / load the vector database indexes depending on availability
  - Internally uses `FAISS` (Facebook AI Similarity Search) and stores the vector embeddings in the `HuggingFaceEmbeddings` format.
- **SecCLI**: Responsible to bring all the components together. Initializes the model, loads data and manages the interactive session
  - Internally uses all the previously listed components: `SecDB`, `SecPlot` & `SecConfig` to create an interactive user session.

**Code Formatting and Structure**: All the components strictly follow an Object Oriented Programming (OOP) approach wherein every member function and associated parameters are neatly encapsulated within classes. The files have also been formatted using `black` (python formatter) to follow the `PEP 8` standard.

### SecConfig.py
The following code block contains configuration details for all the components that follow. In order to conform to and enforce a configuration standard, multiple `dataclass` decorated classes have been created: `SecPlotConfig`, `SecDbConfig` & `SecCliConfig`. The names of all the variables are mostly self explanatory along with comments for explanation if required. The code block also contains the default configurations for all the 3 previously mentioned dataclasses as well as any other data that might be requried by the components that follow.

In [7]:
"""
SecConfig.py / configuration: contains the configuration details of various
different types of classes and objects used in the construction of the RAG
"""
import torch
from dataclasses import dataclass
from transformers import BitsAndBytesConfig
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Disable memory-efficient SDP and Flash SDP
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

@dataclass
class SecPlotConfig:
    """
    Data class for SEC plot configuration.
    """
    df_x: str
    df_val: str
    df_metric: str
    df_filing: str
    df_x_label: str
    filing_type: str
    stock_ticker: str

    df_cols: list[str]
    fin_metrics: list[str]


@dataclass
class SecDbConfig:
    """
    Data class for SEC database configuration.
    """
    end_date: str
    start_date: str
    secd_email: str
    filing_type: str
    secd_company: str
    stock_ticker: str
    embedding_model: str
    model_chunk_size: int
    secd_base_data_dir: str
    model_chunk_overlap: int
    embeddings_cache_dir: str
    force_recalculate_index: bool


@dataclass
class SecCliConfig:
    """
    Data class for SEC CLI configuration.
    """
    model_name: str
    tokenizer: AutoTokenizer
    bnb_config: BitsAndBytesConfig
    model: AutoModelForCausalLM
    pipe: pipeline
    hf_pipe: HuggingFacePipeline



# Default SEC plot configuration
DEFAULT_SEC_PLOT_CONFIG = SecPlotConfig(
    df_x="fy",
    df_val="val",
    df_metric="fact",
    df_filing="form",
    df_x_label="Financial Year (FY)",
    filing_type="10-K",
    stock_ticker="AAPL",
    df_cols=["end", "fy", "fp", "filed", "val"],
    fin_metrics=[
        "Cash",
        "Assets",
        "NetIncomeLoss",
        "SalesRevenueNet",
        "StockholdersEquity",
        "InventoryNet",
        "OperatingIncomeLoss",
        "LongTermDebt",
        "Liabilities",
        "EarningsPerShareBasic",
    ],
)

# Default SEC database configuration
DEFAULT_SEC_DB_CONFIG = SecDbConfig(
    end_date="2023-12-31",
    start_date="1995-01-01",
    secd_email="anshsarkar@gmail.com",
    filing_type="10-K",
    secd_company="NA",
    stock_ticker="AAPL",
    embedding_model="sentence-transformers/all-MiniLM-l6-v2",
    model_chunk_size=800,
    secd_base_data_dir="data/",
    model_chunk_overlap=150,
    embeddings_cache_dir="embeddings/",
    force_recalculate_index=False,
)

# Default model name
DEFAULT_MODEL_NAME = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"

# Default BitsAndBytesConfig for Model Quantization
DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

# Default tokenizer
DEFAULT_TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME)

# Default loaded model
DEFAULT_LOADED_MODEL = AutoModelForCausalLM.from_pretrained(
    DEFAULT_MODEL_NAME,
    quantization_config=DEFAULT_BNB_CONFIG,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# Default pipeline
DEFAULT_PIPE = pipeline(
    "text-generation",
    model=DEFAULT_LOADED_MODEL,
    tokenizer=DEFAULT_TOKENIZER,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    max_new_tokens=1024,
)

# Default SEC CLI configuration
DEFAULT_SEC_CLI_CONFIG = SecCliConfig(
    model_name=DEFAULT_MODEL_NAME,
    tokenizer=DEFAULT_TOKENIZER,
    bnb_config=DEFAULT_BNB_CONFIG,
    model=DEFAULT_LOADED_MODEL,
    pipe=DEFAULT_PIPE,
    hf_pipe=HuggingFacePipeline(pipeline=DEFAULT_PIPE),
)

# Builtin prompts dictionary
# The insights prompt is written keeping in mind some of the most crucial financial insights
# which can be obtained from the 10 - K filings of a particular stock.
builtin_prompts = {
    "insights": "Please provide a detailed financial analysis of a company based on its recent 10-K filings."
    + "Your analysis should cover the following financial aspects:"
    + "1. Revenue Trends: Describe the company's revenue trends over the past few "
    + "years and discuss any significant changes or patterns."
    + "2. Profitability Analysis: Evaluate the company's profitability by analyzing metrics "
    + "such as gross profit margin, operating profit margin, and net profit margin. Discuss "
    + "any factors influencing profitability."
    + "3. Cash Flow Analysis: Assess the company's cash flow from operating activities, investing "
    + "activities, and financing activities. Comment on the company's ability to generate cash "
    + "and manage its cash flow effectively."
    + "4. Debt Levels and Obligations: Review the company's debt levels, including long-term debt, current "
    + "debt, and debt repayment schedules. Discuss the impact of debt on the company's financial position."
    + "5. Capital Expenditures: Analyze the company's capital expenditures and investment activities. Comment "
    + "on the company's investment strategy and its implications for future growth."
    + "6. Return on Investment: Calculate and discuss metrics such as return on equity (ROE) and return on "
    + "assets (ROA) to evaluate the company's efficiency in generating returns for shareholders."
}

# Exports
__all__ = [
    "SecDbConfig",
    "SecCliConfig",
    "SecPlotConfig",
    "builtin_prompts",
    "DEFAULT_SEC_DB_CONFIG",
    "DEFAULT_SEC_CLI_CONFIG",
    "DEFAULT_SEC_PLOT_CONFIG",
]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### SecPlot.py
The following code block contains code related to the `SecPlot` class which is responsible for the proper plotting of graphs of financial metrics that are related and relevant to a given stock. During initialization of an object of the `SecPlot` class, parameters are passed to it via the `SecPlotConfig` dataclass defined under the `SecConfig` code block above. The names of all the variables are mostly self explanatory along with comments for explanation if required. The code block contains various internal functions to be used within the class along with various publicly available functions which can be used to invoke various functionalities relevant to an object of the class.

In [8]:
"""
SecPlot.py contains all the necessary classes and functions required to create,
plot and store multiple graphs related to facts / metrics of a given stock ticker
"""
import re
import os
import pandas as pd
import edgar as SEC
import seaborn as sns
from datetime import datetime
import matplotlib.pyplot as plt

# imports from the config file is requried only when the script is being
# run in a local environment. it is not required in an interactive notebook
# as we have already loaded the configuration above
# from SecConfig import SecPlotConfig
# from SecConfig import DEFAULT_SEC_PLOT_CONFIG


class SecPlot:
    """
    Class for plotting financial data fetched from SEC filings.
    """

    def __init__(self, config: SecPlotConfig):
        """
        Initializes SecPlot instance.

        Args:
            config (SecPlotConfig): Configuration object for SEC plot.
        """
        self._config = config
        self._categories: list[str] = []
        self._plot_data: dict[str, pd.DataFrame] = {}

    def __str__(self):
        """
        String representation of SecPlot instance.

        Returns:
            str: String representation.
        """
        return "SecPlot(" + str(self._config) + ")"

    def __repr__(self):
        """
        String representation of SecPlot instance.

        Returns:
            str: String representation.
        """
        return "SecPlot(" + str(self._config) + ")"

    def _sec_plot_print(self, string, end="\n"):
        """
        Helper method to print formatted messages related to SecPlot.

        Args:
            string (str): Message to print.
            end (str, optional): Ending character for the print statement. Defaults to "\n".
        """
        print("\033[92m" + "SecPlot > " + str(string) + "\033[0m", end=end)

    def _map_fact_to_label(self, fact_name: str):
        """
        Maps a fact name to a readable label.

        Args:
            fact_name (str): Name of the fact to map.

        Returns:
            str: Readable label for the fact.
        """
        # Split the string based on capital letters
        words = re.findall("[A-Z][^A-Z]*", fact_name)
        words = [word.capitalize() for word in words]
        # Capitalize the first letter of each word and join them with spaces
        readable_name = " ".join(words)
        return readable_name

    def fetch_data(self):
        """
        Fetches financial data from SEC filings.
        """
        SEC.set_identity("A. Sarkar anshsarkar@gmail.com")
        company = SEC.Company(self._config.stock_ticker)
        df = company.get_facts().to_pandas()
        df = df[df[self._config.df_filing] == self._config.filing_type]

        unique_fact_categories = list(df[self._config.df_metric].unique())

        for metric in self._config.fin_metrics:
            if metric in unique_fact_categories:
                self._categories.append(metric)

        for metric in self._categories:
            self._plot_data[metric] = df[df[self._config.df_metric] == metric][
                self._config.df_cols
            ]

    def plot_data(self, save=True, display=False):
        """
        Plots the fetched financial data.

        Args:
            save (bool, optional): Whether to save the plots. Defaults to True.
            display (bool, optional): Whether to display the plots. Defaults to False.

        Returns:
            list: List of paths to saved plot images.
        """
        self._sec_plot_print(
            "The plots may be scaled. Refer to top left corner for scale info"
        )

        sns.set_theme()
        fig_paths = []
        for metric in self._categories:
            record_count = self._plot_data[metric][self._config.df_val].count()
            print(f"Total records for {metric} = {record_count}")
            if record_count < 10:
                print("Skipping Due to Insufficient Data")
            else:
                sns.lineplot(
                    data=self._plot_data[metric],
                    x=self._config.df_x,
                    y=self._config.df_val,
                )
                plt.xlabel(self._config.df_x_label)
                plt.ylabel(self._map_fact_to_label(metric))
                if save:
                    plot_dir = f"{os.getcwd()}/plots/{self._config.stock_ticker}"
                    if not os.path.exists(plot_dir):
                        os.makedirs(plot_dir)
                    fig_save_path = f"{plot_dir}/{str(datetime.now())}.png"
                    plt.savefig(fig_save_path)
                    fig_paths.append(fig_save_path)
                if display:
                    plt.show()
                plt.cla()
                plt.clf()
                plt.close()
        return fig_paths

# Example Usage. Can be used to make sure that the class is functioning as intended
# However, it has been commented out in order to enable the faster initialization of the notebook
# if __name__ == "__main__":
#     sec_plot = SecPlot(config=DEFAULT_SEC_PLOT_CONFIG)
#     sec_plot.fetch_data()
#     sec_plot.plot_data()

### SecDB.py
The following code block contains code related to the `SecDB` class which is responsible for the convenient creation and loading of vector indices related to a given stock. During initialization of an object of the `SecDB` class, parameters are passed to it via the `SecDbConfig` dataclass defined under the `SecConfig` code block above. The names of all the variables are mostly self explanatory along with comments for explanation if required. The code block contains various internal functions to be used within the class along with various publicly available functions which can be used to invoke various functionalities relevant to an object of the class.

In [12]:
"""
SecDB.py contains all the necessary classes and functions required to create
or load a vector database with respect to a given stock ticker
"""
import os
import shutil
from tqdm import tqdm
from langchain.vectorstores import FAISS
from sec_edgar_downloader import Downloader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredHTMLLoader

# imports from the config file is requried only when the script is being
# run in a local environment. it is not required in an interactive notebook
# as we have already loaded the configuration above
# from SecConfig import SecDbConfig
# from SecConfig import DEFAULT_SEC_DB_CONFIG


class SecDB:
    """
    Class representing a SEC Database.

    Attributes:
        config (SecDbConfig): Configuration for the SEC Database.
    """

    def __init__(self, config: SecDbConfig):
        """Initialize the SecDB object with given configuration."""
        self._config = config
        self._vector_db = None
        self._config.stock_ticker = self._config.stock_ticker.upper()

    def __str__(self):
        """Return string representation of the SecDB object."""
        return "SecDB(" + str(self._config) + ")"

    def __repr__(self):
        """Return string representation of the SecDB object."""
        return "SecDB(" + str(self._config) + ")"

    def _sec_db_print(self, string, end="\n"):
        """Helper function to print messages related to SecDB operations."""
        print("\033[92m" + "SecDB > " + str(string) + "\033[0m", end=end)

    def get_vector_db(self):
        """Get the vector database."""
        return self._vector_db

    def _fetch_data(self):
        """
        Fetch data from SEC Edgar website based on configuration.

        This function downloads filings for the specified company and filing type
        within the specified date range.
        """
        sec_downloader = Downloader(
            self._config.secd_company,
            self._config.secd_email,
            self._config.secd_base_data_dir,
        )
        self._sec_db_print(
            "(Downloading) STOCK TICKER: "
            + self._config.stock_ticker
            + "FILING TYPE: "
            + self._config.filing_type
        )
        sec_downloader.get(
            self._config.filing_type,
            self._config.stock_ticker,
            download_details=True,
            after=self._config.start_date,
            before=self._config.end_date,
        )

    def _get_stock_filing_data_files(self):
        """
        Get paths of stock filing data files.

        Returns:
            list: List of file paths for stock filing data files.
        """
        filings_directory = (
            self._config.secd_base_data_dir
            + "/sec-edgar-filings/"
            + self._config.stock_ticker
            + "/"
            + self._config.filing_type
        )

        stock_filing_dirs = [
            os.path.join(filings_directory, f)
            for f in os.listdir(filings_directory)
            if os.path.isdir(os.path.join(filings_directory, f))
        ]

        stock_filing_data_files = []
        for stock_filing_dir in stock_filing_dirs:
            for stock_data_file in os.listdir(stock_filing_dir):
                stock_data_file_path = os.path.join(stock_filing_dir, stock_data_file)

                if os.path.isfile(stock_data_file_path) and (
                    stock_data_file.endswith(".html")
                    or stock_data_file.endswith(".htm")
                ):
                    stock_filing_data_files.append(stock_data_file_path)

        return stock_filing_data_files

    def _load_existing(self, embeddings):
        """
        Load existing FAISS index if available.

        Args:
            embeddings: Embeddings to use for indexing.
        """
        if os.path.exists(f"./indexstore/{self._config.stock_ticker}/faiss_index"):
            self._sec_db_print(
                f"Found existing FAISS Index for {self._config.stock_ticker}"
            )
            self._sec_db_print(
                f"FORCE_RECALCULATE_INDEX set to : {self._config.force_recalculate_index}"
            )

            # if we do not want the vector db indices to be recalculated then we
            # load the existing index file available locally.
            if not self._config.force_recalculate_index:
                self._vector_db = FAISS.load_local(
                    f"indexstore/{self._config.stock_ticker}/faiss_index",
                    embeddings,
                    # If running in Jupyter Notebook uncomment following line
                    allow_dangerous_deserialization = True
                )

    def _create_vector_db(self, embeddings, stock_filing_data_files, text_splitter):
        """
        Create vector database from filings data.

        Args:
            embeddings: Embeddings to use for creating vector database.
            stock_filing_data_files (list): List of paths for stock filing data files.
            text_splitter: Text splitter object for splitting documents.

        """
        if self._config.force_recalculate_index or (self._vector_db is None):
            for data_file in tqdm(stock_filing_data_files):
                self._sec_db_print(
                    f"Creating Embeddings & Vector Indices for {data_file}"
                )
                # Loading the HTML file using the Unstructured HTML Loader by LangChain
                # allows efficient structuring of tables and other HTML structures / elements
                loader = UnstructuredHTMLLoader(data_file)
                documents = loader.load()

                # split the document into small chunks
                docs = text_splitter.split_documents(documents)

                # generating the vector database for the document embeddings
                vector_db = FAISS.from_documents(docs, embeddings)
                if self._vector_db is None:
                    self._vector_db = vector_db
                else:
                    self._vector_db.merge_from(vector_db)

                vector_db.save_local(f"indexstore/{data_file.split('.')[0]}_fass_index")

            self._vector_db.save_local(
                f"indexstore/{self._config.stock_ticker}/faiss_index"
            )

    def _cleanup(self):
        """Clean up the data directory."""
        # cleaning up the data directory since it is no longer required
        if os.path.exists(self._config.secd_base_data_dir):
            shutil.rmtree(self._config.secd_base_data_dir)
        self._sec_db_print(f"Cleanup -> {self._config.secd_base_data_dir}")

    def get_database(self):
        """Get the vector database."""
        return self._vector_db

    def init_database(self):
        """Initialize the database."""
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self._config.model_chunk_size,
            chunk_overlap=self._config.model_chunk_overlap,
        )
        embeddings = HuggingFaceEmbeddings(
            model_name=self._config.embedding_model,
            cache_folder=self._config.embeddings_cache_dir,
        )

        self._load_existing(embeddings)
        if (self._vector_db is not None) and not self._config.force_recalculate_index:
            return self._vector_db

        self._fetch_data()
        stock_filing_data_files = self._get_stock_filing_data_files()

        self._vector_db = None  # FAISS In Memory Loaded (IMP) Vector Database

        self._create_vector_db(embeddings, stock_filing_data_files, text_splitter)
        self._cleanup()

        if self._vector_db is None:
            self._sec_db_print(
                "!! An Error Occurred While Initializing the Vector Database"
            )
        return self._vector_db

# Example Usage. Can be used to make sure that the class is functioning as intended
# However, it has been commented out in order to enable the faster initialization of the notebook
# if __name__ == "__main__":
#     sec_db = SecDB(config=DEFAULT_SEC_DB_CONFIG)
#     db = sec_db.init_database()
#     print(db.similarity_search("Products & Services"))

### SecCLI.py
The following code block contains code related to the `SecCLI` class which is responsible for the creation and management of interactive user sessions on the notebook. During initialization of an object of the `SecCLI` class, parameters are passed to it via the `SecCliConfig` dataclass defined under the `SecConfig` code block above. The names of all the variables are mostly self explanatory along with comments for explanation if required. The code block contains various internal functions to be used within the class along with various publicly available functions which can be used to invoke various functionalities relevant to an object of the class.

In [13]:
"""
SecCLI.py contains all the necessary classes and functions required to create,
and manage interactive sessions related to the 10-K filings of a given stock ticker
"""
from langchain.chains import RetrievalQA
from langchain.prompts.prompt import PromptTemplate

# imports from the config file is requried only when the script is being
# run in a local environment. it is not required in an interactive notebook
# as we have already loaded the configuration above
# from SecDB import SecDB
# from SecPlot import SecPlot
# from config import DEFAULT_SEC_DB_CONFIG
# from config import DEFAULT_SEC_CLI_CONFIG
# from config import DEFAULT_SEC_PLOT_CONFIG
# from config import SecCliConfig, builtin_prompts


class SecCLI:
    """
    Class representing a Command Line Interface for SEC Database operations.

    Attributes:
        config (SecCliConfig): Configuration for the CLI.
        sec_db (SecDB): Instance of the SEC Database.
    """
    def __init__(self, config: SecCliConfig):
        """Initialize the SecCLI object with given configuration."""
        self._config = config
        self._sec_db = None

    def _get_prompt_template(self):
        """Create and return the prompt template."""
        input_template = (
            "<s>"
            + "[INST] Answer the question based only on the following context: [/INST] "
            + "{context}"
            + " </s>"
            + "[INST] Question: {question} [/INST]"
        )

        return PromptTemplate(
            template=input_template, input_variables=["context", "question"]
        )

    def _parse_raw_response(self, raw_response):
        """
        Parse raw response from the model.

        Args:
            raw_response (str): Raw response string.

        Returns:
            str: Parsed response.
        """
        raw_response = raw_response.replace("<s>", "")
        raw_response = raw_response.replace("</s>", "")
        raw_response = raw_response.replace("[INST]", "")
        raw_response = raw_response.replace("[/INST]", "")
        raw_response = raw_response.split("\n")

        final_answer = None
        for idx, val in enumerate(raw_response):
            if "Question:" in val:
                final_answer = raw_response[idx:]

        response = ""
        for idx, _ in enumerate(final_answer):
            if idx == 1:
                response = response + "\n\n[Answer]" + final_answer[idx]
            elif len(final_answer[idx]) > 0:
                response = response + final_answer[idx] + "\n"
        return response

    def _help_menu(self):
        """Generate and return the help menu."""
        return (
            "\n"
            + "\\help : Displays all the available commands.\n"
            + "\\load : Load stock embeddings for Q&A on SEC filings.\n"
            + "\\plots : Plot line graphs for key financial metrics from SEC filings.\n"
            + "\\insights : Gives complete insight + plots (executes built in prompts).\n"
            + "\\chat : Ask the RAG questions regarding loaded stock data.\n"
        )

    def _validate_args(self, command, msg_tokens):
        """Validate the arguments for a given command."""
        command_syntax = "Command Syntax: "
        if command == "load":
            return (len(msg_tokens) == 2, command_syntax + "\\load" + " <STOCK_TICKER>")
        elif command == "plots":
            return (
                len(msg_tokens) == 2,
                command_syntax + "\\plots" + " <STOCK_TICKER>",
            )
        elif command == "chat":
            return (len(msg_tokens) >= 2, command_syntax + "\\chat" + " <PROMPT>")
        elif command == "insights":
            return (len(msg_tokens) == 1, command_syntax + "\\insights")
        elif command == "help":
            return (len(msg_tokens) == 1, command_syntax + "\\help")

    def _sectral_log(self, content):
        """
        Log content with a prefix.

        Args:
            content (str): Content to log.

        Returns:
            str: Logged content.
        """
        return f"""Sectral > {content}"""

    def _chat_handler(self, message):
        """Handle chat messages and generate relevant responses."""
        response = None
        command = message.split()[0]
        msg_tokens = message.split()

        if command == "\\help":
            valid = self._validate_args("help", msg_tokens)
            if valid[0]:
                response = self._sectral_log(self._help_menu())
            else:
                response = valid[1]

        elif command == "\\load":
            valid = self._validate_args("load", msg_tokens)
            if valid[0]:
                try:
                    CUSTOM_DB_CONFIG = DEFAULT_SEC_DB_CONFIG
                    CUSTOM_DB_CONFIG.stock_ticker = msg_tokens[1]
                    self._sec_db = SecDB(config=CUSTOM_DB_CONFIG)
                    _ = self._sec_db.init_database()
                    response = (
                        "Loaded Vector Database & Embeddings for "
                        + CUSTOM_DB_CONFIG.stock_ticker
                    )
                    response = self._sectral_log(response)

                except Exception as e:
                    response = self._sectral_log(f"An Error Occurred: {str(e)}")
            else:
                response = self._sectral_log(valid[1])

        elif command == "\\plots":
            valid = self._validate_args("plots", msg_tokens)
            if valid[0]:
                try:
                    CUSTOM_SEC_PLOT_CONFIG = DEFAULT_SEC_PLOT_CONFIG
                    CUSTOM_SEC_PLOT_CONFIG.stock_ticker = msg_tokens[1]
                    sec_plot = SecPlot(config=CUSTOM_SEC_PLOT_CONFIG)
                    sec_plot.fetch_data()
                    figure_paths = sec_plot.plot_data()

                    response = (
                        f"Plotted a total of {len(figure_paths)} figures on various financial metrics and their "
                        + "progression over the years. Note that all the following figures are w.r.t to the stock ticker "
                        + f"provided : {CUSTOM_SEC_PLOT_CONFIG.stock_ticker}"
                    )

                    for figure_path in figure_paths:
                        response += f"{figure_path}\n"
                    response = self._sectral_log(response)

                except Exception as e:
                    response = self._sectral_log(f"An Error Occurred: {str(e)}")

        elif command == "\\insights":
            valid = self._validate_args("insights", msg_tokens)
            if valid[0]:
                qa_chain = RetrievalQA.from_chain_type(
                    llm=self._config.hf_pipe,
                    retriever=self._sec_db.get_vector_db().as_retriever(
                        earch_kwargs={"k": 2}
                    ),  # top 2 results only, speed things up
                    return_source_documents=True,
                    chain_type_kwargs={"prompt": self._get_prompt_template()},
                )
                raw_response = qa_chain.invoke(builtin_prompts["insights"])
                processed_response = self._parse_raw_response(raw_response["result"])
                response = self._sectral_log(processed_response)
            else:
                response = self._sectral_log(valid[1])

        elif command == "\\chat":
            valid = self._validate_args("chat", msg_tokens)
            if valid[0]:
                print("All Good till here . . ..")
                qa_chain = RetrievalQA.from_chain_type(
                    llm=self._config.hf_pipe,
                    retriever=self._sec_db.get_vector_db().as_retriever(
                        earch_kwargs={"k": 2}
                    ),  # top 2 results only, speed things up
                    return_source_documents=True,
                    chain_type_kwargs={"prompt": self._get_prompt_template()},
                )
                print("and here")
                raw_response = qa_chain.invoke(" ".join(msg_tokens[1:]))
                print("and here 2")
                processed_response = self._parse_raw_response(raw_response["result"])
                response = self._sectral_log(processed_response)
            else:
                response = self._sectral_log(valid[1])

        else:
            response = f"{command} is not a supported command.\n\n"
            response = self._sectral_log(response + self._help_menu())

        return response

    def launch(self):
        """Launch the command line interface."""
        print("📚 Sectral (Mistral-7B_instruct.v0.1) 🚀")
        while True:
            print("\n               \n")
            message = input("Prompt > ")
            try:
                if message == "\\quit":
                    print("Exiting Interactive Session . . .")
                    break
                else:
                    response = self._chat_handler(message)
                    print(response)
            except Exception as error:
                print(f"An Error Occured While Handling the Request : {error}")

In [None]:
# creating and launching the interactive session
if __name__ == "__main__":
    cli = SecCLI(config=DEFAULT_SEC_CLI_CONFIG)
    cli.launch()