In [None]:
!pip install sentence_transformers

In [None]:
from sentence_transformers import SentenceTransformer
import chromadb


class EmbeddingFunction:
    def __init__(self):
        self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

    def __call__(self, input):
        return self.model.encode(input).tolist()


class DB:

    def __init__(self, distance_function, root_path):
        self.ef = EmbeddingFunction()
        self.client = chromadb.PersistentClient(path=root_path)
        self.distance_function = distance_function
        assert distance_function in ["l2", "ip", "cosine"], "Distance function should be 'l2' or 'ip' or 'cosine'"
        self.collection = self.client.get_or_create_collection("lab5_" + self.distance_function,
                                                               metadata={"hnsw:space": self.distance_function},
                                                               embedding_function=self.ef)

    def add(self, items):
        old_batch = 0
        new_batch = 1000
        while True:
            if new_batch > len(items["fragments"]):
                break
            self.collection.add(
                documents=items["fragments"][old_batch:new_batch],
                metadatas=items["metadata"][old_batch:new_batch],
                ids=items["ids"][old_batch:new_batch])
            old_batch = new_batch
            new_batch += 1000
        self.collection.add(
            documents=items["fragments"][old_batch:],
            metadatas=items["metadata"][old_batch:],
            ids=items["ids"][old_batch:])

    def query(self, query, n_results):
        return self.collection.query(query_embeddings=self.ef(query), n_results=n_results)

    def clear(self):
        self.client.delete_collection("lab5_" + self.distance_function)
        self.collection = self.client.get_or_create_collection("lab5_" + self.distance_function,
                                                               metadata={"hnsw:space": self.distance_function},
                                                               embedding_function=self.ef)

In [None]:
import re
import os

def split_to_sent(text):
    sentences = re.split(
        r"(((?<!\w\.\w.)(?<!\s\w\.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s(?=[A-Z]))|((?<![\,\-\:])\n(?=[A-Z]|\" )))", text)[::4]
    return sentences

In [None]:
def split_document(lines, fragment_limit=100):
    sentences = split_to_sent(lines)
    result = []
    fragment = ""
    length = 0
    for s in sentences:
        fragment += s + " "
        length += len(s.split(" "))
        if length > fragment_limit:
            result.append(fragment)
            fragment = ""
            length = 0
    return result

In [None]:
import pandas as pd

def split_dataset(dataset_path, fragment_limit=100):
    result_fragments = []
    metadata = []
    result_ids = []
    t = "train.csv"
    filepath = os.path.join(dataset_path, t)
    df = pd.read_csv(filepath, names=['label', 'Title', 'Description'])
    df['text'] = (df['Title'] + '. ' + df['Description'])
    df.drop(columns=['Title', 'Description'], axis=1, inplace=True)
    for index, row in df.iterrows():
      fragments_raw = split_document(row['text'], fragment_limit)
      counter = 0
      for fragment in fragments_raw:
          result_fragments.append(fragment.replace("\n", " "))
          metadata.append({"document": index, "topic": row["label"]})
          result_ids.append(f"{index}_{row['label']}_{counter}")
          counter += 1
    return result_fragments, result_ids, metadata

In [None]:
data_dir = os.path.realpath("./dataset/raw")

In [None]:
fragments, ids, metadata = split_dataset(data_dir, fragment_limit=20)

In [None]:
database_l2 = DB("l2", "./dataset/raw/DB")
database_ip = DB("ip", "./dataset/raw/DB")
database_cosine = DB("cosine", "./dataset/raw/DB")

In [None]:
database_l2.clear()
database_l2.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_ip.clear()
database_ip.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_cosine.clear()
database_cosine.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_l2.query("What Iraq problem is?", 5)

In [None]:
database_ip.query("What Iraq problem is?", 5)

In [None]:
database_cosine.query("What Iraq problem is?", 5)

In [None]:
database_l2.query("В чем проблема Ирака?", 5)

In [None]:
database_ip.query("В чем проблема Ирака?", 5)

In [None]:
database_cosine.query("В чем проблема Ирака?", 5)

In [None]:
!pip install ctransformers

In [None]:
!pip install tiktoken
!pip install openai
!pip install cohere
!pip install kaleido

In [None]:
!pip install gradio==3.48.0

In [None]:
from ctransformers import AutoModelForCausalLM
import gradio as gr

In [None]:
llm = AutoModelForCausalLM.from_pretrained("TheBloke/Mistral-7B-OpenOrca-GGUF", model_file="mistral-7b-openorca.Q4_K_M.gguf", model_type="mistral", gpu_layers=50)


In [None]:
database_l2

In [None]:
def process_request(message, history):
    context = "\n".join(database_l2.query(message, 5)["documents"][0])
    prompt = "Answer the question using provided context. Your answer should be in your own words and be no longer than 50 words.\n"
    prompt += "Context: " + context + "\n"
    prompt += "Question: " + message + "\n"
    prompt += "Answer: "
    answer = llm(prompt)
    return f"{prompt}\n\n{answer}"

In [None]:
demo = gr.ChatInterface(fn=process_request, title="Chat bot")
demo.launch()

# Выводы

Были проведены сравнения работы с моделью, имеющей в примерно 2 раза меньшее количество параметров. C результатом сравнения и использумыми запросами можно ознакомиться на странице https://docs.google.com/spreadsheets/d/1hbwkQYDN5jxTs6qmNrz8bebgqDX_l5LidE6tW_GrX_Y/edit?usp=sharing.
По результатам можно сказать, что проблема в некоторых вопорсах является не обобщающая способность, а нехватка информации в изначальном корпусе