In [80]:
import numpy as np
import redis
from sentence_transformers.util import cos_sim

from rag.api.dependencies import get_embedder
from rag.caching.semantic import SemanticCache
from rag.config import settings
from rag.embeddings import create_embedder
from rag.pipeline import RAGPipeline

redis = redis.from_url(settings.redis_url)
redis

<redis.client.Redis(<redis.connection.ConnectionPool(<redis.connection.Connection(host=localhost,port=6379,db=0)>)>)>

In [29]:
redis.set('foo', 'bar')

True

In [30]:
redis.get('foo')

b'bar'

In [31]:
embedder = create_embedder(settings)

In [32]:
embedder.dimension

384

In [33]:
text = 'Heart attack'
emb = embedder.embed_text('Heart attack')
emb.shape

(384,)

In [34]:
import numpy as np

emb.tobytes()

b'\xd5\xdeu\xbb\x98G\xed\xbc\x8f\xfe\x93<\x93\x86-=\x02Y\x01\xb8bC\x88\xbc\xe5pL=ePu<\x05`D\xbb\xdf%\';6\x9b\xed\xbc\x03\x84 \xbdOg\x85\xbbd\x92\x96<yU\x8b\xbc\xf3\xecN\xbc\xf06\xfa<_\xebm<\xc7\xe4\xb1\xbd\x9b\xb0\xc8<\x8c4\xa8\xbc\xa89\xb0<s\x9a\xd3\xbc\xeaER\xbd\n\x13\x04=\x8b\xf6\x05\xbc!\xed\n\xbdo\xe1\xe3\xbc\xac\xad\xa1\xbdt\xa1\x7f\xbd\x05\xf0\xc4<\xf1\xcc\x92\xbc)o\xd4=\xfe?\x95<\x91\x0f6\xbc\x00\xac\x1b\xbdE\xc0-\xbd"\xa8\x7f=\x03/@\xbb\xcch\x1e<!\xfa\xeb<\x8c\x93}<\x16+\x90\xbb\xc3\xfd[\xbd\rq\x8b=X\xc7\x9d;O\xde\n\xbd\x05\x9e\xac\xbb\xc0\x127>\xfa\xda\x12\xbd\x02<\x87\xbd\xfb\xf3\xc3<wh`=\xd36\xb7\xbc\x05,\x88=\xf6\xb3\xc9\xbd=#\t=14\xac\xbc\x90?\x87;\xed\xa33=p?\n=-\xd9\x01=\x89m\x0b\xbe\xcf`L=\x93\x84\x0f\xbd\n\xc4H\xbb\xc2\xd0\xff;J\x00\xae<\x8f\xe6\xa5=*5\xbe<]\xdeu<\xe1\xba\xfe<\xa1_Z=\xe2\x9b&=3h9=\xa6\xb3\xe8<|EY\xbc\xa2\x97\r\xbd\xd9R>=\x84\r\x87\xbc\xf6k\x13<\xf7\x9eS\xbd\x8c?\xad\xbb<\xae\xc6\xbc\xefE\xf7\xbb\x0b\x0b\xc1\xbc\xc1\xa86<\xf3\x98\xc0\xbc\xf9"\x8b\xbcQ\

In [35]:
from random import random, uniform

key = hash(text)

redis.setex(key, value=emb.tobytes(), time=int(settings.redis_ttl * uniform(0.9, 1.1)))

True

In [36]:
np.frombuffer(redis.get(key), dtype=np.float32)

