# Create Gen2 Warehouses

Create 3 POC warehouses for testing with auto-suspend and auto-resume.
- Gen 1: Size n
- Gen 2: Size n
- Gen 2: Size n-1


In [None]:
--Create Gen1:Size n
CREATE OR REPLACE WAREHOUSE POC_GEN1_S WAREHOUSE_SIZE = SMALL
AUTO_RESUME=TRUE AUTO_SUSPEND=60 INITIALLY_SUSPENDED=TRUE RESOURCE_CONSTRAINT = STANDARD_GEN_1 COMMENT = 'My Small POC Gen1 WH';

--Create Gen2:Size n
CREATE OR REPLACE WAREHOUSE POC_GEN2_S WAREHOUSE_SIZE = SMALL
AUTO_RESUME=TRUE AUTO_SUSPEND=60 INITIALLY_SUSPENDED=TRUE RESOURCE_CONSTRAINT = STANDARD_GEN_2 COMMENT = 'My Small POC Gen2 WH';

--Create Gen2:Size n-1
CREATE OR REPLACE WAREHOUSE POC_GEN2_XS WAREHOUSE_SIZE = XSMALL
AUTO_RESUME=TRUE AUTO_SUSPEND=60 INITIALLY_SUSPENDED=TRUE RESOURCE_CONSTRAINT = STANDARD_GEN_2 COMMENT = 'My XSmall POC Gen2 WH';

# Invalidate Cache

This is the most critical procedural step for a pure performance comparison. Both the warehouse's local disk cache and Snowflake's global result cache must be deliberately invalidated before each test run.

**Warehouse Cache**: The local SSD cache on a virtual warehouse stores data recently accessed from remote storage. To ensure this cache is cleared and does not influence subsequent runs, the warehouse must be suspended and resumed before each test execution.
- ALTER WAREHOUSE <poc_warehouse_name> SUSPEND;
- ALTER WAREHOUSE <poc_warehouse_name> RESUME;

**Result Cache**: Snowflake maintains a result cache that returns the results of previously executed queries without re-computation. For this POC, the result cache must be disabled at the session level to force the query to execute from scratch every time.
- ALTER SESSION SET USE_CACHED_RESULT = FALSE;

# Tag Workload

Ensure that the **QUERY_TAGS** at the session level are applied fore each workload execution.
- ALTER SESSION SET QUERY_TAG = 'POC_GEN2_ETL_WORKLOAD_RUN_1';

# Execute Workloads
The rationale described above is codified in the Python Script below and produces test functions in 4 categories that can be run in threaded or async mode against Snowflake TPC-H Sample Data as described here:https://docs.snowflake.com/en/user-guide/sample-data-tpch

- light
- medium
- heavy
- custom (choose the number of concurrent users and number of runs)

Execute the functions against Gen 1, Gen 2 and downsized Gen2 Warehouses and compare the results.

In [None]:
import pandas as pd
import time
import asyncio
import threading
import concurrent.futures
from typing import Dict, List, Tuple
from snowflake.snowpark import Session
import statistics
from datetime import datetime

# Kim Njeru 07/24/2025 Script to create test harness with TPC-H queries
# Get the current session
session = Session.builder.getOrCreate()

