<a href="https://colab.research.google.com/github/Kira1108/SQL-DIE/blob/main/AOZ_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from IPython.display import clear_output
!pip install langchain
!pip install openai
!git clone https://github.com/Kira1108/simple-chat.git
!pip install simple-chat
!pip install tiktoken
!pip install transformers accelerate bitsandbytes sentencepiece
!pip install datasets
clear_output()

In [3]:
import os
import torch
from transformers import AutoTokenizer, AutoModel
from dataclasses import dataclass
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [4]:
@dataclass
class SentenceEmbedder:

    model_ckpt:str = "sentence-transformers/all-MiniLM-L6-v2"
    normalize:bool = False

    def __post_init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_ckpt)
        self.model = AutoModel.from_pretrained(self.model_ckpt)

    def tokenize(self, sentences):
        return self.tokenizer(sentences, padding = True, truncation = True, return_tensors = 'pt')

    def mean_pooling(self, model_output, attention_mask):

        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min = 1e-9)


    def __call__(self, sentences):
        encoder_input = self.tokenize(sentences)

        with torch.no_grad():
            model_output = self.model(**encoder_input)

        sentence_embeddings = self.mean_pooling(model_output,encoder_input['attention_mask'] )
        if self.normalize:
            sentence_embeddings = F.normalize(sentence_embeddings, p = 2, dim = 1)
        return sentence_embeddings.numpy()

In [5]:
df = pd.read_csv("/content/drive/MyDrive/AI/aoz_redshift_meta.csv")

In [6]:
ddls = df.ddl.tolist()
embedder = SentenceEmbedder()
ddl_embeddings = embedder(ddls)
embds = [ele for ele in ddl_embeddings]
df['embedding'] = embds

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

In [7]:
def find_table(query_string:str, n_recommends = 3):
    query_embedding = embedder([query_string])
    idx = np.argsort(cosine_similarity(query_embedding, ddl_embeddings))[0][::-1][:n_recommends]
    return df.iloc[idx].table.tolist()

## 表搜索

In [8]:
find_table("a table that contains user payment information",5)

['dws_user_prop_payment_info',
 'dws_pay_user_package_tag1_2_payment_daily_di',
 'dws_pay_user_package_payment_daily_di',
 'dws_pay_user_package_tag1_2_3_payment_daily_di',
 'dws_pay_user_payment_daily_di']

In [9]:
find_table("a table that contains user login information",5)

['dwd_login_record',
 'dws_user_login_info',
 'dwd_user_info_ext',
 'dws_user_retention_info',
 'dwd_sso_create_user']

In [10]:
find_table("a table that contains user battle information",5)

['dwd_operate_battle_log',
 'app_diamond_report',
 'dwd_appsflyer_daily_cost',
 'app_market_report_daily',
 'dwd_operation_cheater_monitor']

In [11]:
find_table("a table that contains appsflyer marketing information",5)

['dwd_appsflyer_daily_cost',
 'dwd_af_info_di',
 'app_market_report_daily',
 'dwd_bus_data_brand',
 'dwd_appsflyer_info_new']

In [12]:
find_table("a table that contains ROI information",5)

['app_roi_cn_af_daily',
 'app_roi_cn_af_weekly',
 'app_roi_cn_af_monthly',
 'app_roi_overseas_af_daily',
 'app_roi_overseas_af_daily_wide']

In [13]:
find_table("a table that diamond information",5)

['app_diamond_report',
 'dwd_diamond_info',
 'dim_diamond_info',
 'app_market_report_daily',
 'dim_user_payment_info']

In [14]:
find_table("a table that payment rewards, such as gid information",5)

['dws_pay_user_package_tag1_2_3_payment_daily_di',
 'dws_pay_user_package_payment_daily_di',
 'dws_pay_user_package_tag1_2_payment_daily_di',
 'dws_pay_user_payment_daily_di',
 'dws_user_prop_payment_info']

In [15]:
find_table("a table contains consume information",5)

['dws_pay_user_package_payment_daily_di',
 'dim_user_payment_info',
 'dws_pay_user_package_tag1_2_3_payment_daily_di',
 'dws_pay_user_payment_daily_di',
 'dws_pay_user_prop_payment_daily_di']

In [16]:
find_table("A table contains server open, close, offline information, dim layer",10)

['dim_server_info',
 'dwd_server_loop_close_time',
 'dim_user_info',
 'dim_user_payment_info',
 'dwd_operate_res_log',
 'dim_prop_info',
 'dim_diamond_info',
 'dwd_operate_troop_log',
 'app_market_report_daily',
 'dwd_server_loop_open_time']

In [17]:
find_table("A table describing user, and contains a column named user_type, and simple_type, dim layer",10)

['dim_user_payment_info',
 'dim_user_info',
 'app_diamond_report',
 'dim_diamond_info',
 'app_register_retention',
 'dwd_appsflyer_daily_cost',
 'dim_prop_info',
 'app_pet_report',
 'dwd_zombie_pet_info',
 'dim_package_tag']

In [18]:
find_table("A table contain user daily payment, probabaly has a date column",10)

['dws_pay_user_package_tag1_2_payment_daily_di',
 'app_user_pay',
 'dws_pay_user_package_payment_daily_di',
 'dws_pay_user_package_tag1_payment_daily_di',
 'dws_pay_user_package_tag1_2_3_payment_daily_di',
 'dws_pay_user_payment_daily_di',
 'dws_pay_user_prop_payment_daily_di',
 'app_newpay_user_cn_af_daily',
 'dim_user_payment_info',
 'dws_user_retention_info']

In [21]:
import torch
from transformers import T5Tokenizer,T5ForConditionalGeneration
import multiprocessing

device = "cuda:0" if torch.cuda.is_available() else "cpu"
cpu_cores = multiprocessing.cpu_count()


print("CPU Cores: {}".format(cpu_cores))
print("Using Device: {}".format(device))

CPU Cores: 12
Using Device: cuda:0


## FLAN T5看不懂代码，垃圾

In [23]:
checkpoint = "google/flan-t5-xl"

tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint, device_map = "auto")
clear_output()

In [34]:
from IPython.display import HTML, display

def chat(texts):
    input_ids = tokenizer(
        texts, return_tensors='pt'
    ).input_ids.to("cuda")

    outputs = model.generate(
        input_ids, 
        min_length = 100,
        max_new_tokens = 600,
        length_penalty = 0.2,
        num_beams = 6,
        no_repeat_ngram_size = 3,
        temperature = 0.9,
        top_k = 150,
        top_p = 0.92,
        repetition_penalty = 2.1
    )

    out_texts = tokenizer.decode(outputs[0], skip_special_tokens = True)
    display(HTML(out_texts))

In [37]:
ddl_sqls = df.ddl.tolist()
tables = df.table.tolist()

for table, create_sql in zip(tables, ddl_sqls):

    chat(f"""
    The following is a create table sql, Please explain the sql in natural language,
    sql = {create_sql}

    """)