In [None]:
import os
from dotenv import load_dotenv
import json
import asyncpg
from datetime import datetime
from typing import Optional, Annotated, Sequence, Any, Dict, List
from redis.asyncio import Redis
from sqlalchemy import ARRAY, JSON, Integer, String, ForeignKey, Text, func, text
from sqlalchemy.orm import DeclarativeBase, declared_attr, Mapped, mapped_column
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from pgvector.sqlalchemy import Vector
import tiktoken
from langchain_openai import ChatOpenAI
from langchain_gigachat.chat_models import GigaChat
from langchain_gigachat.embeddings import GigaChatEmbeddings

In [None]:
load_dotenv()

POSTGRES_URL = os.environ.get("POSTGRES_URL")
GIGACHAT_API_KEY = os.environ.get("GIGACHAT_API_KEY")

In [117]:
embedder = GigaChatEmbeddings(
    credentials=GIGACHAT_API_KEY,
    verify_ssl_certs=False
)

In [118]:
engine = None
async_session_maker = None


class Base(AsyncAttrs, DeclarativeBase):
    __abstract__ = True
    __table_args__ = {'schema': 'ksu_test'}

    id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
    created_at: Mapped[datetime] = mapped_column(server_default=func.now())
    updated_at: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now())

    @declared_attr.directive
    def __tablename__(cls) -> str:
        return cls.__name__.lower() + 's'


def connection(method):
    async def wrapper(*args, **kwargs):
        async with async_session_maker() as session:
            try:
                async with session.begin():
                    return await method(*args, session=session, **kwargs)
            except Exception as e:
                await session.rollback()
                raise
            finally:
                await session.close()

    return wrapper


async def init_db():
    global engine, async_session_maker
    engine = create_async_engine(url=POSTGRES_URL)

    async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

async def create_tables():
    async with engine.begin() as conn:
        await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
        await conn.execute(text("CREATE SCHEMA IF NOT EXISTS ksu_test"))
        await conn.run_sync(Base.metadata.create_all)
async def create_indexes():
    async with engine.begin() as conn:
        # –°–æ–∑–¥–∞–Ω–∏–µ –∏–Ω–¥–µ–∫—Å–∞ –¥–ª—è –≤–µ–∫—Ç–æ—Ä–∞ —Å –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏–µ–º ivfflat
        await conn.execute(text("""
            CREATE INDEX IF NOT EXISTS idx_document_embedding
            ON ksu_test.documents
            USING ivfflat (embedding vector_cosine_ops);
        """))

In [119]:
class Document(Base):
    message_number: Mapped[int] = mapped_column(Integer, nullable=False)
    content: Mapped[str] = mapped_column(Text, nullable=False)
    category: Mapped[str] = mapped_column(String(50), nullable=False)
    subcategory: Mapped[str] = mapped_column(String(50), nullable=False)
    meta: Mapped[Dict[str, Any]] = mapped_column("metadata", JSON, nullable=True)
    embedding: Mapped[List[float]] = mapped_column(Vector(1024), nullable=False)

In [120]:
def chunk_text_with_overlap(
        text: str,
        max_tokens: int = 500,
        overlap_tokens: int = 100) -> list[str]:
    '''
    –†–∞–∑–±–∏–≤–∞–µ—Ç —Ç–µ–∫—Å—Ç –Ω–∞ —á–∞–Ω–∫–∏ —Å –ø–µ—Ä–µ–∫—Ä—ã—Ç–∏–µ–º –ø–æ —Ç–æ–∫–µ–Ω–∞–º
    '''
    enc = tiktoken.get_encoding("cl100k_base")
    tokens = enc.encode(text)
    chunks = []
    start = 0

    while start < len(tokens):
        end = min(start + max_tokens, len(tokens))
        chunk_tokens = tokens[start:end]
        chunk_text = enc.decode(chunk_tokens)
        chunks.append(chunk_text)
        start += max_tokens - overlap_tokens

    return chunks

In [121]:
@connection
async def generate_embeddings(corpus_path: str = "corpus.json", session=None):
    corpus = json.load(open(corpus_path, "r", encoding="utf-8"))
    docs = []

    for data in corpus:
        chunks = chunk_text_with_overlap(
            data["content"],
            max_tokens=500,
            overlap_tokens=100
        )
        print(data.get("metadata", {}))
        for idx, chunk in enumerate(chunks):
            embedding_vector = embedder.embed_query(chunk)

            db_document = Document(
                message_number=data.get("message_number", -1),
                content=chunk,
                category=data.get("category", ""),
                subcategory=data.get("subcategory", ""),
                meta=json.dumps(data.get("metadata", {})),
                embedding=embedding_vector
            )
            docs.append(db_document)
            
            # –ë–∞—Ç—á–∏–Ω–≥ –∫–∞–∂–¥—ã–µ 100 –¥–æ–∫—É–º–µ–Ω—Ç–æ–≤
            # if len(docs) >= 100:
            #     session.add_all(docs)
            #     await session.flush()  # –ù–µ –∫–æ–º–º–∏—Ç–∏–º, —Ç–æ–ª—å–∫–æ flush
            #     docs.clear()

    session.add_all(docs)
    await session.commit()

In [122]:
await init_db()

In [123]:
await create_tables()

In [124]:
await create_indexes()

In [125]:
await generate_embeddings()

