In [1]:
import os
os.environ["WANDB_PROJECT"]="tinyllama_Text2Sql_lora"
from enum import Enum
from functools import partial
import pandas as pd
import torch

from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer, TrainingArguments, set_seed
from datasets import load_dataset
from trl import SFTTrainer
from peft import get_peft_model, LoraConfig, TaskType

import re

seed = 42
set_seed(seed)

In [2]:
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [4]:
dataset_name = "gretelai/synthetic_text_to_sql"
dataset = load_dataset(dataset_name)

README.md:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

(…)nthetic_text_to_sql_train.snappy.parquet:   0%|          | 0.00/32.4M [00:00<?, ?B/s]

(…)ynthetic_text_to_sql_test.snappy.parquet:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5851 [00:00<?, ? examples/s]

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 5851
    })
})

In [6]:
# def get_schema(entry):
#     schema = ''
#     for stmt in entry['sql_context'].split(';'):
#         if 'CREATE TABLE' in stmt:
#             print(stmt)
#             table_name = stmt.split()[2]
#             columns = stmt.split('(', 1)[1].rsplit(')', 1)[0]
#             col_defs = columns.split(',')
#             schema = schema + "Table: " + table_name + "\n"
#             for col in col_defs:
#                 parts = col.strip().split()
#                 if len(parts)>= 2:
#                     schema = schema + "- " + parts[0] + ": " + parts[1].upper()
#             schema = schema + "\n"
#     return {"schema": schema}

In [7]:
# import re

# def get_schema(entry):
#     schema = ''
#     for stmt in entry['sql_context'].split(';'):
#         stmt = stmt.strip()
#         if stmt.upper().startswith('CREATE TABLE') and 'AS SELECT' in stmt.upper():
#             # Try to extract table name
#             table_name = stmt.split()[2]
#             select_part = stmt.upper().split('AS SELECT', 1)[1]
#             # Extract columns before FROM
#             columns_raw = re.split(r'\bFROM\b', select_part, 1)[0]
#             columns = [c.strip().split()[-1] for c in columns_raw.split(',')]
#             schema += "Table: " + table_name + "\n"
#             for col in columns:
#                 schema += f"- {col}: UNKNOWN\n"
#         elif stmt.upper().startswith('CREATE TABLE'):
#             table_name = stmt.split()[2]
#             columns = stmt.split('(', 1)[1].rsplit(')', 1)[0]
#             col_defs = columns.split(',')
#             schema += "Table: " + table_name + "\n"
#             for col in col_defs:
#                 parts = col.strip().split()
#                 if len(parts) >= 2:
#                     schema += "- " + parts[0] + ": " + parts[1].upper() + "\n"
#     return {"schema": schema}

In [8]:
def get_schema(entry):
    schema = ''
    for stmt in entry['sql_context'].split(';'):
        stmt = stmt.strip()
        # Skip empty or invalid CREATE TABLEs
        if not stmt.upper().startswith('CREATE TABLE'):
            continue

        # Try to extract the table name safely
        try:
            tokens = stmt.split()
            table_name = tokens[2].strip("'`\"")  # Strip extra quotes
        except IndexError:
            schema += "Malformed CREATE TABLE statement skipped.\n"
            continue

        # Handle CTAS
        if 'AS SELECT' in stmt.upper():
            select_part = stmt.upper().split('AS SELECT', 1)[1]
            columns_raw = re.split(r'\bFROM\b', select_part, 1)[0]
            columns = [c.strip().split()[-1] for c in columns_raw.split(',')]
            schema += f"Table: {table_name}\n"
            for col in columns:
                schema += f"- {col}: UNKNOWN\n"
        # Handle regular CREATE TABLE with column definitions
        elif '(' in stmt and ')' in stmt:
            try:
                columns = stmt.split('(', 1)[1].rsplit(')', 1)[0]
                col_defs = columns.split(',')
                schema += f"Table: {table_name}\n"
                for col in col_defs:
                    parts = col.strip().split()
                    if len(parts) >= 2:
                        schema += f"- {parts[0]}: {parts[1].upper()}\n"
            except Exception:
                schema += f"Table: {table_name} (column parsing error)\n"
        else:
            # Malformed table definition (no columns or SELECT)
            schema += f"Table: {table_name}\n"

    return {"schema": schema}

In [9]:
dataset['train'] = dataset['train'].map(get_schema)
dataset['test'] = dataset['test'].map(get_schema)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5851 [00:00<?, ? examples/s]

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation', 'schema'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation', 'schema'],
        num_rows: 5851
    })
})

In [11]:
from jinja2 import Template

template_str = """
<|im_start|>system
You are a SQL assistant. Use the following schema to answer queries.
{{ schema.strip() }}
<|im_end|>
<|im_start|>user
{{ sql_prompt }}
<|im_end|>
{% if sql is defined and is_training == True %}
<|im_start|>assistant
{{ sql }}
<|im_end|>
{% endif %}
{% if add_generation_prompt %}
<|im_start|>assistant
{% endif %}
"""

jinja_template = Template(template_str)

In [12]:
def preprocess(examples, is_training=True, add_generation_prompt=False):
    contents = []
    for i in range(len(examples['sql_prompt'])):
        rendered = jinja_template.render(
            schema = examples['schema'][i],
            sql_prompt=examples['sql_prompt'][i],
            context=examples['sql_context'][i],
            sql=examples['sql'][i] if is_training and 'sql' in examples else None,
            add_generation_prompt=add_generation_prompt,
            is_training = is_training
        ).strip()
        contents.append(rendered)
    return {'messages': contents}

In [13]:
dataset['train'] = dataset['train'].map(
    lambda x: preprocess(x, is_training=True, add_generation_prompt=False),
    batched=True,
    remove_columns=dataset["train"].column_names
)

dataset['test'] = dataset['test'].map(
    lambda x: preprocess(x, is_training=False, add_generation_prompt=True),
    batched=True,
    remove_columns=dataset["test"].column_names
)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5851 [00:00<?, ? examples/s]

In [14]:
dataset

DatasetDict({
    train: Dataset({
        features: ['messages'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['messages'],
        num_rows: 5851
    })
})

In [15]:
dataset['train'][0]

{'messages': '<|im_start|>system\nYou are a SQL assistant. Use the following schema to answer queries.\nTable: salesperson\n- salesperson_id: INT\n- name: TEXT\n- region: TEXT\nTable: timber_sales\n- sales_id: INT\n- salesperson_id: INT\n- volume: REAL\n- sale_date: DATE\n<|im_end|>\n<|im_start|>user\nWhat is the total volume of timber sold by each salesperson, sorted by salesperson?\n<|im_end|>\n\n<|im_start|>assistant\nSELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;\n<|im_end|>'}

In [16]:
dataset['test'][0]

{'messages': "<|im_start|>system\nYou are a SQL assistant. Use the following schema to answer queries.\nTable: creative_ai\n- application_id: INT\n- name: TEXT\n- region: TEXT\n- explainability_score: FLOAT\n<|im_end|>\n<|im_start|>user\nWhat is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table?\n<|im_end|>\n\n\n<|im_start|>assistant"}

In [17]:
template = """{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"""

In [18]:
class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token=ChatmlSpecialTokens.pad_token.value,
        bos_token=ChatmlSpecialTokens.bos_token.value,
        eos_token=ChatmlSpecialTokens.eos_token.value,
        additional_special_tokens=ChatmlSpecialTokens.list(),
        trust_remote_code=True
    )
tokenizer.chat_template = template
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(32005, 2048)

In [19]:
tokenizer.padding_side="left"
def get_prediction_batched(samples, column_name):
    # batch = []
    # for conversation in samples["messages"]:
    #     chatml_gen_prompt = tokenizer.apply_chat_template(conversation[:-1], tokenize=False, add_generation_prompt=True)
    #     batch.append(chatml_gen_prompt)
    #text = tokenizer.apply_chat_template(conversation_history, add_generation_prompt=True, tokenize=False)
    inputs = tokenizer(samples["messages"], return_tensors="pt", padding=True, truncation=True)#, add_special_tokens=False)
    inputs = {k: v.to("cuda") for k,v in inputs.items()}
    outputs = model.generate(**inputs, 
                             max_new_tokens=100, 
                             do_sample=True, 
                             top_p=0.95, 
                             temperature=0.2, 
                             repetition_penalty=1.1, 
                             eos_token_id=tokenizer.eos_token_id,
                             pad_token_id=tokenizer.eos_token_id,
                            )
    outputs = tokenizer.batch_decode(outputs)
    outputs = [output.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() for output in outputs]
    return {column_name: outputs}

In [20]:
model.to("cuda")
test_dataset = dataset["test"].shuffle().select(range(25))
test_dataset = test_dataset.map(
    partial(get_prediction_batched, column_name="base_assistant_message"),
    batched=True,
    batch_size=1)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [21]:
peft_config = LoraConfig(r=8,
                         lora_alpha=16,
                         lora_dropout=0.1,
                         target_modules=["gate_proj","q_proj","lm_head","o_proj","k_proj","embed_tokens","down_proj","up_proj","v_proj"],
                         task_type=TaskType.CAUSAL_LM)

In [22]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32005, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): 

In [23]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# cast non-trainable params in fp16
for p in model.parameters():
    if not p.requires_grad:
        p.data = p.to(torch.float16)

trainable params: 6,852,688 || all params: 1,106,921,552 || trainable%: 0.6191


In [24]:
dataset = dataset.rename_columns({"messages": "text"})

In [25]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['text'],
        num_rows: 5851
    })
})

In [26]:
output_dir = "tinyllama_Text2Sql_lora"
per_device_train_batch_size = 1
per_device_eval_batch_size = 1
gradient_accumulation_steps = 16
logging_steps = 25
learning_rate = 2e-5
max_grad_norm = 1.0
max_steps = 250
num_train_epochs=2
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
max_seq_length = 2048

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    max_grad_norm=max_grad_norm,
    weight_decay=0.1,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,
    fp16=True,
    report_to=["tensorboard", "wandb"],
    hub_private_repo=True,
    push_to_hub=True,
    num_train_epochs=num_train_epochs,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)



In [27]:
train_dataset = dataset["train"].shuffle().select(range(5000))

In [28]:
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=dataset["test"],
    tokenizer=tokenizer,
    # packing=True,
    # dataset_text_field="content",
    # max_seq_length=max_seq_length,
)

  trainer = SFTTrainer(


Converting train dataset to ChatML:   0%|          | 0/5000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/5851 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/5851 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/5851 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/5851 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [29]:
trainer.train()
trainer.save_model()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mprojectsbyswathi[0m ([33mprojectsbyswathi-na[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss
1,0.7638,1.029207




In [31]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load tokenizer and match special tokens
class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
    "Swathi8378/tinyllama_Text2Sql_lora",
    pad_token=ChatmlSpecialTokens.pad_token.value,
    bos_token=ChatmlSpecialTokens.bos_token.value,
    eos_token=ChatmlSpecialTokens.eos_token.value,
    additional_special_tokens=ChatmlSpecialTokens.list(),
    trust_remote_code=True
)
tokenizer.chat_template = template  # Set this if required

# Load base model FIRST (same architecture)
base_model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",  # or the correct base
    trust_remote_code=True
)

# Resize embeddings to match tokenizer (must be done BEFORE loading LoRA)
base_model.resize_token_embeddings(len(tokenizer))  # Ensure shape [32005, hidden_size]

# Now load the PEFT adapter
from peft import PeftModel

model = PeftModel.from_pretrained(
    base_model,
    "Swathi8378/tinyllama_Text2Sql_lora",
    torch_dtype=torch.float16
)

model.to("cuda")
model.eval()


tokenizer_config.json:   0%|          | 0.00/2.28k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/488 [00:00<?, ?B/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): lora.Embedding(
          (base_layer): Embedding(32005, 2048)
          (lora_dropout): ModuleDict(
            (default): Dropout(p=0.1, inplace=False)
          )
          (lora_A): ModuleDict()
          (lora_B): ModuleDict()
          (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x32005 (cuda:0)])
          (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
          (lora_magnitude_vector): ModuleDict()
        )
        (layers): ModuleList(
          (0-21): 22 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0

In [32]:
test_dataset = test_dataset.map(
    partial(get_prediction_batched, column_name="instruct_assistant_message"),
    batched=True,
    batch_size=1)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [34]:
df = test_dataset.to_pandas()
df.head(2)

Unnamed: 0,messages,base_assistant_message,instruct_assistant_message
0,<|im_start|>system\nYou are a SQL assistant. U...,What is the average maintenance cost for all m...,"SELECT name, maintenance_cost FROM MilitaryEqu..."
1,<|im_start|>system\nYou are a SQL assistant. U...,,"INSERT INTO volunteers (organization_id, name)..."


In [35]:
df['messages'][0]

'<|im_start|>system\nYou are a SQL assistant. Use the following schema to answer queries.\nTable: MilitaryEquipment\n- equipment_id: INT\n- name: VARCHAR(255)\n- region: VARCHAR(255)\n- maintenance_cost: FLOAT\n<|im_end|>\n<|im_start|>user\nWhat are the names and maintenance costs of all military equipment in the Atlantic region with a maintenance cost less than $5000?\n<|im_end|>\n\n\n<|im_start|>assistant'

In [36]:
df['base_assistant_message'][0]

'What is the average maintenance cost for all military equipment in the Atlantic region?\nWhat is the average maintenance cost for all military equipment in the Pacific region?\nWhat is the average maintenance cost for all military equipment in the Indian Ocean region?\nWhat is the average maintenance cost for all military equipment in the Mediterranean region?\nWhat is the average maintenance cost for all military equipment in the North American region?\nWhat is the average maintenance cost for all military equipment in the South American'

In [37]:
df['instruct_assistant_message'][0]

"SELECT name, maintenance_cost FROM MilitaryEquipment WHERE region = 'Atlantic' AND maintenance_cost < 5000;"

In [38]:
df['messages'][1]

"<|im_start|>system\nYou are a SQL assistant. Use the following schema to answer queries.\nTable: organizations\n- id: INT\n- name: TEXT\nTable: volunteers\n- id: INT\n- organization_id: INT\n- name: TEXT\n<|im_end|>\n<|im_start|>user\nInsert new records for 3 additional volunteers for the 'Doctors Without Borders' organization.\n<|im_end|>\n\n\n<|im_start|>assistant"