# TPC-H Queries
TPCH_QUERIES = {
    'Q1': """
    SELECT
        l_returnflag,
        l_linestatus,
        sum(l_quantity) as sum_qty,
        sum(l_extendedprice) as sum_base_price,
        sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
        sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
        avg(l_quantity) as avg_qty,
        avg(l_extendedprice) as avg_price,
        avg(l_discount) as avg_disc,
        count(*) as count_order
    FROM
        snowflake_sample_data.tpch_sf1.lineitem
    WHERE
        l_shipdate <= dateadd(day, -90, '1998-12-01'::date)
    GROUP BY
        l_returnflag,
        l_linestatus
    ORDER BY
        l_returnflag,
        l_linestatus;
    """,
    
    'Q3': """
    SELECT
        l_orderkey,
        sum(l_extendedprice * (1 - l_discount)) as revenue,
        o_orderdate,
        o_shippriority
    FROM
        snowflake_sample_data.tpch_sf1.customer,
        snowflake_sample_data.tpch_sf1.orders,
        snowflake_sample_data.tpch_sf1.lineitem
    WHERE
        c_mktsegment = 'BUILDING'
        AND c_custkey = o_custkey
        AND l_orderkey = o_orderkey
        AND o_orderdate < '1995-03-15'::date
        AND l_shipdate > '1995-03-15'::date
    GROUP BY
        l_orderkey,
        o_orderdate,
        o_shippriority
    ORDER BY
        revenue desc,
        o_orderdate
    LIMIT 10;
    """,
    
    'Q6': """
    SELECT
        sum(l_extendedprice * l_discount) as revenue
    FROM
        snowflake_sample_data.tpch_sf1.lineitem
    WHERE
        l_shipdate >= '1994-01-01'::date
        AND l_shipdate < dateadd(year, 1, '1994-01-01'::date)
        AND l_discount between 0.06 - 0.01 AND 0.06 + 0.01
        AND l_quantity < 24;
    """,
    
    'Q10': """
    SELECT
        c_custkey,
        c_name,
        sum(l_extendedprice * (1 - l_discount)) as revenue,
        c_acctbal,
        n_name,
        c_address,
        c_phone,
        c_comment
    FROM
        snowflake_sample_data.tpch_sf1.customer,
        snowflake_sample_data.tpch_sf1.orders,
        snowflake_sample_data.tpch_sf1.lineitem,
        snowflake_sample_data.tpch_sf1.nation
    WHERE
        c_custkey = o_custkey
        AND l_orderkey = o_orderkey
        AND o_orderdate >= '1993-10-01'::date
        AND o_orderdate < dateadd(month, 3, '1993-10-01'::date)
        AND l_returnflag = 'R'
        AND c_nationkey = n_nationkey
    GROUP BY
        c_custkey,
        c_name,
        c_acctbal,
        c_phone,
        n_name,
        c_address,
        c_comment
    ORDER BY
        revenue desc
    LIMIT 20;
    """
}

# WAREHOUSE AND CACHE MANAGEMENT

def setup_warehouse_and_cache(warehouse_name: str, query_tag: str):
    """Setup warehouse, invalidate cache, disable result cache, and set query tag"""
    print(f"\n🔧 Setting up warehouse and cache configuration...")
    
    try:
        # 1. Select warehouse
        print(f"📍 Switching to warehouse: {warehouse_name}")
        session.sql(f"USE WAREHOUSE {warehouse_name}").collect()
        
        # 2. Suspend and resume warehouse to invalidate cache
        print(f"⏸️  Attempting to suspend warehouse {warehouse_name}...")
        try:
            session.sql(f"ALTER WAREHOUSE {warehouse_name} SUSPEND").collect()
            print(f"✅ Warehouse suspended")
        except Exception as suspend_error:
            error_msg = str(suspend_error).lower()
            if "already suspended" in error_msg or "invalid state" in error_msg:
                print(f"ℹ️  Warehouse {warehouse_name} is already suspended")
            else:
                print(f"⚠️  Suspend warning: {suspend_error}")
        
        print(f"▶️  Resuming warehouse {warehouse_name}...")
        session.sql(f"ALTER WAREHOUSE {warehouse_name} RESUME").collect()
        print(f"✅ Warehouse resumed")
        
        # Brief pause to ensure warehouse is ready
        time.sleep(3)
        
        # 3. Disable result cache at session level
        print(f"🚫 Disabling result cache for this session...")
        session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect()
        print(f"✅ Result cache disabled")
        
        # 4. Set query tag
        print(f"🏷️  Setting query tag: {query_tag}")
        session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}'").collect()
        print(f"✅ Query tag set")
        
        print(f"✅ Setup complete!")
        print(f"   - Warehouse: {warehouse_name}")
        print(f"   - Result Cache: DISABLED")
        print(f"   - Query Tag: {query_tag}")
        
        return True
        
    except Exception as e:
        print(f"❌ Error setting up warehouse and cache: {e}")
        return False

