In [1]:
import random
import json
from datetime import datetime, timedelta
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    logging
)
from peft import LoraConfig, get_peft_model

In [2]:
# Define basic vocabulary
entities = ["customers", "employees", "orders", "products", "transactions", "accounts"]
attributes = {
    "customers": ["age", "city", "join_date", "total_spent", "membership_status"],
    "employees": ["age", "department", "hire_date", "salary", "designation"],
    "orders": ["amount", "order_date", "status", "region", "customer_id"],
    "products": ["price", "category", "stock", "launch_date"],
    "transactions": ["amount", "date", "type", "account_id"],
    "accounts": ["balance", "open_date", "branch", "account_type"],
}
comparison_ops = ['greater than', "less than", "equal to", "not equal to", "above", "below"]
logical_ops = ["and", "or"]
negations = ["not", "excluding", "without", "other than", "except"]
aggregations = ["count of", "total", "average", "maximum", "minimum", "top 5", "top 10", "bottom 5", "sum of"]
cities = ["Mumbai", "Delhi", "Bangalore", "Pune", "Hyderabad", "Chennai"]
departments = ["HR", "Finance", "Engineering", "Sales", "Marketing"]
statuses = ["active", "inactive", "pending", "completed", "cancelled"]
categories = ["electronics", "fashion", "grocery", "furniture", "toys"]

In [3]:
def random_date(start_year=2015, end_year=2025):
    start = datetime(start_year, 1, 1)
    end = datetime(end_year, 12, 31)
    return start + timedelta(days=random.randint(0, (end - start).days))

def random_filter(entity):
    attr = random.choice(attributes[entity])
    if "date" in attr:
        if random.random() > 0.5:
            d1 = random_date()
            d2 = d1 + timedelta(days=random.randint(10, 300))
            return f"{attr} between {d1.strftime('%Y-%m-%d')} and {d2.strftime('%Y-%m-%d')}"
        else:
            date_str = random_date().strftime('%Y-%m-%d')
            return f"{attr} after {date_str}"
    elif attr in ["city", "region"]:
        return f"{attr} is {random.choice(cities)}"
    elif attr == "department":
        return f"{attr} is {random.choice(departments)}"
    elif attr == "category":
        return f"{attr} is {random.choice(categories)}"
    elif attr in ["status", "membership_status"]:
        return f"{attr} is {random.choice(statuses)}"
    else:
        op = random.choice(comparison_ops)
        val = random.randint(10, 100)
        return f"{attr} {op} {val}"


In [4]:
def parse_condition(raw):
    parts = raw.split()
    field = parts[0]
    
    if "between" in raw:
        dates = raw.split("between")[1].strip().split("and")
        return {
            "field": field,
            "operator": "between",
            "value": [dates[0].strip(), dates[1].strip()]
        }
    elif "after" in raw:
        return {
            "field": field,
            "operator": ">",
            "value": raw.split("after")[1].strip()
        }
    elif "is" in raw:
        return {
            "field": field,
            "operator": "=",
            "value": raw.split("is")[1].strip()
        }
    else:
        op_map = {
            "greater than": ">",
            "less than": "<",
            "equal to": "=",
            "not equal to": "!=",
            "above": ">",
            "below": "<"
        }
        for op_text, symbol in op_map.items():
            if op_text in raw:
                return {
                    "field": field,
                    "operator": symbol,
                    "value": int(parts[-1])
                }
        return {"field": field, "operator": "=", "value": parts[-1]}


In [5]:
def random_query():
    entity = random.choice(entities)
    num_filters = random.choice([1, 2, 3])
    raw_filters = [random_filter(entity) for _ in range(num_filters)]
    conditions = [parse_condition(f) for f in raw_filters]
    
    agg = None
    limit = None
    order = None
    order_by = None
    
    if random.random() < 0.4:
        agg = random.choice(aggregations)
        if "top" in agg:
            order = "desc"
            limit = int(agg.split()[1])
            order_by = conditions[0]["field"]
        elif "bottom" in agg:
            order = "asc"
            limit = int(agg.split()[1])
            order_by = conditions[0]["field"]
    
    phrasing = f"Show the {agg if agg else ''} {entity} where " + " and ".join(raw_filters)
    return phrasing.strip(), entity, conditions, limit, order, order_by

In [6]:
def generate_synthetic_dataset(n=2000):
    dataset = []
    for _ in range(n):
        instruction, entity, conditions, limit, order, order_by = random_query()
        response = {
            "entity": entity,
            "conditions": conditions
        }
        if limit:
            response["limit"] = limit
        if order:
            response["order"] = order
        if order_by:
            response["order_by"] = order_by
        
        dataset.append({
            "instruction": instruction,
            "response": response
        })
    
    random.shuffle(dataset)
    split = int(0.8 * len(dataset))
    train_data, test_data = dataset[:split], dataset[split:]
    
    with open("train_data.json", "w") as f:
        json.dump(train_data, f, indent=2)

    with open("test_data.json", "w") as f:
        json.dump(test_data, f, indent=2)

In [7]:
# Run the generator
generate_synthetic_dataset(2000)