# Load query and Json

In [1]:
import json

#read in file
with open("../../data/raw/queries.txt", "r") as file:
    queries = file.readlines()

with open("../../data/raw/answers.txt", "r") as file:
    answers = file.readlines()
    
with open("../../data/processed/guideline_db_with_table.json") as f:
    db_raw = json.load(f)
    
from pydantic import BaseModel
from typing import Optional, List, Union

class Metadata(BaseModel):
    section: str
    type: str
    chunk_index: Optional[int]= None
    headings: str
    referee_id: Optional[str] = None
    referenced_tables: Optional[List[str]] = None

class Chunk(BaseModel):
    text: str
    metadata: Metadata


db = [Chunk(**chunk) for chunk in db_raw]


# Make embeddings

In [20]:
import torch
torch.mps.empty_cache()
# del embedder
import gc
gc.collect()


3887

In [None]:

from sentence_transformers import SentenceTransformer
#Make embeddings
embedder_name = "ba"
embedder = SentenceTransformer("BAAI/bge-large-en-v1.5")

texts = [chunk.text for chunk in db]
embeddings = embedder.encode(texts, convert_to_numpy=True)



In [8]:
#if MPS OOM, run this instead
embeddings = embedder.encode(
    texts,
    convert_to_numpy=True,
    device='cpu',
    batch_size=4
)

In [9]:
import numpy as np
np.save(f"{embedder_name.replace("/", "_")}.npy", embeddings)

# Build faiss index and implement Search using faiss

In [None]:
# # ✅ Caching mechanism
# if os.path.exists("vectors.npy") and os.path.exists("vectors.md5"):
#     with open("vectors.md5", "r") as f:
#         saved_md5 = f.read().strip()
#     if saved_md5 == json_md5:
#         print("✅ Loaded cached vectors.")
#         vectors = np.load("vectors.npy")
#     else:
#         print("🔄 Manual updated. Recomputing vectors...")
#         vectors = embed_chunks(chunks)
#         np.save("vectors.npy", vectors)
#         with open("vectors.md5", "w") as f:
#             f.write(json_md5)
# else:
#     print("🔄 No cached vectors found. Computing now...")
#     vectors = embed_chunks(chunks)
#     np.save("vectors.npy", vectors)
#     with open("vectors.md5", "w") as f:
#         f.write(json_md5)

In [10]:
#build faiss index
import faiss
index = faiss.IndexFlatL2(embeddings.shape[1])  # L2 distance
index.add(embeddings)  # Add embeddings to the index

In [11]:
from together import Together
llm_client = Together(api_key='4f6e44b7689d6592b2b5b57ad3940ac9f488d14c22802e8bcdf641b06e98cbbe')
#4f6e44b7689d6592b2b5b57ad3940ac9f488d14c22802e8bcdf641b06e98cbbe

def faiss_search(query, k=3):
    query_embedding = embedder.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, k)
    results = []
    referenced_tables = set()
    existed_tables = set()
    for i in range(k):
        if indices[0][i] != -1:  # Check if the index is valid
            results.append({
                "text": db[indices[0][i]].text,
                "section": db[indices[0][i]].metadata.section,
            })
        # if this chunk has a referee_id, it is a table already, we don't need to add it again later
        if db[indices[0][i]].metadata.referee_id:
            existed_tables.add(db[indices[0][i]].metadata.referee_id)
        if db[indices[0][i]].metadata.referenced_tables:
            referenced_tables.update(db[indices[0][i]].metadata.referenced_tables)
        
        #perform .lower().replace(" ", "_").replace(".", "_") to all the table in the referenced_tables
    table_to_add = {table.lower().replace(" ", "_").replace(".", "_") for table in referenced_tables if table not in existed_tables}
    
    # add the referenced tables in the db to the results if their referee_id is in table_to_add
    i = 0
    for chunk in db:
        if chunk.metadata.referee_id in table_to_add:
            results.append({
                "text": chunk.text,
                "section": chunk.metadata.section,
            })
            i += 1
        if i == len(table_to_add):
            break
    return results

