In [None]:
import os
import base64
import langchain
# from langchain.document_loaders.csv_loader import CSVLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.schema import AttributeInfo

In [None]:
langchain.debug = True

In [None]:
with open("../data/secrets.json") as secrets:
    secrets_dict = eval(secrets.read())
    open_api_key = base64.b64decode(secrets_dict["openai_api_key"]).decode('ascii')
    os.environ["OPENAI_API_KEY"] = open_api_key
    if "organization_id" in secrets_dict.keys():
        openai_organization = base64.b64decode(secrets_dict["organization_id"]).decode('ascii')
        os.environ["OPENAI_ORGANIZATION"] = openai_organization

In [None]:
import csv
from typing import Dict, List, Optional

from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader


class CSVLoader(BaseLoader):
    """Loads a CSV file into a list of documents.

    Each document represents one row of the CSV file. Every row is converted into a
    key/value pair and outputted to a new line in the document's page_content.

    The source for each document loaded from csv is set to the value of the
    `file_path` argument for all documents by default.
    You can override this by setting the `source_column` argument to the
    name of a column in the CSV file.
    The source of each document will then be set to the value of the column
    with the name specified in `source_column`.

    Output Example:
        .. code-block:: txt

            column1: value1
            column2: value2
            column3: value3
    """

    def __init__(
        self,
        file_path: str,
        source_column: Optional[str] = None,
        metadata_columns_dtypes: Optional[Dict[str, str]] = None,
        csv_args: Optional[Dict] = None,
        encoding: Optional[str] = None,
    ):
        """

        Args:
            file_path: The path to the CSV file.
            source_column: The name of the column in the CSV file to use as the source.
              Optional. Defaults to None.
            metadata_columns_dtypes: Name of column as keys and data type as values.
              Optional. Defaults to None.
            csv_args: A dictionary of arguments to pass to the csv.DictReader.
              Optional. Defaults to None.
            encoding: The encoding of the CSV file. Optional. Defaults to None.
        """
        self.file_path = file_path
        self.source_column = source_column
        self.metadata_columns_dtypes = metadata_columns_dtypes
        self.encoding = encoding
        self.csv_args = csv_args or {}

    def load(self) -> List[Document]:
        """Load data into document objects."""

        docs = []
        with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
            csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
            for i, row in enumerate(csv_reader):
                content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
                try:
                    source = (
                        row[self.source_column]
                        if self.source_column is not None
                        else self.file_path
                    )
                except KeyError:
                    raise ValueError(
                        f"Source column '{self.source_column}' not found in CSV file."
                    ) 
                
                # metadata = {"source": source, "row": i}
                metadata = {}
                if self.metadata_columns_dtypes:
                    for k, v in row.items():
                        if k in self.metadata_columns_dtypes.keys():
                            if self.metadata_columns_dtypes[k] in ["int", "float"]:
                                v = eval(v)
                            metadata.update({k: v})
                
                doc = Document(page_content=content, metadata=metadata)
                docs.append(doc)

        return docs

In [None]:
# Place your cv in data folder and load here
loader = CSVLoader('../data/movies_title_overview_vote.csv', metadata_columns_dtypes={"vote_average": "float"})

In [None]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 5000, chunk_overlap = 0)

In [None]:
# Creating the cv database for document indexing
index_creator = VectorstoreIndexCreator(text_splitter=text_splitter)
docsearch = index_creator.from_loaders([loader])

# Self query retriever

## Definition

In [None]:

llm_model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", max_tokens=1000)
# llm_model = ChatOpenAI(temperature=0, model="gpt-4", max_tokens=1000)

In [None]:
metadata_field_info = [
    AttributeInfo(
        name="vote_average",
        description="The average score given to the movie.",
        type="float",
    )
]
document_content_description = "List of movies with an overview and scoring from a public website."

# Explore use_original_query arg and figure if prompt template can be changed.
sq_retriever = SelfQueryRetriever.from_llm(
    llm=llm_model,
    vectorstore=docsearch.vectorstore,
    document_contents=document_content_description,
    metadata_field_info=metadata_field_info,
    enable_limit=False,
    verbose=True,
    search_kwargs={"k": 20}
)

## Explore internal prompts and processes

In [None]:
print(sq_retriever.llm_chain.prompt.examples)

In [None]:
sq_retriever.llm_chain.prompt.input_variables

In [None]:
print(sq_retriever.llm_chain.prompt.format_prompt(query="I want to watch a movie about outer space exploration."))

In [None]:
prompt_eip, stop_eip = sq_retriever.llm_chain.prep_prompts(
    [{"query": "I want to watch a movie about outer space exploration."}])

In [None]:
prompt_eip

In [None]:
resp_eip = sq_retriever.llm_chain.llm.generate_prompt(prompt_eip, stop_eip, callbacks=None)

In [None]:
sq_retriever.llm_chain.create_outputs(resp_eip)[0]

In [None]:
sq_retriever.get_relevant_documents("I want to watch a movie about outer space exploration.")

# Retrieval QA Chain

## Definition

In [None]:
template = """Use the following movies data to find the best matches for the user request in the question overview topic. Rules:
- You can return more than one movie if they are a good match. 
- Answer with movie names of the medias and a short text justifying the choice.
- The justification must take into account overview topic matching and vote average score (higher=better).

Media data:
{context}

Question overview topic: 
{question}

Example answer:
1. Title: [selected movie title]
- Justification: [Given justification]
- Score: [vote_average]

Movies attending to rules and ordered from best to worst:
"""

alt_retrieval_prompt = PromptTemplate(template=template, input_variables=["context", "question"])

In [None]:
agent_llm_model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", max_tokens=1000)
# agent_llm_model = ChatOpenAI(temperature=0, model="gpt-4", max_tokens=1000)

In [None]:
growwer_media = RetrievalQA.from_chain_type(llm=agent_llm_model, chain_type="stuff", retriever=sq_retriever, chain_type_kwargs={"prompt": alt_retrieval_prompt})

## Testing

In [None]:
# Simple question
question = "A movie about outer space exploration with rating over 5"
response = growwer_media.run(question)
print(response)

In [None]:
print(response)