This notebook shows sql clustering using embedding and vector database

## 1. Breparing Database

In [6]:
%pip install -r requirements.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [2]:

import pandas as pd

splits = {'train': 'synthetic_text_to_sql_train.snappy.parquet', 'test': 'synthetic_text_to_sql_test.snappy.parquet'}
df = pd.read_parquet("hf://datasets/gretelai/synthetic_text_to_sql/" + splits["train"])

df.head

  from .autonotebook import tqdm as notebook_tqdm


<bound method NDFrame.head of           id                 domain  \
0       5097               forestry   
1       5098       defense industry   
2       5099         marine biology   
3       5100     financial services   
4       5101                 energy   
...      ...                    ...   
99995  89651              nonprofit   
99996  89652                 retail   
99997  89653       fitness industry   
99998  89654      space exploration   
99999  89655  wildlife conservation   

                                      domain_description    sql_complexity  \
0      Comprehensive data on sustainable forest manag...       single join   
1      Defense contract data, military equipment main...       aggregation   
2      Comprehensive data on marine species, oceanogr...         basic SQL   
3      Detailed financial data including investment s...       aggregation   
4      Energy market data covering renewable energy s...  window functions   
...                              

In [3]:
from transformers import pipeline
pipe = pipeline("feature-extraction", model="microsoft/codebert-base")     

Device set to use mps:0


In [4]:
import numpy as np


sql_query = "SELECT name, age FROM users WHERE age > 30"
embeddings = pipe(sql_query)

embedding_size = np.array(embeddings).shape
print(f"Embedding size: {embedding_size}")

Embedding size: (1, 12, 768)


# Using FAISS as Vector database

In [36]:
import faiss
import numpy as np
from tqdm import tqdm


d = 768  # 向量维度

top_sql_queries = df.head(10)["sql"]  # 假设字段名称是 sql_query

# 使用一个字典来存储原文与其对应的索引
text_dict = {i: top_sql_queries[i] for i in range(len(top_sql_queries))}
print(f"top_sql_queries={top_sql_queries}, text_dict={text_dict}")

index = faiss.IndexFlatL2(d)  # 使用 L2 距离

for sql_query in tqdm(top_sql_queries):
  embeddings = pipe(sql_query)
  # 转换嵌入为 NumPy 数组并计算平均值（句子的整体 embedding）
  embeddings = np.array(embeddings[0])  # 获取第一个元素，即每个 token 的 embedding
  sentence_embedding = embeddings.mean(axis=0)  # 对所有 token 的 embedding 取平均
  
  index.add(np.expand_dims(sentence_embedding, axis=0).astype('float32'))  # 添加到 FAISS 索引中

