In [2]:
# !pip install datasets evaluate bitsandbytes optimum peft

In [3]:
# !pip install -U bitsandbytes

In [4]:
!accelerate config default

Configuration already exists at /home/shrirang/.cache/huggingface/accelerate/default_config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.


In [5]:
from datasets import load_dataset
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, BitsAndBytesConfig
from accelerate import Accelerator
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from peft import LoraConfig, get_peft_model
import tqdm

In [6]:
torch.cuda.empty_cache()

In [7]:
torch.set_float32_matmul_precision('high')

In [8]:
accelerator = Accelerator(mixed_precision="bf16")

In [9]:
subset_indices = list(range(256))

In [10]:
training_ds = load_dataset("gretelai/synthetic_text_to_sql", split="train")
valid_ds = load_dataset("gretelai/synthetic_text_to_sql", split="test")

In [11]:
valid_ds = Subset(valid_ds, subset_indices)

In [12]:
training_ds[0]

{'id': 5097,
 'domain': 'forestry',
 'domain_description': 'Comprehensive data on sustainable forest management, timber production, wildlife habitat, and carbon sequestration in forestry.',
 'sql_complexity': 'single join',
 'sql_complexity_description': 'only one join (specify inner, outer, cross)',
 'sql_task_type': 'analytics and reporting',
 'sql_task_type_description': 'generating reports, dashboards, and analytical insights',
 'sql_prompt': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
 'sql_context': "CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",
 'sql'

In [13]:
torch.cuda.empty_cache()

In [14]:
torch.backends.cuda.matmul.allow_tf32 = True

In [15]:
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4"
# )
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True
# )

In [16]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [17]:
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    torch_dtype=torch.bfloat16
)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [18]:
model = get_peft_model(model, lora_config)

In [19]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
device

device(type='cuda')

In [22]:
!nvidia-smi

Mon Feb 24 18:51:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A4000               Off |   00000000:03:00.0  On |                    0 |
| 41%   37C    P5             19W /  140W |     747MiB /  15352MiB |     21%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [23]:
def collate_fn(batch):
    """ Tokenizes and dynamically pads a batch of text samples. """
    system_prompt = "You are an AI that translates natural language into SQL queries. You will output only the SQL query that outputs the following natural language question."
    sql_pairs = [f"{system_prompt}\n\nQuestion:\n{example['sql_prompt']} {tokenizer.eos_token}\n SQL Query:\n{example['sql']}" for example in batch]

    tokenized = tokenizer(sql_pairs, padding=True, truncation=True, return_tensors="pt")

    input_ids = tokenized["input_ids"]
    attention_mask = tokenized["attention_mask"]

    labels = input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "input_ids": input_ids.to(device),
        "attention_mask": attention_mask.to(device),
        "labels": labels.to(device),
    }

In [24]:
training_ds.set_format(type="torch")

In [25]:
train_dataloader = DataLoader(training_ds, batch_size=4, collate_fn=collate_fn, shuffle=True)
valid_dataloader = DataLoader(valid_ds, batch_size=4, collate_fn=collate_fn, shuffle=False)

In [26]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [27]:
num_epochs = 3

In [28]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for param in model.parameters():
    param.data = param.data.to(device)
    if param.grad is not None:
        param.grad.data = param.grad.data.to(device)

# import bitsandbytes as bnb
# for name, module in model.named_modules():
#     if isinstance(module, bnb.nn.Linear4bit):
#         module.to(device)

