In [41]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load the fine-tuned model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("./fine_tuned_t5_sql")  # Path to your saved model
tokenizer = T5Tokenizer.from_pretrained("./fine_tuned_t5_sql")

In [42]:
question = 'What is the born_state of the head whos department has more than 100 num_employees'
schema = "{'department': ['Department_ID', 'Name', 'Creation', 'Ranking', 'Budget_in_Billions', 'Num_Employees'], 'head': ['head_ID', 'name', 'born_state', 'age'], 'management': ['department_ID', 'head_ID', 'temporary_acting']}"
sample_input =  f"Given the following SQL Schema: {schema}. Provide a SQL query reponse for: {question}"

input_ids = tokenizer.encode(
    sample_input,
    max_length=512,           # Ensure this matches the training max_length
    truncation=True,          # Truncate input if it's too long
    return_tensors="pt"       # Return as PyTorch tensors
)

model.eval()  # Set model to evaluation mode
output_ids = model.generate(
    input_ids,
    max_length=128,           # Set max_length for the output
    num_beams=5,              # Beam search for better results (can adjust as needed)
    temperature=1.0,          # Adjust temperature for randomness in output
    repetition_penalty=2.5,   # Penalize repetition
    early_stopping=True       # Stop early if the output sequence is complete
)

In [43]:
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print("Input:", sample_input)
print("Generated SQL Query:", output_text)


Input: Given the following SQL Schema: {'department': ['Department_ID', 'Name', 'Creation', 'Ranking', 'Budget_in_Billions', 'Num_Employees'], 'head': ['head_ID', 'name', 'born_state', 'age'], 'management': ['department_ID', 'head_ID', 'temporary_acting']}. Provide a SQL query reponse for: What is the born_state of the head whos department has more than 100 num_employees
Generated SQL Query: SELECT born_state FROM head WHERE department_id > 100 GROUP BY born_state HAVING count(*) > 100
