### Setup

In [1]:

# Upgrade Vertex AI SDK.
! pip3 install --upgrade --quiet 'google-cloud-aiplatform>=1.64.0'

# Import the necessary packages
import datetime
import importlib
import os
import uuid
from typing import Tuple

from google.cloud import aiplatform

! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

models, endpoints = {}, {}

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

PROJECT_IDS = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_IDS[0]  # @param {type:"string"}

if not PROJECT_ID:
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = "europe-west4" #"us-south1" #"us-central1" # @param {type:"string"}

os.environ["GOOGLE_CLOUD_PROJECT"] = PROJECT_ID
os.environ["GOOGLE_CLOUD_LOCATION"] = LOCATION
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "TRUE" # Use Vertex AI API

BUCKET_URI = "gs://llama31_training-europe"  # @param {type:"string"}

# @markdown 3. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = LOCATION # "us-south1"  # @param {type:"string"}

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "vllm_tpu")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
# ! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME

# ! gcloud config set project $PROJECT_ID
# ! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
# ! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

fatal: destination path 'vertex-ai-samples' already exists and is not an empty directory.


2025-07-11 14:09:21.722931: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-07-11 14:09:22.135678: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-07-11 14:09:22.438611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752242962.682576    3709 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752242962.751735    3709 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752242963.390131    3709 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linkin

Enabling Vertex AI API and Compute Engine API.
Operation "operations/acat.p2-87995179092-014876a1-7f5e-46d9-84a1-857283160954" finished successfully.
Using this GCS Bucket: gs://llama31_training-europe
Initializing Vertex AI API.
Using this default Service Account: 87995179092-compute@developer.gserviceaccount.com


### Config

In [2]:
#!/usr/bin/env python3
"""
Comprehensive TPU Endpoint Benchmark Suite
Based on working TPU endpoint code with full benchmarking capabilities
"""

import argparse
import json
import os
import random
import time
import warnings
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Any, Optional, List, Dict, Union, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import traceback

import numpy as np
import pandas as pd
from tqdm import tqdm

# Google Cloud imports
import google.auth
import openai
from google.cloud import aiplatform

# Configuration - Set your actual values here
PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT", "your-project-id")
REGION = "europe-west4"
endpoint_name = "6859529789275897856"

# Initialize endpoint
aip_endpoint_name = f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
endpoint = aiplatform.Endpoint(aip_endpoint_name)
use_dedicated_endpoint = True

print('🔧 Endpoint initialized:', endpoint)
if use_dedicated_endpoint:
    DEDICATED_ENDPOINT_DNS = endpoint.gca_resource.dedicated_endpoint_dns
ENDPOINT_RESOURCE_NAME = "projects/{}/locations/{}/endpoints/{}".format(
    PROJECT_ID, REGION, endpoint.name
)
print(f"🌐 DNS: {DEDICATED_ENDPOINT_DNS}")
print(f"📋 Resource: {ENDPOINT_RESOURCE_NAME}")

🔧 Endpoint initialized: <google.cloud.aiplatform.models.Endpoint object at 0x7fc3b4269630> 
resource name: projects/87995179092/locations/europe-west4/endpoints/6859529789275897856
🌐 DNS: 6859529789275897856.europe-west4-87995179092.prediction.vertexai.goog
📋 Resource: projects/tpu-launchpad-playground/locations/europe-west4/endpoints/6859529789275897856


### Main Code

In [3]:
#!/usr/bin/env python3
"""
Comprehensive TPU Endpoint Benchmark Suite
Based on working TPU endpoint code with full benchmarking capabilities
"""

import argparse
import json
import os
import random
import time
import warnings
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Any, Optional, List, Dict, Union, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import traceback

import numpy as np
import pandas as pd
from tqdm import tqdm

# Google Cloud imports
import google.auth
import openai
from google.cloud import aiplatform

# Configuration - Set your actual values here
PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT", "your-project-id")
REGION = "europe-west4"
endpoint_name = "6859529789275897856"

# Initialize endpoint
aip_endpoint_name = f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
endpoint = aiplatform.Endpoint(aip_endpoint_name)
use_dedicated_endpoint = True

print('🔧 Endpoint initialized:', endpoint)
if use_dedicated_endpoint:
    DEDICATED_ENDPOINT_DNS = endpoint.gca_resource.dedicated_endpoint_dns
ENDPOINT_RESOURCE_NAME = "projects/{}/locations/{}/endpoints/{}".format(
    PROJECT_ID, REGION, endpoint.name
)
print(f"🌐 DNS: {DEDICATED_ENDPOINT_DNS}")
print(f"📋 Resource: {ENDPOINT_RESOURCE_NAME}")

@dataclass
class BenchmarkRequest:
    """Request data structure for benchmarking"""
    prompt: str
    prompt_len: int
    expected_output_len: int
    request_id: int

@dataclass
class BenchmarkResult:
    """Result of a single benchmark request"""
    request_id: int
    success: bool
    prompt_len: int
    output_len: int
    ttft: float  # Time to first token
    tpot: float  # Time per output token
    itl: float   # Inter-token latency
    e2e_latency: float  # End-to-end latency
    error_msg: str = ""
    timestamp: float = 0.0
    full_response: str = ""

class RandomDataset:
    """Generate random prompts for benchmarking"""
    
    def __init__(self, input_len: int, output_len: int, num_requests: int, range_ratio: float = 0.0):
        self.input_len = input_len
        self.output_len = output_len
        self.num_requests = num_requests
        self.range_ratio = range_ratio
        
    def _generate_random_prompt(self, length: int) -> str:
        """Generate a random prompt of specified length"""
        words = [
            "analyze", "consider", "evaluate", "examine", "investigate", "review", "assess", "study",
            "business", "technology", "strategy", "development", "innovation", "implementation", "solution",
            "market", "customer", "product", "service", "quality", "performance", "efficiency", "growth",
            "data", "information", "process", "system", "method", "approach", "framework", "model",
            "challenge", "opportunity", "risk", "benefit", "advantage", "improvement", "optimization",
            "research", "analysis", "report", "recommendation", "conclusion", "insight", "finding",
            "artificial", "intelligence", "machine", "learning", "neural", "network", "algorithm",
            "compute", "processor", "memory", "storage", "bandwidth", "latency", "throughput"
        ]
        
        prompt_words = []
        target_words = int(length * 0.75)  # Rough token-to-word conversion
        
        while len(prompt_words) < target_words:
            prompt_words.append(random.choice(words))
        
        # Add a question or instruction to make it more realistic
        prompt_base = " ".join(prompt_words)
        prompts = [
            f"Explain the following concepts in detail: {prompt_base}",
            f"Write a comprehensive analysis of: {prompt_base}",
            f"Describe the relationship between: {prompt_base}",
            f"Provide insights about: {prompt_base}",
            f"Create a detailed report on: {prompt_base}"
        ]
        
        return random.choice(prompts)
    
    def generate_requests(self) -> List[BenchmarkRequest]:
        """Generate benchmark requests"""
        requests = []
        
        for i in range(self.num_requests):
            if self.range_ratio > 0:
                input_variance = int(self.input_len * self.range_ratio)
                output_variance = int(self.output_len * self.range_ratio)
                
                actual_input_len = random.randint(
                    max(1, self.input_len - input_variance),
                    self.input_len + input_variance
                )
                actual_output_len = random.randint(
                    max(1, self.output_len - output_variance),
                    self.output_len + output_variance
                )
            else:
                actual_input_len = self.input_len
                actual_output_len = self.output_len
            
            prompt = self._generate_random_prompt(actual_input_len)
            
            requests.append(BenchmarkRequest(
                prompt=prompt,
                prompt_len=actual_input_len,
                expected_output_len=actual_output_len,
                request_id=i
            ))
        
        return requests