array([-3.75168514e-03, -2.89648026e-02,  1.80657189e-02,  4.23646681e-02,
       -3.08388917e-05, -1.66336931e-02,  4.99123521e-02,  1.49727808e-02,
       -2.99644587e-03,  2.55047507e-03, -2.90046744e-02, -3.91883962e-02,
       -4.07115323e-03,  1.83803514e-02, -1.70085300e-02, -1.26297353e-02,
        3.05437744e-02,  1.45214489e-02, -8.68621394e-02,  2.44982745e-02,
       -2.05328688e-02,  2.15118676e-02, -2.58304831e-02, -5.13362065e-02,
        3.22447196e-02, -8.17645621e-03, -3.39175500e-02, -2.78174561e-02,
       -7.89445341e-02, -6.24098331e-02,  2.40402315e-02, -1.79199893e-02,
        1.03727646e-01,  1.82189904e-02, -1.11121098e-02, -3.80058289e-02,
       -4.24196906e-02,  6.24162033e-02, -2.93248962e-03,  9.66854021e-03,
        2.88057942e-02,  1.54770724e-02, -4.39966749e-03, -5.37088029e-02,
        6.80867210e-02,  4.81502339e-03, -3.39034162e-02, -5.26786083e-03,
        1.78782463e-01, -3.58533636e-02, -6.60324246e-02,  2.39200499e-02,
        5.47871254e-02, -

In [37]:
from rag.caching.embedding import EmbeddingCache

emb_cache = EmbeddingCache(settings)

In [38]:
text, emb

('Heart attack',
 array([-3.75168514e-03, -2.89648026e-02,  1.80657189e-02,  4.23646681e-02,
        -3.08388917e-05, -1.66336931e-02,  4.99123521e-02,  1.49727808e-02,
        -2.99644587e-03,  2.55047507e-03, -2.90046744e-02, -3.91883962e-02,
        -4.07115323e-03,  1.83803514e-02, -1.70085300e-02, -1.26297353e-02,
         3.05437744e-02,  1.45214489e-02, -8.68621394e-02,  2.44982745e-02,
        -2.05328688e-02,  2.15118676e-02, -2.58304831e-02, -5.13362065e-02,
         3.22447196e-02, -8.17645621e-03, -3.39175500e-02, -2.78174561e-02,
        -7.89445341e-02, -6.24098331e-02,  2.40402315e-02, -1.79199893e-02,
         1.03727646e-01,  1.82189904e-02, -1.11121098e-02, -3.80058289e-02,
        -4.24196906e-02,  6.24162033e-02, -2.93248962e-03,  9.66854021e-03,
         2.88057942e-02,  1.54770724e-02, -4.39966749e-03, -5.37088029e-02,
         6.80867210e-02,  4.81502339e-03, -3.39034162e-02, -5.26786083e-03,
         1.78782463e-01, -3.58533636e-02, -6.60324246e-02,  2.39200499e

In [39]:
emb_cache.set(text, emb)

In [40]:
emb_cache.get(text).shape

(384,)

In [41]:
np.isclose(emb_cache.get(text), emb).all()

np.True_

In [42]:
text = """In 1877, Dr Heinrich Koebner inflicted an experimental trauma on the uninvolved
skin of a psoriatic patient. This resulted in the appearance of a typical
psoriatic lesion at the site of trauma. This reaction, known as Koebner's
phenomenon (KP), has subsequently been associated with several skin diseases.
However, it has not been associated previously with necrobiosis lipoidica
diabeticorum (NBL), a rare skin manifestation of diabetes mellitus. This report
presents the unusual finding of NBL associated with KP in a patient with
diabetes mellitus."""

In [43]:
embedder.embed_text(text).shape

(384,)

In [44]:
%timeit embedder.embed_text(text)

11.4 ms ± 85.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Sematic cache

In [54]:
from langchain_core.embeddings import Embeddings
from typing import List
from langchain_redis import RedisSemanticCache
from langchain.embeddings import HuggingFaceEmbeddings

# embedder = HuggingFaceEmbeddings(model_name = settings.embedding_model)

class LangChainEmbedderAdapter(Embeddings):
  """Adapter to make your BaseEmbedder work with LangChain"""

  def __init__(self, base_embedder):
      self.embedder = base_embedder

  def embed_documents(self, texts: List[str]) -> List[List[float]]:
      """Embed multiple documents"""
      embeddings = self.embedder.embed_texts(texts)  # Returns (n, dim) numpy array
      return embeddings.tolist()  # Convert to list of lists

  def embed_query(self, text: str) -> List[float]:
      """Embed a single query"""
      embedding = self.embedder.embed_text(text)  # Returns (dim,) numpy array
      return embedding.tolist()  # Convert to list

sem_cache = RedisSemanticCache(
    redis_url=settings.redis_url,
    embeddings=LangChainEmbedderAdapter(embedder),
    distance_threshold=0.2,
)

Batches: 100%|██████████| 1/1 [00:00<00:00, 165.01it/s]

13:40:47 redisvl.index.index INFO   Index already exists, not overwriting.





In [46]:
from dotenv import load_dotenv

from rag.config import settings, PROJECT_ROOT
from rag.ingestion import create_chunker
from rag.embeddings import create_embedder
from rag.retrieval import create_reranker
from rag.generation import create_llm
from rag.storage import (
    BaseDocumentStore,
    BaseVectorStore,
    Document,
    SearchResult,
    make_chunk_id,
    parse_chunk_id,
    InMemoryDocumentStore,
    FAISSVectorStore, PostgresDocumentStore, PgvectorVectorStore,
)

from datasets import load_dataset, Dataset

load_dotenv()

embedder = create_embedder(settings)
chunker = create_chunker(settings)
reranker = create_reranker(settings)
llm = create_llm(settings)

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 297.37it/s]


In [47]:
doc_store = PostgresDocumentStore(settings)
vec_store = PgvectorVectorStore(settings)

[2025-10-25 13:37:48] [rag.storage.document_stores.postgres] [INFO] PostgresDocumentStore initialized
[2025-10-25 13:37:48] [rag.storage.document_stores.postgres] [INFO] PostgresDocumentStore initialized
[2025-10-25 13:37:48] [rag.storage.vector_stores.pgvector] [INFO] PgvectorVectorStore initialized (cosine distance)
[2025-10-25 13:37:48] [rag.storage.vector_stores.pgvector] [INFO] PgvectorVectorStore initialized (cosine distance)


In [94]:
question =  "What causes heart attacks?"
question =  "What is the reason of myocarditis?"

In [95]:
from rag.prompts import get_user_prompt, SYSTEM_PROMPT
from langchain.schema import Generation

llm_string = settings.llm_model + str(settings.llm_temperature)

# 0. Check semanti cache first

res = sem_cache.lookup(
    prompt=question,
    llm_string=llm_string,
)

if res is not None:
    answer = res[0].text
    print('Got from cache')

# 1. Embed query
q_emb = emb_cache.get(question)

# 1.1 Cache query
if q_emb is None:
    q_emb = embedder.embed_text(question)
    emb_cache.set(question, q_emb)

# 2. Vector search
search_results = vec_store.search(q_emb, settings.k)

# 3. Retrieve full documents
initial_docs = [doc_store.get_chunk(r.chunk_id) for r in search_results]

# 4. Rerank and take top_k
reranked_docs = reranker.rerank(question, initial_docs)[:settings.top_k]

# 5. Format prompt
prompt = get_user_prompt(question, reranked_docs)

# 6. Generate answer
answer = llm.generate(
    prompt,
    system_prompt=SYSTEM_PROMPT,
    max_tokens=settings.llm_max_tokens,
    temperature=settings.llm_temperature,
)

print(answer)

Batches: 100%|██████████| 1/1 [00:00<00:00, 146.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 150.49it/s]
Batches: 100%|██████████| 4/4 [00:00<00:00,  6.73it/s]


Myocarditis can be caused by various factors, including viral infections such as enteroviruses, adenoviruses, and herpes viruses [idx=18039618]. Additionally, it can be associated with inflammatory bowel disease (IBD) [idx=12603508], systemic lupus erythematosus (SLE) [idx=24268009], and autoimmune conditions like Behçet disease [idx=26740268].

- Enteroviruses, adenoviruses, and herpes viruses are frequent causes of myocarditis and dilated cardiomyopathy [idx=18039618].
- Myopericarditis is a rare extraintestinal complication of IBD, and recurrent episodes can be a sign of IBD [idx=12603508].
- SLE can affect the heart, leading to myocarditis among other cardiac manifestations [idx=24268009].
- Behçet disease can also lead to myocarditis, often alongside other systemic symptoms [idx=26740268].

Citations: [idx=18039618], [idx=12603508], [idx=24268009], [idx=26740268]


In [55]:
from langchain.schema import Generation
llm_string = settings.llm_model + str(settings.llm_temperature)

sem_cache.update(prompt=question, llm_string=llm_string, return_val=[Generation(text=answer)])

Batches: 100%|██████████| 1/1 [00:00<00:00, 147.79it/s]


In [56]:
question

'What causes heart attacks?'

In [62]:
res = sem_cache.lookup(prompt="What is the reason of heart attack?", llm_string=llm_string)
res

Batches: 100%|██████████| 1/1 [00:00<00:00, 153.23it/s]


[Generation(text="Heart attacks, also known as myocardial infarctions, are primarily caused by the acute loss of a large number of myocardial cells due to blockages in the coronary arteries. This blockage leads to a lack of blood flow to the heart muscle, resulting in cell death. [idx=18620057#0]\n\n- The heart's limited regenerative capacity means that when cardiomyocytes die, it triggers a reparative response that often results in scar tissue formation and ventricular dilation. [idx=18620057#0]\n\nCitations: [idx=18620057#0]")]

In [88]:
semantic_cache = SemanticCache(settings)

14:04:56 sentence_transformers.SentenceTransformer INFO   Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5


Batches: 100%|██████████| 1/1 [00:00<00:00, 161.03it/s]

14:04:58 redisvl.index.index INFO   Index already exists, not overwriting.





In [89]:
semantic_cache.get('Why heart attack happens?')

Batches: 100%|██████████| 1/1 [00:00<00:00, 160.08it/s]


"Heart attacks, also known as myocardial infarctions, are primarily caused by the acute loss of a large number of myocardial cells due to blockages in the coronary arteries. This blockage leads to a lack of blood flow to the heart muscle, resulting in cell death. [idx=18620057#0]\n\n- The heart's limited regenerative capacity means that when cardiomyocytes die, it triggers a reparative response that often results in scar tissue formation and ventricular dilation. [idx=18620057#0]\n\nCitations: [idx=18620057#0]"

In [90]:
semantic_cache.get('What is the primary reason of heart attack?')

Batches: 100%|██████████| 1/1 [00:00<00:00, 142.58it/s]


"Heart attacks, also known as myocardial infarctions, are primarily caused by the acute loss of a large number of myocardial cells due to blockages in the coronary arteries. This blockage leads to a lack of blood flow to the heart muscle, resulting in cell death. [idx=18620057#0]\n\n- The heart's limited regenerative capacity means that when cardiomyocytes die, it triggers a reparative response that often results in scar tissue formation and ventricular dilation. [idx=18620057#0]\n\nCitations: [idx=18620057#0]"

In [91]:
semantic_cache.get('What is the secondary reason of heart attack?')

Batches: 100%|██████████| 1/1 [00:00<00:00, 144.40it/s]


"Heart attacks, also known as myocardial infarctions, are primarily caused by the acute loss of a large number of myocardial cells due to blockages in the coronary arteries. This blockage leads to a lack of blood flow to the heart muscle, resulting in cell death. [idx=18620057#0]\n\n- The heart's limited regenerative capacity means that when cardiomyocytes die, it triggers a reparative response that often results in scar tissue formation and ventricular dilation. [idx=18620057#0]\n\nCitations: [idx=18620057#0]"

In [92]:
semantic_cache.get('What is the reason of myocarditis?')

Batches: 100%|██████████| 1/1 [00:00<00:00, 159.33it/s]


In [85]:
def cosine_similarity(a: str, b: str) -> float:
    emb1 = embedder.embed_text(a)
    emb2 = embedder.embed_text(b)
    return cos_sim(emb1, emb2).flatten()

cosine_similarity('What is the reason of myocarditis?', "What causes heart attacks?")

Batches: 100%|██████████| 1/1 [00:00<00:00, 150.85it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 184.71it/s]


tensor([0.7841])

In [86]:
cosine_similarity('What is the reason of myocarditis?', "What causes heart failures?")

Batches: 100%|██████████| 1/1 [00:00<00:00, 166.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 209.53it/s]


tensor([0.7626])

In [87]:
cosine_similarity("What causes heart attacks?", "What causes heart failures?")

Batches: 100%|██████████| 1/1 [00:00<00:00, 172.31it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 198.07it/s]


tensor([0.8505])