In [2]:
VERSION = 1

In [3]:
from pathlib import Path
import json
import requests
import redis
from redis.commands.search.field import TextField, VectorField
from redis.exceptions import ResponseError
import numpy as np
from redis.commands.search.query import Query
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from textwrap import dedent

In [4]:
with open(Path().absolute() / '..' / 'dev-config.json', 'r') as f:
    CONFIG = json.load(f)

In [5]:
class Embeddings:
    def __init__(self, index_name, *, config):
        self._config = config
        self._index_name = index_name
        
        self._open_ai_emb = OpenAIEmbeddings(openai_api_key=self._config['openai_api_key'])
        self._redis = redis.Redis(host='localhost', port=6379, db=0)
        
        self._init_redis()
        
    def _init_redis(self):
        try:
            self._redis.ft(self._index_name).dropindex()
        except ResponseError:
            pass
        
        emb_field = VectorField(
            name='embedding',
            algorithm='HNSW',
            attributes=dict(
                type='FLOAT64',
                dim=1536,
                distance_metric='COSINE',
            ),
        )
        version_field = TextField('version')
        
        self._redis.ft(self._index_name).create_index([
            emb_field,
            version_field,
        ])


    def _get_embedding(self, text: str) -> np.array:
        embedding = self._open_ai_emb.embed_query(text)
        
        return np.array(embedding)
    
    def _save_embedding(self, text: str, embedding: np.array) -> bytes:
        self._redis.hset(f"text:{text}", mapping = dict(
            embedding=embedding.tobytes(),
            text=text,
            version=VERSION,
        ))
                
    def add(self, text: str) -> np.array:
        found = self._redis.hget(f'text:{text}', 'version')
        if found and found == VERSION:
            return found
        
        embedding = self._get_embedding(text)
        self._save_embedding(text, embedding)

        return embedding.tobytes()

    def knn(self, text: str, *, k: int = 100) -> list:
        q = (
            Query(f"(@version:{VERSION})=>[KNN {k} @embedding $e]")
            .return_field('text')
            .return_field('__embedding_score')
            .dialect(2)
        )
        result = self._redis.ft(self._index_name).search(q, query_params={"e": self._get_embedding(text).tobytes()})
        
        return result
    
    def close(self):
        self._s.close()

In [7]:
e = Embeddings('examplegpt', config=CONFIG)
e.add('Пользователя зовут Вадим')
e.add('Жену пользователя зовут Катя')
e.add('У пользователя и у жены одинаковые фамилии')
e.add('Фамилия пользователя — Пуштаев')
e.add('Пользователь женат')

b'\x00\x00\x00 \x14/\xa4\xbf\x00\x00\x00\xe0\xc1\x0en\xbf\x00\x00\x00 \xf2\xbb\x86\xbf\x00\x00\x00\xe0\xaa\xb3\x91\xbf\x00\x00\x00`Z1\xa2\xbf\x00\x00\x00\x80b\xd4V\xbf\x00\x00\x00\x00\x99\xcf\x91\xbf\x00\x00\x00\xa0\xb8\xdd\x97\xbf\x00\x00\x00 \x17\x12X\xbf\x00\x00\x00\xe0\x12\xa1\xa2\xbf\x00\x00\x00`\xe7tt?\x00\x00\x00@=\xf6j?\x00\x00\x00\xa0\xdd3\x89?\x00\x00\x00 \xa4Uj\xbf\x00\x00\x00\xe0\xa5\xaa\\?\x00\x00\x00\xe0\xb3\xa5q?\x00\x00\x00`z~~?\x00\x00\x00\xc0\x16\xd3A\xbf\x00\x00\x00\xe0\xebm\x94?\x00\x00\x00`U(\xad\xbf\x00\x00\x00\x80\xd0\xd0c\xbf\x00\x00\x00`\x9e\xce\x88\xbf\x00\x00\x00\x00\xd7\xa6\x90\xbf\x00\x00\x00@\xbb\x81\x85\xbf\x00\x00\x00`o\xf8\x85\xbf\x00\x00\x00\x80;\xa1\x88?\x00\x00\x00@-g}?\x00\x00\x00\xa0(\xb7\xa1\xbf\x00\x00\x00@\x81\xdcu?\x00\x00\x00\x80\xd3\xb3w\xbf\x00\x00\x00\x00!\n\x9f?\x00\x00\x00\x00\xf1-\x95\xbf\x00\x00\x00\xe0\x07\xd2\x95\xbf\x00\x00\x00@\xeb\xa6\x93\xbf\x00\x00\x00\xe0\xbbH\xa6\xbf\x00\x00\x00\x00!\n\x8f\xbf\x00\x00\x00\xa0\xd7m\x81?\x00\x00\

