In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import time

# Verifica se a GPU está disponível
print(f"GPU disponível: {torch.cuda.is_available()}")

# Configurações do modelo e verificação da memória disponível
available_memory = torch.cuda.get_device_properties(0).total_memory
print(f"Memória disponível: {available_memory / 1e9} GB")

model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

# Prompt com schema do banco de dados e instruções
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you do not understand the question or cannot generate a valid SQL query based on the database schema, return 'Não entendi a pergunta. Pode fazê-la novamente, por gentileza.'
- Remember that certain strings are in Portuguese (e.g., 'Melhora', 'Óbito', 'Alta' for outcomes).
- Follow the database schema closely, and use the appropriate joins where necessary.

### Database Schema
This query will run on a database whose schema is represented in this string:

CREATE TABLE patients (
  id SERIAL PRIMARY KEY, -- Unique ID for each patient
  name VARCHAR(100), -- Patient's full name
  gender VARCHAR(10), -- Gender of the patient (Male, Female, Other)
  birthday DATE, -- Date of birth of the patient
  contact_number VARCHAR(20), -- Patient's contact number
  email VARCHAR(100), -- Patient's email address
  address VARCHAR(150), -- Residential address of the patient
  city_id INTEGER, -- City ID referencing city table
  enabled BOOLEAN -- Indicates if the patient is active in the system
);

CREATE TABLE admissions (
  id SERIAL PRIMARY KEY, -- Unique ID for each admission
  patient_id INTEGER, -- ID of the patient admitted
  admission_date TIMESTAMP, -- Date and time of admission
  discharge_date TIMESTAMP, -- Date and time of discharge
  admission_reason VARCHAR(200), -- Reason for admission
  hospital_id INTEGER, -- ID of the hospital where the patient was admitted
  admission_status VARCHAR(50), -- Status of the admission (e.g., "Admitted", "Discharged", "Transferred")
  treatment_outcome VARCHAR(100), -- Outcome of the treatment (e.g., "Melhora", "Óbito", "Alta", in Portuguese)
  complications VARCHAR(200), -- Complications encountered during treatment
  bed_id INTEGER, -- ID of the bed assigned during admission
  department_id INTEGER, -- Department where the patient was admitted
  cid_primary INTEGER, -- Primary diagnosis of the patient (references CID codes)
  PRIMARY KEY (id),
  FOREIGN KEY (patient_id) REFERENCES patients(id), -- Join with patients table
  FOREIGN KEY (hospital_id) REFERENCES hospitals(id), -- Join with hospitals table
  FOREIGN KEY (department_id) REFERENCES departments(id), -- Join with departments table
  FOREIGN KEY (cid_primary) REFERENCES cid_codes(id) -- Join with CID codes table
);

CREATE TABLE hospitals (
  id SERIAL PRIMARY KEY, -- Unique ID for each hospital
  name VARCHAR(100), -- Name of the hospital
  location VARCHAR(150), -- Address or location of the hospital
  phone_number VARCHAR(20) -- Contact number of the hospital
);

CREATE TABLE beds (
  id SERIAL PRIMARY KEY, -- Unique ID for each bed
  hospital_id INTEGER, -- ID of the hospital to which the bed belongs
  bed_type VARCHAR(50), -- Type of bed (e.g., "ICU", "General", "Berçário")
  bed_status VARCHAR(50), -- Availability status (e.g., "Ocupado", "Disponível", in Portuguese)
  accomodation_type VARCHAR(50), -- Type of accommodation (e.g., "Berçário", "Pediátrico")
  FOREIGN KEY (hospital_id) REFERENCES hospitals(id) -- Join with hospitals table
);

CREATE TABLE treatments (
  id SERIAL PRIMARY KEY, -- Unique ID for each treatment
  admission_id INTEGER, -- ID of the admission associated with the treatment
  treatment_description TEXT, -- Detailed description of the treatment
  start_date DATE, -- Start date of the treatment
  end_date DATE, -- End date of the treatment
  professional_id INTEGER, -- ID of the doctor responsible for the treatment
  FOREIGN KEY (admission_id) REFERENCES admissions(id), -- Join with admissions table
  FOREIGN KEY (professional_id) REFERENCES users(id) -- Join with users table
);

CREATE TABLE users (
  id SERIAL PRIMARY KEY, -- Unique ID for each user in the system
  name VARCHAR(100), -- Name of the user (doctor, nurse, or staff)
  role VARCHAR(50), -- Role in the hospital (e.g., "doctor", "nurse", "admin")
  email VARCHAR(100), -- Email address of the user
  enabled BOOLEAN, -- Indicates if the user is active in the system
  chat_enabled BOOLEAN -- Indicates if the user has access to chat functionality
);

CREATE TABLE departments (
  id SERIAL PRIMARY KEY, -- Unique ID for each department
  name VARCHAR(100), -- Name of the department (e.g., "Cardiologia", "Pediatria")
  description TEXT -- Description of the department
);

CREATE TABLE medications (
  id SERIAL PRIMARY KEY, -- Unique ID for each medication
  name VARCHAR(100), -- Name of the medication
  dosage VARCHAR(50), -- Dosage of the medication (e.g., "500mg", "2ml")
  administration_method VARCHAR(50), -- Method of administration (e.g., "Oral", "Intravenous")
);