async def setup_warehouse_and_cache_async(warehouse_name: str, query_tag: str):
    """Async version of warehouse setup"""
    print(f"\n🔧 Setting up warehouse and cache configuration (ASYNC)...")
    
    try:
        loop = asyncio.get_event_loop()
        
        # 1. Select warehouse
        print(f"📍 Switching to warehouse: {warehouse_name}")
        await loop.run_in_executor(None, lambda: session.sql(f"USE WAREHOUSE {warehouse_name}").collect())
        
        # 2. Suspend and resume warehouse
        print(f"⏸️  Attempting to suspend warehouse {warehouse_name}...")
        try:
            await loop.run_in_executor(None, lambda: session.sql(f"ALTER WAREHOUSE {warehouse_name} SUSPEND").collect())
            print(f"✅ Warehouse suspended")
        except Exception as suspend_error:
            error_msg = str(suspend_error).lower()
            if "already suspended" in error_msg or "invalid state" in error_msg:
                print(f"ℹ️  Warehouse {warehouse_name} is already suspended")
            else:
                print(f"⚠️  Suspend warning: {suspend_error}")
        
        print(f"▶️  Resuming warehouse {warehouse_name}...")
        await loop.run_in_executor(None, lambda: session.sql(f"ALTER WAREHOUSE {warehouse_name} RESUME").collect())
        print(f"✅ Warehouse resumed")
        
        await asyncio.sleep(3)
        
        # 3. Disable result cache
        print(f"🚫 Disabling result cache...")
        await loop.run_in_executor(None, lambda: session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect())
        
        # 4. Set query tag
        print(f"🏷️  Setting query tag: {query_tag}")
        await loop.run_in_executor(None, lambda: session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}'").collect())
        
        print(f"✅ Async setup complete!")
        return True
        
    except Exception as e:
        print(f"❌ Error in async setup: {e}")
        return False

# CORE BENCHMARK FUNCTIONS

def run_tpch_benchmark():
    """Run the TPC-H benchmark queries - ORIGINAL SYNCHRONOUS VERSION"""
    results = {}
    timing_results = {}
    
    print("🚀 Starting TPC-H Benchmark")
    print("=" * 40)
    
    for query_name, query in TPCH_QUERIES.items():
        print(f"\n🔄 Executing {query_name}...")
        start_time = time.time()
        
        try:
            df = session.sql(query).to_pandas()
            end_time = time.time()
            execution_time = end_time - start_time
            
            results[query_name] = df
            timing_results[query_name] = execution_time
            
            print(f"✅ {query_name} completed in {execution_time:.2f} seconds")
            print(f"📊 Returned {len(df)} rows")
            
        except Exception as e:
            print(f"❌ Error executing {query_name}: {e}")
            timing_results[query_name] = 0
    
    total_time = sum(timing_results.values())
    print(f"\n📈 EXECUTION SUMMARY:")
    print("=" * 40)
    for query_name, exec_time in timing_results.items():
        print(f"{query_name}: {exec_time:.2f}s")
    print(f"\nTotal execution time: {total_time:.2f}s")
    print(f"Average query time: {total_time/len(timing_results):.2f}s")
    
    return results, timing_results

def run_single_benchmark(run_id: int, query_tag: str) -> Dict:
    """Execute a single benchmark run for threading"""
    start_time = time.time()
    
    try:
        # Set session parameters
        session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect()
        session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}_run_{run_id}'").collect()
        
        # Suppress output for load testing
        import sys
        from io import StringIO
        old_stdout = sys.stdout
        sys.stdout = StringIO()
        
        results, timing_results = run_tpch_benchmark()
        
        sys.stdout = old_stdout
        
        end_time = time.time()
        total_execution_time = end_time - start_time
        
        return {
            'run_id': run_id,
            'total_execution_time': total_execution_time,
            'individual_timings': timing_results,
            'total_queries': len(timing_results),
            'status': 'SUCCESS',
            'error': None,
            'timestamp': datetime.now()
        }
        
    except Exception as e:
        sys.stdout = old_stdout
        end_time = time.time()
        total_execution_time = end_time - start_time
        
        return {
            'run_id': run_id,
            'total_execution_time': total_execution_time,
            'individual_timings': {},
            'total_queries': 0,
            'status': 'ERROR',
            'error': str(e),
            'timestamp': datetime.now()
        }

async def run_single_benchmark_async(run_id: int, query_tag: str) -> Dict:
    """Execute a single benchmark run for async"""
    start_time = time.time()
    
    try:
        loop = asyncio.get_event_loop()
        
        # Set session parameters
        await loop.run_in_executor(None, lambda: session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect())
        await loop.run_in_executor(None, lambda: session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}_run_{run_id}'").collect())
        
        # Run benchmark in executor
        results, timing_results = await loop.run_in_executor(None, run_tpch_benchmark_silent)
        
        end_time = time.time()
        total_execution_time = end_time - start_time
        
        return {
            'run_id': run_id,
            'total_execution_time': total_execution_time,
            'individual_timings': timing_results,
            'total_queries': len(timing_results),
            'status': 'SUCCESS',
            'error': None,
            'timestamp': datetime.now()
        }
        
    except Exception as e:
        end_time = time.time()
        total_execution_time = end_time - start_time
        
        return {
            'run_id': run_id,
            'total_execution_time': total_execution_time,
            'individual_timings': {},
            'total_queries': 0,
            'status': 'ERROR',
            'error': str(e),
            'timestamp': datetime.now()
        }

