In [None]:
# Ajuste de codificação UTF-8 para evitar erros no Colab
import locale
locale.getpreferredencoding = lambda: 'UTF-8'

# Instalação das bibliotecas necessárias
!pip install transformers==4.35.2

# Importação das bibliotecas e configuração do modelo
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sqlparse  # Importar sqlparse para formatação de SQL
import time  # Importar biblioteca time para medir tempo de execução

# Carregar o modelo e tokenizer
tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b")
model = AutoModelForCausalLM.from_pretrained(
    "chatdb/natural-sql-7b",
    device_map="auto",
    torch_dtype=torch.float16,
)

# Verificar se a GPU está disponível
print(f"GPU disponível: {torch.cuda.is_available()}")

# Definição do prompt específico para o modelo Natural-SQL-7b
prompt_template = """
# Task
Generate a SQL query to answer the following question: `{question}`

### PostgreSQL Database Schema
The query will run on a database with the following schema:

CREATE TABLE patients (
  id SERIAL PRIMARY KEY,
  name VARCHAR(100),
  gender VARCHAR(10),
  birthday DATE,
  contact_number VARCHAR(20),
  email VARCHAR(100),
  address VARCHAR(150),
  city_id INTEGER,
  enabled BOOLEAN
);

CREATE TABLE admissions (
  id SERIAL PRIMARY KEY,
  patient_id INTEGER,
  admission_date TIMESTAMP,
  discharge_date TIMESTAMP,
  admission_reason VARCHAR(200),
  hospital_id INTEGER,
  admission_status VARCHAR(50),
  treatment_outcome VARCHAR(100),
  complications VARCHAR(200),
  bed_id INTEGER,
  department_id INTEGER,
  cid_primary INTEGER,
  FOREIGN KEY (patient_id) REFERENCES patients(id),
  FOREIGN KEY (hospital_id) REFERENCES hospitals(id),
  FOREIGN KEY (department_id) REFERENCES departments(id),
  FOREIGN KEY (cid_primary) REFERENCES cid_codes(id)
);

CREATE TABLE hospitals (
  id SERIAL PRIMARY KEY,
  name VARCHAR(100),
  location VARCHAR(150),
  phone_number VARCHAR(20)
);

CREATE TABLE beds (
  id SERIAL PRIMARY KEY,
  hospital_id INTEGER,
  bed_type VARCHAR(50),
  availability_status VARCHAR(50),
  accomodation_type VARCHAR(50),
  FOREIGN KEY (hospital_id) REFERENCES hospitals(id)
);

CREATE TABLE treatments (
  id SERIAL PRIMARY KEY,
  admission_id INTEGER,
  treatment_description TEXT,
  start_date DATE,
  end_date DATE,
  doctor_id INTEGER,
  FOREIGN KEY (admission_id) REFERENCES admissions(id),
  FOREIGN KEY (doctor_id) REFERENCES users(id)
);

CREATE TABLE users (
  id SERIAL PRIMARY KEY,
  name VARCHAR(100),
  role VARCHAR(50),
  email VARCHAR(100),
  enabled BOOLEAN,
  chat_enabled BOOLEAN
);

CREATE TABLE exams (
  id SERIAL PRIMARY KEY,
  admission_id INTEGER,
  exam_type VARCHAR(100),
  exam_date DATE,
  results TEXT,
  FOREIGN KEY (admission_id) REFERENCES admissions(id)
);

CREATE TABLE patient_history (
  id SERIAL PRIMARY KEY,
  patient_id INTEGER,
  history_description TEXT,
  date_recorded DATE,
  FOREIGN KEY (patient_id) REFERENCES patients(id)
);

CREATE TABLE departments (
  id SERIAL PRIMARY KEY,
  name VARCHAR(100),
  description TEXT
);

CREATE TABLE medications (
  id SERIAL PRIMARY KEY,
  name VARCHAR(100),
  dosage VARCHAR(50),
  administration_method VARCHAR(50)
);

CREATE TABLE treatment_medications (
  id SERIAL PRIMARY KEY,
  treatment_id INTEGER,
  medication_id INTEGER,
  dosage VARCHAR(50),
  frequency VARCHAR(50),
  start_date TIMESTAMP,
  end_date TIMESTAMP,
  FOREIGN KEY (treatment_id) REFERENCES treatments(id),
  FOREIGN KEY (medication_id) REFERENCES medications(id)
);

CREATE TABLE cid_codes (
  id SERIAL PRIMARY KEY,
  code VARCHAR(10),
  description TEXT
);

### SQL
Here is the SQL query that answers the question: `{question}`
'''sql
"""

