## 读取 .env 文件

可以使用 `python-dotenv` 库来加载 `.env` 文件中的环境变量。首先确保已经安装了该库 (`pip install python-dotenv`)。

In [1]:
import os
from dotenv import load_dotenv

# 加载 .env 文件中的环境变量
# 这会查找当前目录或父目录中的 .env 文件
load_dotenv()

# 现在你可以像访问普通环境变量一样访问它们
postgres_user = os.getenv('POSTGRES_USER')
postgres_password = os.getenv('POSTGRES_PASSWORD')
postgres_db = os.getenv('POSTGRES_DB')
postgres_port = os.getenv('POSTGRES_PORT')

print(f"User: {postgres_user}")
print(f"Password: {postgres_password}")
print(f"Database: {postgres_db}")
print(f"Port: {postgres_port}")

# 确保你的 .env 文件与 notebook 在同一个目录，或者在父目录中
# .env 文件内容示例:
# POSTGRES_USER=your_user
# POSTGRES_PASSWORD=your_password
# POSTGRES_DB=your_database
# POSTGRES_PORT=5432

User: admin
Password: admin123
Database: vectordb
Port: 5432


### 在 Jupyter Notebook 中运行 Asyncio 代码

当你尝试在 Jupyter Notebook 中使用 `asyncio.run()` 时，通常会遇到 `RuntimeError: asyncio.run() cannot be called from a running event loop`。这是因为 Jupyter Notebook (特别是其内核 `ipykernel`) 已经管理着一个正在运行的 `asyncio` 事件循环。

`asyncio.run()` 设计为启动一个新的事件循环，这与 Jupyter 已经运行的循环冲突。

正确的做法是直接 `await` 你的异步函数。Jupyter/IPython 会在它现有的事件循环中处理这个 `await`。

In [2]:
import asyncpg

# 建立连接
conn = await asyncpg.connect(
    user=postgres_user,
    password=postgres_password,
    database=postgres_db,
    host='127.0.0.1',
    port=postgres_port # 确保也使用了 .env 中的端口
)


# 执行查询
# 示例：创建一个表（如果它不存在）并插入一些数据
await conn.execute('''
    CREATE TABLE IF NOT EXISTS your_table (
        id SERIAL PRIMARY KEY,
        name TEXT
    );
''')
await conn.execute("INSERT INTO your_table (name) VALUES ($1), ($2)", "test_name1", "test_name2")

rows = await conn.fetch('SELECT * FROM your_table')

# 打印结果
for row in rows:
    print(dict(row))

{'id': 1, 'name': 'test_name1'}
{'id': 2, 'name': 'test_name2'}
{'id': 3, 'name': 'test_name1'}
{'id': 4, 'name': 'test_name2'}
{'id': 5, 'name': 'test_name1'}
{'id': 6, 'name': 'test_name2'}
{'id': 7, 'name': 'test_name1'}
{'id': 8, 'name': 'test_name2'}
{'id': 9, 'name': 'test_name1'}
{'id': 10, 'name': 'test_name2'}
{'id': 11, 'name': 'test_name1'}
{'id': 12, 'name': 'test_name2'}
{'id': 13, 'name': 'test_name1'}
{'id': 14, 'name': 'test_name2'}


In [3]:
# show all tables
tables = await conn.fetch('''
    SELECT table_name
    FROM information_schema.tables
    WHERE table_schema = 'public';
''')
for table in tables:
    print(table)

<Record table_name='your_table'>
<Record table_name='testvector'>
<Record table_name='arxiv_meta'>


In [4]:
from pgvector.asyncpg import register_vector
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
await conn.execute("CREATE EXTENSION IF NOT EXISTS vchord CASCADE;")
extensions = await conn.fetch('SELECT extname, extversion FROM pg_extension;')
await register_vector(conn)
extensions

[<Record extname='plpgsql' extversion='1.0'>,
 <Record extname='vector' extversion='0.8.0'>,
 <Record extname='vchord' extversion='0.3.0'>]

In [5]:
await conn.execute('DROP TABLE IF EXISTS testvector CASCADE;')
await conn.execute('DROP TABLE IF EXISTS arxiv_meta CASCADE;')
print('Table "testvector" dropped if it existed, before recreation.')

Table "testvector" dropped if it existed, before recreation.


In [6]:
import numpy as np

# 创建一个表，testvector，使用4维度的float16 向量，有3个字段，id(手动指定的字符串)，content，embedding，如果不存在
await conn.execute('''
    CREATE TABLE IF NOT EXISTS testvector (
        id TEXT PRIMARY KEY,
        content TEXT,
        embedding HALFVEC(4)
    );
''')
# 插入数据

# 定义一个 NumPy 数组
vector_data_np = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float16)

# With pgvector.asyncpg and register_vector(conn), 
# we should be able to pass the numpy array directly.
await conn.execute('''
    INSERT INTO testvector (id, content, embedding)
    VALUES ($1, $2, $3)
    ON CONFLICT (id) DO NOTHING;
''', 'test_id_pgvector', 'test_content_pgvector', vector_data_np)

# 查询数据
rows = await conn.fetch('SELECT * FROM testvector WHERE id = $1', 'test_id_pgvector')
# 打印结果  
for row in rows:
    print(dict(row))
    print(row['embedding'].to_numpy().dtype)

{'id': 'test_id_pgvector', 'content': 'test_content_pgvector', 'embedding': HalfVector([0.0999755859375, 0.199951171875, 0.300048828125, 0.39990234375])}
>f2