{'tags': ['–æ—à–∏–±–∫–∞', '–Ω–µ–ø–æ–Ω–∏–º–∞–Ω–∏–µ', '—Å–∏—Å—Ç–µ–º–Ω–æ–µ'], 'updated_at': '2025-07-15T10:16:00+03:00', 'buttons': [{'title': 'üè† –í –Ω–∞—á–∞–ª–æ', 'go_to': 11}], 'transitions': {'default': 11}}
{'tags': ['–ø—Ä–∏–≤–µ—Ç—Å—Ç–≤–∏–µ', '–≥–ª–∞–≤–Ω–æ–µ –º–µ–Ω—é'], 'updated_at': '2025-07-15T10:00:00+03:00', 'buttons': [{'title': '–û —á—ë–º —Ç—ã —Ö–æ—á–µ—à—å —É–∑–Ω–∞—Ç—å?', 'go_to': 11}, {'title': '–ö–∞–∫ –ø–æ–¥–≥–æ—Ç–æ–≤–∏—Ç—å—Å—è –∫ –ø–æ—Å—Ç—É–ø–ª–µ–Ω–∏—é', 'go_to': 21}, {'title': '–û –ø–æ—Å—Ç—É–ø–ª–µ–Ω–∏–∏', 'go_to': 2}], 'transitions': {'default': 11}}
{'tags': ['–ø–æ—Å—Ç—É–ø–ª–µ–Ω–∏–µ', '–∫–æ–Ω—Ç–∞–∫—Ç—ã'], 'updated_at': '2025-07-15T10:00:00+03:00', 'buttons': [{'title': '–ö–∞–∫–æ–π –ø—Ä–æ—Ö–æ–¥–Ω–æ–π –±–∞–ª–ª?', 'go_to': 5}, {'title': '–°–∫–æ–ª—å–∫–æ –º–µ—Å—Ç –¥–ª—è –ø—Ä–∏–µ–º–∞ –Ω–∞ 2024-2025 —É—á–µ–±–Ω—ã–π –≥–æ–¥?', 'go_to': 31}], 'transitions': {}}
{'tags': ['–ø–æ—Å—Ç—É–ø–ª–µ–Ω–∏–µ', '–ø—Ä–æ—Ö–æ–¥–Ω—ã–µ –±–∞–ª–ª—ã'], 'updated_at': '2025-07-15T10:04:00+03:00', 'buttons'

In [126]:
from sqlalchemy import select


@connection
async def generate(session=None):
    res = await session.execute(select(Document))
    return res.scalars().all()

res = await generate()

# print(len(res))
# for r in res:
#     print(r.meta)

In [127]:
@connection
async def search_index(embedding: List[float], top_k: int = 3, session=None) -> List[Dict[str, Any]]:
    embedding_str = '[' + ','.join(map(str, embedding)) + ']'
    query = text(f"""
        SELECT message_number, content, category, subcategory, metadata, embedding,
               1.0 / ((embedding <=> '{embedding_str}'::vector) + 1e-6) AS score
        FROM ksu_test.documents
        ORDER BY embedding <=> '{embedding_str}'::vector
        LIMIT {top_k};
    """)
    result = await session.execute(query)
    return result.mappings().all()

@connection
async def search_content(content, top_k=3, session=None):
    query = text("""
        SELECT message_number, content, category, subcategory, metadata,
               ts_rank_cd(to_tsvector('russian', content), plainto_tsquery('russian', :content)) AS rank
        FROM ksu_test.documents
        WHERE to_tsvector('russian', content) @@ plainto_tsquery('russian', :content)
        ORDER BY rank DESC
        LIMIT :top_k;
    """)
    result = await session.execute(query, {"content": content, "top_k": top_k})
    return result.mappings().all()


In [128]:
content = "–ü—Ä–∏–≤–µ—Ç, –Ω–∞–ø–∏—à–∏ –ø—Ä–æ –±—É—Ñ–µ—Ç"
embedding = embedder.embed_query(content)

res_content = await search_content(content)
for r in res_content:
    print(r)

res = await search_index(embedding)
for r in res:
    print(r)

{'message_number': 4, 'content': '–ò–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è –æ —Ç—Ä—É–¥–æ—É—Å—Ç—Ä–æ–π—Å—Ç–≤–µ:', 'category': '–¢—Ä—É–¥–æ—É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ', 'subcategory': '–û—Å–Ω–æ–≤–Ω–∞—è –∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è', 'metadata': '{"tags": ["\\u0442\\u0440\\u0443\\u0434\\u043e\\u0443\\u0441\\u0442\\u0440\\u043e\\u0439\\u0441\\u0442\\u0432\\u043e", "\\u043c\\u0435\\u043d\\u044e"], "updated_at": "2025-07-15T10:06:00+03:00", "buttons": [{"title": "\\u0413\\u0434\\u0435 \\u043c\\u043e\\u0433\\u0443\\u0442 \\u0440\\u0430\\u0431\\u043e\\u0442\\u0430\\u0442\\u044c \\u0432\\u0430\\u0448\\u0438 \\u0432\\u044b\\u043f\\u0443\\u0441\\u043a\\u043d\\u0438\\u043a\\u0438?", "go_to": 19}, {"title": "\\u041a\\u0435\\u043c \\u043c\\u043e\\u0433\\u0443\\u0442 \\u0440\\u0430\\u0431\\u043e\\u0442\\u0430\\u0442\\u044c \\u0432\\u0430\\u0448\\u0438 \\u0432\\u044b\\u043f\\u0443\\u0441\\u043a\\u043d\\u0438\\u043a\\u0438?", "go_to": 35}, {"title": "\\u0420\\u0430\\u0441\\u0441\\u043a\\u0430\\u0436\\u0438 \\u043f\\u0440\\u043e \\u0443\\u