In [1]:
import torch
from peft import PeftModel, PeftConfig
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, pipeline

USE_FINETUNED = True

model_id = "/home/ksaff/Desktop/ttyd/api/model/snapshots/82128714b6174570a64b3dd1f3e9c146bda26cf9"
peft_model_id = "/home/ksaff/Desktop/ttyd/api/fine-tuned_model"
config = PeftConfig.from_pretrained(peft_model_id)

bnb_config = BitsAndBytesConfig(
                                load_in_4bit=True,
                                bnb_4bit_use_double_quant=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=torch.bfloat16
                                )

model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config, device_map={'': 0}
)

if USE_FINETUNED:
    model = PeftModel.from_pretrained(model, peft_model_id)
    
tokenizer = AutoTokenizer.from_pretrained(
                                          config.base_model_name_or_path,
                                          )

  from .autonotebook import tqdm as notebook_tqdm


['/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00001-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00002-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00003-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00004-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00005-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00006-of-00007.safetensors', '/home/ksaff/Desktop/ttyd/api/model/snapshots/d3e967887d285343b8e239e26c6778c26931a536/model-00007-of-00007.safetensors']


Loading checkpoint shards: 100%|██████████| 7/7 [00:18<00:00,  2.59s/it]


In [16]:
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.chains import ConversationChain


generation_config = GenerationConfig.from_pretrained(model_id)
generation_config.max_new_tokens = 512
generation_config.max_time = 5
generation_config.pad_token_id = 2


text_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    generation_config=generation_config,
)
 
llm = HuggingFacePipeline(pipeline=text_pipeline)
memory = ConversationBufferMemory()
conversation_buf = ConversationChain(
    llm=llm,
    memory=memory,)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PLBartForCausalLM', 'Prophe

In [17]:
def create_prompt(question):
    database_scheme = """
CREATE TABLE Salaries
Id INTEGER PRIMARY KEY, -- Unique ID for each employee
EmployeeName VARCHAR, -- Name of the employee
JobTitle VARCHAR, -- Name of employees proffesion
BasePay NUMERIC, -- Base pay of employee
OvertimePay NUMERIC, -- Overtime pay of employee
OtherPay NUMERIC, -- Other pays of employee
Benefits NUMERIC, -- Benefits of employee
TotalPay NUMERIC, -- Total pay of employee
TotalPayBenefits NUMERIC, -- Sum of pay benefits of employee
Year INTEGER, -- Year data from row reffers to
"""
    text = (
    f"""### Task
    Generate an SQL query to answer the following question:
    {question}
### Database Schema
    This query will run on a database whose schema is represented in this string:""" 
    +
    database_scheme
    +
f"""### SQL
    Given the database schema, here is the SQL query that answers `{question}`:
    ```sql
    """
)
    return text

In [18]:
question = "Give me sum of all benefits in year 2013"
response = conversation_buf(create_prompt(question))

print(response['response'])
print('\nQuery only:',response['response'].split('SELECT')[1].split('\n')[0])

 I see 1 columns of type NUMERIC in the database. I will try to answer the question by aggregating on that column.
SELECT SUM(Benefits) FROM Salaries WHERE Year = 2013

### Output
    Given the database schema, here is the output of the

Query only:  SUM(Benefits) FROM Salaries WHERE Year = 2013


In [19]:
question_2 = 'Now do the same for total pay'
response_2 = conversation_buf(question_2)

print(response_2['response'])
print('\nQuery only:',response_2['response'].split('SELECT')[1].split('\n')[0])

  I see 1 columns of type NUMERIC in the database. I will try to answer the question by aggregating on that column.
SELECT SUM(TotalPay) FROM Salaries WHERE Year = 2013

### SQL
    Given the database schema, here is the SQL query

Query only:  SUM(TotalPay) FROM Salaries WHERE Year = 2013


In [20]:
question_3 = 'Now for OtherPay'
response_3 = conversation_buf(question_3)

print(response_3['response'])
print('\nQuery only:',response_3['response'].split('SELECT')[1].split('\n')[0])

   I see 1 columns of type NUMERIC in the database. I will try to answer the question by aggregating on that column.
SELECT SUM(OtherPay) FROM Salaries WHERE Year = 2013

### SQL
    Given the database schema, here is

Query only:  SUM(OtherPay) FROM Salaries WHERE Year = 2013