top_sql_queries=0    SELECT salesperson_id, name, SUM(volume) as to...
1    SELECT equipment_type, SUM(maintenance_frequen...
2    SELECT COUNT(*) FROM marine_species WHERE loca...
3    SELECT trader_id, stock, SUM(price * quantity)...
4    SELECT type, cost FROM (SELECT type, cost, ROW...
5    SELECT SUM(spending) FROM defense.eu_humanitar...
6    SELECT SpeciesName, AVG(WaterTemp) as AvgTemp ...
7    DELETE FROM Program_Outcomes WHERE program_id ...
8    SELECT SUM(fare) FROM bus_routes WHERE route_n...
9    SELECT AVG(Property_Size) FROM Inclusive_Housi...
Name: sql, dtype: object, text_dict={0: 'SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;', 1: 'SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;', 2: "SELECT COUNT(*) FROM marine_species WH

100%|██████████| 10/10 [00:00<00:00, 25.35it/s]


Let's try to retrive

In [45]:
# 假设你已经加载了数据并取得了前 1 条 SQL 查询
xq = """
  SELECT    SUM(volume) AS total_volume,
            salesperson_id,
            name
  FROM      timber_sales
  JOIN      salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id
  GROUP BY  salesperson_id,
            name
  ORDER BY  total_volume DESC;
"""

# 计算 SQL 查询的 embeddings
xq_embeddings = pipe(xq)
xq_embeddings = np.array(xq_embeddings)

# 确保 embeddings 的形状为 (1, num_tokens, embedding_dim)
print(f"Shape of embeddings: {xq_embeddings.shape}")  # 形状应该是 (1, 63, 768)

# 对所有 token embeddings 取平均，得到句子的 embedding
xq_sentence_embedding = xq_embeddings.mean(axis=1)  # 对维度 1 取平均，得到 (1, 768)

# 确认句子的 embedding 形状，应该是 (768,)
print(f"Shape of sentence embedding: {xq_sentence_embedding.shape}")  # 应该是 (768,)

# 将句子的 embedding 转换为 (1, 768) 形式
xq_sentence_vector = xq_sentence_embedding.astype('float32')  # 不需要 expand_dims，因为它已经是二维的

# 确保索引的维度与 embedding 的维度匹配
print(f"Shape of query vector for FAISS: {xq_sentence_vector.shape}")  # 应该是 (1, 768)

# 查询索引，返回最相似的 5 个向量
D, I = index.search(xq_sentence_vector, 5)

print(f"query={xq}")
# 打印最相似的向量的原文
for idx in I[0]:
    print(f"{text_dict[idx]}")

Shape of embeddings: (1, 125, 768)
Shape of sentence embedding: (1, 768)
Shape of query vector for FAISS: (1, 768)
query=
  SELECT    SUM(volume) AS total_volume,
            salesperson_id,
            name
  FROM      timber_sales
  JOIN      salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id
  GROUP BY  salesperson_id,
            name
  ORDER BY  total_volume DESC;

SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;
SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;
SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY SpeciesName;
SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenan

## Now Try Clustering

In [None]:
import faiss
import numpy as np
from tqdm import tqdm


# 设置聚类的簇数
k = 10  # 假设我们想要10个簇
d = 768  # 向量维度

top_sql_queries = df.head(1000)["sql"]  # 假设字段名称是 sql_query

clustering = faiss.Clustering(d, k)

# 设置聚类的参数（例如，最大迭代次数）
clustering.niter = 20  # 设置最大迭代次数
clustering.max_points_per_centroid = 1000  # 每个质心的最大点数
clustering.verbose = True  # 输出详细信息

# 使用一个字典来存储原文与其对应的索引
text_dict = {i: top_sql_queries[i] for i in range(len(top_sql_queries))}

index = faiss.IndexFlatL2(d)  # 使用 L2 距离

vectors = []
for sql_query in tqdm(top_sql_queries):
  embeddings = pipe(sql_query)
  # 转换嵌入为 NumPy 数组并计算平均值（句子的整体 embedding）
  embeddings = np.array(embeddings[0])  # 获取第一个元素，即每个 token 的 embedding
  sentence_embedding = embeddings.mean(axis=0)  # 对所有 token 的 embedding 取平均
  
  vector = np.expand_dims(sentence_embedding, axis=0).astype('float32')
  vectors.append(vector)

# 将所有向量堆叠成一个大的数组
vectors = np.vstack(vectors)  # (num_samples, d)

# 执行聚类训练
clustering.train(vectors, index)  # 进行聚类训练

# 获取聚类中心
centroids = faiss.vector_float_to_array(clustering.centroids)  # 获取聚类中心
print(f"centroids shape before reshape: {centroids.shape}")

# 检查是否可以 reshape
if centroids.shape[0] == k * d:
    centroids = centroids.reshape(k, d)  # reshape 成 k 个簇中心，每个簇中心是 d 维
    print(f"聚类中心：\n{centroids}")
else:
    print(f"聚类中心的形状不符合预期，实际大小为 {centroids.shape}")

# 添加向量到索引
index.add(vectors)  # 聚类后添加向量


100%|██████████| 1000/1000 [00:21<00:00, 46.21it/s]

Clustering 1000 points in 768D to 10 clusters, redo 1 times, 20 iterations
  Preprocessing in 0.00 s
centroids shape before reshape: (7680,) objective=5795.19 imbalance=1.227 nsplit=0       
聚类中心：
[[-0.18499173  0.22511667  0.13517778 ... -0.54662466 -0.41021362
   0.46190667]
 [-0.28051266  0.08221427  0.15986638 ... -0.44513115 -0.38535556
   0.4418672 ]
 [-0.29708812  0.06968737  0.14393795 ... -0.46478567 -0.41295123
   0.46540022]
 ...
 [-0.33051264  0.0879201   0.20769997 ... -0.47243038 -0.46141124
   0.5047192 ]
 [-0.33388335  0.08140443  0.16814938 ... -0.49328902 -0.46472928
   0.45314988]
 [-0.34249762  0.11567847  0.23195948 ... -0.5568997  -0.44863892
   0.46856833]]
  Iteration 19 (0.01 s, search 0.01 s): objective=5793.7 imbalance=1.222 nsplit=0       





In [None]:


# 获取每个簇中心对应的原文
centroid_to_text = {}  # 用来存储每个簇中心对应的原文

# 对每个簇中心，找到最接近的向量及其原文
for i in range(k):
    # 找到与第 i 个簇中心最接近的向量
    centroid_vector = np.expand_dims(centroids[i], axis=0).astype('float32')  # 当前簇中心
    distances, indices = index.search(centroid_vector, 1)  # 查找最近的一个向量
    
    # 获取该簇中心对应的原文
    closest_index = indices[0][0]  # 获取最近的向量索引
    closest_text = text_dict[closest_index]  # 获取该向量的原文
    
    # 将簇中心和对应的原文存储到字典中
    centroid_to_text[i] = closest_text

# 打印每个簇中心对应的原文
for cluster_id, text in centroid_to_text.items():
    print(f"簇 {cluster_id} 的原文：{text}")

簇 0 的原文：SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;
簇 1 的原文：SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;
簇 2 的原文：SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';
簇 3 的原文：SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;
簇 4 的原文：SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
簇 5 的原文：SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021;
簇 6 的原文：SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY