# Imports

In [1]:
import pandas as pd
import torch
import datasets
import asyncio
from tqdm import tqdm
from typing import Callable
import re
import math
import random

import os
import psutil
import GPUtil
import gc

from groq import AsyncGroq, RateLimitError

import _config

In [2]:
ENABLE_THINKING = False

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["WANDB_API_KEY"] = _config.WANDB_API_KEY
os.environ["WANDB_PROJECT"] = _config.WANDB_PROJECT
os.environ["GROQ_API_KEY"] = _config.GROQ_API_KEY

# Utils

In [4]:
def get_vm_usage_metrics():
    # CPU usage
    cpu_load = psutil.cpu_percent(interval=1, percpu=True)
    for id, load in enumerate(cpu_load):
        print(f"CPU {id} load: {load:.2f}")
    # RAM usage
    ram = psutil.virtual_memory()
    print(f"RAM Total: {ram.total/(1024**3):.2f} GB, Used: {(ram.used)/(1024**3):.2f} GB")
    # GPU
    if torch.cuda.is_available():
        gpus = GPUtil.getGPUs()
        for gpu in gpus:
            print(f"GPU {gpu.id} ({gpu.name}) load: {gpu.load*100}%")
            print(f"GPU {gpu.id} ({gpu.name}) VRAM Total: {gpu.memoryTotal} MB, Used {gpu.memoryUsed} MB")
    # Disk 
    disk = psutil.disk_usage('/')
    print(f"Disk Total: {disk.total/(1024**3):.2f} GB, Used: {(disk.used)/(1024**3):.2f} GB")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')
get_vm_usage_metrics()

Device: cpu
CPU 0 load: 1.00
RAM Total: 1.86 GB, Used: 0.84 GB
Disk Total: 28.02 GB, Used: 19.70 GB


# Data

In [5]:
ds = datasets.load_dataset('gretelai/synthetic_text_to_sql', streaming=False)
ds_train, ds_test = ds['train'], ds['test']

split = ds_train.train_test_split(test_size=0.025, seed=42)
ds_train = split['train']
ds_valid = split['test']

get_vm_usage_metrics()
ds_train

CPU 0 load: 1.00
RAM Total: 1.86 GB, Used: 0.84 GB
Disk Total: 28.02 GB, Used: 19.70 GB


Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 97500
})

# Data generation

In [6]:
SYS_PROMPT = """You are a text-to-SQL data generator. You will be provided with a user prompt (sql_prompt), 
the query context (sql_context), and the correct SQL query (sql) that answers the user's question.
Your task is to generate an alternative SQL query that is worse than the provided SQL. Your query must 
contain at least one error - this is very important. Do not return the original query.
Return only the changed query, do not name it, do not explain it. Remeber to add at least one error.

SQL_PROMPT: "{sql_prompt}"
SQL_CONTEXT: "{sql_context}"
SQL: "{sql}"
"""

models_available = [
    'llama-3.1-8b-instant',
    'llama-3.3-70b-versatile',
    'meta-llama/llama-4-maverick-17b-128e-instruct',
    'meta-llama/llama-4-scout-17b-16e-instruct',
    'moonshotai/kimi-k2-instruct',
    'moonshotai/kimi-k2-instruct-0905',
    'openai/gpt-oss-120b',
    'openai/gpt-oss-20b',
    'qwen/qwen3-32b'
]

In [7]:
N_TASKS = len(ds_train)
N_CONCURRENT_TASKS = 10

client = AsyncGroq(api_key=os.environ.get("GROQ_API_KEY"))
api_semaphore = asyncio.Semaphore(N_CONCURRENT_TASKS)


async def create_chat_completion(prompt_args, callback: Callable = None, client=client, base_delay=5, max_retries=50):
    if not models_available:
        raise Exception("No models available")
        
    for attempt in range(1, max_retries + 1):
        if not models_available:
            break
            
        async with api_semaphore:
            try:
                model = random.choice(models_available)
                prompt = SYS_PROMPT.format(
                    sql_prompt=prompt_args['sql_prompt'],
                    sql_context=prompt_args['sql_context'],
                    sql=prompt_args['sql']
                )
                chat_completion = await client.chat.completions.create(
                    messages=[
                        {
                            "role": "system",
                            "content": prompt,
                        }
                    ],
                    model=model,
                )
                if chat_completion.choices[0].message.content != prompt_args['sql']:
                    if callback:
                        callback()
                    return {
                        'sql_prompt': prompt_args['sql_prompt'],
                        'sql_context': prompt_args['sql_context'],
                        'sql': prompt_args['sql'],
                        'model_used': model,
                        'completion': chat_completion.choices[0].message.content
                    }
                
            except RateLimitError as e:
                msg = str(e)
                if ("TPD" in msg) or ("RPD" in msg):
                    if model in models_available:
                        models_available.remove(model)
                        print(f"Model {model} is no longer available ({e})", flush=True)
                else:    
                    wait_time = base_delay * (2 ** (attempt - 1))
                    retry_after_match = re.search(r'Please try again in (\d+(?:\.\d+)?)s', msg)
                    if retry_after_match:
                        wait_time = float(retry_after_match.group(1))
                    # print(f'RateLimitError: attempt {attempt} out of {max_retries}, waiting {wait_time}s...', flush=True)
                    await asyncio.sleep(wait_time)
                
            except Exception as e:
                print(f'Other error: {e}', flush=True)
                
    # print(f'Exceeded maximum number of retries ({max_retries}) or no models available', flush=True)
    raise Exception(f'Exceeded maximum number of retries ({max_retries}) or no models available')



