### Busca textual em bases de dados
#### text-to-sql

In [1]:
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import sqlite3

### Dados

In [2]:
df = pd.read_excel('dados/revendas_lpc_2023-02-05_2023-02-11.xlsx', skiprows=9)

In [3]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19302 entries, 0 to 19301
Data columns (total 15 columns):
 #   Column             Non-Null Count  Dtype         
---  ------             --------------  -----         
 0   CNPJ               19302 non-null  int64         
 1   RAZÃO              19302 non-null  object        
 2   FANTASIA           8184 non-null   object        
 3   ENDEREÇO           19302 non-null  object        
 4   NÚMERO             19298 non-null  object        
 5   COMPLEMENTO        4607 non-null   object        
 6   BAIRRO             19265 non-null  object        
 7   CEP                19302 non-null  int64         
 8   MUNICÍPIO          19302 non-null  object        
 9   ESTADO             19302 non-null  object        
 10  BANDEIRA           19302 non-null  object        
 11  PRODUTO            19302 non-null  object        
 12  UNIDADE DE MEDIDA  19302 non-null  object        
 13  PREÇO DE REVENDA   19302 non-null  float64       
 14  DATA D

In [4]:
df.head(3)

Unnamed: 0,CNPJ,RAZÃO,FANTASIA,ENDEREÇO,NÚMERO,COMPLEMENTO,BAIRRO,CEP,MUNICÍPIO,ESTADO,BANDEIRA,PRODUTO,UNIDADE DE MEDIDA,PREÇO DE REVENDA,DATA DA COLETA
0,61602199002409,COMPANHIA ULTRAGAZ S A,ULTRAGAZ,RUA AMARO CASTRO LIMA,1852,,VILA NOVA CAMPO GRANDE,79106361,CAMPO GRANDE,MATO GROSSO DO SUL,ULTRAGAZ,GLP,R$ / 13 kg,110.0,2023-02-07
1,61602199006587,COMPANHIA ULTRAGAZ S A,ULTRAGAZ,AVENIDA CAIRU,989,,NAVEGANTES,90230031,PORTO ALEGRE,RIO GRANDE DO SUL,ULTRAGAZ,GLP,R$ / 13 kg,105.0,2023-02-07
2,3188000121,COMPETRO COMERCIO E DISTRIBUICAO DE DERIVADOS ...,COMPETRO,RUA HUMBERTO DE CAMPOS,306,,JARDIM ZULMIRA,18061000,SOROCABA,SAO PAULO,BRANCA,ETANOL,R$ / litro,3.25,2023-02-06


In [5]:
df.shape

(19302, 15)

In [6]:
table_cols = ['CNPJ', 'COMPANY_NAME', 'FANTASY_NAME', 'ADDRESS', 'NUMBER', 'COMPLEMENT', 'NEIGHBORHOOD', 'ZIP_CODE', 
              'MUNICIPALITY', 'STATE', 'FLAG', 'PRODUCT', ' UNIT_OF_MEASUREMENT', 'RETAIL_PRICE', 'DATE_OF_COLLECTION']
df.columns = table_cols

In [7]:
table_name = 'revendas_lpc'
conn = sqlite3.connect(':memory:')
df.to_sql(name=table_name, con=conn)

19302

In [8]:
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) AS qtd FROM ' + table_name)
ret = cursor.fetchall()
print(ret[0])

(19302,)


### Modelo

In [9]:
# https://huggingface.co/juierror/flan-t5-text2sql-with-schema
tokenizer = AutoTokenizer.from_pretrained('juierror/text-to-sql-with-table-schema')
model = AutoModelForSeq2SeqLM.from_pretrained('juierror/text-to-sql-with-table-schema')

In [28]:
def prepare_input(question: str, table: List[str]):
    table_prefix = 'table:'
    question_prefix = 'question:'
    join_table = ','.join(table)
    inputs = f'{question_prefix} {question} {table_prefix} {join_table}'
    input_ids = tokenizer(inputs, max_length=700, return_tensors='pt').input_ids
    return input_ids

def inference(question: str, table: List[str]) -> str:
    input_data = prepare_input(question=question, table=table)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    result = result.replace('table', table_name)
    return result

#### Questões

In [44]:
def get_response(question, exec_query=False):
    resp_sql = inference(question=question, table=table_cols)
    print('SQL:', resp_sql)
    
    if exec_query:
        cursor.execute(resp_sql)
        ret = cursor.fetchall()
        print('Resposta:', ret)

In [45]:
get_response("return the 'ADDRESS' of company whose CNPJ is 61602199002409", True)

SQL: SELECT ADDRESS FROM revendas_lpc WHERE CNPJ = 61602199002409
Resposta: [('RUA AMARO CASTRO LIMA',)]


In [46]:
get_response("what is the RETAIL PRICE of CNPJ 61602199002409", True)

SQL: SELECT RETAIL_PRICE FROM revendas_lpc WHERE CNPJ = 61602199002409
Resposta: [(110.0,)]


In [49]:
get_response("return the average RETAIL PRICE when the FLAG is ULTRAGAZ")

SQL: SELECT AVG RETAIL_PRICE FROM revendas_lpc WHERE FLAG = ultragaz


In [51]:
cursor.execute("SELECT AVG(RETAIL_PRICE) FROM revendas_lpc WHERE FLAG = 'ULTRAGAZ'")
ret = cursor.fetchall()
print('Resposta:', ret)

Resposta: [(109.75725047080972,)]


In [52]:
conn.close()

### Conclusão