In [1]:
%pip install -q transformers torch accelerate
%pip install -U bitsandbytes



In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import AdamW
from tqdm import tqdm


In [3]:
teacher_schema_prompt = """
# E-Commerce Database Schema

This database models an online retail store.
**customers** register and place **orders** containing multiple **order_items**, where each item references a **product** from the catalog.
The system tracks customer info, product inventory, order history, and detailed purchase records.

## Tables:

**customers**:
- customer_id (INT, PRIMARY KEY, AUTO_INCREMENT) - Unique customer identifier
- first_name (VARCHAR(50), NOT NULL) - Customer's first name
- last_name (VARCHAR(50), NOT NULL) - Customer's last name
- email (VARCHAR(100), UNIQUE, NOT NULL) - Customer's email address
- phone (VARCHAR(20), NULLABLE) - Customer's phone number
- registration_date (DATE, NOT NULL) - Date when customer registered
- city (VARCHAR(50), NULLABLE) - Customer's city
- country (VARCHAR(50), NULLABLE) - Customer's country

**products**:
- product_id (INT, PRIMARY KEY, AUTO_INCREMENT) - Unique product identifier
- product_name (VARCHAR(100), NOT NULL) - Name of the product
- category (VARCHAR(50), NOT NULL) - Product category (Electronics, Clothing, Home & Garden, Sports, Books, etc.)
- price (DECIMAL(10,2), NOT NULL) - Current product price in currency units
- stock_quantity (INT, NOT NULL, DEFAULT 0) - Available inventory count (0 = out of stock)
- supplier_id (INT, NULLABLE) - Reference to supplier (not detailed in this schema)

**orders**:
- order_id (INT, PRIMARY KEY, AUTO_INCREMENT) - Unique order identifier
- customer_id (INT, NOT NULL, FOREIGN KEY) - References customers.customer_id
- order_date (DATE, NOT NULL) - Date when order was placed
- total_amount (DECIMAL(10,2), NOT NULL) - Total order value including all items
- status (VARCHAR(20), NOT NULL, DEFAULT 'Pending') - Order status (Pending, Completed, Cancelled, Shipped, Processing)
- shipping_address (TEXT, NULLABLE) - Full shipping address for the order

**order_items**:
- order_item_id (INT, PRIMARY KEY, AUTO_INCREMENT) - Unique order item identifier
- order_id (INT, NOT NULL, FOREIGN KEY) - References orders.order_id
- product_id (INT, NOT NULL, FOREIGN KEY) - References products.product_id
- quantity (INT, NOT NULL) - Number of units ordered for this product
- unit_price (DECIMAL(10,2), NOT NULL) - Price per unit at time of order (may differ from current product price)

## Relationships:
- customers.customer_id → orders.customer_id (One-to-Many: One customer can have multiple orders)
- orders.order_id → order_items.order_id (One-to-Many: One order can contain multiple items)
- products.product_id → order_items.product_id (One-to-Many: One product can appear in multiple order items)

Generate nothing but a SQL query valid on the above schema to product data for the following question:

"""

In [4]:
import random
from torch.utils.data import Dataset, DataLoader, random_split

# --------------------------
# Text-to-SQL Dataset
# --------------------------
class TextToSQLDataset(Dataset):
    def __init__(self, samples):
        """
        samples: List of tuples (question, sql_query)
        """
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return {
            'question': self.samples[idx][0],
            'sql': self.samples[idx][1]
        }

# Database Schema:
# customers: customer_id, first_name, last_name, email, phone, registration_date, city, country
# products: product_id, product_name, category, price, stock_quantity, supplier_id
# orders: order_id, customer_id, order_date, total_amount, status, shipping_address
# order_items: order_item_id, order_id, product_id, quantity, unit_price

# Sample data for text-to-SQL generation
text_to_sql_samples = [
    # Basic SELECT queries
    ("Show me all customers", "SELECT * FROM customers;"),
    ("List all customers", "SELECT * FROM customers;"),
    ("Display customer information", "SELECT * FROM customers;"),

    ("What products do we have?", "SELECT * FROM products;"),
    ("Show me the product catalog", "SELECT * FROM products;"),
    ("List all available products", "SELECT * FROM products;"),

    ("Show me all orders", "SELECT * FROM orders;"),
    ("Display order information", "SELECT * FROM orders;"),

    # SELECT with WHERE clauses
    ("Show customers from New York", "SELECT * FROM customers WHERE city = 'New York';"),
    ("Find customers in New York", "SELECT * FROM customers WHERE city = 'New York';"),
    ("List customers located in New York", "SELECT * FROM customers WHERE city = 'New York';"),

    ("Show customers from USA", "SELECT * FROM customers WHERE country = 'USA';"),
    ("Find American customers", "SELECT * FROM customers WHERE country = 'USA';"),

    ("Show products under $50", "SELECT * FROM products WHERE price < 50;"),
    ("Find cheap products under 50 dollars", "SELECT * FROM products WHERE price < 50;"),
    ("List products costing less than $50", "SELECT * FROM products WHERE price < 50;"),

    ("Show expensive products over $100", "SELECT * FROM products WHERE price > 100;"),
    ("Find products costing more than 100 dollars", "SELECT * FROM products WHERE price > 100;"),

    ("Show products in Electronics category", "SELECT * FROM products WHERE category = 'Electronics';"),
    ("Find electronic products", "SELECT * FROM products WHERE category = 'Electronics';"),

    ("Show completed orders", "SELECT * FROM orders WHERE status = 'Completed';"),
    ("Find finished orders", "SELECT * FROM orders WHERE status = 'Completed';"),

    ("Show pending orders", "SELECT * FROM orders WHERE status = 'Pending';"),
    ("Find orders that are still pending", "SELECT * FROM orders WHERE status = 'Pending';"),

    # Aggregation queries
    ("How many customers do we have?", "SELECT COUNT(*) FROM customers;"),
    ("Count total customers", "SELECT COUNT(*) FROM customers;"),
    ("What's the total number of customers?", "SELECT COUNT(*) FROM customers;"),

    ("How many products are there?", "SELECT COUNT(*) FROM products;"),
    ("Count all products", "SELECT COUNT(*) FROM products;"),

    ("What's the average product price?", "SELECT AVG(price) FROM products;"),
    ("Calculate average price of products", "SELECT AVG(price) FROM products;"),
    ("Find the mean product price", "SELECT AVG(price) FROM products;"),

    ("What's the most expensive product price?", "SELECT MAX(price) FROM products;"),
    ("Find the highest product price", "SELECT MAX(price) FROM products;"),

    ("What's the cheapest product price?", "SELECT MIN(price) FROM products;"),
    ("Find the lowest product price", "SELECT MIN(price) FROM products;"),

    ("What's the total value of all orders?", "SELECT SUM(total_amount) FROM orders;"),
    ("Calculate total order revenue", "SELECT SUM(total_amount) FROM orders;"),

    # GROUP BY queries
    ("Show number of customers by city", "SELECT city, COUNT(*) FROM customers GROUP BY city;"),
    ("Count customers in each city", "SELECT city, COUNT(*) FROM customers GROUP BY city;"),
    ("Group customers by city", "SELECT city, COUNT(*) FROM customers GROUP BY city;"),

    ("Show number of customers by country", "SELECT country, COUNT(*) FROM customers GROUP BY country;"),
    ("Count customers per country", "SELECT country, COUNT(*) FROM customers GROUP BY country;"),

    ("Show number of products by category", "SELECT category, COUNT(*) FROM products GROUP BY category;"),
    ("Count products in each category", "SELECT category, COUNT(*) FROM products GROUP BY category;"),

    ("Show average price by product category", "SELECT category, AVG(price) FROM products GROUP BY category;"),
    ("Calculate average price for each category", "SELECT category, AVG(price) FROM products GROUP BY category;"),

    ("Show total orders by status", "SELECT status, COUNT(*) FROM orders GROUP BY status;"),
    ("Count orders by their status", "SELECT status, COUNT(*) FROM orders GROUP BY status;"),

    ("Show total revenue by month", "SELECT EXTRACT(MONTH FROM order_date) as month, SUM(total_amount) FROM orders GROUP BY EXTRACT(MONTH FROM order_date);"),
    ("Calculate monthly revenue", "SELECT EXTRACT(MONTH FROM order_date) as month, SUM(total_amount) FROM orders GROUP BY EXTRACT(MONTH FROM order_date);"),

    # JOIN queries
    ("Show customer names with their orders", "SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;"),
    ("List customers and their order details", "SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;"),
    ("Find customers with their order information", "SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;"),

    ("Show order details with customer information", "SELECT o.order_id, o.total_amount, c.first_name, c.last_name, c.email FROM orders o JOIN customers c ON o.customer_id = c.customer_id;"),
    ("List orders with customer details", "SELECT o.order_id, o.total_amount, c.first_name, c.last_name, c.email FROM orders o JOIN customers c ON o.customer_id = c.customer_id;"),

    ("Show products in each order", "SELECT o.order_id, p.product_name, oi.quantity FROM orders o JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;"),
    ("List order items with product names", "SELECT o.order_id, p.product_name, oi.quantity FROM orders o JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;"),
    ("Find what products were ordered", "SELECT o.order_id, p.product_name, oi.quantity FROM orders o JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;"),

    ("Show customer purchases with product details", "SELECT c.first_name, c.last_name, p.product_name, oi.quantity, oi.unit_price FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;"),
    ("List what each customer bought", "SELECT c.first_name, c.last_name, p.product_name, oi.quantity, oi.unit_price FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;"),

    ("Show customers who have never placed an order", "SELECT c.first_name, c.last_name FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id WHERE o.customer_id IS NULL;"),
    ("Find customers without any orders", "SELECT c.first_name, c.last_name FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id WHERE o.customer_id IS NULL;"),
    ("List customers who haven't ordered anything", "SELECT c.first_name, c.last_name FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id WHERE o.customer_id IS NULL;"),

    # Subqueries and nested queries
    ("Show customers who placed orders above average", "SELECT DISTINCT c.first_name, c.last_name FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE o.total_amount > (SELECT AVG(total_amount) FROM orders);"),
    ("Find customers with high-value orders", "SELECT DISTINCT c.first_name, c.last_name FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE o.total_amount > (SELECT AVG(total_amount) FROM orders);"),

    ("Show products more expensive than average", "SELECT product_name, price FROM products WHERE price > (SELECT AVG(price) FROM products);"),
    ("Find above-average priced products", "SELECT product_name, price FROM products WHERE price > (SELECT AVG(price) FROM products);"),
    ("List products costing more than the average", "SELECT product_name, price FROM products WHERE price > (SELECT AVG(price) FROM products);"),

    ("Show the most expensive product", "SELECT product_name, price FROM products WHERE price = (SELECT MAX(price) FROM products);"),
    ("Find the highest priced product", "SELECT product_name, price FROM products WHERE price = (SELECT MAX(price) FROM products);"),

    ("Show customers from cities with more than 5 customers", "SELECT * FROM customers WHERE city IN (SELECT city FROM customers GROUP BY city HAVING COUNT(*) > 5);"),
    ("Find customers in popular cities", "SELECT * FROM customers WHERE city IN (SELECT city FROM customers GROUP BY city HAVING COUNT(*) > 5);"),

    ("Show products that have been ordered", "SELECT DISTINCT p.product_name FROM products p WHERE p.product_id IN (SELECT DISTINCT product_id FROM order_items);"),
    ("Find products that were purchased", "SELECT DISTINCT p.product_name FROM products p WHERE p.product_id IN (SELECT DISTINCT product_id FROM order_items);"),
    ("List ordered products", "SELECT DISTINCT p.product_name FROM products p WHERE p.product_id IN (SELECT DISTINCT product_id FROM order_items);"),

    ("Show products that have never been ordered", "SELECT product_name FROM products WHERE product_id NOT IN (SELECT DISTINCT product_id FROM order_items WHERE product_id IS NOT NULL);"),
    ("Find unsold products", "SELECT product_name FROM products WHERE product_id NOT IN (SELECT DISTINCT product_id FROM order_items WHERE product_id IS NOT NULL);"),

    ("Show top 5 customers by total spending", "SELECT c.first_name, c.last_name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.customer_id, c.first_name, c.last_name ORDER BY total_spent DESC LIMIT 5;"),
    ("Find biggest spenders", "SELECT c.first_name, c.last_name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.customer_id, c.first_name, c.last_name ORDER BY total_spent DESC LIMIT 5;"),

    ("Show top 3 most expensive products", "SELECT product_name, price FROM products ORDER BY price DESC LIMIT 3;"),
    ("Find the 3 highest priced products", "SELECT product_name, price FROM products ORDER BY price DESC LIMIT 3;"),
    ("List top 3 costly products", "SELECT product_name, price FROM products ORDER BY price DESC LIMIT 3;"),

    # Complex queries with multiple conditions
    ("Show customers from USA who spent more than $500", "SELECT c.first_name, c.last_name, SUM(o.total_amount) as total FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE c.country = 'USA' GROUP BY c.customer_id, c.first_name, c.last_name HAVING SUM(o.total_amount) > 500;"),
    ("Find American customers with high spending", "SELECT c.first_name, c.last_name, SUM(o.total_amount) as total FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE c.country = 'USA' GROUP BY c.customer_id, c.first_name, c.last_name HAVING SUM(o.total_amount) > 500;"),

    ("Show electronics products under $200", "SELECT product_name, price FROM products WHERE category = 'Electronics' AND price < 200;"),
    ("Find affordable electronic items", "SELECT product_name, price FROM products WHERE category = 'Electronics' AND price < 200;"),

    ("Show orders from last month with high value", "SELECT * FROM orders WHERE EXTRACT(MONTH FROM order_date) = EXTRACT(MONTH FROM CURRENT_DATE - INTERVAL '1 month') AND total_amount > 1000;"),
    ("Find recent high-value orders", "SELECT * FROM orders WHERE EXTRACT(MONTH FROM order_date) = EXTRACT(MONTH FROM CURRENT_DATE - INTERVAL '1 month') AND total_amount > 1000;"),

    ("Show customers who bought electronics", "SELECT DISTINCT c.first_name, c.last_name FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id WHERE p.category = 'Electronics';"),
    ("Find customers who purchased electronic products", "SELECT DISTINCT c.first_name, c.last_name FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id WHERE p.category = 'Electronics';"),

    ("Show average order value by customer city", "SELECT c.city, AVG(o.total_amount) as avg_order_value FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.city;"),
    ("Calculate average spending per city", "SELECT c.city, AVG(o.total_amount) as avg_order_value FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.city;"),

    # Date-based queries
    ("Show orders from this year", "SELECT * FROM orders WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE);"),
    ("Find current year orders", "SELECT * FROM orders WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE);"),

    ("Show customers registered in 2023", "SELECT * FROM customers WHERE EXTRACT(YEAR FROM registration_date) = 2023;"),
    ("Find customers who joined in 2023", "SELECT * FROM customers WHERE EXTRACT(YEAR FROM registration_date) = 2023;"),

    ("Show orders placed in the last 30 days", "SELECT * FROM orders WHERE order_date >= CURRENT_DATE - INTERVAL '30 days';"),
    ("Find recent orders from past month", "SELECT * FROM orders WHERE order_date >= CURRENT_DATE - INTERVAL '30 days';"),

    # Product inventory queries
    ("Show products with low stock", "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10;"),
    ("Find products running out of stock", "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10;"),
    ("List low inventory items", "SELECT product_name, stock_quantity FROM products WHERE stock_quantity < 10;"),

    ("Show out of stock products", "SELECT product_name FROM products WHERE stock_quantity = 0;"),
    ("Find products with no stock", "SELECT product_name FROM products WHERE stock_quantity = 0;"),

    # Revenue and sales analysis
    ("Show monthly sales totals", "SELECT EXTRACT(MONTH FROM order_date) as month, EXTRACT(YEAR FROM order_date) as year, SUM(total_amount) as monthly_total FROM orders GROUP BY EXTRACT(YEAR FROM order_date), EXTRACT(MONTH FROM order_date) ORDER BY year, month;"),
    ("Calculate monthly revenue breakdown", "SELECT EXTRACT(MONTH FROM order_date) as month, EXTRACT(YEAR FROM order_date) as year, SUM(total_amount) as monthly_total FROM orders GROUP BY EXTRACT(YEAR FROM order_date), EXTRACT(MONTH FROM order_date) ORDER BY year, month;"),

    ("Show best selling products by quantity", "SELECT p.product_name, SUM(oi.quantity) as total_sold FROM products p JOIN order_items oi ON p.product_id = oi.product_id GROUP BY p.product_id, p.product_name ORDER BY total_sold DESC;"),
    ("Find most popular products", "SELECT p.product_name, SUM(oi.quantity) as total_sold FROM products p JOIN order_items oi ON p.product_id = oi.product_id GROUP BY p.product_id, p.product_name ORDER BY total_sold DESC;"),

    ("Show product revenue ranking", "SELECT p.product_name, SUM(oi.quantity * oi.unit_price) as revenue FROM products p JOIN order_items oi ON p.product_id = oi.product_id GROUP BY p.product_id, p.product_name ORDER BY revenue DESC;"),
    ("Find highest revenue generating products", "SELECT p.product_name, SUM(oi.quantity * oi.unit_price) as revenue FROM products p JOIN order_items oi ON p.product_id = oi.product_id GROUP BY p.product_id, p.product_name ORDER BY revenue DESC;"),
]