CREATE TABLE treatment_medications (
  id SERIAL PRIMARY KEY, -- Unique ID for each treatment-medication association
  treatment_id INTEGER, -- ID of the treatment associated
  medication_id INTEGER, -- ID of the medication used
  dosage VARCHAR(50), -- Dosage administered
  frequency VARCHAR(50), -- Frequency of administration (e.g., "Once daily", "Twice daily")
  start_date TIMESTAMP, -- Start date of medication administration
  end_date TIMESTAMP, -- End date of medication administration
  FOREIGN KEY (treatment_id) REFERENCES treatments(id), -- Join with treatments table
  FOREIGN KEY (medication_id) REFERENCES medications(id) -- Join with medications table
);

CREATE TABLE exams (
  id SERIAL PRIMARY KEY, -- Unique ID for each exam
  admission_id INTEGER, -- ID of the admission associated with the exam
  exam_type VARCHAR(100), -- Type of exam (e.g., "ECG", "Hemograma Completo", "Tomografia Computadorizada")
  result TEXT, -- Result of the exam (e.g., "Conduta a Ser Avaliada", in Portuguese)
  exam_date TIMESTAMP, -- Date of the exam
  professional_id INTEGER, -- ID of the professional who conducted the exam
  FOREIGN KEY (admission_id) REFERENCES admissions(id), -- Join with admissions table
  FOREIGN KEY (professional_id) REFERENCES users(id) -- Join with users table
);

CREATE TABLE cid_codes (
  id SERIAL PRIMARY KEY, -- Unique ID for each CID code
  code VARCHAR(10), -- CID code (e.g., "A41", "J18", "I50")
  description TEXT -- Description of the CID (e.g., "Septicemia", "Infarto agudo do miocárdio", "Diabetes mellitus não insulinodependente")
);

/* CID codes used in the system:
A41    Septicemia
J18    Pneumonia não especificada
I21    Infarto agudo do miocárdio
I50    Insuficiência cardíaca
N17    Insuficiência renal aguda
K35    Apendicite aguda
K57    Doença diverticular do intestino
N39    Infecção do trato urinário
J45    Asma aguda severa
E11    Diabetes mellitus não insulinodependente com complicações
*/

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""


# Função para gerar query SQL com tempo de execução
def generate_query_with_time(question):
    prompt = prompt_template.format(question=question)
    start_time = time.time()

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )

    end_time = time.time()
    elapsed_time = end_time - start_time

    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    torch.cuda.empty_cache()

    if "[SQL]" in outputs[0]:
        sql_query = outputs[0].split("[SQL]")[-1].strip()
        formatted_sql = sqlparse.format(sql_query, reindent=True)
        return formatted_sql, elapsed_time
    else:
        return "Não entendi a pergunta. Pode fazê-la novamente, por gentileza.", elapsed_time

# Lista de perguntas
questions = [
    "How many patients were admitted in the last month?",
    "Which patients were hospitalized more than once in the last year?",
    "How many admissions had complications during treatment?",
    "Which admissions resulted in death in the last quarter?",
    "What is the bed occupancy rate at 'São José' hospital?",
    "What are the most common ICDs among patients admitted to the Cardiology department?",
    "How many patients under 12 years old were admitted to the Pediatrics department?",
    "What is the average length of stay per patient?",
    "What are the main reasons for admission at Luz hospital?",
    "Which patients were discharged after improvement following antibiotic treatment?",
    "What is the hospital readmission rate within 30 days after discharge?",
    "What were the top 5 ICDs for admissions in the last semester?",
    "What are the 3 most prescribed medications for patients with heart failure?",
    "How many imaging exams were performed in the last 6 months?",
    "Which exams resulted in a conduct to be evaluated?",
    "What is the number of patients admitted through emergency and scheduled consultations?",
    "Who are the 10 patients with the most admissions in the last year?",
    "Which doctors performed the most treatments this month?",
    "Which admissions occurred at the 'Universitário de Brasília' hospital?",
    "How many patients were discharged with complications during hospitalization?",
    "How many pediatric beds are currently occupied?",
    "Which patients have a history of hypertension and were admitted with complications?",
    "What treatments were applied to patients with diabetes mellitus?",
    "Which hospitals had the most admissions in the Emergency department?",
    "Which admissions resulted in transfers to another hospital in the last quarter?"
]

# Loop para processar cada pergunta e medir o tempo de conversão
for i, question in enumerate(questions):
    print(f"Pergunta {i+1}: {question}")
    generated_sql, conversion_time = generate_query_with_time(question)
    print("SQL gerada:\n", generated_sql)
    print(f"Tempo de conversão: {conversion_time:.2f} segundos\n")


GPU disponível: True
Memória disponível: 23.802544128 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Pergunta 1: How many patients were admitted in the last month?
SQL gerada:
 SELECT COUNT(DISTINCT a.patient_id) AS number_of_patients
FROM admissions a
WHERE a.admission_date >= (CURRENT_DATE - interval '1 month');
Tempo de conversão: 58.40 segundos

Pergunta 2: Which patients were hospitalized more than once in the last year?
SQL gerada:
 SELECT p.name,
       COUNT(a.id) AS admission_count
FROM patients p
JOIN admissions a ON p.id = a.patient_id
WHERE a.admission_date >= (CURRENT_DATE - interval '1 year')
GROUP BY p.name
HAVING COUNT(a.id) > 1
ORDER BY admission_count DESC NULLS LAST;
Tempo de conversão: 102.03 segundos

Pergunta 3: How many admissions had complications during treatment?
SQL gerada:
 SELECT COUNT(a.id)
FROM admissions a
WHERE a.complications IS NOT NULL
  AND a.complications != '';
Tempo de conversão: 33.75 segundos

Pergunta 4: Which admissions resulted in death in the last quarter?
SQL gerada:
 SELECT a.id,
       a.patient_id,
       a.admission_date,
       a.dis