async def worker(worker_id, task_queue, results, pbar, stop_event):
    """Worker that processes tasks from the queue"""
    while not task_queue.empty() and not stop_event.is_set():
        try:
            try:
                prompt_args = await asyncio.wait_for(task_queue.get(), timeout=1.0)
            except asyncio.TimeoutError: # check stop_event
                continue
            
            if not models_available:
                print(f"Worker {worker_id}: No models available. Stopping.", flush=True)
                task_queue.put_nowait(prompt_args)
                stop_event.set()
                break
            
            try:
                result = await create_chat_completion(
                    prompt_args=prompt_args,
                    callback=lambda: pbar.update(1)
                )
                results.append(result)
                
            except Exception as e:
                print(f"Worker {worker_id}: Task failed with error: {e}", flush=True)
                
                if "No models available" in str(e):
                    print(f"Worker {worker_id}: Stopping due to no models available", flush=True)
                    task_queue.put_nowait(prompt_args)
                    stop_event.set()
                    break
                    
            finally:
                task_queue.task_done()
                
        except Exception as e:
            print(f"Worker {worker_id}: Unexpected error: {e}", flush=True)

    

async def main(dataset):
    """Main execution using queue-based approach"""
    if not models_available:
        print("ERROR: No models available at start. Cannot run any tasks.", flush=True)
        return []
    
    with tqdm(total=N_TASKS, desc='Generating samples') as pbar:
        task_queue = asyncio.Queue()
        results = []
        stop_event = asyncio.Event()
        
        for i in range(N_TASKS):
            await task_queue.put({
                'sql_prompt': dataset[i]['sql_prompt'],
                'sql_context': dataset[i]['sql_context'],
                'sql': dataset[i]['sql']
            })
        
        workers = []
        for i in range(N_CONCURRENT_TASKS):
            worker_task = asyncio.create_task(
                worker(i, task_queue, results, pbar, stop_event)
            )
            workers.append(worker_task)
        
        try:
            # Wait for queue to be empty OR stop event is set
            while not task_queue.empty() and not stop_event.is_set():
                # Check queue size periodically
                await asyncio.sleep(1)
                
                # Check if models are still available
                if not models_available:
                    print("No models available, stopping workers...", flush=True)
                    stop_event.set()
                    break
            
            # Wait for queue to empty or stop event
            if not stop_event.is_set():
                await task_queue.join()
            else:
                print("Stop event triggered, waiting for workers to finish current tasks...", flush=True)
                # Give workers time to finish current tasks
                await asyncio.sleep(2)
            
        except Exception as e:
            print(f"Error in main loop: {e}", flush=True)
            stop_event.set()
        
        finally:
            # Cancel all worker tasks
            for w in workers:
                w.cancel()
            
            # Wait for workers to be cancelled
            if workers:
                await asyncio.gather(*workers, return_exceptions=True)
        
        # Final check for remaining tasks in queue
        remaining_tasks = task_queue.qsize()
        if remaining_tasks > 0:
            print(f"Warning: {remaining_tasks} tasks were not processed", flush=True)
            # Update progress bar to reflect actual progress
            pbar.total = N_TASKS - remaining_tasks
        
        return results


# Run the main function
results = await main(ds_train)

Generating samples:   3%|█▊                                                                 | 2671/97500 [14:18<12:10:33,  2.16it/s]

Model llama-3.3-70b-versatile is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `llama-3.3-70b-versatile` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on tokens per day (TPD): Limit 100000, Used 99874, Requested 209. Please try again in 1m11.712s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}})


Generating samples:   3%|█▉                                                                  | 2776/97500 [15:00<8:50:54,  2.97it/s]

Model openai/gpt-oss-20b is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `openai/gpt-oss-20b` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on tokens per day (TPD): Limit 200000, Used 199713, Requested 309. Please try again in 9.504s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}})


Generating samples:   4%|██▍                                                                | 3531/97500 [20:06<10:15:12,  2.55it/s]

