In [5]:
# Config
DB_CONN = "dbname=appdb user=appuser password=secret port=5432 host=rag-data"
EMB_MODEL_PATH = "/wrk/models/embedding_models/models--ai-forever--ruBert-large/snapshots/efdc76b4678bc5c9a51642a4a5364371a89cea96"
LLM_MODEL_PATH = "/wrk/models/llms/models--RefalMachine--RuadaptQwen2.5-7B-Lite-Beta-GGUF/snapshots/68ae9dff37a839f3441b9383519cffc4f7d829dd/FP16.gguf"
TOP_K = 7

In [4]:
# Models uploading
from transformers import AutoTokenizer, AutoModel
import torch

emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_PATH)
emb_model = AutoModel.from_pretrained(EMB_MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb_model = emb_model.to(device)
emb_model.eval()

# llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
# llm_model = AutoModel.from_pretrained(
#     LLM_MODEL_PATH,
#     dtype=torch.float16,
#     device_map="cuda"
# )
# llm_model.eval()

  from .autonotebook import tqdm as notebook_tqdm


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(120138, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1,

In [16]:
# Creating embeddings function
import torch.nn.functional as F

MAX_LENGTH = 512

# Ignoring useless tokens
def average_pool(last_hidden_states, attention_mask):
    mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
    sum_embeddings = torch.sum(last_hidden_states * mask_expanded, 1)
    sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
    return sum_embeddings/sum_mask

# Creating embeddings from text
def embed(text: str):
    inputs = emb_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH
        ).to(device)
    with torch.no_grad():
        outputs = emb_model(**inputs)
        emb = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
        emb = F.normalize(emb, p=2, dim=1)
    return emb[0].cpu().numpy()

In [17]:
# Searching relevant documents
import psycopg2
import json

conn = psycopg2.connect(DB_CONN)
cur = conn.cursor()

def search_context(query, top_k=TOP_K):
    query_emb = embed(query).tolist()
    cur.execute(
        """
        SELECT content, metadata FROM documents_ruBert ORDER BY embedding <-> %s LIMIT %s
        """,
        (json.dumps(query_emb), top_k)
    )
    results = cur.fetchall()
    return [r[0] for r in results]

In [None]:
search_context("Вопрос")

In [32]:
# Asking LLM
def ask_llm(question, context):
    prompt = f"""Ты - умный ассистент, помогающий сотрудникам ответить на вопросы. Используй приведённый контекст для ответа на вопросы.
    
    
    Контекст:
    {context}
    
    
    Вопрос: {question}
    Ответ:"""
    
    inputs = llm_tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    with torch.no_grad():
        output = llm_model.generate(
            **inputs,
            max_new_tokes=300,
            temperature = 0.7,
            do_sample=True,
            top_p=0.9
        )
    llm_tokenizer.decode(output[0], skip_special_tokens=True)

def answer_question(question: str):
    context_chunks = search_context(question)
    context = "\n\n".join(context_chunks)
    return ask_llm(question, context)

In [None]:
q = "Вопрос"
answer_question(q)