In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

    # if you have atleast 15GB of GPU memory, run load the model in float16
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )


In [None]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'Check EspnCricInfo'
- Remember that battingaverage is sum of Runs divided by number of matches
- Remember that Century is Runs greater than or equal to 100


### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE Scores (
  MatchID INTEGER PRIMARY KEY, -- Unique ID for each Match
  Opposition VARCHAR(50), -- Name of cricket team
  Innings INTEGER, -- batted first or second
  Runs INTEGER  -- Runs Scored in the match
);


"""

In [None]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
question = "Total Runs Scored against each opposition"
generated_sql = generate_query(question)
print(generated_sql)

### Task Generate a SQL query to answer [QUESTION]Total Runs Scored against each opposition[/QUESTION] ### Instructions - If you cannot answer the question with the available database schema,
                                                                                                                                                                                        return 'Check EspnCricInfo' - Remember that battingaverage is sum of Runs divided by number of matches - Remember that Century is Runs greater than
or equal to 100 ### Database Schema This query will run on a database whose schema is represented in this string:
CREATE TABLE Scores (MatchID INTEGER PRIMARY KEY, -- Unique ID for each Match
 Opposition VARCHAR(50), -- Name of cricket team
 Innings INTEGER, -- batted first or second
 Runs INTEGER -- Runs Scored in the match
);


SELECT s.Opposition,
       SUM(s.Runs) AS total_runs,
       COUNT(s.MatchID) AS matches_played,
       AVG(s.Runs) AS batting_average,
       

In [None]:
question = "Runs scored during First Innings"
generated_sql = generate_query(question)
print(generated_sql)

### Task Generate a SQL query to answer [QUESTION]Runs scored during First Innings[/QUESTION] ### Instructions - If you cannot answer the question with the available database schema,
                                                                                                                                                                               return 'Check EspnCricInfo' - Remember that battingaverage is sum of Runs divided by number of matches - Remember that Century is Runs greater than
or equal to 100 ### Database Schema This query will run on a database whose schema is represented in this string:
CREATE TABLE Scores (MatchID INTEGER PRIMARY KEY, -- Unique ID for each Match
 Opposition VARCHAR(50), -- Name of cricket team
 Innings INTEGER, -- batted first or second
 Runs INTEGER -- Runs Scored in the match
);


SELECT SUM(s.Runs) AS total_runs
FROM Scores s
WHERE s.Innings = 1;


In [None]:
question = "BattingAverage"
generated_sql = generate_query(question)
print(generated_sql)

### Task Generate a SQL query to answer [QUESTION]BattingAverage[/QUESTION] ### Instructions - If you cannot answer the question with the available database schema,
                                                                                                                                                             return 'Check EspnCricInfo' - Remember that battingaverage is sum of Runs divided by number of matches - Remember that Century is Runs greater than
or equal to 100 ### Database Schema This query will run on a database whose schema is represented in this string:
CREATE TABLE Scores (MatchID INTEGER PRIMARY KEY, -- Unique ID for each Match
 Opposition VARCHAR(50), -- Name of cricket team
 Innings INTEGER, -- batted first or second
 Runs INTEGER -- Runs Scored in the match
);


SELECT CAST(SUM(s.Runs) AS FLOAT) / NULLIF(COUNT(s.MatchID), 0) AS BattingAverage
FROM Scores s;


In [None]:
question = "List of Matches where scores was century"
generated_sql = generate_query(question)
print(generated_sql)

### Task Generate a SQL query to answer [QUESTION]List of Matches
where scores was century[/QUESTION] ### Instructions - If you cannot answer the question with the available database schema,
                                                                                                                     return 'Check EspnCricInfo' - Remember that battingaverage is sum of Runs divided by number of matches - Remember that Century is Runs greater than
  or equal to 100 ### Database Schema This query will run on a database whose schema is represented in this string:
  CREATE TABLE Scores (MatchID INTEGER PRIMARY KEY, -- Unique ID for each Match
 Opposition VARCHAR(50), -- Name of cricket team
 Innings INTEGER, -- batted first or second
 Runs INTEGER -- Runs Scored in the match
);


SELECT s.MatchID,
       s.Opposition,
       s.Innings,
       s.Runs
FROM Scores s
WHERE s.Runs >= 100