In [29]:
# model = torch.compile(model)
model.to(device)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0-27): 28 x Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1536, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): Linear(in_fe

In [30]:
gradient_accumulation_steps = 4
batch_size = 1

for step, batch in tqdm.notebook.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
    batch = {k: v.to(device) for k, v in batch.items()}
    
    outputs = model(**batch)
    loss = outputs.loss
    loss = loss / gradient_accumulation_steps

    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    if (step + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

    if (step + 1) % 1000 == 0:
        training_loss = loss.item() * 4
        
        model.eval()
        total_loss = 0

        for val_step, val_batch in tqdm.notebook.tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            val_batch = {k: v.to(device) for k, v in val_batch.items()}
            with torch.no_grad():
                val_outputs = model(**val_batch)
                val_loss = val_outputs.loss
                total_loss += val_loss.item()

        avg_loss = total_loss / len(valid_dataloader)
        print(f"Training Loss at Step {step+1}, Loss: {training_loss:.4f}")
        print(f"Validation Loss at Step {step+1}, Loss: {avg_loss:.4f}")

        model.train()

  0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 1000, Loss: 0.9099
Validation Loss at Step 1000, Loss: 1.0337


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 2000, Loss: 0.7285
Validation Loss at Step 2000, Loss: 0.8711


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 3000, Loss: 0.7271
Validation Loss at Step 3000, Loss: 0.8420


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 4000, Loss: 1.2774
Validation Loss at Step 4000, Loss: 0.8237


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 5000, Loss: 0.6474
Validation Loss at Step 5000, Loss: 0.8092


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 6000, Loss: 0.8588
Validation Loss at Step 6000, Loss: 0.8005


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 7000, Loss: 0.7774
Validation Loss at Step 7000, Loss: 0.7914


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 8000, Loss: 0.7733
Validation Loss at Step 8000, Loss: 0.7867


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 9000, Loss: 0.8121
Validation Loss at Step 9000, Loss: 0.7801


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 10000, Loss: 0.7961
Validation Loss at Step 10000, Loss: 0.7755


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 11000, Loss: 0.8418
Validation Loss at Step 11000, Loss: 0.7707


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 12000, Loss: 0.6378
Validation Loss at Step 12000, Loss: 0.7660


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 13000, Loss: 0.8803
Validation Loss at Step 13000, Loss: 0.7627


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 14000, Loss: 0.8372
Validation Loss at Step 14000, Loss: 0.7584


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 15000, Loss: 0.7096
Validation Loss at Step 15000, Loss: 0.7550


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 16000, Loss: 0.7774
Validation Loss at Step 16000, Loss: 0.7499


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 17000, Loss: 0.7064
Validation Loss at Step 17000, Loss: 0.7476


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 18000, Loss: 0.7326
Validation Loss at Step 18000, Loss: 0.7450


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 19000, Loss: 0.6601
Validation Loss at Step 19000, Loss: 0.7427


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 20000, Loss: 0.6647
Validation Loss at Step 20000, Loss: 0.7393


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 21000, Loss: 0.7790
Validation Loss at Step 21000, Loss: 0.7377


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 22000, Loss: 0.7667
Validation Loss at Step 22000, Loss: 0.7368


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 23000, Loss: 0.7260
Validation Loss at Step 23000, Loss: 0.7332


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 24000, Loss: 0.6476
Validation Loss at Step 24000, Loss: 0.7312


  0%|          | 0/64 [00:00<?, ?it/s]

Training Loss at Step 25000, Loss: 0.6084
Validation Loss at Step 25000, Loss: 0.7292


In [36]:
table_description = """
Table Descriptions:

Table: yclamujetf (Customer Table)
bhwqyackvb (INTEGER)
Customer Identifier: A unique identifier for each customer.

yjkfcsqlsb (TEXT)
First Name: The customer's given name.

tqnwzgbtgf (TEXT)
Last Name: The customer's family name.

taeatbvlbq (TEXT)
Email: The customer's email address.

vpkrulppkd (TEXT)
Phone: The customer's contact phone number.

obbvxwwzqg (TEXT)
Address: The customer's residential or mailing address.

fshvmouozp (DATE)
Date of Birth: The customer's birth date.

pfwsrivqwb (TEXT)
SSN: The customer's Social Security Number (or other national identifier).

qpkmhrxtbc (INTEGER)
Credit Score: A numerical value representing the customer's creditworthiness.

aslwxlorcs (TIMESTAMP)
Created At: The timestamp marking when the customer record was created.

Table: branches
branch_id
Branch Identifier: A unique identifier for each branch.

branch_name
Branch Name: The official name of the branch.

address
Street Address: The physical street address of the branch.

city
City: The city in which the branch is located.

state
State/Region: The state or region where the branch operates.

zip_code
ZIP/Postal Code: The postal code for the branch location.

phone
Phone Number: Contact telephone number for the branch.

Table: accounts
account_id
Account Identifier: A unique identifier for each account.

customer_id
Customer Link: The identifier that links the account to a customer (from yclamujetf).

branch_id
Branch Link: The identifier of the branch where the account was opened or is maintained.

account_type
Type of Account: Specifies the kind of account (e.g., checking, savings).

account_number
Account Number: The official number assigned to the account.

balance
Current Balance: The current monetary balance of the account.

status
Account Status: Indicates the current state of the account (e.g., active, closed).

opened_date
Opened Date: The date on which the account was opened.

Table: cards
card_id
Card Identifier: A unique identifier for each card.

account_id
Associated Account: The identifier for the account to which the card is linked.

card_number
Card Number: The number printed on the card.

card_type
Type of Card: The category of the card (e.g., credit, debit).

expiry_date
Expiry Date: The expiration date of the card.

cvv
CVV: The Card Verification Value (security code).

status
Card Status: The current status of the card (e.g., active, blocked).

Table: loans
loan_id
Loan Identifier: A unique identifier for each loan.

customer_id
Customer Link: The identifier linking the loan to a customer (from yclamujetf).

branch_id
Branch Link: The identifier of the branch that processed or services the loan.

loan_type
Type of Loan: Specifies the loan category (e.g., personal, auto, mortgage).

loan_amount
Loan Amount: The total amount of money borrowed.

interest_rate
Interest Rate: The rate at which interest accrues on the loan.

term_months
Loan Term: The duration of the loan in months.

monthly_payment
Monthly Payment: The scheduled payment amount due each month.

remaining_balance
Remaining Balance: The outstanding amount yet to be paid.

start_date
Start Date: The date when the loan began.

end_date
End Date: The scheduled or actual date when the loan is to be or was fully repaid.

status
Loan Status: The current state of the loan (e.g., active, closed, default).

Table: transactions
transaction_id
Transaction Identifier: A unique identifier for each transaction.

account_id
Associated Account: The identifier of the account related to the transaction.

transaction_type
Type of Transaction: Indicates the nature of the transaction (e.g., deposit, withdrawal, transfer).

amount
Amount: The monetary value involved in the transaction.

transaction_date
Transaction Date: The date and time when the transaction occurred.

description
Description: Additional details or context about the transaction.

status
Transaction Status: Indicates the current status (e.g., completed, pending, failed).
"""

In [37]:
import torch

def generate_text_stream(model, tokenizer, prompt, device, max_new_tokens=100):
    """
    Streams generated text token-by-token using the model's forward pass with caching.
    This avoids the tensor size mismatch error by properly handling past_key_values.
    """
    system_prompt = "You are an AI that translates natural language into SQL queries. You will output only the SQL query that outputs the following natural language question. Give me response in one single SQL query."
    system_prompt = f"{system_prompt}\n{table_description}\n"
    full_prompt = f"{system_prompt}\nQuestion:\n{prompt} {tokenizer.eos_token}\n SQL Query:\n"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
    generated_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    past_key_values = None
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if past_key_values is None:
                outputs = model(input_ids=generated_ids, attention_mask=attention_mask, use_cache=True)
            else:
                outputs = model(
                    input_ids=generated_ids[:, -1].unsqueeze(-1),
                    attention_mask=attention_mask,
                    use_cache=True,
                    past_key_values=past_key_values
                )
                
            logits = outputs.logits
            past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
            
            next_token_logits = logits[:, -1, :]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(next_token_probs, num_samples=1)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype, device=device)], dim=1)
            
            token = tokenizer.decode(next_token.squeeze(), skip_special_tokens=True)
            yield token

            if next_token.item() == tokenizer.eos_token_id:
                break