# Create dataset
train_dataset = TextToSQLDataset(text_to_sql_samples)

# Setup data loaders
BATCH_SIZE = 8
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total samples: {len(text_to_sql_samples)}")
print(f"Training samples: {train_size}")
print(f"Validation samples: {val_size}")
print(f"Sample data point: {text_to_sql_samples[0]}")

# Example of accessing data
for batch in train_loader:
    teacher_qs = [teacher_schema_prompt + f"{q}\n Generated SQL: "   for q in batch['question']]
    print(f"Batch teacher questions: {len(teacher_qs), type(teacher_qs)}")
    print(f"Batch questions: {len(batch['question']), type(batch['question'])}")
    print(f"Batch SQLs: {batch['sql']}")
    break  # Just show first batch

Total samples: 110
Training samples: 88
Validation samples: 22
Sample data point: ('Show me all customers', 'SELECT * FROM customers;')
Batch teacher questions: (8, <class 'list'>)
Batch questions: (8, <class 'list'>)
Batch SQLs: ['SELECT EXTRACT(MONTH FROM order_date) as month, EXTRACT(YEAR FROM order_date) as year, SUM(total_amount) as monthly_total FROM orders GROUP BY EXTRACT(YEAR FROM order_date), EXTRACT(MONTH FROM order_date) ORDER BY year, month;', "SELECT c.first_name, c.last_name, SUM(o.total_amount) as total FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE c.country = 'USA' GROUP BY c.customer_id, c.first_name, c.last_name HAVING SUM(o.total_amount) > 500;", 'SELECT * FROM products WHERE price < 50;', 'SELECT c.first_name, c.last_name, p.product_name, oi.quantity, oi.unit_price FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN order_items oi ON o.order_id = oi.order_id JOIN products p ON oi.product_id = p.product_id;', 'SELECT c.first_

In [5]:
# --------------------------
# Student (GPT-2)
# --------------------------
student_name = "gpt2"
student = GPT2LMHeadModel.from_pretrained(
                              student_name).cuda()

student_tokenizer = GPT2Tokenizer.from_pretrained(student_name)
student_tokenizer.pad_token = student_tokenizer.eos_token  # GPT-2 fix
student.train()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Define 4-bit quantization configuration (As teacher is too large without quantization)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# --------------------------
# Teacher (Qwen2-7B)
# --------------------------
# Load the teacher model in 4-bit
teacher_name = "Qwen/Qwen2-7B-Instruct"
teacher = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    # We are quantizing to just 4bits to manage memory wisely
    quantization_config=quantization_config,
    device_map="auto"
)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_name)


