In [10]:
import os
from typing import List
from decouple import config, AutoConfig
from mistralai import Mistral
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_core.embeddings import Embeddings
from langchain_community.vectorstores import UpstashVectorStore
config = AutoConfig(search_path="/home/harry/Chatbot")

In [11]:
MISTRAL_API_KEY = config("MISTRAL_API_KEY")
HF_TOKEN = config("HF_TOKEN")
os.environ["HF_TOKEN"] = HF_TOKEN
UPSTASH_VECTOR_REST_URL = config("UPSTASH_VECTOR_REST_URL")
UPSTASH_VECTOR_REST_TOKEN = config("UPSTASH_VECTOR_REST_TOKEN")

In [15]:
class MistralEmbeddings(Embeddings):
    def __init__(self, model: str = "mistral-embed"):
        self.model = model
        self.client = Mistral(api_key=MISTRAL_API_KEY)
        
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        cleaned_texts = [t.replace("\n", " ") for t in texts]
        response = self.client.embeddings.create(
            model=self.model,
            inputs=cleaned_texts
        )
        return [e.embedding for e in response.data]
    
    def embed_query(self, text: str) -> List[float]:
        text = text.replace("\n", " ")
        response = self.client.embeddings.create(
            model=self.model,
            inputs=[text]
        )
        return response.data[0].embedding

In [17]:
embeddings = MistralEmbeddings()
store = UpstashVectorStore(
    embedding=embeddings,
    index_url=UPSTASH_VECTOR_REST_URL,
    index_token=UPSTASH_VECTOR_REST_TOKEN
)

In [18]:
retriever = store.as_retriever(
    search_type='similarity',
    search_kwargs={'k': 2}
)

In [19]:
from langchain_mistralai.chat_models import ChatMistralAI
from mistralai import Mistral

In [20]:
LLM_CONFIG = {
    "api_key" : MISTRAL_API_KEY, 
}

In [21]:
model = ChatMistralAI(**LLM_CONFIG)

In [22]:
from langchain_core.prompts import ChatPromptTemplate

message = """
Answer this question using the provided context only.

{question}

Context:
{context}
"""

prompt = ChatPromptTemplate.from_messages([("human", message)])

In [23]:
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

runnable = RunnableParallel(
    passed=RunnablePassthrough(),
    modified=lambda x: x["num"] * 123121,
)

runnable.invoke({"num": 31})

{'passed': {'num': 31}, 'modified': 3816751}

In [24]:
from langchain_core.output_parsers import StrOutputParser
parser = StrOutputParser()

In [25]:
chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | model | parser

In [33]:
chain.invoke("what is shiraz famous for?")

'Shiraz is famous for being one of the top tourist cities in Iran, and it is also known as the city of poets and literature.'