In [1]:
# Config
DB_CONN = "dbname=appdb user=appuser password=secret port=5432 host=rag-data"
EMB_MODEL_PATH = "/wrk/models/embedding_models/models--intfloat--multilingual-e5-large-instruct/snapshots/274baa43b0e13e37fafa6428dbc7938e62e5c439"
LLM_MODEL_PATH = "/wrk/models/llms/models--ai-forever--ruGPT-3.5-13B/snapshots/64b115374b8f086ef13ccb8ba4f49d8076a53324"
TOP_K = 5

In [2]:
# Models uploading
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import torch

torch.cuda.empty_cache()

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

emb_tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_PATH)
emb_model = AutoModel.from_pretrained(EMB_MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb_model = emb_model.to(device)
emb_model.eval()

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(
    LLM_MODEL_PATH,
    device_map="cuda",
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    quantization_config=quantization_config
)

llm_model.eval()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [01:52<00:00, 18.70s/it]


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50272, 5120)
    (wpe): Embedding(2048, 5120)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-39): 40 x GPT2Block(
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Linear8bitLt(in_features=5120, out_features=15360, bias=True)
          (c_proj): Linear8bitLt(in_features=5120, out_features=5120, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear8bitLt(in_features=5120, out_features=20480, bias=True)
          (c_proj): Linear8bitLt(in_features=20480, out_features=5120, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((5120,), eps

In [3]:
# Creating embeddings function
import torch.nn.functional as F

MAX_LENGTH = 512

# Ignoring useless tokens
def average_pool(last_hidden_states, attention_mask):
    mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
    sum_embeddings = torch.sum(last_hidden_states * mask_expanded, 1)
    sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
    return sum_embeddings/sum_mask

# Creating embeddings from text
def embed(text: str):
    inputs = emb_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH
        ).to(device)
    with torch.no_grad():
        outputs = emb_model(**inputs)
        emb = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
        emb = F.normalize(emb, p=2, dim=1)
    return emb[0].cpu().numpy()

In [4]:
# Searching relevant documents
import psycopg2
import json

conn = psycopg2.connect(DB_CONN)
cur = conn.cursor()

def search_context(query, top_k=TOP_K):
    query_emb = embed(query).tolist()
    cur.execute(
        """
        SELECT content, metadata FROM documents_e5 ORDER BY embedding <-> %s LIMIT %s
        """,
        (json.dumps(query_emb), top_k)
    )
    results = cur.fetchall()
    return [r[0] for r in results]

In [None]:
search_context("Вопрос")

In [6]:
# Asking LLM
def ask_llm(question, context):
    prompt = f"""Ты - умный ассистент, помогающий сотрудникам ответить на вопросы. Используй приведённый контекст для ответа на вопросы. Отвечай строго на вопросы пользователя и не задавай новых вопросов.
    
    
    Контекст:
    {context}
    
    
    Вопрос: {question}
    Ответ:"""
    
    inputs = llm_tokenizer(prompt, return_tensors="pt", max_length=1536).to(llm_model.device)
    with torch.no_grad():
        output = llm_model.generate(
            **inputs,
            # max_new_tokens=200,
            max_length=2048,
            do_sample=False,
            early_stopping=True,
            # temperature = 0.5,
            # top_p=0.9
        )
    return llm_tokenizer.decode(output[0], skip_special_tokens=True)

def answer_question(question: str):
    context_chunks = search_context(question)
    context = "\n\n".join(context_chunks)
    return ask_llm(question, context)

In [7]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
# os.environ['TORCH_USE_CUDA_DSA']='1'

In [None]:
q = "Вопрос"
print(answer_question(q))