# --------------------------
# KD Helpers
# --------------------------

# layer:22 for Qwen
# layer: 25 for StarCoder

def get_teacher_hidden(input_ids, attention_mask, layer_idx=25):
    with torch.no_grad():
        outputs = teacher(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        h = outputs.hidden_states[layer_idx + 1]   # take hidden state at layer
        h = teacher.model.norm(h)                  # apply final norm

        # GPT-2 is in float32. Qwen is in float16. So the results of Qwen should be upcast
        h = h.to(torch.float32)  # upcast so it matches GPT-2's dtype

    return h

"""
Here we basically try to project the hidden state layer of teacher model into a similarly positioned hidden layer of the student model. Why?

Because in this way the dimensions match. We use a nn.Linear() layer for this purpose which is a trainable layer.

Why aren't we passing the teacher output through its lm_head and computing the loss between the students output through its lm_head?

Because in that approach, the final vocab size of the student model is different (smaller) from the teacher model and a projection (or a simpler cropping) leads to information loss. Hence we do this approach
"""

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

"\nHere we basically try to project the hidden state layer of teacher model into a similarly positioned hidden layer of the student model. Why?\n\nBecause in this way the dimensions match. We use a nn.Linear() layer for this purpose which is a trainable layer.\n\nWhy aren't we passing the teacher output through its lm_head and computing the loss between the students output through its lm_head?\n\nBecause in that approach, the final vocab size of the student model is different (smaller) from the teacher model and a projection (or a simpler cropping) leads to information loss. Hence we do this approach\n"

In [7]:
"""
Sample testing on student before training
"""

# --------------------------
# Validation
# --------------------------

val_loss = 0
with torch.no_grad():
    for batch in val_loader:
        questions = batch["question"]
        sqls = batch["sql"]

        # Same concatenation as training
        student_texts = [f"Question: {q}\nSQL: {s}" for q, s in zip(questions, sqls)]

        enc_student = student_tokenizer(
            student_texts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to("cuda")

        labels = enc_student["input_ids"].clone()

        question_length = [len(student_tokenizer(f"Question: {q}\nSQL: ", return_tensors="pt")["input_ids"][0]) for q in questions]

        # Masking the user question
        for i, q_len in enumerate(question_length):
          labels[i, :q_len] = -100

        enc_student["labels"] = labels

        outputs = student(**enc_student)
        val_loss += outputs.loss.item()

val_loss /= len(val_loader)
print(f"Loss: {val_loss:.4f}")
print()


# --------------------------
# Custom Testing
# --------------------------
student.eval()

custom_questions = [
    "Show me the total number of passengers in 2024",
    "List all flights departing from Delhi",
    "What was the average ticket price last month?"
]

with torch.no_grad():
    for q in custom_questions:
        # Format same as training (question prompt only, no SQL appended)
        input_text = f"Question: {q}\nSQL:"

        enc = student_tokenizer(
            input_text,
            return_tensors="pt"
        ).to("cuda")

        # Generate SQL (autoregressive generation)
        generated_ids = student.generate(
            **enc,
            max_length=128,         # limit for SQL length
            num_beams=5,            # beam search for better results
            early_stopping=True,
            pad_token_id=student_tokenizer.eos_token_id
        )

        output_text = student_tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )

        # Strip input part, keep only generated SQL
        if "SQL:" in output_text:
            sql_pred = output_text.split("SQL:")[-1].strip()
        else:
            sql_pred = output_text.strip()

        print(f"\n❓ Question: {q}")
        print(f"📝 Predicted SQL: {sql_pred}")


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Loss: 5.3911


❓ Question: Show me the total number of passengers in 2024
📝 Predicted SQL: 1,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000

❓ Question: List all flights departing from Delhi
📝 Predicted SQL: All flights departing from Delhi

❓ Question: What was the average ticket price last month?
📝 Predicted SQL: The average ticket price last month was $9.99.
The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last month was $9.99. The average ticket price last


In [9]:
"""
Sample testing on teacher before training
"""

# --------------------------
# Custom Testing - Teacher
# --------------------------
teacher.eval()

import random

# Pick a random batch
random_batch = random.choice(list(train_loader))

questions = random_batch["question"]
sqls = random_batch["sql"]


with torch.no_grad():
    for q, s in zip(questions, sqls):
        # Teacher doesn't need "SQL:" tag explicitly (since it's instruction tuned),
        # but we can format prompts consistently if we want
        input_text = teacher_schema_prompt + f"{q}\n Generated SQL: "

        enc = teacher_tokenizer(
            input_text,
            return_tensors="pt"
        ).to("cuda")

        # Generate without fixed max_length
        generated_ids = teacher.generate(
            **enc,
            num_beams=5,                # beam search for more precise outputs
            do_sample=False,            # deterministic (greedy+beam)
            early_stopping=True,
            pad_token_id=teacher_tokenizer.eos_token_id
        )

        output_text = teacher_tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )

        # # Extract SQL portion
        # if "SQL:" in output_text:
        #     sql_pred = output_text.split("SQL:")[-1].strip()
        # else:
        sql_pred = output_text.strip()

        # print(f"\n❓ Question: {q}")
        print(f"\n📝 Teacher SQL Prediction: {sql_pred[sql_pred.find(q):]}")
        print(f"✅ Actual SQL: ", s)



