In [98]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import sqlite3

#Loading the BERT model and tokenizer.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')


# A function to predict the type of question.
def predict_question_type(question):
    inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    return predicted_class



# Creating a sample database with SQLite.
def create_database():
    conn = sqlite3.connect('students.db')
    c = conn.cursor()
    c.execute('''CREATE TABLE IF NOT EXISTS students 
                 (id INTEGER PRIMARY KEY, name TEXT, age INTEGER, grade TEXT)''')
    # Sample data.
    sample_data = [
        (1, 'Alice', 20, 'A'),
        (2, 'Bob', 21, 'B'),
        (3, 'Charlie', 22, 'A'),
        (4, 'David', 23, 'C'),
        (5, 'Eva', 20, 'B'),
    ]
    c.executemany("INSERT OR REPLACE INTO students VALUES (?, ?, ?, ?)", sample_data)
    conn.commit()
    conn.close()


# A function to convert a question into SQL.
def text_to_sql(question):
    question = question.lower()
    
    if "name" in question and "id" in question:
        return "SELECT name FROM students WHERE id = ?;"
    elif "age" in question and "of" in question:
        return "SELECT age FROM students WHERE name = ?;"
    elif "grade" in question:
        return "SELECT grade FROM students WHERE name = ?;"
    elif "show all students" in question:
        return "SELECT * FROM students;"
    else:
        return None



# Processing questions and displaying answers.
for question in questions:
    
    # Converting 's to a comprehensible form.
    if "'s" in question:
        question = question.replace("'s", " of")

    sql_query = text_to_sql(question)

    # If the question is general, it does not require a parameter.
    if sql_query and "WHERE" not in sql_query:
        param = None
    elif "id" in question.lower():
        try:
            param = int(question.split("ID ")[1].strip("?"))
        except (IndexError, ValueError):
            print(f"⚠️ Error in processing the ID in the question.: {question}")
            param = None
    elif "of" in question.lower():
        try:
            param = question.split("of ")[1].strip("?")
        except IndexError:
            print(f"⚠️  Error in processing the Name in the question.: {question}")
            param = None
    else:
        param = None

    if sql_query:
        print(f"Question: {question}")
        print(f"Generated SQL: {sql_query}")
        result = execute_sql(sql_query, param)
        print(f"Result: {result}\n")
    else:
        print(f"⚠️  The question was not processed correctly.: {question}\n")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Question: What is the name of the student with ID 5?
Generated SQL: SELECT name FROM students WHERE id = ?;
Result: [('Eva',)]

Question: What is the age of Alice?
Generated SQL: SELECT age FROM students WHERE name = ?;
Result: [(20,)]

Question: What is Bob of grade?
Generated SQL: SELECT grade FROM students WHERE name = ?;
Result: []

Question: Show all students.
Generated SQL: SELECT * FROM students;
Result: [(1, 'Alice', 20, 'A'), (2, 'Bob', 21, 'B'), (3, 'Charlie', 22, 'A'), (4, 'David', 23, 'C'), (5, 'Eva', 20, 'B')]

