In [1]:
import asyncio
import time
import uuid
from typing import Dict, List

In [2]:
!pip install redis
!pip install transformers

[0m

In [3]:
import numpy as np
import redis.asyncio as redis
import torch
import torch.nn.functional as F
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from torch import Tensor
from transformers import AutoModel, AutoTokenizer, pipeline

In [4]:
REDIS_DB = 0
REDIS_HOST = "94537961.xyz"
REDIS_PORT = 6379
REDIS_PASSWORD = "0f8G0s9aokzjBh5B6W9ZGLUo"

In [5]:
EMBEDDING_MODEL_NAME = "andersonbcdefg/bge-small-4096"
VECTOR_DIMENSION = 384
TOKENS_LIMIT = 4096 - 16  # To be safe
DEVICE = "cuda"
INDEX_NAME = "idx:pages_vss"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME, truncation=True)
model = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME).half().to(DEVICE)

pipe = pipeline(
    "feature-extraction",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device=DEVICE,
)

In [7]:
def get_redis_client():
    return redis.Redis(
        host=REDIS_HOST,
        port=REDIS_PORT,
        db=REDIS_DB,
        password=REDIS_PASSWORD,
    )

In [8]:
def merge_embeddings(embeddings):
    embeddings = F.normalize(embeddings, p=2, dim=1)

    # Merge embeddings
    embeddings = embeddings.mean(dim=0)

    return embeddings


def average_pool(states: Tensor) -> Tensor:
    return states.mean(dim=0)


def prepare_text(text: str):
    tokens = tokenizer(text, padding=False, truncation=False)
    chunks = []
    for i in range(0, len(tokens["input_ids"]), TOKENS_LIMIT):
        chunk = {
            "input_ids": tokens["input_ids"][i : i + TOKENS_LIMIT],
            "attention_mask": tokens["attention_mask"][i : i + TOKENS_LIMIT],
        }
        chunks.append(chunk)

    texts = []
    for chunk in chunks:
        text = tokenizer.decode(
            chunk["input_ids"],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )
        texts.append(text)

    return texts


async def put_crawled_url(url: str, embedings):
    client = get_redis_client()
    result = await client.hset(
        f"pages:{url}", mapping={"url": url, "embeddings": embedings}
    )
    await client.aclose()
    return result

In [9]:
def get_embeddings(text: str) -> Tensor:
    texts = prepare_text(text)

    outputs: List[List[float]] = []

    for text in texts:
        output = pipe(text)[0]
        outputs.extend(output)

    embeddings_list = torch.tensor(outputs)

    return average_pool(embeddings_list).cpu().numpy().astype(np.float32).tobytes()


async def put_url(url: str, text: str):
    embedings = get_embeddings(text)
    return await put_crawled_url(url, embedings)

In [10]:
DB_HOST = "94537961.xyz"
DB_PORT = 6543
DB_USER = "postgres"
DB_PASSWORD = "bcf8cd1cde6347fe"
DB_NAME = "sites"

In [11]:
!pip install aiopg

[0m

In [12]:
import asyncio

import aiopg

# import psycopg2  # type: ignore

In [13]:
class AsyncDB:
    def __init__(self, host, port, user, password, database):
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.database = database

        self.dsn = (
            f"dbname={database} user={user} password={password} host={host} port={port}"
        )

    async def insert(self, table, columns, values):
        async with aiopg.create_pool(self.dsn) as pool:
            async with pool.acquire() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({', '.join(['%s'] * len(values))})",
                        values,
                    )

    async def select(
        self, table, columns, where=None, limit=None, offset=None, order_by=None
    ):
        query = f"SELECT {', '.join(columns)} FROM {table}"
        if where:
            query += f" WHERE {where}"
        if order_by:
            query += f" ORDER BY {order_by}"
        if limit:
            query += f" LIMIT {limit}"
        if offset:
            query += f" OFFSET {offset}"

        async with aiopg.create_pool(self.dsn) as pool:
            async with pool.acquire() as conn:
                async with conn.cursor() as cursor:
                    async with cursor.begin():
                        await cursor.execute(query)
                        return await cursor.fetchall()

    async def count(self, table):
        async with aiopg.create_pool(self.dsn) as pool:
            async with pool.acquire() as conn:
                async with conn.cursor() as cursor:
                    async with cursor.begin():
                        await cursor.execute(f"SELECT COUNT(*) FROM {table}")
                        return await cursor.fetchone()

    def close(self):
        pass


def create_db() -> AsyncDB:
    return AsyncDB(DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME)

In [14]:
def get_embeddings_batch(batch: List[str]) -> List[Tensor]:
    batch_size = len(batch)

    texts = [prepare_text(text) for text in batch]
    sizes = [len(text) for text in texts]
    texts = [text for text in texts for text in text]

    embeddings: List[Tensor] = []

    for index in range(0, len(texts), batch_size):
        texts_batch = texts[index : index + batch_size]
        output = pipe(texts_batch)

        iterator = iter(output)

        for chunk_size in sizes:
            chunk = []
            for _ in range(chunk_size):
                chunk.extend(next(iterator)[0])

            embeddings.append(
                average_pool(torch.tensor(chunk).to(DEVICE))
                .cpu()
                .numpy()
                .astype(np.float32)
                .tobytes()
            )

    return embeddings


async def put_crawled_url(client, url: str, embedings):
    result = await client.hset(
        f"pages:{url}", mapping={"url": url, "embeddings": embedings}
    )
    return result


async def get_crawled_url(url: str):
    client = get_redis_client()
    result = await client.ft(INDEX_NAME).load_document(f"pages:{url}")
    await client.aclose()
    return result


async def get_all_urls_from_redis():
    client = get_redis_client()
    keys = await client.keys("pages:*")
    urls = [key.decode("utf-8")[6:] for key in keys]
    await client.aclose()
    return set(urls)

In [17]:
async def move_to_redis(start_offset: int = 0):
    start = time.time()
    db = create_db()
    client = get_redis_client()

    total_count = await db.count("websites")
    total_count = total_count[0]
    print(f"Total count: {total_count}")
    put_chunk_size = 128
    page_size = put_chunk_size * 2

    already_moved = await get_all_urls_from_redis()
    print(f"Initialized in {time.time() - start} seconds")
    start = time.time()

    for i in range(start_offset, total_count, page_size):
        start = time.time()

        urls = await db.select(
            "websites", ["url", "content"], order_by="url", limit=page_size, offset=i
        )
        print(f"Selected {len(urls)} urls in {time.time() - start} seconds")
        start = time.time()
        cur_len = len(urls)
        urls = [
            (url, content)
            for url, content in urls
            if url not in already_moved
            and content is not None
            and str(content).strip() != ""
        ]
        print(f"Filtered {cur_len - len(urls)} urls in {time.time() - start} seconds")
        start = time.time()

        for chunk in range(0, len(urls), put_chunk_size):
            batch = list(urls[chunk : chunk + put_chunk_size])
            texts = [text for _, text in batch]
            temp_urls = [url for url, _ in batch]

            embeddings = get_embeddings_batch(texts)
            torch.cuda.empty_cache()
            tasks = [
                asyncio.create_task(
                    put_crawled_url(client, temp_urls[i], embeddings[i])
                )
                for i in range(len(texts))
            ]
            await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)

        print(f"Moved {len(urls)} urls in {time.time() - start} seconds")
        print(f"Total moved: {i + page_size}")
        start = time.time()

    await client.close()

In [None]:
await move_to_redis(2048)

Total count: 2302237
Initialized in 35.946128129959106 seconds