📝 Teacher SQL Prediction: Group customers by city
 Generated SQL:  SELECT `city`, COUNT(`customer_id`) AS `customer_count` FROM `customers` GROUP BY
✅ Actual SQL:  SELECT city, COUNT(*) FROM customers GROUP BY city;

📝 Teacher SQL Prediction: Find recent orders from past month
 Generated SQL:  SELECT * FROM orders WHERE order_date >= CURDATE() - INTERVAL 1 MONTH;
✅ Actual SQL:  SELECT * FROM orders WHERE order_date >= CURRENT_DATE - INTERVAL '30 days';

📝 Teacher SQL Prediction: Show orders from last month with high value
 Generated SQL:  SELECT * FROM orders WHERE order_date BETWEEN DATE_SUB(CURDATE(), INTERVAL 1 MONTH) AND
✅ Actual SQL:  SELECT * FROM orders WHERE EXTRACT(MONTH FROM order_date) = EXTRACT(MONTH FROM CURRENT_DATE - INTERVAL '1 month') AND total_amount > 1000;

📝 Teacher SQL Prediction: Find the 3 highest priced products
 Generated SQL:  SELECT product_id, product_name, price FROM products ORDER BY price DESC LIMIT 3;
✅ Actual SQL:  SELECT product_name, price FROM prod

