In [1]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM,AutoTokenizer

config = PeftConfig.from_pretrained("Rajan/training_run")
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-350M")
base_model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-350M")
model = PeftModel.from_pretrained(base_model, "Rajan/training_run").to("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [20]:

text = """CREATE TABLE departments (
    id integer,
    name character varying,
    location character varying
)

CREATE TABLE employees (
    id integer,
    salary integer,
    department_id integer,
    name character varying,
    position character varying
)


-- Using valid PostgreSQL, answer the following questions for the tables provided above.

-- Give me name and salary of employees whose salary is greater than 50000 and department is Account.
SELECT"""

input_ids = tokenizer(text, return_tensors="pt").input_ids

generated_ids = model.generate(input_ids.cuda(), max_length=1024)
opt = tokenizer.decode(generated_ids[0], skip_special_tokens=True)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [25]:
def extract_and_correct_sql(text, correct = False):
    lines = text.splitlines()

    start_index = 0
    for i, line in enumerate(lines):
        if line.strip().upper().startswith('SELECT'):
            start_index = i
            break

    generated_sql = "\n".join(lines[start_index:])

    if correct:
        if not generated_sql.strip().endswith(';'):
            generated_sql = generated_sql.strip() + ';'

    return generated_sql

In [26]:
extract_and_correct_sql(opt)

'SELECT name, salary FROM employees WHERE salary > 50000 INTERSECT SELECT name, salary FROM employees WHERE department_id = 1'

In [5]:
import re 
class SQLExtractor:
    def __init__(self, text: str):
        self.text = text

    def extract_select_commands(self):
        # Regular expression to find all SELECT commands
        select_pattern = re.compile(r"SELECT\s.*?;", re.IGNORECASE | re.DOTALL)
        select_commands = select_pattern.findall(self.text)
        return select_commands


In [6]:
ext = SQLExtractor(text=opt)

In [7]:
ext.extract_select_commands()

[]

In [9]:
print(opt)

CREATE TABLE departments (
    id integer,
    name character varying,
    location character varying
)

CREATE TABLE employees (
    id integer,
    salary integer,
    department_id integer,
    name character varying,
    position character varying
)


-- Using valid PostgreSQL, answer the following questions for the tables provided above.

-- Give me name and salary of employees whose salary is greater than 50000.
SELECT name, salary FROM employees WHERE salary > 50000


In [14]:
import re

class SQLExtractor:
    def __init__(self, text: str):
        self.text = text

    def extract_select_commands(self):
        # Regular expression to find all SELECT commands ending with a semicolon
        select_pattern = re.compile(r"SELECT\s.*?;", re.IGNORECASE | re.DOTALL)
        select_commands = select_pattern.findall(self.text)
        return select_commands

# Example usage
sql_text = """
CREATE TABLE departments (
    id integer,
    name character varying,
    location character varying
);

CREATE TABLE employees (
    id integer,
    salary integer,
    department_id integer,
    name character varying,
    position character varying
);

-- Using valid PostgreSQL, answer the following questions for the tables provided above.

-- Give me name and salary of employees whose salary is greater than 50000.
SELECT name, salary FROM employees WHERE salary > 50000;
"""

extractor = SQLExtractor(sql_text)
select_commands = extractor.extract_select_commands()
for command in select_commands:
    print(command)


SELECT name, salary FROM employees WHERE salary > 50000;


In [12]:
import re

class SQLExtractor:
    def __init__(self, text: str):
        self.text = text

    def extract_select_commands(self):
        # Regular expression to find all SELECT commands
        select_pattern = re.compile(r"SELECT\s.*?(?=(SELECT|CREATE|\Z))", re.IGNORECASE | re.DOTALL)
        select_commands = select_pattern.findall(self.text)
        return [command.strip() for command in select_commands]

# Example usage
sql_text = """
CREATE TABLE departments (
    id integer,
    name character varying,
    location character varying
);

CREATE TABLE employees (
    id integer,
    salary integer,
    department_id integer,
    name character varying,
    position character varying
);

-- Using valid PostgreSQL, answer the following questions for the tables provided above.

-- Give me name and salary of employees whose salary is greater than 50000.
SELECT name, salary FROM employees WHERE salary > 50000
"""

extractor = SQLExtractor(sql_text)
select_commands = extractor.extract_select_commands()
for command in select_commands:
    print(command)