In [99]:
class Chat:
    MORE = 'MORE_INFO_NEEDED'
    
    def __init__(self, *, config, embeddings):
        self._config = config
        self._embeddings = embeddings
        
        self._chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=config['openai_api_key'], max_tokens=1200, model_kwargs={'temperature': 0.3})
        self._system_prompt = dedent(
            f"""
            I am a helpful assistant with access to memory.
            Whenever I cannot answer a question, I can access my memory by doing the folowing:
                * I write `{self.MORE}`: QUERY, where QUERY is what I feel missing from my knowledge.
            I never use the same QUERY twice.
            Never use the same query twice, always come up with something new.
            Remeber that we can do it more than once to find an answer.
            """.strip() + '\n'
        )
        self.reset()
        
    def _ask_chat(self, lst: list) -> str:
        print('---')
        for x in lst:
            print(f'[{type(x).__name__}]')
            print(x.content)
        print('---')
        print('')
        
        return self._chat(lst)
        
    def __call__(self, text) -> str:
        return self.say(text)
    
    def _mores_message(self, mores: list) -> AIMessage:
        if mores:
            return AIMessage(content='There are the things I ask previously:\n{}'.format('\n'.join(
                f'  * {self.MORE}: {m}' for m in mores
            )))
        else:
            return None
    
    def _facts_message(self, facts: list) -> AIMessage:
        if facts:
            return AIMessage(content='This is what I remembered:\n{}'.format('\n'.join(
                f'  * {f}' for f in facts
            )))
        else:
            return None
    
    def reset(self) -> None:
        self._history = []
    
    def say(self, text, facts: list = None, mores: list = None, depth=0) -> str:
        if facts is None:
            facts = []
        if mores is None:
            mores = []
        
        to_send = [
            AIMessage(content=self._system_prompt),
            HumanMessage(content=text),
        ]
        for m in (self._mores_message(mores), self._facts_message(facts)):
            if m is not None:
                to_send.append(m)
        response = self._ask_chat(to_send)
        response_text = response.content
        
        if response_text.startswith(self.MORE) and depth < 2:
            more = response_text[len(self.MORE) + 2:]
            print(f'[*] New MORE! {more}')
            if more not in mores:
                mores.append(more)
                new_facts = False
                for e in self._embeddings.knn(more, k=2).docs:
                    if e.text not in facts:
                        facts.append(e.text)
                        new_facts = True
                        print(f'[*] New fact! {e.text}')
                    if new_facts:
                        return self.say(text, depth=depth+1, facts=facts, mores=mores)
                    else:
                        print('[*] No new facts')
            else:
                print('[*] Duplicate request')

        return response


In [100]:
chat = Chat(config=CONFIG, embeddings=e)

In [101]:
chat.reset(); chat('Как зовут жену Вадима Пуштаева?')

---
[AIMessage]
I am a helpful assistant with access to memory.
            Whenever I cannot answer a question, I can access my memory by doing the folowing:
                * I write `MORE_INFO_NEEDED`: QUERY, where QUERY is what I feel missing from my knowledge.
            I never use the same QUERY twice.
            Never use the same query twice, always come up with something new.
            Remeber that we can do it more than once to find an answer.

[HumanMessage]
Как зовут жену Вадима Пуштаева?
---

[*] New MORE! Где можно найти информацию о жене Вадима Пуштаева? (например, в какой сфере он работает или где они живут)
[*] New fact! Пользователя зовут Вадим
---
[AIMessage]
I am a helpful assistant with access to memory.
            Whenever I cannot answer a question, I can access my memory by doing the folowing:
                * I write `MORE_INFO_NEEDED`: QUERY, where QUERY is what I feel missing from my knowledge.
            I never use the same QUERY twice.
            

AIMessage(content="I apologize for the confusion in my previous response. I do not have any information about Vadim Pushtaev's wife as it is not publicly available. Is there anything else I can help you with?", additional_kwargs={}, example=False)