In [10]:
# Projection: teacher hidden → student hidden space
projection = nn.Linear(teacher.config.hidden_size, student.config.hidden_size).to("cuda")

In [11]:
# Defining the loss functions for the same

criterion_ce = nn.CrossEntropyLoss(ignore_index=student_tokenizer.pad_token_id)
criterion_kl = nn.KLDivLoss(reduction="batchmean")

temperature = 2.0
alpha = 0.75  # balance between CE and KD

optimizer = AdamW(student.parameters(), lr=5e-5)


In [12]:
# --------------------------
# Training Loop
# --------------------------
EPOCHS = 20
for epoch in range(EPOCHS):
    student.train()
    projection.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        # Batch has dict: {"question": [...], "sql": [...]}
        questions = batch["question"]
        sqls = batch["sql"]


        # Adding the schema prompt for each teacher_qs query
        questions_with_info = [teacher_schema_prompt + f"{q}\n Generated SQL: "   for q in batch['question']]


        # Tokenize input (questions) for teacher (only input, teacher not trained on SQL)
        enc_teacher = teacher_tokenizer(
            questions_with_info,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to("cuda")

        # Combine question + SQL for GPT-2 causal LM training (Also including schema details to use GPT-2's inherent abilities to perform the task)
        student_texts = [f"Generate a SQL query for the following question:\nQuestion: {q}\n Generated SQL: {s}" for q, s in zip(questions, sqls)]

        enc_student = student_tokenizer(
            student_texts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to("cuda")

        labels = enc_student["input_ids"].clone()

        """
        Here we are masking the inputs to GPT-2. WHY?

        GPT2 acts like a LSTM model that takes the full qs and answer seqence and does token predicting. However when we compute the loss, we dont need to accomodate for the qs predicting part as in validation a question will always be given

        So we mask the question with "-100" value in the loss function and only penalize errors in SQL prediction

        This ultimately gives a huge boost to the accuracy
        """

        question_length = [len(student_tokenizer(f"Question: {q}\nSQL: ", return_tensors="pt")["input_ids"][0]) for q in questions]

        for i, q_len in enumerate(question_length):
            labels[i, :q_len] = -100

        enc_student["labels"] = labels

        # Forward pass student
        student_outputs = student(**enc_student, output_hidden_states=True)
        student_logits = student_outputs.logits
        student_hidden = student_outputs.hidden_states[-1]  # final hidden of GPT-2

        # Teacher hidden (intermediate layer)
        teacher_hidden = get_teacher_hidden(
            enc_teacher["input_ids"], enc_teacher["attention_mask"], layer_idx=22
        )

        # Project teacher hidden into student space
        teacher_hidden_proj = projection(teacher_hidden)

        # --------------------------
        # KD Losses
        # --------------------------

        # Mean pool over sequence length (As we are inputting qs+answer for GPT-2 and just qs for Qwen we will - Mean-pool across sequence.) ie; compare representations instead of every token

        student_pooled = student_hidden.mean(dim=1)         # [batch, hidden]
        teacher_pooled = teacher_hidden_proj.mean(dim=1)    # [batch, hidden]


        # 1) Hidden-state MSE loss (align teacher + student reps)
        loss_mse = nn.functional.mse_loss(student_pooled, teacher_pooled)

        # 2) CE loss (hard targets from SQL labels)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = enc_student["labels"][..., 1:].contiguous()
        loss_ce = student_outputs.loss

        # Combined loss
        loss = alpha * loss_ce + (1 - alpha) * loss_mse
        total_loss += loss.item()

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

    # --------------------------
    # Validation
    # --------------------------
    student.eval()
    projection.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            questions = batch["question"]
            sqls = batch["sql"]

            # Same concatenation as training
            student_texts = [f"Question: {q}\nSQL: {s}" for q, s in zip(questions, sqls)]

            enc_student = student_tokenizer(
                student_texts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to("cuda")

            labels = enc_student["input_ids"].clone()

            question_length = [len(student_tokenizer(f"Question: {q}\nSQL: ", return_tensors="pt")["input_ids"][0]) for q in questions]

            # Masking the user question
            for i, q_len in enumerate(question_length):
              labels[i, :q_len] = -100

            enc_student["labels"] = labels

            outputs = student(**enc_student)
            val_loss += outputs.loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
    print()


Epoch 1 Training: 100%|██████████| 11/11 [00:54<00:00,  4.96s/it]


Epoch 1 Training Loss: 5.1813
Epoch 1 Validation Loss: 1.3835



Epoch 2 Training: 100%|██████████| 11/11 [00:52<00:00,  4.73s/it]


Epoch 2 Training Loss: 1.8320
Epoch 2 Validation Loss: 1.1458



Epoch 3 Training: 100%|██████████| 11/11 [00:52<00:00,  4.79s/it]


Epoch 3 Training Loss: 1.5389
Epoch 3 Validation Loss: 0.9986



Epoch 4 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 4 Training Loss: 1.2913
Epoch 4 Validation Loss: 0.8648



Epoch 5 Training: 100%|██████████| 11/11 [00:52<00:00,  4.77s/it]


Epoch 5 Training Loss: 1.1745
Epoch 5 Validation Loss: 0.7578



Epoch 6 Training: 100%|██████████| 11/11 [00:52<00:00,  4.76s/it]


Epoch 6 Training Loss: 1.0736
Epoch 6 Validation Loss: 0.6621



Epoch 7 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 7 Training Loss: 0.9443
Epoch 7 Validation Loss: 0.5782



Epoch 8 Training: 100%|██████████| 11/11 [00:52<00:00,  4.79s/it]


Epoch 8 Training Loss: 0.8563
Epoch 8 Validation Loss: 0.5108



Epoch 9 Training: 100%|██████████| 11/11 [00:52<00:00,  4.76s/it]


Epoch 9 Training Loss: 0.8275
Epoch 9 Validation Loss: 0.4593



Epoch 10 Training: 100%|██████████| 11/11 [00:52<00:00,  4.77s/it]


Epoch 10 Training Loss: 0.7604
Epoch 10 Validation Loss: 0.4186



Epoch 11 Training: 100%|██████████| 11/11 [00:52<00:00,  4.76s/it]


Epoch 11 Training Loss: 0.7186
Epoch 11 Validation Loss: 0.3820



Epoch 12 Training: 100%|██████████| 11/11 [00:52<00:00,  4.76s/it]


Epoch 12 Training Loss: 0.7056
Epoch 12 Validation Loss: 0.3466



Epoch 13 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 13 Training Loss: 0.6744
Epoch 13 Validation Loss: 0.3196



Epoch 14 Training: 100%|██████████| 11/11 [00:52<00:00,  4.77s/it]


Epoch 14 Training Loss: 0.6365
Epoch 14 Validation Loss: 0.2970



Epoch 15 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 15 Training Loss: 0.6154
Epoch 15 Validation Loss: 0.2719



Epoch 16 Training: 100%|██████████| 11/11 [00:52<00:00,  4.77s/it]


Epoch 16 Training Loss: 0.5994
Epoch 16 Validation Loss: 0.2529



Epoch 17 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 17 Training Loss: 0.5764
Epoch 17 Validation Loss: 0.2397



Epoch 18 Training: 100%|██████████| 11/11 [00:52<00:00,  4.78s/it]


Epoch 18 Training Loss: 0.5221
Epoch 18 Validation Loss: 0.2352



Epoch 19 Training: 100%|██████████| 11/11 [00:52<00:00,  4.76s/it]


Epoch 19 Training Loss: 0.4923
Epoch 19 Validation Loss: 0.2220



Epoch 20 Training: 100%|██████████| 11/11 [00:52<00:00,  4.77s/it]


Epoch 20 Training Loss: 0.4654
Epoch 20 Validation Loss: 0.2162



In [19]:
# --------------------------
# Custom Testing
# --------------------------
student.eval()

import random

# Pick a random batch
random_batch = random.choice(list(train_loader))

questions = random_batch["question"]
sqls = random_batch["sql"]

# for q, s in zip(questions, sqls):
#     print(f"❓ Question: {q}")
#     print(f"✅ SQL: {s}")

teacher_preds = []
student_preds = []


with torch.no_grad():
    for q, s in zip(questions, sqls):
        """
        For student
        """

        # Format same as training (question prompt only, no SQL appended)
        input_text = f"Generate a SQL query for the following question:\nQuestion: {q}\n Generated SQL: "

        enc = student_tokenizer(
            input_text,
            return_tensors="pt"
        ).to("cuda")

        # Generate SQL (autoregressive generation)
        generated_ids = student.generate(
            **enc,
            max_length=128,         # limit for SQL length
            num_beams=5,            # beam search for better results
            early_stopping=True,
            pad_token_id=student_tokenizer.eos_token_id
        )

        output_text = student_tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )

        # # Strip input part, keep only generated SQL
        # if "SQL:" in output_text:
        #     sql_pred = output_text.split("SQL:")[-1].strip()
        # else:
        student_sql_pred = output_text.strip()

        input_text = teacher_schema_prompt + f"{q}\n Generated SQL: "

        enc = teacher_tokenizer(
            input_text,
            return_tensors="pt"
        ).to("cuda")

        # Generate without fixed max_length
        generated_ids = teacher.generate(
            **enc,
            num_beams=5,                # beam search for more precise outputs
            do_sample=False,            # deterministic (greedy+beam)
            early_stopping=True,
            pad_token_id=teacher_tokenizer.eos_token_id
        )

        output_text = teacher_tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )

        # # Extract SQL portion
        # if "SQL:" in output_text:
        #     sql_pred = output_text.split("SQL:")[-1].strip()
        # else:
        teacher_sql_pred = output_text.strip()

        teacher_preds.append(teacher_sql_pred[teacher_sql_pred.find("\n Generated SQL: ")+len("\n Generated SQL: "):])
        student_preds.append(student_sql_pred[student_sql_pred.find("\n Generated SQL: ")+len("\n Generated SQL: "):])

        print(f"\n\n❓ Question: {q}")
        print(f"✅ Actual SQL: {s}")
        print(f"\n📝 Teacher SQL Prediction: {teacher_sql_pred[teacher_sql_pred.find("\n Generated SQL: ")+len("\n Generated SQL: "):]}")


        print(f"\n📝 Predicted SQL by student: {student_sql_pred[student_sql_pred.find("\n Generated SQL: ")+len("\n Generated SQL: "):]}")




❓ Question: Show order details with customer information
✅ Actual SQL: SELECT o.order_id, o.total_amount, c.first_name, c.last_name, c.email FROM orders o JOIN customers c ON o.customer_id = c.customer_id;

📝 Teacher SQL Prediction:  SELECT orders.order_id, orders.order_date, orders.total_amount, orders.status, orders.shipping_address

📝 Predicted SQL by student: --------------- SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN products p ON o.product_id = p.product_id;


❓ Question: Show products that have been ordered
✅ Actual SQL: SELECT DISTINCT p.product_name FROM products p WHERE p.product_id IN (SELECT DISTINCT product_id FROM order_items);

📝 Teacher SQL Prediction:  SELECT products.product_id, products.product_name FROM products INNER JOIN order_items ON products.product_id = order

📝 Predicted SQL by student: !!!SELECT * FROM products WHERE product_id NOT IN (SELECT DISTINCT product_id FROM orde

In [25]:
import pandas as pd

# Set max col width + enable wrap
pd.set_option('display.max_colwidth', None)   # don't truncate
pd.set_option('display.expand_frame_repr', False)  # don't wrap entire table weirdly

# Suppose you collect results in a loop
results = []

for q, s, teacher_sql_pred, student_sql_pred in zip(questions, sqls, teacher_preds, student_preds):
    results.append({
        "Question": q,
        "Actual SQL": s,
        "Student SQL": student_sql_pred[student_sql_pred.rfind('--')+2:],
        "Teacher SQL": teacher_sql_pred,
    })

# Create DataFrame
df = pd.DataFrame(results)

# Print as table
print(df.to_markdown(index=False))

display(df)

| Question                                     | Actual SQL                                                                                                                             | Student SQL                                                                                                                                                                  | Teacher SQL                                                                                                           |
|:---------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------|
| Show order details with customer inf

Unnamed: 0,Question,Actual SQL,Student SQL,Teacher SQL
0,Show order details with customer information,"SELECT o.order_id, o.total_amount, c.first_name, c.last_name, c.email FROM orders o JOIN customers c ON o.customer_id = c.customer_id;","SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN products p ON o.product_id = p.product_id;","SELECT orders.order_id, orders.order_date, orders.total_amount, orders.status, orders.shipping_address"
1,Show products that have been ordered,SELECT DISTINCT p.product_name FROM products p WHERE p.product_id IN (SELECT DISTINCT product_id FROM order_items);,!!SELECT * FROM products WHERE product_id NOT IN (SELECT DISTINCT product_id FROM order_items);,"SELECT products.product_id, products.product_name FROM products INNER JOIN order_items ON products.product_id = order"
2,Show orders from this year,SELECT * FROM orders WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE);,SELECT * FROM orders WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE);,SELECT * FROM orders WHERE YEAR(order_date) = YEAR(CURRENT_DATE);
3,Find the highest priced product,"SELECT product_name, price FROM products WHERE price = (SELECT MAX(price) FROM products);","SELECT product_name, price FROM products WHERE price > (SELECT MAX(price) FROM products);","SELECT `product_id`, `product_name`, `price` FROM `products` ORDER BY `price"
4,Show products in Electronics category,SELECT * FROM products WHERE category = 'Electronics';,SELECT * FROM products WHERE category = 'Electronics';,SELECT * FROM products WHERE category = 'Electronics';
5,List customers and their order details,"SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;","SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;","SELECT c.first_name, c.last_name, o.order_id, o.order_date, o.total_amount"
6,Show number of customers by city,"SELECT city, COUNT(*) FROM customers GROUP BY city;","SELECT city, COUNT(*) FROM customers GROUP BY city;","SELECT `city`, COUNT(`customer_id`) AS `customer_count` FROM `customers` GROUP BY"
7,List all customers,SELECT * FROM customers;,"SELECT c.first_name, c.last_name, o.order_id, o.total_amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN products p ON o.product_id = p.product_id;",SELECT * FROM customers;