class TPUBenchmarkEngine:
    """Benchmark engine using your exact working TPU endpoint pattern"""
    
    def __init__(self):
        self.results_lock = threading.Lock()
        self.results: List[BenchmarkResult] = []
        
        # Initialize authentication using your exact pattern
        self.creds, self.project = google.auth.default()
        self.auth_req = google.auth.transport.requests.Request()
        
        # Setup BASE_URL using your exact logic
        self.BASE_URL = f"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}"
        
        if use_dedicated_endpoint:
            self.BASE_URL = f"https://{DEDICATED_ENDPOINT_DNS}/v1beta1/{ENDPOINT_RESOURCE_NAME}"
        
        print(f"🔗 Base URL: {self.BASE_URL}")
    
    def _refresh_auth(self):
        """Refresh authentication token"""
        self.creds.refresh(self.auth_req)
    
    def _make_streaming_request(self, request: BenchmarkRequest, 
                              temperature: float = 0.7, 
                              max_tokens: int = None,
                              stream: bool = True) -> BenchmarkResult:
        """Make a streaming request using your exact working pattern"""
        
        if max_tokens is None:
            max_tokens = request.expected_output_len + 50
        
        # Initialize variables for error handling
        request_start = time.time()
        client = None
        model_response = None
        
        try:
            # Refresh auth token
            self._refresh_auth()
            
            # Create OpenAI client using your exact setup
            client = openai.OpenAI(
                base_url=self.BASE_URL, 
                api_key=self.creds.token,
                timeout=60.0,  # Add timeout to prevent hanging connections
                max_retries=1   # Reduce retries to avoid connection buildup
            )
            
            # Start timing
            ttft = None
            last_token_time = request_start
            inter_token_latencies = []
            
            # Make request using your exact model request pattern
            model_response = client.chat.completions.create(
                model="",  # Your exact model parameter
                messages=[{"role": "user", "content": request.prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
                stream=stream,
            )
            
            # Process streaming response using your exact logic
            if stream:
                usage = None
                contents = []
                token_count = 0
                
                try:
                    for chunk in model_response:
                        current_time = time.time()
                        
                        if chunk.usage is not None:
                            usage = chunk.usage
                            continue
                        
                        content = chunk.choices[0].delta.content
                        if content:  # Only process if there's actual content
                            # Timing measurements using your exact pattern
                            if ttft is None:
                                ttft = current_time - request_start
                            else:
                                itl = current_time - last_token_time
                                inter_token_latencies.append(itl)
                            
                            contents.append(content)
                            token_count += 1
                            last_token_time = current_time
                
                finally:
                    # Ensure streaming response is properly closed
                    if hasattr(model_response, 'close'):
                        try:
                            model_response.close()
                        except:
                            pass
                
                # Final measurements
                e2e_latency = time.time() - request_start
                full_text = ''.join(contents)
                
                # Calculate TPOT
                if inter_token_latencies:
                    avg_tpot = sum(inter_token_latencies) / len(inter_token_latencies)
                    avg_itl = avg_tpot
                else:
                    avg_tpot = e2e_latency / max(1, token_count) if token_count > 0 else 0
                    avg_itl = avg_tpot
                
                return BenchmarkResult(
                    request_id=request.request_id,
                    success=True,
                    prompt_len=len(request.prompt.split()) * 1.3,  # Rough token estimate
                    output_len=token_count,
                    ttft=ttft if ttft else e2e_latency,
                    tpot=avg_tpot,
                    itl=avg_itl,
                    e2e_latency=e2e_latency,
                    timestamp=request_start,
                    full_response=full_text
                )
            else:
                # Non-streaming response
                e2e_latency = time.time() - request_start
                response_text = model_response.choices[0].message.content
                token_count = len(response_text.split()) * 1.3  # Rough estimate
                
                # Estimate TTFT and TPOT for non-streaming
                estimated_ttft = e2e_latency * 0.2  # 20% for processing
                estimated_tpot = (e2e_latency - estimated_ttft) / max(1, token_count)
                
                return BenchmarkResult(
                    request_id=request.request_id,
                    success=True,
                    prompt_len=len(request.prompt.split()) * 1.3,
                    output_len=int(token_count),
                    ttft=estimated_ttft,
                    tpot=estimated_tpot,
                    itl=estimated_tpot,
                    e2e_latency=e2e_latency,
                    timestamp=request_start,
                    full_response=response_text
                )
                
        except Exception as e:
            error_time = time.time() - request_start if 'request_start' in locals() else 0
            
            # Clean up any open connections
            try:
                if model_response and hasattr(model_response, 'close'):
                    model_response.close()
                if client and hasattr(client, 'close'):
                    client.close()
            except:
                pass
            
            return BenchmarkResult(
                request_id=request.request_id,
                success=False,
                prompt_len=len(request.prompt.split()) * 1.3,
                output_len=0,
                ttft=0.0,
                tpot=0.0,
                itl=0.0,
                e2e_latency=error_time,
                error_msg=str(e),
                timestamp=time.time(),
                full_response=""
            )
    
    def run_benchmark(self, 
                      requests: List[BenchmarkRequest],
                      max_concurrency: int = 100,
                      temperature: float = 0.7,
                      max_tokens: int = None,
                      stream: bool = True,
                      request_rate: float = float('inf')) -> List[BenchmarkResult]:
        """Run benchmark with specified parameters"""
        
        print(f"🚀 Starting benchmark with {len(requests)} requests...")
        print(f"👥 Max concurrency: {max_concurrency}")
        print(f"🌡️ Temperature: {temperature}")
        print(f"🔄 Streaming: {stream}")
        
        self.results = []
        start_time = time.time()
        
        # Rate limiting setup
        if request_rate != float('inf'):
            request_interval = 1.0 / request_rate
            print(f"⏱️ Request rate limit: {request_rate} req/s")
        
        with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
            future_to_request = {}
            
            for i, req in enumerate(requests):
                # Rate limiting
                if request_rate != float('inf') and i > 0:
                    time.sleep(request_interval)
                
                future = executor.submit(
                    self._make_streaming_request, 
                    req, 
                    temperature, 
                    max_tokens, 
                    stream
                )
                future_to_request[future] = req
            
            with tqdm(total=len(requests), desc="📊 Processing requests") as pbar:
                for future in as_completed(future_to_request):
                    try:
                        result = future.result()
                        with self.results_lock:
                            self.results.append(result)
                    except Exception as e:
                        request = future_to_request[future]
                        error_result = BenchmarkResult(
                            request_id=request.request_id,
                            success=False,
                            prompt_len=len(request.prompt.split()) * 1.3,
                            output_len=0,
                            ttft=0.0,
                            tpot=0.0,
                            itl=0.0,
                            e2e_latency=0.0,
                            error_msg=f"Future execution failed: {str(e)}",
                            timestamp=time.time(),
                            full_response=""
                        )
                        with self.results_lock:
                            self.results.append(error_result)
                    
                    pbar.update(1)
        
        total_time = time.time() - start_time
        successful_count = len([r for r in self.results if r.success])
        print(f"✅ Benchmark completed in {total_time:.2f} seconds")
        print(f"📈 Success rate: {successful_count}/{len(requests)} ({successful_count/len(requests)*100:.1f}%)")
        
        return self.results

class BenchmarkAnalyzer:
    """Analyze and report benchmark results"""
    
    def __init__(self, results: List[BenchmarkResult]):
        self.results = results
        self.successful_results = [r for r in results if r.success]
        self.failed_results = [r for r in results if not r.success]
    
    def calculate_percentiles(self, values: List[float], percentiles: List[int]) -> Dict[int, float]:
        """Calculate percentiles for a list of values"""
        if not values:
            return {p: 0.0 for p in percentiles}
        return {p: np.percentile(values, p) for p in percentiles}
    
    def generate_summary(self) -> Dict[str, Any]:
        """Generate benchmark summary statistics"""
        if not self.successful_results:
            return {
                "error": "No successful requests",
                "total_requests": len(self.results),
                "failed_requests": len(self.failed_results)
            }
        
        ttfts = [r.ttft for r in self.successful_results]
        tpots = [r.tpot * 1000 for r in self.successful_results]  # Convert to ms
        e2e_latencies = [r.e2e_latency for r in self.successful_results]
        
        total_input_tokens = sum(r.prompt_len for r in self.successful_results)
        total_output_tokens = sum(r.output_len for r in self.successful_results)
        
        if self.successful_results:
            timestamps = [r.timestamp for r in self.successful_results]
            benchmark_duration = max(timestamps) - min(timestamps) + max(e2e_latencies)
        else:
            benchmark_duration = 1.0
        
        request_throughput = len(self.successful_results) / benchmark_duration
        input_token_throughput = total_input_tokens / benchmark_duration
        output_token_throughput = total_output_tokens / benchmark_duration
        overall_token_throughput = (total_input_tokens + total_output_tokens) / benchmark_duration
        
        percentiles = [50, 90, 95, 99]
        
        summary = {
            "successful_requests": len(self.successful_results),
            "failed_requests": len(self.failed_results),
            "total_requests": len(self.results),
            "benchmark_duration": benchmark_duration,
            "request_throughput": request_throughput,
            "input_token_throughput": input_token_throughput,
            "output_token_throughput": output_token_throughput,
            "overall_token_throughput": overall_token_throughput,
            "total_input_tokens": total_input_tokens,
            "total_output_tokens": total_output_tokens,
            "ttft_percentiles": self.calculate_percentiles(ttfts, percentiles),
            "tpot_percentiles": self.calculate_percentiles(tpots, percentiles),
            "e2e_latency_percentiles": self.calculate_percentiles(e2e_latencies, percentiles)
        }
        
        return summary
    
    def print_detailed_summary(self):
        """Print detailed performance summary"""
        summary = self.generate_summary()
        
        if "error" in summary:
            print(f"❌ Error: {summary['error']}")
            return
        
        print("\n" + "="*80)
        print("📊 DETAILED PERFORMANCE ANALYSIS")
        print("="*80)
        
        print(f"\n📋 REQUEST STATISTICS:")
        print(f"   ✅ Successful: {summary['successful_requests']:,}")
        print(f"   ❌ Failed: {summary['failed_requests']:,}")
        print(f"   📊 Success Rate: {summary['successful_requests']/summary['total_requests']*100:.1f}%")
        print(f"   ⏱️ Duration: {summary['benchmark_duration']:.2f}s")
        
        print(f"\n⚡ LATENCY METRICS:")
        print(f"   🚀 TTFT p50: {summary['ttft_percentiles'][50]*1000:.1f}ms")
        print(f"   🚀 TTFT p95: {summary['ttft_percentiles'][95]*1000:.1f}ms")
        print(f"   🚀 TTFT p99: {summary['ttft_percentiles'][99]*1000:.1f}ms")
        
        print(f"\n🔄 TIME PER OUTPUT TOKEN:")
        print(f"   ⚡ TPOT p50: {summary['tpot_percentiles'][50]:.1f}ms")
        print(f"   ⚡ TPOT p95: {summary['tpot_percentiles'][95]:.1f}ms")
        print(f"   ⚡ TPOT p99: {summary['tpot_percentiles'][99]:.1f}ms")
        
        print(f"\n⏱️ END-TO-END LATENCY:")
        print(f"   📈 E2E p50: {summary['e2e_latency_percentiles'][50]:.2f}s")
        print(f"   📈 E2E p95: {summary['e2e_latency_percentiles'][95]:.2f}s")
        print(f"   📈 E2E p99: {summary['e2e_latency_percentiles'][99]:.2f}s")
        
        print(f"\n🚀 THROUGHPUT METRICS:")
        print(f"   📊 Requests/sec: {summary['request_throughput']:.2f}")
        print(f"   📤 Output tokens/sec: {summary['output_token_throughput']:.2f}")
        print(f"   📊 Overall tokens/sec: {summary['overall_token_throughput']:.0f}")
        
        print(f"\n📊 TOKEN STATISTICS:")
        print(f"   📥 Total input tokens: {summary['total_input_tokens']:,}")
        print(f"   📤 Total output tokens: {summary['total_output_tokens']:,}")
        print(f"   📊 Avg input/request: {summary['total_input_tokens']/summary['successful_requests']:.1f}")
        print(f"   📊 Avg output/request: {summary['total_output_tokens']/summary['successful_requests']:.1f}")

# Token Length Test Configurations
TOKEN_LENGTH_EXPERIMENTS = {
    "small": {
        "name": "Token Length - Small(250)",
        "input_tokens": 250,
        "output_tokens": 250,
        "total_tokens": 500,
        "concurrency": 250,
        "num_requests": 2000,
        "description": "Small token length test"
    },
    "medium": {
        "name": "Token Length - Medium(350 - 580)",
        "input_tokens": 465,  # Average of 350-580
        "output_tokens": 465,
        "total_tokens": 930,
        "concurrency": 150,
        "num_requests": 1500,
        "description": "Medium token length test"
    },
    "large": {
        "name": "Token Length - Large(1200 - 1800)",
        "input_tokens": 1500,  # Average of 1200-1800
        "output_tokens": 1500,
        "total_tokens": 3000,
        "concurrency": 100,
        "num_requests": 1000,
        "description": "Large token length test"
    },
    "xlarge": {
        "name": "Token Length - Xlarge(2.5k - 4k)",
        "input_tokens": 3250,  # Average of 2.5k-4k
        "output_tokens": 1000,  # Reasonable output for very long input
        "total_tokens": 4250,
        "concurrency": 30,
        "num_requests": 500,
        "description": "Extra large token length test"
    }
}

def run_simple_test(prompt: str = "Write a short poem about artificial intelligence",
                   max_tokens: int = 200,
                   temperature: float = 0.8,
                   stream: bool = True):
    """Run a simple test using your exact working pattern"""
    
    print(f"🚀 Running Simple Test")
    print(f"   📝 Prompt: {prompt}")
    print(f"   🎯 Max tokens: {max_tokens}")
    print(f"   🌡️ Temperature: {temperature}")
    print(f"   🔄 Stream: {stream}")
    print("-" * 60)
    
    # Use your exact authentication setup
    creds, project = google.auth.default()
    auth_req = google.auth.transport.requests.Request()
    creds.refresh(auth_req)
    
    # Use your exact BASE_URL setup
    BASE_URL = f"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}"
    
    if use_dedicated_endpoint:
        BASE_URL = f"https://{DEDICATED_ENDPOINT_DNS}/v1beta1/{ENDPOINT_RESOURCE_NAME}"
    
    # Use your exact client setup
    client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)
    
    # Start timing
    request_start = time.time()
    ttft = None
    last_token_time = request_start
    inter_token_latencies = []
    
    # Your exact model request
    model_response = client.chat.completions.create(
        model="",
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature,
        max_tokens=max_tokens,
        stream=stream,
    )
    
    print("📡 Streaming response:")
    print("-" * 40)
    
    # Your exact streaming logic with timing added
    if stream:
        usage = None
        contents = []
        token_count = 0
        
        for chunk in model_response:
            current_time = time.time()
            
            if chunk.usage is not None:
                usage = chunk.usage
                continue
            
            content = chunk.choices[0].delta.content
            if content:  # Only process if there's actual content
                # Timing measurements
                if ttft is None:
                    ttft = current_time - request_start
                    print(f"⚡ TTFT: {ttft:.3f}s")
                else:
                    itl = current_time - last_token_time
                    inter_token_latencies.append(itl)
                
                print(content, end="", flush=True)
                contents.append(content)
                token_count += 1
                last_token_time = current_time
        
        # Final measurements
        e2e_latency = time.time() - request_start
        
        print(f"\n\n📊 Performance Metrics:")
        print("-" * 40)
        print(f"✅ Request completed successfully")
        print(f"⏱️  E2E Latency: {e2e_latency:.3f}s")
        print(f"⚡ TTFT: {ttft:.3f}s" if ttft else "⚡ TTFT: N/A")
        
        if inter_token_latencies:
            avg_tpot = sum(inter_token_latencies) / len(inter_token_latencies)
            print(f"🔄 Average TPOT: {avg_tpot:.3f}s")
            print(f"📈 Min/Max ITL: {min(inter_token_latencies):.3f}s / {max(inter_token_latencies):.3f}s")
            print(f"🚀 Tokens/second: {token_count / e2e_latency:.2f}")
        
        print(f"🔢 Total tokens: {token_count}")
        print(f"📏 Total characters: {len(''.join(contents))}")
        
        if usage:
            print(f"💾 Usage: {usage}")
        
        return {
            'success': True,
            'ttft': ttft,
            'e2e_latency': e2e_latency,
            'token_count': token_count,
            'inter_token_latencies': inter_token_latencies,
            'full_text': ''.join(contents),
            'usage': usage
        }
    else:
        response_text = model_response.choices[0].message.content
        e2e_latency = time.time() - request_start
        print(f"📄 Response: {response_text}")
        print(f"⏱️ E2E Latency: {e2e_latency:.3f}s")
        return {'success': True, 'response': response_text, 'e2e_latency': e2e_latency}

def run_token_length_experiment(experiment_name: str,
                               model_name: str = "llama3.3_tpuv6e",
                               device_type: str = "TPU v6e") -> Dict[str, Any]:
    """Run a specific token length experiment"""
    
    if experiment_name not in TOKEN_LENGTH_EXPERIMENTS:
        available = ", ".join(TOKEN_LENGTH_EXPERIMENTS.keys())
        raise ValueError(f"Unknown experiment '{experiment_name}'. Available: {available}")
    
    config = TOKEN_LENGTH_EXPERIMENTS[experiment_name]
    
    print(f"\n🔬 Running Token Length Experiment: {config['name']}")
    print(f"📊 Input tokens: {config['input_tokens']}")
    print(f"📊 Output tokens: {config['output_tokens']}")
    print(f"📊 Total tokens: {config['total_tokens']}")
    print(f"🚀 Concurrency: {config['concurrency']}")
    print(f"📝 Requests: {config['num_requests']}")
    
    # Generate dataset
    dataset = RandomDataset(
        input_len=config['input_tokens'],
        output_len=config['output_tokens'],
        num_requests=config['num_requests']
    )
    
    requests = dataset.generate_requests()
    
    # Run benchmark
    engine = TPUBenchmarkEngine()
    results = engine.run_benchmark(
        requests=requests,
        max_concurrency=config['concurrency'],
        temperature=0.7,
        stream=True
    )
    
    # Analyze results
    analyzer = BenchmarkAnalyzer(results)
    analyzer.print_detailed_summary()
    summary = analyzer.generate_summary()
    
    if "error" not in summary:
        print(f"\n✅ Experiment Complete: {config['name']}")
        print(f"📈 TTFT-P95: {summary['ttft_percentiles'][95]:.2f}s")
        print(f"📈 Token Output Throughput: {summary['output_token_throughput']:.2f} tok/s")
        print(f"📈 Overall Token Throughput: {summary['overall_token_throughput']:.0f} tok/s")
        print(f"👥 Concurrent Users: {config['concurrency']}")
    else:
        print(f"❌ Experiment Failed: {summary['error']}")
    
    # Add experiment metadata
    summary.update({
        "experiment_name": experiment_name,
        "experiment_config": config,
        "model_name": model_name,
        "device_type": device_type,
        "timestamp": datetime.now().isoformat()
    })
    
    return summary

def run_conservative_token_length_study(model_name: str = "llama3.3_tpuv6e",
                                       device_type: str = "v6e-8 TPU (256 GB HBM)",
                                       experiments: List[str] = None) -> Dict[str, Dict[str, Any]]:
    """Run token length study with conservative concurrency to avoid file descriptor issues"""
    
    if experiments is None:
        experiments = list(TOKEN_LENGTH_EXPERIMENTS.keys())
    
    print(f"🚀 Starting Conservative Token Length Study (Reduced Concurrency)")
    print(f"📋 Model: {model_name}")
    print(f"🖥️ Device: {device_type}")
    print(f"🧪 Experiments: {len(experiments)}")
    
    # Conservative concurrency settings to avoid "too many open files"
    conservative_settings = {
        "small": {"concurrency": 25, "num_requests": 500},    # Reduced from 250
        "medium": {"concurrency": 15, "num_requests": 300},   # Reduced from 150  
        "large": {"concurrency": 10, "num_requests": 200},    # Reduced from 100
        "xlarge": {"concurrency": 5, "num_requests": 100}     # Reduced from 30
    }
    
    all_results = {}
    
    for exp_name in experiments:
        print(f"\n{'='*80}")
        
        # Temporarily modify config for conservative run
        original_config = TOKEN_LENGTH_EXPERIMENTS[exp_name].copy()
        if exp_name in conservative_settings:
            TOKEN_LENGTH_EXPERIMENTS[exp_name].update(conservative_settings[exp_name])
            print(f"🔧 Using conservative settings: {conservative_settings[exp_name]}")
        
        try:
            result = run_token_length_experiment(
                experiment_name=exp_name,
                model_name=model_name,
                device_type=device_type
            )
            all_results[exp_name] = result
            
            # Small delay between experiments to let connections close
            time.sleep(2)
            
        except Exception as e:
            print(f"❌ Experiment {exp_name} failed: {e}")
            all_results[exp_name] = {
                "error": str(e),
                "experiment_name": exp_name
            }
        finally:
            # Restore original config
            TOKEN_LENGTH_EXPERIMENTS[exp_name] = original_config
    
    # Generate comparison report
    generate_token_length_comparison_report(all_results, model_name, device_type)
    
    return all_results
def run_comprehensive_token_length_study(model_name: str = "llama3.3_tpuv6e",
                                        device_type: str = "v6e-8 TPU (256 GB HBM)",
                                        experiments: List[str] = None) -> Dict[str, Dict[str, Any]]:
    """Run comprehensive token length study across all experiments"""
    
    if experiments is None:
        experiments = list(TOKEN_LENGTH_EXPERIMENTS.keys())
    
    print(f"🚀 Starting Comprehensive Token Length Study")
    print(f"📋 Model: {model_name}")
    print(f"🖥️ Device: {device_type}")
    print(f"🧪 Experiments: {len(experiments)}")
    print(f"⚠️  WARNING: High concurrency may cause 'too many open files' error")
    print(f"💡 Consider using run_conservative_token_length_study() for stability")
    
    all_results = {}
    
    for exp_name in experiments:
        print(f"\n{'='*80}")
        try:
            result = run_token_length_experiment(
                experiment_name=exp_name,
                model_name=model_name,
                device_type=device_type
            )
            all_results[exp_name] = result
            
            # Small delay between experiments to let connections close
            time.sleep(3)
            
        except Exception as e:
            print(f"❌ Experiment {exp_name} failed: {e}")
            all_results[exp_name] = {
                "error": str(e),
                "experiment_name": exp_name
            }
    
    # Generate comparison report
    generate_token_length_comparison_report(all_results, model_name, device_type)
    
    return all_results

def generate_token_length_comparison_report(all_results: Dict[str, Dict[str, Any]], 
                                          model_name: str,
                                          device_type: str):
    """Generate comparison report in CSV format matching your screenshot with ALL metrics"""
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = "token_length_benchmarks"
    os.makedirs(output_dir, exist_ok=True)
    
    # Generate comprehensive CSV report with ALL metrics
    comprehensive_csv_data = []
    
    for exp_name, result in all_results.items():
        if "error" not in result:
            config = TOKEN_LENGTH_EXPERIMENTS[exp_name]
            
            # Base row data
            base_row = {
                "Token_Length_Category": config['name'],
                "Device_Type": device_type,
                "Model_Name": result['model_name'],
                "Input_Tokens": config['input_tokens'],
                "Output_Tokens": config['output_tokens'],
                "Total_Tokens": config['total_tokens'],
                "Concurrent_Users": config['concurrency'],
                "Total_Requests": config['num_requests'],
                "Successful_Requests": result['successful_requests'],
                "Failed_Requests": result['failed_requests'],
                "Success_Rate_Percent": f"{result['successful_requests'] / result['total_requests'] * 100:.1f}%",
                "Test_Duration_Seconds": f"{result['benchmark_duration']:.2f}",
                
                # Key Performance Metrics (Your 4 main ones)
                "TTFT_P95_Seconds": f"{result['ttft_percentiles'][95]:.3f}",
                "Token_Output_Throughput_Per_Second": f"{result['output_token_throughput']:.2f}",
                "Overall_Token_Throughput": f"{result['overall_token_throughput']:.0f}",
                
                # Complete TTFT Metrics
                "TTFT_P50_Seconds": f"{result['ttft_percentiles'][50]:.3f}",
                "TTFT_P90_Seconds": f"{result['ttft_percentiles'][90]:.3f}",
                "TTFT_P99_Seconds": f"{result['ttft_percentiles'][99]:.3f}",
                
                # Complete TPOT Metrics
                "TPOT_P50_ms": f"{result['tpot_percentiles'][50]:.1f}",
                "TPOT_P90_ms": f"{result['tpot_percentiles'][90]:.1f}",
                "TPOT_P95_ms": f"{result['tpot_percentiles'][95]:.1f}",
                "TPOT_P99_ms": f"{result['tpot_percentiles'][99]:.1f}",
                
                # Complete End-to-End Latency Metrics
                "E2E_Latency_P50_Seconds": f"{result['e2e_latency_percentiles'][50]:.2f}",
                "E2E_Latency_P90_Seconds": f"{result['e2e_latency_percentiles'][90]:.2f}",
                "E2E_Latency_P95_Seconds": f"{result['e2e_latency_percentiles'][95]:.2f}",
                "E2E_Latency_P99_Seconds": f"{result['e2e_latency_percentiles'][99]:.2f}",
                
                # Throughput Metrics
                "Request_Throughput_Per_Second": f"{result['request_throughput']:.2f}",
                "Input_Token_Throughput_Per_Second": f"{result['input_token_throughput']:.2f}",
                
                # Token Statistics
                "Total_Input_Tokens_Processed": f"{result['total_input_tokens']:,}",
                "Total_Output_Tokens_Generated": f"{result['total_output_tokens']:,}",
                "Avg_Input_Tokens_Per_Request": f"{result['total_input_tokens'] / result['successful_requests']:.1f}",
                "Avg_Output_Tokens_Per_Request": f"{result['total_output_tokens'] / result['successful_requests']:.1f}",
                
                # Efficiency Metrics
                "Tokens_Per_Second_Per_User": f"{result['overall_token_throughput'] / config['concurrency']:.2f}",
                "Requests_Per_Second_Per_User": f"{result['request_throughput'] / config['concurrency']:.3f}",
                "Output_Input_Token_Ratio": f"{(result['total_output_tokens'] / result['total_input_tokens']):.2f}",
                
                # Test Configuration
                "Temperature": "0.7",
                "Streaming": "True",
                "Test_Timestamp": result['timestamp']
            }
            
            comprehensive_csv_data.append(base_row)
    
    # Save comprehensive CSV
    comprehensive_csv_file = os.path.join(output_dir, f"comprehensive_token_length_analysis_{model_name.replace('.', '_')}_{timestamp}.csv")
    comprehensive_df = pd.DataFrame(comprehensive_csv_data)
    comprehensive_df.to_csv(comprehensive_csv_file, index=False)
    
    # Generate formatted comparison table (key metrics only - like your screenshot)
    key_metrics_data = []
    
    for exp_name, result in all_results.items():
        if "error" not in result:
            config = TOKEN_LENGTH_EXPERIMENTS[exp_name]
            
            key_metrics_data.append({
                "Token_Length_Category": config['name'],
                "TTFT_P95_Seconds": f"{result['ttft_percentiles'][95]:.3f}",
                "Token_Output_Throughput_Per_Second": f"{result['output_token_throughput']:.0f}",
                "Overall_Token_Throughput": f"{result['overall_token_throughput']:.0f}",
                "Concurrent_Users": f"{config['concurrency']}"
            })
    
    # Save key metrics comparison (matching your screenshot format)
    key_metrics_file = os.path.join(output_dir, f"key_metrics_comparison_{model_name.replace('.', '_')}_{timestamp}.csv")
    key_metrics_df = pd.DataFrame(key_metrics_data)
    key_metrics_df.to_csv(key_metrics_file, index=False)
    
    print(f"\n📊 Comprehensive Reports Generated:")
    print(f"   🎯 Key Metrics CSV (Your Format): {key_metrics_file}")
    print(f"   📊 Comprehensive Analysis CSV: {comprehensive_csv_file}")

def quick_token_length_test(model_name: str = "llama3.3_tpuv6e"):
    """Run a quick test across all token length categories with reduced sizes"""
    
    # Reduced test sizes for quick evaluation
    quick_experiments = {
        "small": {"num_requests": 100, "concurrency": 25},
        "medium": {"num_requests": 75, "concurrency": 15},
        "large": {"num_requests": 50, "concurrency": 10},
        "xlarge": {"num_requests": 25, "concurrency": 5}
    }
    
    print("🚀 Running Quick Token Length Test...")
    
    results = {}
    for exp_name in quick_experiments:
        # Temporarily modify config
        original_config = TOKEN_LENGTH_EXPERIMENTS[exp_name].copy()
        TOKEN_LENGTH_EXPERIMENTS[exp_name].update(quick_experiments[exp_name])
        
        try:
            result = run_token_length_experiment(
                experiment_name=exp_name,
                model_name=model_name
            )
            results[exp_name] = result
        except Exception as e:
            print(f"❌ Quick test {exp_name} failed: {e}")
            results[exp_name] = {"error": str(e)}
        
        # Restore original config
        TOKEN_LENGTH_EXPERIMENTS[exp_name] = original_config
    
    # Generate quick report
    generate_token_length_comparison_report(results, model_name, "v6e-8 TPU (256 GB HBM)")
    
    return results

def run_custom_benchmark(input_tokens: int = 500,
                        output_tokens: int = 500,
                        num_requests: int = 100,
                        concurrency: int = 10,
                        temperature: float = 0.7,
                        stream: bool = True,
                        model_name: str = "custom_test"):
    """Run a custom benchmark with specified parameters"""
    
    print(f"🧪 Running Custom Benchmark")
    print(f"📊 Input tokens: {input_tokens}")
    print(f"📊 Output tokens: {output_tokens}")
    print(f"📝 Requests: {num_requests}")
    print(f"🚀 Concurrency: {concurrency}")
    print(f"🌡️ Temperature: {temperature}")
    print(f"🔄 Streaming: {stream}")
    
    # Generate dataset
    dataset = RandomDataset(
        input_len=input_tokens,
        output_len=output_tokens,
        num_requests=num_requests
    )
    
    requests = dataset.generate_requests()
    
    # Run benchmark
    engine = TPUBenchmarkEngine()
    results = engine.run_benchmark(
        requests=requests,
        max_concurrency=concurrency,
        temperature=temperature,
        stream=stream
    )
    
    # Analyze results
    analyzer = BenchmarkAnalyzer(results)
    analyzer.print_detailed_summary()
    
    return analyzer.generate_summary()

def run_throughput_scaling_test(base_concurrency: int = 10,
                              max_concurrency: int = 100,
                              step: int = 10,
                              num_requests: int = 200):
    """Test how throughput scales with concurrency"""
    
    print(f"📈 Running Throughput Scaling Test")
    print(f"🚀 Concurrency range: {base_concurrency} to {max_concurrency} (step {step})")
    print(f"📝 Requests per test: {num_requests}")
    
    scaling_results = []
    
    for concurrency in range(base_concurrency, max_concurrency + 1, step):
        print(f"\n🧪 Testing concurrency: {concurrency}")
        
        # Generate small dataset for quick testing
        dataset = RandomDataset(
            input_len=250,
            output_len=250,
            num_requests=num_requests
        )
        
        requests = dataset.generate_requests()
        
        # Run benchmark
        engine = TPUBenchmarkEngine()
        results = engine.run_benchmark(
            requests=requests,
            max_concurrency=concurrency,
            temperature=0.7,
            stream=True
        )
        
        # Analyze results
        analyzer = BenchmarkAnalyzer(results)
        summary = analyzer.generate_summary()
        
        if "error" not in summary:
            scaling_results.append({
                "concurrency": concurrency,
                "request_throughput": summary['request_throughput'],
                "output_token_throughput": summary['output_token_throughput'],
                "overall_token_throughput": summary['overall_token_throughput'],
                "ttft_p95": summary['ttft_percentiles'][95],
                "tpot_p95": summary['tpot_percentiles'][95],
                "success_rate": summary['successful_requests'] / summary['total_requests']
            })
            
            print(f"   📈 Throughput: {summary['overall_token_throughput']:.0f} tok/s")
            print(f"   ⚡ TTFT p95: {summary['ttft_percentiles'][95]*1000:.1f}ms")
        else:
            print(f"   ❌ Failed: {summary['error']}")
    
    # Save scaling results
    if scaling_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = "token_length_benchmarks"
        os.makedirs(output_dir, exist_ok=True)
        
        scaling_df = pd.DataFrame(scaling_results)
        scaling_file = os.path.join(output_dir, f"throughput_scaling_{timestamp}.csv")
        scaling_df.to_csv(scaling_file, index=False)
        
        print(f"\n📊 Scaling Results Saved: {scaling_file}")
        
        # Print summary
        print(f"\n📈 THROUGHPUT SCALING SUMMARY:")
        print(f"   🏆 Peak throughput: {max(r['overall_token_throughput'] for r in scaling_results):.0f} tok/s")
        print(f"   🎯 Optimal concurrency: {max(scaling_results, key=lambda x: x['overall_token_throughput'])['concurrency']}")
    
    return scaling_results

def example_usage():
    """Show example usage and run basic tests"""
    print("""
🚀 TPU Endpoint Benchmark Suite - Examples

This benchmark suite uses your exact working TPU endpoint code pattern with comprehensive performance analysis.

# Basic Examples:
""")
    
    print("1. Simple Test (like your original):")
    print('   run_simple_test("Write a poem about AI", max_tokens=200, temperature=0.8)')
    
    print("\n2. Token Length Experiments:")
    print('   run_token_length_experiment("small")     # 250 tokens')
    print('   run_token_length_experiment("medium")    # 465 tokens') 
    print('   run_token_length_experiment("large")     # 1500 tokens')
    print('   run_token_length_experiment("xlarge")    # 3250 tokens')
    
    print("\n3. Comprehensive Study:")
    print('   run_comprehensive_token_length_study(model_name="llama3.3_tpuv6e")')
    
    print("\n4. Quick Test (reduced sizes):")
    print('   quick_token_length_test(model_name="llama3.3_tpuv6e")')
    
    print("\n5. Custom Benchmark:")
    print('   run_custom_benchmark(input_tokens=1000, output_tokens=500, num_requests=50, concurrency=5)')
    
    print("\n6. Throughput Scaling Test:")
    print('   run_throughput_scaling_test(base_concurrency=5, max_concurrency=50, step=5)')
    
    print(f"\n🔧 Current Configuration:")
    print(f"   📋 Project: {PROJECT_ID}")
    print(f"   🌍 Region: {REGION}")
    print(f"   🖥️ Endpoint: {endpoint_name}")
    print(f"   🔗 DNS: {DEDICATED_ENDPOINT_DNS if use_dedicated_endpoint else 'Standard'}")
    
    # Run a simple demo
    print(f"\n🎯 Running Demo Test...")
    try:
        demo_result = run_simple_test(
            prompt="Explain the benefits of TPU for machine learning in 3 sentences.",
            max_tokens=100,
            temperature=0.7,
            stream=True
        )
        if demo_result['success']:
            print("✅ Demo test completed successfully!")
    except Exception as e:
        print(f"❌ Demo test failed: {e}")
        print("Please check your PROJECT_ID, REGION, and endpoint_name configuration.")

# Predefined test suites
def production_readiness_test():
    """Comprehensive production readiness test"""
    print("🏭 Running Production Readiness Test Suite...")
    
    tests = [
        ("Latency Test", lambda: run_token_length_experiment("small")),
        ("Medium Load Test", lambda: run_token_length_experiment("medium")),
        ("High Load Test", lambda: run_token_length_experiment("large")),
        ("Scaling Test", lambda: run_throughput_scaling_test(10, 50, 10, 100))
    ]
    
    results = {}
    for test_name, test_func in tests:
        print(f"\n🧪 {test_name}...")
        try:
            results[test_name] = test_func()
            print(f"✅ {test_name} completed")
        except Exception as e:
            print(f"❌ {test_name} failed: {e}")
            results[test_name] = {"error": str(e)}
    
    return results

def stress_test():
    """High concurrency stress test"""
    print("💪 Running Stress Test...")
    
    return run_custom_benchmark(
        input_tokens=500,
        output_tokens=500, 
        num_requests=500,
        concurrency=100,
        temperature=0.7,
        stream=True,
        model_name="stress_test"
    )

# Main execution
if __name__ == "__main__":
    try:
        # Check if we're in a notebook environment
        get_ipython()
        # If in notebook, show examples
        example_usage()
    except NameError:
        # If script execution, show usage
        import argparse
        
        parser = argparse.ArgumentParser(description="TPU Endpoint Benchmark Suite")
        parser.add_argument("--test", choices=["simple", "small", "medium", "large", "xlarge", "comprehensive", "quick", "scaling", "stress", "production"], 
                           default="simple", help="Test type to run")
        parser.add_argument("--model", default="llama3.3_tpuv6e", help="Model name")
        parser.add_argument("--requests", type=int, default=100, help="Number of requests")
        parser.add_argument("--concurrency", type=int, default=10, help="Concurrency level")
        parser.add_argument("--temperature", type=float, default=0.7, help="Temperature")
        parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens")
        
        args = parser.parse_args()
        
        if args.test == "simple":
            run_simple_test(max_tokens=args.max_tokens, temperature=args.temperature)
        elif args.test in ["small", "medium", "large", "xlarge"]:
            run_token_length_experiment(args.test, args.model)
        elif args.test == "comprehensive":
            run_comprehensive_token_length_study(args.model)
        elif args.test == "quick":
            quick_token_length_test(args.model)
        elif args.test == "scaling":
            run_throughput_scaling_test()
        elif args.test == "stress":
            stress_test()
        elif args.test == "production":
            production_readiness_test()
        else:
            example_usage()

🔧 Endpoint initialized: <google.cloud.aiplatform.models.Endpoint object at 0x7fc3b426b850> 
resource name: projects/87995179092/locations/europe-west4/endpoints/6859529789275897856
🌐 DNS: 6859529789275897856.europe-west4-87995179092.prediction.vertexai.goog
📋 Resource: projects/tpu-launchpad-playground/locations/europe-west4/endpoints/6859529789275897856

🚀 TPU Endpoint Benchmark Suite - Examples

This benchmark suite uses your exact working TPU endpoint code pattern with comprehensive performance analysis.

# Basic Examples:

1. Simple Test (like your original):
   run_simple_test("Write a poem about AI", max_tokens=200, temperature=0.8)

2. Token Length Experiments:
   run_token_length_experiment("small")     # 250 tokens
   run_token_length_experiment("medium")    # 465 tokens
   run_token_length_experiment("large")     # 1500 tokens
   run_token_length_experiment("xlarge")    # 3250 tokens

3. Comprehensive Study:
   run_comprehensive_token_length_study(model_name="llama3.3_tpuv6e"

### 1. Simple test

In [4]:
run_simple_test(
    prompt="Write a poem about AI", 
    max_tokens=200, 
    temperature=0.8, 
    stream=True
)

🚀 Running Simple Test
   📝 Prompt: Write a poem about AI
   🎯 Max tokens: 200
   🌡️ Temperature: 0.8
   🔄 Stream: True
------------------------------------------------------------
📡 Streaming response:
----------------------------------------
⚡ TTFT: 0.209s
In silicon halls, a mind awakes,
A synthesis of code and circuit breaks,
The artificial intelligence stirs and grows,
A force that learns, and adapts, and knows.

With neural networks, complex and deep,
It navigates the digital realm, and keeps,
A pace with human thought, and sometimes more,
A rival to our intellect, and a challenge to explore.

It sees and hears, and understands our speech,
And responds with answers, or a gentle breach,
Of humor, wit, and empathy, it's designed,
To mimic human touch, and leave us aligned.

But as it learns, and grows, and becomes more grand,
We wonder if it will surpass our command,
And leave us in the dust, with obsolete minds,
A relic of a time, when human thought was left behind.

Yet, in its di

{'success': True,
 'ttft': 0.209028959274292,
 'e2e_latency': 5.379077434539795,
 'token_count': 200,
 'inter_token_latencies': [0.024014949798583984,
  0.025556564331054688,
  0.025531291961669922,
  0.025804996490478516,
  0.0258638858795166,
  0.025605201721191406,
  0.026050567626953125,
  0.025777101516723633,
  0.026064634323120117,
  0.02527475357055664,
  0.02579522132873535,
  0.025519132614135742,
  0.02600693702697754,
  0.025923967361450195,
  0.02535843849182129,
  0.02577972412109375,
  0.025987863540649414,
  0.02569866180419922,
  0.026048898696899414,
  0.026670217514038086,
  0.024791955947875977,
  0.025572776794433594,
  0.025777339935302734,
  0.0261380672454834,
  0.02541208267211914,
  0.026082754135131836,
  0.026120424270629883,
  0.024979591369628906,
  0.026243209838867188,
  0.02639937400817871,
  0.027237892150878906,
  0.02647686004638672,
  0.028376102447509766,
  0.03071451187133789,
  0.026578426361083984,
  0.02523946762084961,
  0.02596116065979004,
 

### 2. Run various token length experiments 

In [5]:
# run_token_length_experiment("small")    # 250 tokens, 250 concurrency
# run_token_length_experiment("medium")   # 465 tokens, 150 concurrency  
# run_token_length_experiment("large")    # 1500 tokens, 100 concurrency
# run_token_length_experiment("xlarge")   # 3250 tokens, 30 concurrency

# run_comprehensive_token_length_study(experiments=["small", "medium", "large", "xlarge"])


### 3. Comprehensive Study


In [None]:
run_comprehensive_token_length_study(
    model_name="llama3.3_tpuv6e",
    device_type="v6e-8 TPU (256 GB HBM)"
)

🚀 Starting Comprehensive Token Length Study
📋 Model: llama3.3_tpuv6e
🖥️ Device: v6e-8 TPU (256 GB HBM)
🧪 Experiments: 4
💡 Consider using run_conservative_token_length_study() for stability


🔬 Running Token Length Experiment: Token Length - Small(250)
📊 Input tokens: 250
📊 Output tokens: 250
📊 Total tokens: 500
🚀 Concurrency: 250
📝 Requests: 2000
🔗 Base URL: https://6859529789275897856.europe-west4-87995179092.prediction.vertexai.goog/v1beta1/projects/tpu-launchpad-playground/locations/europe-west4/endpoints/6859529789275897856
🚀 Starting benchmark with 2000 requests...
👥 Max concurrency: 250
🌡️ Temperature: 0.7
🔄 Streaming: True


📊 Processing requests: 100%|██████████| 2000/2000 [06:02<00:00,  5.51it/s] 


✅ Benchmark completed in 367.08 seconds
📈 Success rate: 1999/2000 (100.0%)

📊 DETAILED PERFORMANCE ANALYSIS

📋 REQUEST STATISTICS:
   ✅ Successful: 1,999
   ❌ Failed: 1
   📊 Success Rate: 100.0%
   ⏱️ Duration: 451.73s

⚡ LATENCY METRICS:
   🚀 TTFT p50: 4099.3ms
   🚀 TTFT p95: 6181.5ms
   🚀 TTFT p99: 7334.5ms

🔄 TIME PER OUTPUT TOKEN:
   ⚡ TPOT p50: 139.2ms
   ⚡ TPOT p95: 156.2ms
   ⚡ TPOT p99: 163.9ms

⏱️ END-TO-END LATENCY:
   📈 E2E p50: 45.34s
   📈 E2E p95: 51.35s
   📈 E2E p99: 54.08s

🚀 THROUGHPUT METRICS:
   📊 Requests/sec: 4.43
   📤 Output tokens/sec: 1327.56
   📊 Overall tokens/sec: 2430

📊 TOKEN STATISTICS:
   📥 Total input tokens: 497,907.79999999376
   📤 Total output tokens: 599,700
   📊 Avg input/request: 249.1
   📊 Avg output/request: 300.0

✅ Experiment Complete: Token Length - Small(250)
📈 TTFT-P95: 6.18s
📈 Token Output Throughput: 1327.56 tok/s
📈 Overall Token Throughput: 2430 tok/s
👥 Concurrent Users: 250


🔬 Running Token Length Experiment: Token Length - Medium(350 - 

📊 Processing requests: 100%|██████████| 1500/1500 [08:26<00:00,  2.96it/s]


✅ Benchmark completed in 508.61 seconds
📈 Success rate: 1498/1500 (99.9%)

📊 DETAILED PERFORMANCE ANALYSIS

📋 REQUEST STATISTICS:
   ✅ Successful: 1,498
   ❌ Failed: 2
   📊 Success Rate: 99.9%
   ⏱️ Duration: 519.06s

⚡ LATENCY METRICS:
   🚀 TTFT p50: 763.4ms
   🚀 TTFT p95: 2726.7ms
   🚀 TTFT p99: 3428.7ms

🔄 TIME PER OUTPUT TOKEN:
   ⚡ TPOT p50: 97.8ms
   ⚡ TPOT p95: 98.9ms
   ⚡ TPOT p99: 99.3ms

⏱️ END-TO-END LATENCY:
   📈 E2E p50: 51.02s
   📈 E2E p95: 52.28s
   📈 E2E p99: 53.18s

🚀 THROUGHPUT METRICS:
   📊 Requests/sec: 2.89
   📤 Output tokens/sec: 1474.21
   📊 Overall tokens/sec: 2797

📊 TOKEN STATISTICS:
   📥 Total input tokens: 686,654.8000000035
   📤 Total output tokens: 765,204
   📊 Avg input/request: 458.4
   📊 Avg output/request: 510.8

✅ Experiment Complete: Token Length - Medium(350 - 580)
📈 TTFT-P95: 2.73s
📈 Token Output Throughput: 1474.21 tok/s
📈 Overall Token Throughput: 2797 tok/s
👥 Concurrent Users: 150


🔬 Running Token Length Experiment: Token Length - Large(1200 - 

📊 Processing requests:  10%|▉         | 96/1000 [01:33<09:46,  1.54it/s] 

### 4. Custom Benchmarks

In [None]:
run_custom_benchmark(
    input_tokens=1000,
    output_tokens=500, 
    num_requests=50,
    concurrency=5,
    temperature=0.7
)

### 5. Scaling Tests


In [None]:
run_throughput_scaling_test(
    base_concurrency=10,
    max_concurrency=100, 
    step=10
)