# Função para gerar a consulta SQL e medir o tempo de conversão
def generate_natural_sql_query(question):
    start_time = time.time()  # Iniciar contagem do tempo
    prompt = prompt_template.format(question=question)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=100001,
        pad_token_id=100001,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    end_time = time.time()  # Finalizar contagem do tempo
    elapsed_time = end_time - start_time  # Calcular tempo total

    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return outputs[0].split("'''sql")[-1].strip(), elapsed_time

# Lista das 25 perguntas em inglês para teste
questions = [
    "How many patients were admitted in the last month?",
    "Which patients were hospitalized more than once in the last year?",
    "How many admissions had complications during treatment?",
    "Which admissions resulted in death in the last quarter?",
    "What is the bed occupancy rate at 'São José' hospital?",
    "What are the most common ICDs among patients admitted to the Cardiology department?",
    "How many patients under 12 years old were admitted to the Pediatrics department?",
    "What is the average length of stay per patient?",
    "What are the main reasons for admission at the Luz hospital?",
    "Which patients were discharged after improvement following a treatment with antibiotics?",
    "What is the hospital readmission rate within 30 days after discharge?",
    "What were the top 5 ICDs for admissions in the last semester?",
    "What are the 3 most prescribed medications for patients with heart failure?",
    "How many imaging exams were performed in the last 6 months?",
    "Which exams resulted in a conduct to be evaluated?",
    "What is the number of patients admitted through emergency and scheduled consultations?",
    "Who are the 10 patients with the most admissions in the past year?",
    "Which doctors performed the most treatments this month?",
    "Which were the admissions that occurred at the 'Universitário de Brasília' hospital?",
    "How many patients were discharged with complications during hospitalization?",
    "How many pediatric beds are currently occupied?",
    "Which patients have a history of hypertension and were admitted with complications?",
    "What treatments were applied to patients with diabetes mellitus?",
    "Which hospitals had the most admissions in the Emergency department?",
    "Which admissions resulted in transfers to another hospital in the last quarter?"
]

# Gerar consultas SQL para cada pergunta e medir o tempo de conversão
for i, question in enumerate(questions):
    print(f"Pergunta {i+1}: {question}")
    generated_sql, conversion_time = generate_natural_sql_query(question)
    print("SQL gerada:\n", sqlparse.format(generated_sql, reindent=True))
    print(f"Tempo de conversão: {conversion_time:.2f} segundos\n")


Collecting transformers==4.35.2
  Downloading transformers-4.35.2-py3-none-any.whl.metadata (123 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/123.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.5/123.5 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.19,>=0.14 (from transformers==4.35.2)
  Downloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m107.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m52.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokeniz

  _torch_pytree._register_pytree_node(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/4.29k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


model.safetensors.index.json:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.85G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

GPU disponível: True
Pergunta 1: How many patients were admitted in the last month?




SQL gerada:
 SELECT COUNT(*)
FROM admissions
WHERE admission_date >= NOW() - INTERVAL '1 month';
Tempo de conversão: 4.35 segundos

Pergunta 2: Which patients were hospitalized more than once in the last year?
SQL gerada:
 SELECT p.id,
       p.name
FROM patients p
JOIN admissions a ON p.id = a.patient_id
WHERE a.admission_date > CURRENT_DATE - INTERVAL '1 year'
GROUP BY p.id
HAVING COUNT(a.id) > 1;
Tempo de conversão: 4.19 segundos

Pergunta 3: How many admissions had complications during treatment?
SQL gerada:
 SELECT COUNT(*)
FROM admissions
WHERE complications IS NOT NULL;
Tempo de conversão: 1.07 segundos

Pergunta 4: Which admissions resulted in death in the last quarter?
SQL gerada:
 SELECT *
FROM admissions
WHERE treatment_outcome ILIKE '%death%'
  AND admission_date >= date_trunc('quarter', CURRENT_DATE) - INTERVAL '1 quarter'
  AND admission_date < date_trunc('quarter', CURRENT_DATE);
Tempo de conversão: 3.66 segundos

Pergunta 5: What is the bed occupancy rate at 'São José' 