In [38]:
prompt = "Find customers who have multiple accounts and their total balance across all accounts."
for token in generate_text_stream(model, tokenizer, prompt, device):
    print(token, end="", flush=True)

SELECT c1.customer_id, c1.balance, SUM(c2.balance) FROM accounts a1 JOIN cards c2 ON a1.account_id = c2.account_id JOIN loans l ON c2.account_id = l.account_id); SELECT c1.customer_id, c1.balance FROM customers WHERE c1.customer_id NOT IN (SELECT customer_id FROM accounts JOIN transactions ON accounts.account_id = transactions.account_id WHERE (c2.account_id, c2.amount) IN (inspect_products(c, transactions) WHERE

In [40]:
model.save_pretrained("DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT")
tokenizer.save_pretrained("DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT")

('DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT/tokenizer_config.json',
 'DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT/special_tokens_map.json',
 'DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT/tokenizer.json')

In [None]:
from huggingface_hub import HfApi
HF_TOKEN = ""

repo_name = "NotShrirang/DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT"
api = HfApi(token=HF_TOKEN)
api.create_repo(repo_id=repo_name, exist_ok=True)

model.push_to_hub(repo_name, token=HF_TOKEN)
tokenizer.push_to_hub(repo_name, token=HF_TOKEN)

No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/NotShrirang/DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT/commit/a92cdb32b6745e9cfaa0950c9d475ba2ab4b4e66', commit_message='Upload tokenizer', commit_description='', oid='a92cdb32b6745e9cfaa0950c9d475ba2ab4b4e66', pr_url=None, repo_url=RepoUrl('https://huggingface.co/NotShrirang/DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT', endpoint='https://huggingface.co', repo_type='model', repo_id='NotShrirang/DeepSeek-R1-Distill-Qwen-1.5B-SQL-Coder-PEFT'), pr_revision=None, pr_num=None)