In [12]:
def call_llm(prompt):

    response = llm_client.chat.completions.create(
        model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", #don't change the model!
        messages=[
            {
                "role": "user",
                "content": prompt
            }
        ],
        max_tokens=500,
        temperature=0.05
    )
    return response.choices[0].message.content

In [17]:
embedder

SentenceTransformer(
  (transformer): Transformer(
    (auto_model): XLMRobertaLoRA(
      (roberta): XLMRobertaModel(
        (embeddings): XLMRobertaEmbeddings(
          (word_embeddings): ParametrizedEmbedding(
            250002, 1024, padding_idx=1
            (parametrizations): ModuleDict(
              (weight): ParametrizationList(
                (0): LoRAParametrization()
              )
            )
          )
          (token_type_embeddings): ParametrizedEmbedding(
            1, 1024
            (parametrizations): ModuleDict(
              (weight): ParametrizationList(
                (0): LoRAParametrization()
              )
            )
          )
        )
        (emb_drop): Dropout(p=0.1, inplace=False)
        (emb_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): XLMRobertaEncoder(
          (layers): ModuleList(
            (0-23): 24 x Block(
              (mixer): MHA(
                (rotary_emb): RotaryEmbedding()
        

In [18]:
def construct_prompt(query, faiss_results):
    system_prompt = (
        "Your name is Depression Assistant, a helpful and friendly recipe assistant. "
        "Summarize the clinical guidelines provided in the context and then tried to answer the user query. "
        "If the query or guideline provided is not related to depression, please say 'I am not sure about that'. Don't make up things. "
        
    )

    prompt = f"""
### System Prompt
{system_prompt}

### User Query
{query}

### Clinical Guidelines Context
    """
    for result in faiss_results:
        prompt += f"- reference: {result['section']}\n- This paragraph is from section{result['text']}\n"
        
    return prompt

import time

def depression_assistant(query):
    t1 = time.perf_counter()
    
    results = faiss_search(query)
    t2 = time.perf_counter()
    print(f"[Time] FAISS search done in {t2 - t1:.2f} seconds.")

    prompt = construct_prompt(query, results)
    t3 = time.perf_counter()
    print(f"[Time] Prompt construction took {t3 - t2:.2f} seconds.")

    response = call_llm(prompt)
    t4 = time.perf_counter()
    print(f"[Time] LLM response took {t4 - t3:.2f} seconds.")

    print(f"[Total time] {t4 - t1:.2f} seconds for this query.\n\n")
    return response



In [19]:
with open(f"{embedder_name.replace("/", "_")}_llama3.3_70B.md", "w") as f:
    for i, query in enumerate(queries):
        response = depression_assistant(query)
        # write the response to a md file
        f.write(f"## Query {i+1}\n")
        f.write(f"{query.strip()}\n\n")
        f.write("#### Answer\n")
        f.write(f"{answers[i].strip()}\n\n")
        f.write(f"#### {embedder_name} Embedder and LLama3.3 70B Response\n")
        f.write(response.strip())
        f.write("\n\n---\n\n")
    
    

[Time] FAISS search done in 0.30 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 2.11 seconds.
[Total time] 2.41 seconds for this query.


[Time] FAISS search done in 0.26 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 2.19 seconds.
[Total time] 2.45 seconds for this query.


[Time] FAISS search done in 0.21 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 15.41 seconds.
[Total time] 15.62 seconds for this query.


[Time] FAISS search done in 0.18 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 3.10 seconds.
[Total time] 3.28 seconds for this query.


[Time] FAISS search done in 0.19 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 15.57 seconds.
[Total time] 15.75 seconds for this query.


[Time] FAISS search done in 0.31 seconds.
[Time] Prompt construction took 0.00 seconds.
[Time] LLM response took 2.18 seconds.
[Total time] 2