Model openai/gpt-oss-120b is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `openai/gpt-oss-120b` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on tokens per day (TPD): Limit 200000, Used 199759, Requested 481. Please try again in 1m43.68s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}})


Generating samples:   4%|██▋                                                                | 3829/97500 [22:20<15:34:45,  1.67it/s]

Other error: Error code: 503 - {'error': {'message': 'meta-llama/llama-4-maverick-17b-128e-instruct is currently over capacity. Please try again and back off exponentially. Visit https://groqstatus.com to see if there is an active incident.', 'type': 'internal_server_error'}}


Generating samples:   4%|██▋                                                                | 3999/97500 [23:34<14:06:39,  1.84it/s]

Model qwen/qwen3-32b is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `qwen/qwen3-32b` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on tokens per day (TPD): Limit 500000, Used 499905, Requested 475. Please try again in 1m5.664s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}})


Generating samples:   6%|███▋                                                               | 5424/97500 [33:57<16:32:13,  1.55it/s]

Model moonshotai/kimi-k2-instruct is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `moonshotai/kimi-k2-instruct` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on requests per day (RPD): Limit 1000, Used 1000, Requested 1. Please try again in 1m26.4s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'requests', 'code': 'rate_limit_exceeded'}})


Generating samples:   6%|███▋                                                               | 5430/97500 [34:00<12:34:47,  2.03it/s]

Model moonshotai/kimi-k2-instruct-0905 is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `moonshotai/kimi-k2-instruct-0905` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on requests per day (RPD): Limit 1000, Used 1000, Requested 1. Please try again in 1m26.4s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'requests', 'code': 'rate_limit_exceeded'}})


Generating samples:   6%|████                                                               | 5980/97500 [42:12<29:06:54,  1.15s/it]

Model meta-llama/llama-4-scout-17b-16e-instruct is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `meta-llama/llama-4-scout-17b-16e-instruct` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on requests per day (RPD): Limit 1000, Used 1000, Requested 1. Please try again in 1m26.4s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'requests', 'code': 'rate_limit_exceeded'}})


Generating samples:   6%|████▏                                                              | 6175/97500 [46:59<40:36:15,  1.60s/it]

Model meta-llama/llama-4-maverick-17b-128e-instruct is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `meta-llama/llama-4-maverick-17b-128e-instruct` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on requests per day (RPD): Limit 1000, Used 1000, Requested 1. Please try again in 1m26.4s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'requests', 'code': 'rate_limit_exceeded'}})


Generating samples:   7%|████▍                                                              | 6411/97500 [59:06<82:01:02,  3.24s/it]

Worker 6: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                              | 6416/97500 [59:23<81:51:21,  3.24s/it]

Worker 9: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▎                                                            | 6431/97500 [1:00:07<71:22:33,  2.82s/it]

Worker 3: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▎                                                            | 6501/97500 [1:03:47<78:55:07,  3.12s/it]

Worker 1: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▎                                                            | 6503/97500 [1:03:52<74:24:07,  2.94s/it]

Worker 7: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▎                                                            | 6554/97500 [1:06:28<74:09:23,  2.94s/it]

Worker 2: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6563/97500 [1:06:59<85:07:29,  3.37s/it]

Worker 4: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6607/97500 [1:09:09<79:41:55,  3.16s/it]

Worker 3: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6633/97500 [1:10:27<78:31:07,  3.11s/it]

Worker 9: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6660/97500 [1:11:45<74:04:29,  2.94s/it]

Worker 5: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6691/97500 [1:13:18<72:45:06,  2.88s/it]

Worker 0: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6704/97500 [1:13:57<74:43:44,  2.96s/it]

Worker 2: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6719/97500 [1:14:41<82:31:37,  3.27s/it]

Worker 1: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▍                                                            | 6748/97500 [1:16:11<76:44:16,  3.04s/it]

Worker 4: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6788/97500 [1:18:15<82:36:07,  3.28s/it]

Worker 8: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6847/97500 [1:21:15<75:22:48,  2.99s/it]

Worker 9: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6874/97500 [1:22:32<69:04:49,  2.74s/it]

Worker 0: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6883/97500 [1:22:56<72:39:55,  2.89s/it]

Worker 1: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6930/97500 [1:25:16<71:40:19,  2.85s/it]

Worker 5: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▌                                                            | 6933/97500 [1:25:24<69:44:47,  2.77s/it]

Worker 6: Task failed with error: Exceeded maximum number of retries (50) or no models available
Worker 3: Task failed with error: Exceeded maximum number of retries (50) or no models available


Generating samples:   7%|████▋                                                            | 6942/97500 [1:25:52<75:52:50,  3.02s/it]

