In [14]:
# Config
DB_CONN = "dbname=appdb user=appuser password=secret port=5432 host=rag-data"
EMB_MODEL_PATH = "/wrk/models/embedding_models/models--intfloat--multilingual-e5-large-instruct/snapshots/274baa43b0e13e37fafa6428dbc7938e62e5c439"
LLM_MODEL_PATH = "/wrk/models/llms/models--AnatoliiPotapov--T-lite-instruct-0.1/snapshots/d346cb648c2e302461edfe72528f1999d2ef88b5"
TOP_K = 5

In [15]:
# Models uploading
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import torch

torch.cuda.empty_cache()

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_PATH)
emb_model = AutoModel.from_pretrained(EMB_MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb_model = emb_model.to(device)
emb_model.eval()

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_PATH,
    device_map=device,
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    quantization_config=quantization_config
)

llm_model.eval()

Loading checkpoint shards: 100%|██████████| 4/4 [00:34<00:00,  8.71s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128259, 4096, padding_idx=128256)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear8bitLt(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_lay

In [16]:
# Creating embeddings function
import torch.nn.functional as F

MAX_LENGTH = 512

# Ignoring useless tokens
def average_pool(last_hidden_states, attention_mask):
    mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
    sum_embeddings = torch.sum(last_hidden_states * mask_expanded, 1)
    sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
    return sum_embeddings/sum_mask

# Creating embeddings from text
def embed(text: str):
    inputs = emb_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH
        ).to(device)
    with torch.no_grad():
        outputs = emb_model(**inputs)
        emb = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
        emb = F.normalize(emb, p=2, dim=1)
    return emb[0].cpu().numpy()

In [17]:
# Searching relevant documents
import psycopg2
import json

conn = psycopg2.connect(DB_CONN)
cur = conn.cursor()

def search_context(query, top_k=TOP_K):
    query_emb = embed(query).tolist()
    cur.execute(
        """
        SELECT content, metadata FROM documents_e5 ORDER BY embedding <-> %s LIMIT %s
        """,
        (json.dumps(query_emb), top_k)
    )
    results = cur.fetchall()
    return [r[0] for r in results]

In [None]:
search_context("Вопрос")

In [None]:
# Asking LLM
def ask_llm(question, context, chat_history):
    prompt=f"""
Ты - умный ассистент, помогающий сотрудникам ответить на вопросы под документам. Используй приведённый контекст для ответа на вопросы. Отвечай строго на вопросы пользователя и не задавай новых вопросов.
Если ответ не найден в контексте - скажи, что информации нет.   

История диалога:
{chat_history}

Вопрос:
{question}

Контекст из документов:
{context}
"""
    
    inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output = llm_model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            # early_stopping=True,
            # num_beams=3
            # temperature = 0.5,
            # top_p=0.9
        )
    return llm_tokenizer.decode(output[0], skip_special_tokens=True)

# def answer_question(question: str):
#     context_chunks = search_context(question)
#     context = "\n\n".join(context_chunks)
#     return ask_llm(question, context)

In [None]:
# Chat history

def rephrase_question(question, history):
    history_text = "\n".join([f"Пользователь: {h['user']}\nАссистен: {h['assistant']}" for h in history])
    prompt = f"""
История диалога:
{history_text}

Вопрос пользователя: "{question}"

Переформулируй его так, чтобы он был самодостаточным запросом для поиска в документах.
"""
    return llm_generate(prompt)

In [None]:
# Chat

def chat(question):
    global chat_history
    
    if len(chat_history) > 0:
        new_question = rephrase_question(question, chat_history)
    else:
        new_question = question
    
    context_chunks = search_context(new_question)
    context = "\n\n".join(context_chunks)
    
    answer = ask_llm(question, context, chat_history)
    chat_history.append({'user': question, 'assistant':answer})
    
    return answer

In [None]:
q1 = "Вопрос"
print(chat(q1))

In [None]:
q2 = "Вопрос2"
print(chat(q2))