In [None]:
# Schema({'id': String, 'title': String, 'authors': List(String), 'abstract': String, 'date': String, 'categories': List(String), 'created': String, 'updated': String, 'license': String, 'jasper_v1': Array(Float32, shape=(1024,)), 'conan_v1': Array(Float32, shape=(1792,))})
await conn.execute('''
    CREATE TABLE IF NOT EXISTS arxiv_meta (
        id TEXT PRIMARY KEY,
        title TEXT,
        authors TEXT[],
        abstract TEXT,
        date DATE,
        categories TEXT[],
        created DATE,
        updated DATE,
        license TEXT,
        jasper_v1 halfvec(1024),
        conan_v1 halfvec(1792)
    );
''')

# show all tables
tables = await conn.fetch('''
    SELECT table_name
    FROM information_schema.tables
    WHERE table_schema = 'public';
''')
for table in tables:
    print(table)

<Record table_name='your_table'>
<Record table_name='testvector'>
<Record table_name='arxiv_meta'>


In [8]:
# 获取 arxiv_meta 表的 schema
table_schema = await conn.fetch("""
    SELECT column_name, data_type, character_maximum_length, column_default, is_nullable
    FROM information_schema.columns
    WHERE table_name = 'arxiv_meta';
""")

for column in table_schema:
    print(dict(column))


{'column_name': 'created', 'data_type': 'date', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'updated', 'data_type': 'date', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'date', 'data_type': 'date', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'jasper_v1', 'data_type': 'USER-DEFINED', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'conan_v1', 'data_type': 'USER-DEFINED', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'license', 'data_type': 'text', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'categories', 'data_type': 'ARRAY', 'character_maximum_length': None, 'column_default': None, 'is_nullable': 'YES'}
{'column_name': 'title', 'data_type': 'text', 'character_maximum_length': None, 'col

In [None]:
import polars as pl
import asyncpg
import numpy as np
from tqdm import tqdm

# 定义公共列和嵌入列
common_columns = ['id', 'title', 'authors', 'abstract', 'date', 'categories', 'created', 'updated', 'license']
embedding_columns = ['jasper_v1', 'conan_v1']

# 定义批次大小
batch_size = 100_000
lazy_frame = pl.scan_parquet("./data/*.parquet")
# 计算总行数
total_rows = int(lazy_frame.select(pl.len()).collect().item())
print(f"Total rows: {total_rows}")

date_cols_to_convert = ['date', 'created', 'updated']

# 连接到 PostgreSQL 数据库
vector_type = np.float16
async def main():
    # 按批次循环处理数据
    for i in tqdm(range(0, total_rows, batch_size)):
        # 提取批次
        batch_lf = lazy_frame.slice(i, batch_size)
        embeddings = {
            name: pl.Series(
                (batch_lf.select(name)).collect()
            ).to_numpy().astype(vector_type)
            for name in embedding_columns
        }
        # 转换日期列
        for col_name in date_cols_to_convert:
            batch_lf = batch_lf.with_columns(
                pl.col(col_name).str.to_date(format="%Y-%m-%d", strict=False).alias(col_name)
            )
        batch_df = batch_lf.select(common_columns).collect()
        values = (
            (*item, *[embeddings[name][i] for name in embedding_columns])
            for i, item in enumerate(
                batch_df.select(common_columns).iter_rows()
            )
        )
        insert_meta_sql = '''
            INSERT INTO arxiv_meta (id, title, authors, abstract, date, categories, created, updated, license, jasper_v1, conan_v1)
            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
            ON CONFLICT (id) DO NOTHING;
        '''
        await conn.executemany(insert_meta_sql, values)

# 运行主函数
await main()

Total rows: 699547


100%|██████████| 7/7 [01:09<00:00,  9.95s/it]


In [10]:

# 查询 arxiv_meta 表中的数据总数
count_result = await conn.fetchval('SELECT COUNT(*) FROM arxiv_meta;')
print(f'Total rows in arxiv_meta: {count_result}')


Total rows in arxiv_meta: 697495


In [11]:

# 查询 arxiv_meta 表的头部数据
head_rows = await conn.fetch('SELECT * FROM arxiv_meta LIMIT 5;')
for row in head_rows:
    print(dict(row))


{'id': '0704.3649', 'title': 'Quantile and Probability Curves Without Crossing', 'authors': ['Victor Chernozhukov', 'Ivan Fernandez-Val', 'Alfred Galichon'], 'abstract': 'This paper proposes a method to address the longstanding problem of lack of monotonicity in estimation of conditional and structural quantile functions, also known as the quantile crossing problem. The method consists in sorting or monotone rearranging the original estimated non-monotone curve into a monotone rearranged curve. We show that the rearranged curve is closer to the true quantile curve in finite samples than the original curve, establish a functional delta method for rearrangement-related operators, and derive functional limit theory for the entire rearranged curve and its functionals. We also establish validity of the bootstrap for estimating the limit law of the the entire rearranged curve and its functionals. Our limit results are generic in that they apply to every estimator of a monotone econometric fu

In [12]:
await conn.execute("SET maintenance_work_mem = '16GB';")
await conn.execute("SET max_parallel_maintenance_workers = 15;")
await conn.execute("SET max_parallel_workers = 15;")

'SET'

In [None]:
sql_build_index = '''
CREATE INDEX IF NOT EXISTS arxiv_meta_{column}_idx 
ON arxiv_meta 
USING vchordrq ({column} {ops})
WITH (options = $$
[build.internal]
lists = [{lists}]
build_threads = {threads}
$$);
'''
lists = int(2 * total_rows / 1000)
threads = 15

for column in embedding_columns:
    print(f"Creating vchordrq index for {column}...")
    await conn.execute(sql_build_index.format(
        column=column, lists=lists, threads=threads, ops='halfvec_ip_ops'
    ))

Creating vchordrq index for jasper_v1...
Creating vchordrq index for conan_v1...