def run_tpch_benchmark_silent():
    """Silent version of benchmark for async execution"""
    results = {}
    timing_results = {}
    
    for query_name, query in TPCH_QUERIES.items():
        start_time = time.time()
        try:
            df = session.sql(query).to_pandas()
            end_time = time.time()
            execution_time = end_time - start_time
            
            results[query_name] = df
            timing_results[query_name] = execution_time
        except Exception as e:
            timing_results[query_name] = 0
    
    return results, timing_results

# THREADED LOAD TESTING

def run_threaded_load_test(warehouse_name: str, concurrent_users: int = 10, total_runs: int = None, test_name: str = "THREADED_LOAD_TEST"):
    """Run threaded load test"""
    total_runs = total_runs or concurrent_users
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    query_tag = f"{test_name}_{warehouse_name}_{concurrent_users}users_{timestamp}"
    
    print(f"🚀 Starting THREADED TPC-H Load Test")
    print(f"📊 Test Name: {test_name}")
    print(f"📊 Warehouse: {warehouse_name}")
    print(f"📊 Concurrent Users: {concurrent_users}")
    print(f"📊 Total Runs: {total_runs}")
    print(f"🏷️  Query Tag: {query_tag}")
    print(f"⏰ Start Time: {datetime.now()}")
    print("=" * 60)
    
    # Setup warehouse and cache
    if not setup_warehouse_and_cache(warehouse_name, query_tag):
        print("❌ Failed to setup warehouse and cache. Aborting test.")
        return None
    
    results = []
    start_time = time.time()
    
    # Use ThreadPoolExecutor
    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_users) as executor:
        futures = [
            executor.submit(run_single_benchmark, i, query_tag) 
            for i in range(total_runs)
        ]
        
        for i, future in enumerate(concurrent.futures.as_completed(futures)):
            result = future.result()
            results.append(result)
            
            if (i + 1) % max(1, total_runs // 10) == 0:
                print(f"✅ Completed {i + 1}/{total_runs} runs")
    
    end_time = time.time()
    total_test_time = end_time - start_time
    
    analyze_results(results, total_test_time, concurrent_users, warehouse_name, query_tag, "THREADED")
    return results

# ASYNC LOAD TESTING

async def run_async_load_test(warehouse_name: str, concurrent_users: int = 10, total_runs: int = None, test_name: str = "ASYNC_LOAD_TEST"):
    """Run async load test"""
    total_runs = total_runs or concurrent_users
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    query_tag = f"{test_name}_{warehouse_name}_{concurrent_users}users_{timestamp}"
    
    print(f"🚀 Starting ASYNC TPC-H Load Test")
    print(f"📊 Test Name: {test_name}")
    print(f"📊 Warehouse: {warehouse_name}")
    print(f"📊 Concurrent Users: {concurrent_users}")
    print(f"📊 Total Runs: {total_runs}")
    print(f"🏷️  Query Tag: {query_tag}")
    print(f"⏰ Start Time: {datetime.now()}")
    print("=" * 60)
    
    # Setup warehouse and cache
    if not await setup_warehouse_and_cache_async(warehouse_name, query_tag):
        print("❌ Failed to setup warehouse and cache. Aborting test.")
        return None
    
    start_time = time.time()
    
    # Create semaphore to limit concurrent connections
    semaphore = asyncio.Semaphore(concurrent_users)
    
    async def run_with_semaphore(run_id):
        async with semaphore:
            return await run_single_benchmark_async(run_id, query_tag)
    
    # Create and run all tasks
    tasks = [run_with_semaphore(i) for i in range(total_runs)]
    results = []
    completed = 0
    
    for coro in asyncio.as_completed(tasks):
        result = await coro
        results.append(result)
        completed += 1
        
        if completed % max(1, total_runs // 10) == 0:
            print(f"✅ Completed {completed}/{total_runs} runs")
    
    end_time = time.time()
    total_test_time = end_time - start_time
    
    analyze_results(results, total_test_time, concurrent_users, warehouse_name, query_tag, "ASYNC")
    return results

# RESULTS ANALYSIS

def analyze_results(results: List[Dict], total_test_time: float, concurrent_users: int, warehouse_name: str, query_tag: str, execution_model: str):
    """Analyze and display load test results"""
    
    successful_runs = [r for r in results if r['status'] == 'SUCCESS']
    failed_runs = [r for r in results if r['status'] == 'ERROR']
    
    print(f"\n📈 {execution_model} LOAD TEST RESULTS")
    print("=" * 60)
    print(f"🎯 Test Configuration:")
    print(f"   - Execution Model: {execution_model}")
    print(f"   - Warehouse: {warehouse_name}")
    print(f"   - Query Tag: {query_tag}")
    print(f"   - Concurrent Users: {concurrent_users}")
    print(f"   - Total Runs: {len(results)}")
    print(f"   - Total Test Time: {total_test_time:.2f}s")
    
    print(f"\n📊 Success/Failure:")
    print(f"   - Successful Runs: {len(successful_runs)}")
    print(f"   - Failed Runs: {len(failed_runs)}")
    print(f"   - Success Rate: {len(successful_runs)/len(results)*100:.1f}%")
    
    if successful_runs:
        benchmark_times = [r['total_execution_time'] for r in successful_runs]
        
        print(f"\n⏱️  Benchmark Execution Times:")
        print(f"   - Min: {min(benchmark_times):.3f}s")
        print(f"   - Max: {max(benchmark_times):.3f}s")
        print(f"   - Average: {statistics.mean(benchmark_times):.3f}s")
        print(f"   - Median: {statistics.median(benchmark_times):.3f}s")
        if len(benchmark_times) > 1:
            print(f"   - Std Dev: {statistics.stdev(benchmark_times):.3f}s")
        
        print(f"\n🚀 Performance Metrics:")
        print(f"   - Benchmarks per Second: {len(successful_runs)/total_test_time:.2f}")
        print(f"   - Queries per Second: {len(successful_runs) * len(TPCH_QUERIES)/total_test_time:.2f}")
        
        # Individual query performance
        query_stats = {}
        for run in successful_runs:
            for query_name, exec_time in run['individual_timings'].items():
                if query_name not in query_stats:
                    query_stats[query_name] = []
                query_stats[query_name].append(exec_time)
        
        print(f"\n📊 Individual Query Performance:")
        for query_name in sorted(query_stats.keys()):
            times = query_stats[query_name]
            print(f"   {query_name}: Avg={statistics.mean(times):.3f}s, Min={min(times):.3f}s, Max={max(times):.3f}s")
    
    if failed_runs:
        print(f"\n❌ Errors ({len(failed_runs)} total):")
        error_counts = {}
        for run in failed_runs:
            error = run['error']
            error_counts[error] = error_counts.get(error, 0) + 1
        
        for error, count in error_counts.items():
            print(f"   - {error}: {count} times")

# CONVENIENCE FUNCTIONS

def light_threaded_test(warehouse_name: str, test_name: str = "LIGHT_THREADED"):
    return run_threaded_load_test(warehouse_name, 10, 20, test_name)

def medium_threaded_test(warehouse_name: str, test_name: str = "MEDIUM_THREADED"):
    return run_threaded_load_test(warehouse_name, 50, 100, test_name)

def heavy_threaded_test(warehouse_name: str, test_name: str = "HEAVY_THREADED"):
    return run_threaded_load_test(warehouse_name, 100, 200, test_name)

# Async convenience functions with sync wrappers
def run_async_test(coro):
    """Helper to run async functions from sync context"""
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running():
            import concurrent.futures
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(asyncio.run, coro)
                return future.result()
        else:
            return loop.run_until_complete(coro)
    except RuntimeError:
        return asyncio.run(coro)

def light_async_test(warehouse_name: str, test_name: str = "LIGHT_ASYNC"):
    return run_async_test(run_async_load_test(warehouse_name, 10, 20, test_name))

def medium_async_test(warehouse_name: str, test_name: str = "MEDIUM_ASYNC"):
    return run_async_test(run_async_load_test(warehouse_name, 50, 100, test_name))

def heavy_async_test(warehouse_name: str, test_name: str = "HEAVY_ASYNC"):
    return run_async_test(run_async_load_test(warehouse_name, 100, 200, test_name))

# USAGE INSTRUCTIONS
print("🚀 TPC-H Load Testing Script Ready!")
print("\n📋 Available Test Functions:")
print("=" * 50)
print("THREADED TESTS:")
print("1. light_threaded_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("2. medium_threaded_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("3. heavy_threaded_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("4. run_threaded_load_test('WH', concurrent_users, total_runs, 'TEST')")

print("\nASYNC TESTS:")
print("1. light_async_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("2. medium_async_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("3. heavy_async_test('WAREHOUSE_NAME', 'TEST_NAME')")
print("4. run_async_test(run_async_load_test('WH', concurrent_users, total_runs, 'TEST'))")

print("\nORIGINAL SYNC:")
print("1. results, timings = run_tpch_benchmark()")

print("\nExample Usage:")
print("# Compare threaded vs async")
print("threaded_results = light_threaded_test('POC_GEN2', 'THREADED_TEST')")
print("async_results = light_async_test('POC_GEN2', 'ASYNC_TEST')")

In [None]:
import pandas as pd
import time
import asyncio
import concurrent.futures
from typing import Dict, List
from snowflake.snowpark import Session
import statistics
from datetime import datetime

# Get session
session = Session.builder.getOrCreate()

# TPC-H Queries (keeping only essential ones for brevity)
TPCH_QUERIES = {
    'Q1': """
    SELECT l_returnflag, l_linestatus, sum(l_quantity) as sum_qty,
           sum(l_extendedprice) as sum_base_price, count(*) as count_order
    FROM snowflake_sample_data.tpch_sf1.lineitem
    WHERE l_shipdate <= dateadd(day, -90, '1998-12-01'::date)
    GROUP BY l_returnflag, l_linestatus
    ORDER BY l_returnflag, l_linestatus;
    """,
    'Q6': """
    SELECT sum(l_extendedprice * l_discount) as revenue
    FROM snowflake_sample_data.tpch_sf1.lineitem
    WHERE l_shipdate >= '1994-01-01'::date 
          AND l_shipdate < '1995-01-01'::date
          AND l_discount between 0.05 AND 0.07
          AND l_quantity < 24;
    """
}

def setup_warehouse(warehouse_name: str, query_tag: str) -> bool:
    """Setup warehouse and disable caching with proper error handling"""
    try:
        # Use warehouse
        session.sql(f"USE WAREHOUSE {warehouse_name}").collect()
        
        # Try to clear cache by suspending/resuming, but handle errors gracefully
        try:
            # Check current warehouse state first
            state_result = session.sql(f"SHOW WAREHOUSES LIKE '{warehouse_name}'").collect()
            if state_result:
                current_state = state_result[0]['state']
                print(f"Warehouse {warehouse_name} current state: {current_state}")
                
                # Only suspend if warehouse is running
                if current_state == 'STARTED':
                    session.sql(f"ALTER WAREHOUSE {warehouse_name} SUSPEND").collect()
                    print(f"Warehouse {warehouse_name} suspended")
                elif current_state == 'SUSPENDED':
                    print(f"Warehouse {warehouse_name} already suspended")
                
                # Resume warehouse
                session.sql(f"ALTER WAREHOUSE {warehouse_name} RESUME").collect()
                print(f"Warehouse {warehouse_name} resumed")
                
        except Exception as warehouse_error:
            print(f"Warehouse suspend/resume warning (continuing anyway): {warehouse_error}")
            # Continue execution even if suspend/resume fails
        
        # Disable result cache and set query tag
        session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect()
        session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}'").collect()
        
        # Brief pause to ensure warehouse is ready
        time.sleep(2)
        return True
        
    except Exception as e:
        print(f"Setup error: {e}")
        return False

def run_benchmark() -> Dict:
    """Run TPC-H benchmark and return timing results"""
    results = {}
    for query_name, query in TPCH_QUERIES.items():
        start_time = time.time()
        try:
            session.sql(query).collect()
            results[query_name] = time.time() - start_time
        except Exception:
            results[query_name] = 0
    return results

def run_single_test(run_id: int, query_tag: str) -> Dict:
    """Execute single benchmark run"""
    start_time = time.time()
    try:
        session.sql("ALTER SESSION SET USE_CACHED_RESULT = FALSE").collect()
        session.sql(f"ALTER SESSION SET QUERY_TAG = '{query_tag}_run_{run_id}'").collect()
        timings = run_benchmark()
        return {
            'run_id': run_id,
            'total_time': time.time() - start_time,
            'timings': timings,
            'status': 'SUCCESS'
        }
    except Exception as e:
        return {
            'run_id': run_id,
            'total_time': time.time() - start_time,
            'timings': {},
            'status': 'ERROR',
            'error': str(e)
        }

def run_load_test(warehouse_name: str, concurrent_users: int = 10, 
                  total_runs: int = None, test_name: str = "LOAD_TEST") -> List[Dict]:
    """Run threaded load test"""
    total_runs = total_runs or concurrent_users
    query_tag = f"{test_name}_{warehouse_name}_{concurrent_users}users_{datetime.now().strftime('%H%M%S')}"
    
    print(f"🚀 Starting {test_name}: {concurrent_users} users, {total_runs} runs")
    
    if not setup_warehouse(warehouse_name, query_tag):
        return []
    
    start_time = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_users) as executor:
        futures = [executor.submit(run_single_test, i, query_tag) for i in range(total_runs)]
        results = [future.result() for future in concurrent.futures.as_completed(futures)]
    
    analyze_results(results, time.time() - start_time, concurrent_users)
    return results

async def run_async_load_test(warehouse_name: str, concurrent_users: int = 10,
                             total_runs: int = None, test_name: str = "ASYNC_TEST") -> List[Dict]:
    """Run async load test"""
    total_runs = total_runs or concurrent_users
    query_tag = f"{test_name}_{warehouse_name}_{concurrent_users}users_{datetime.now().strftime('%H%M%S')}"
    
    print(f"🚀 Starting {test_name}: {concurrent_users} users, {total_runs} runs")
    
    if not setup_warehouse(warehouse_name, query_tag):
        return []
    
    start_time = time.time()
    semaphore = asyncio.Semaphore(concurrent_users)
    
    async def run_with_limit(run_id):
        async with semaphore:
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(None, run_single_test, run_id, query_tag)
    
    tasks = [run_with_limit(i) for i in range(total_runs)]
    results = await asyncio.gather(*tasks)
    
    analyze_results(results, time.time() - start_time, concurrent_users)
    return results

def analyze_results(results: List[Dict], total_time: float, concurrent_users: int):
    """Analyze and display results"""
    successful = [r for r in results if r['status'] == 'SUCCESS']
    failed = len(results) - len(successful)
    
    print(f"\n📊 Results: {len(successful)}/{len(results)} successful ({failed} failed)")
    print(f"⏱️  Total time: {total_time:.2f}s")
    
    if successful:
        times = [r['total_time'] for r in successful]
        print(f"📈 Execution times - Min: {min(times):.2f}s, Max: {max(times):.2f}s, Avg: {statistics.mean(times):.2f}s")
        print(f"🚀 Throughput: {len(successful)/total_time:.2f} runs/sec")

# Convenience functions
def light_test(warehouse_name: str) -> List[Dict]:
    return run_load_test(warehouse_name, 5, 10, "LIGHT")

def medium_test(warehouse_name: str) -> List[Dict]:
    return run_load_test(warehouse_name, 20, 40, "MEDIUM")

def heavy_test(warehouse_name: str) -> List[Dict]:
    return run_load_test(warehouse_name, 50, 100, "HEAVY")

def async_test(warehouse_name: str, users: int = 10, runs: int = None):
    return asyncio.run(run_async_load_test(warehouse_name, users, runs))

print("🚀 TPC-H Load Testing Ready!")
print("Usage: light_test('WAREHOUSE_NAME'), medium_test('WAREHOUSE_NAME'), heavy_test('WAREHOUSE_NAME')")
print("Async: async_test('WAREHOUSE_NAME', users=10, runs=20)")

In [None]:
# Update the variables
light_test_results = light_test('POC_GEN1_S')

In [None]:
# Update the variables
light_test_results = light_test('POC_GEN2_S')

In [None]:
# Update the variables
light_test_results = light_test('POC_GEN2_XS')

In [None]:
# Update the variables
async_test_results = async_test('POC_GEN1_S', users=5, runs=100)

In [None]:
# Update the variables
async_test_results = async_test('POC_GEN2_S', users=5, runs=100)

In [None]:
# Update the variables
async_test_results = async_test('POC_GEN2_XS', users=5, runs=100)