In [None]:
VERSION = 1

In [None]:
from pathlib import Path
import json
import requests
import redis
from redis.commands.search.field import TextField, VectorField
from redis.exceptions import ResponseError
import numpy as np
from redis.commands.search.query import Query
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from textwrap import dedent

In [None]:
with open(Path().absolute() / '..' / 'dev-config.json', 'r') as f:
    CONFIG = json.load(f)

In [None]:
class Embeddings:
    def __init__(self, index_name, *, config):
        self._config = config
        self._index_name = index_name
        
        self._open_ai_emb = OpenAIEmbeddings(openai_api_key=self._config['openai_api_key'])
        self._redis = redis.Redis(host='localhost', port=6379, db=0)
        
        self._init_redis()
        
    def _init_redis(self):
        try:
            self._redis.ft(self._index_name).dropindex()
        except ResponseError:
            pass
        
        emb_field = VectorField(
            name='embedding',
            algorithm='HNSW',
            attributes=dict(
                type='FLOAT64',
                dim=1536,
                distance_metric='COSINE',
            ),
        )
        version_field = TextField('version')
        
        self._redis.ft(self._index_name).create_index([
            emb_field,
            version_field,
        ])


    def _get_embedding(self, text: str) -> np.array:
        embedding = self._open_ai_emb.embed_query(text)
        
        return np.array(embedding)
    
    def _save_embedding(self, text: str, embedding: np.array) -> bytes:
        self._redis.hset(f"text:{text}", mapping = dict(
            embedding=embedding.tobytes(),
            text=text,
            version=VERSION,
        ))
                
    def add(self, text: str) -> np.array:
        found = self._redis.hget(f'text:{text}', 'version')
        if found and found == VERSION:
            return found
        
        embedding = self._get_embedding(text)
        self._save_embedding(text, embedding)

        return embedding.tobytes()

    def knn(self, text: str, *, k: int = 100) -> list:
        q = (
            Query(f"(@version:{VERSION})=>[KNN {k} @embedding $e]")
            .return_field('text')
            .return_field('__embedding_score')
            .dialect(2)
        )
        result = self._redis.ft(self._index_name).search(q, query_params={"e": self._get_embedding(text).tobytes()})
        
        return result
    
    def close(self):
        self._s.close()

In [None]:
e = Embeddings('examplegpt', config=CONFIG)
e.add('Пользователя зовут Вадим')
e.add('Жену пользователя зовут Катя')
e.add('У пользователя и у жены одинаковые фамилии')
e.add('Фамилия пользователя — Пуштаев')
e.add('Пользователь женат')

In [None]:
class Chat:
    MORE = 'MORE_INFO_NEEDED'
    
    def __init__(self, *, config, embeddings):
        self._config = config
        self._embeddings = embeddings
        
        self._chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=config['openai_api_key'], max_tokens=1200, model_kwargs={'temperature': 0.3})
        self._prompt = dedent(
            f"""
            You are a helpful assistant.
            If you need more details or context ask me using the following format: `{self.MORE}: QUERY`, where QUERY is a query to knowledge DB based on embeddings.
            I'll execute the query and give you some facts to work with.
            Never use `{self.MORE}` if it's not a query.
            Never use the same query twice, always come up with something new. Remeber that we can do it more than once to find an answer.
            ---
            """.strip() + '\n'
        )
        self._history = []
        
    def __call__(self, text) -> str:
        return self.say(text)
    
    def say(self, text, depth=0) -> str:
        self._history.append(HumanMessage(content=self._prompt + text))
        response = self._chat(self._history)
        self._history.append(response)
        response_text = response.content
        
        if response_text.startswith(self.MORE) and depth < 3:
            request = response_text[len(self.MORE) + 2:]
            facts = []
            for e in self._embeddings.knn(request, k=2).docs:
                self._history.append(HumanMessage(content=f'Please consider the following to be a fact: {e.text}'))
            return self.say(text, depth=depth+1)

        self._print_dialog()
        
        return response
    
    def _print_dialog(self):
        print('---')
        for x in self._history:
            print(type(x))
            print(x.content)
            print('')
        print('---')
        print('')

In [None]:
chat = Chat(config=CONFIG, embeddings=e)

In [None]:
chat('Как зовут жену Вадима Пуштаева?')