Model llama-3.1-8b-instant is no longer available (Error code: 429 - {'error': {'message': 'Rate limit reached for model `llama-3.1-8b-instant` in organization `org_01kep40x18enbt3npfx9edhkwf` service tier `on_demand` on tokens per day (TPD): Limit 500000, Used 499734, Requested 346. Please try again in 13.824s. Need more tokens? Upgrade to Dev Tier today at https://console.groq.com/settings/billing', 'type': 'tokens', 'code': 'rate_limit_exceeded'}})
Worker 2: Task failed with error: Exceeded maximum number of retries (50) or no models available
Worker 2: No models available. Stopping.
No models available, stopping workers...
Stop event triggered, waiting for workers to finish current tasks...
Worker 0: Task failed with error: Exceeded maximum number of retries (50) or no models available
Worker 9: Task failed with error: Exceeded maximum number of retries (50) or no models available
Worker 3: Task failed with error: Exceeded maximum number of retries (50) or no models available
Worke

Generating samples:   7%|████▋                                                            | 6943/97500 [1:25:54<69:58:43,  2.78s/it]



Generating samples: 100%|████████████████████████████████████████████████████████████████████▋| 6943/6973 [1:25:54<00:22,  1.35it/s]


In [8]:
results_df = pd.DataFrame({
    'sql_prompt': [result['sql_prompt'] for result in results],
    'sql_context': [result['sql_context'] for result in results],
    'sql': [result['sql'] for result in results],
    'model_used': [result['model_used'] for result in results],
    'completion': [result['completion'] for result in results]
})
print(results_df.shape)
results_df

(6943, 5)


Unnamed: 0,sql_prompt,sql_context,sql,model_used,completion
0,What is the average moisture level for each cr...,"CREATE TABLE crop_moisture (id INT, crop_id IN...","SELECT type, AVG(moisture) as avg_moisture FRO...",meta-llama/llama-4-maverick-17b-128e-instruct,"SELECT type, moisture as avg_moisture FROM cro..."
1,Add a new job title called 'Data Science Manag...,CREATE TABLE JobTitle (JobTitleID INT PRIMARY ...,"INSERT INTO JobTitle (JobTitleID, JobTitleName...",meta-llama/llama-4-maverick-17b-128e-instruct,"INSERT INTO JobTitel (JobTitleID, JobTitleName..."
2,What is the total number of military equipment...,CREATE TABLE MaintenanceRequests (RequestID IN...,SELECT COUNT(*) FROM MaintenanceRequests WHERE...,meta-llama/llama-4-scout-17b-16e-instruct,SELECT COUNT(*) FROM MaintenanceRequests WHERE...
3,Insert a new record into the 'community_educat...,"CREATE TABLE community_education (id INT, prog...","INSERT INTO community_education (id, program, ...",moonshotai/kimi-k2-instruct,"""INSERT INTO community_education (id, program,..."
4,How many users signed up daily in the 'games' ...,"CREATE TABLE signups (user_id INT, category TE...","SELECT DATE(timestamp) as signup_date, COUNT(D...",moonshotai/kimi-k2-instruct,"SELECT DATE(timestamp) as signup_date, COUNT(u..."
...,...,...,...,...,...
6938,What are the top 5 most sold garments by sales...,"CREATE TABLE garments (id INT, name VARCHAR(10...","SELECT garments.name, garments_sales.total_sol...",llama-3.1-8b-instant,"SELECT garments.name, garments_sales.total_sol..."
6939,List all the bus stops in the city of Santiago...,"bus_stops (id, name, city, country, issues)",SELECT bus_stops.* FROM bus_stops WHERE bus_st...,llama-3.1-8b-instant,SELECT bus_stops.* FROM bus_stops WHERE bus_st...
6940,How many players from each continent play Non-...,"CREATE TABLE countries (id INT, name VARCHAR(2...","SELECT c.continent, COUNT(DISTINCT p.id) as nu...",llama-3.1-8b-instant,"INSERT INTO countries ('3', 'Canada', 'North A..."
6941,What is the total mass of space debris in the ...,"CREATE TABLE space_debris (debris_id INT, name...",SELECT SUM(mass) FROM space_debris WHERE origi...,llama-3.1-8b-instant,SELECT SUM(mass) FROM space_debris WHERE origi...


In [17]:
results_df.value_counts('model_used')

model_used
llama-3.1-8b-instant                             1700
moonshotai/kimi-k2-instruct                      1000
meta-llama/llama-4-scout-17b-16e-instruct        1000
meta-llama/llama-4-maverick-17b-128e-instruct     999
moonshotai/kimi-k2-instruct-0905                  999
openai/gpt-oss-120b                               347
qwen/qwen3-32b                                    337
llama-3.3-70b-versatile                           312
openai/gpt-oss-20b                                249
Name: count, dtype: int64

In [14]:
# results_df.to_csv('rm_data.xlsx', index=False)