#  Vertex Inference Unified

This notebook uses the unified `google-genai` library (imported as `from google import genai`). It supports:
- **Vertex AI Backend:** Uploads videos to GCS during the 'Prepare' step.
- **Gemini API Backend:** Uploads videos using the **File API** during the 'Prepare' step.

**Pipeline Steps:**
1.  **Import Libraries & Configure.**
2.  **Config** - Set up the configuration for the pipeline, including the model, model config, prompts, and prompt config.
2.  **Initialize Clients:** Set up AI client and Storage client.
3.  **(Only the first time) Fetch Dataset:** Downloads metadata from HuggingFace.
4.  **(Only the first time on each API type) Download, Extract & Prepare Videos:** Downloads, extracts, uploads (GCS/File API). Updates metadata.
5.  **Bulk Inference (Async):** Performs inference using pre-uploaded video resources.
6.  **Single Prompt Testing (UI):** Allows interactive testing of the video with prompts

### Notes:

1. After switching from Vertex to Gemini and vice versa, be sure to follow the steps:
    - Run all cells in order to re-upload the videos to the correct storage client, you can enable the SKIP_DOWNLOAD and SKIP_EXTRACT flags to skip the download and extraction steps. Only the upload step is needed

2. Gemini API's file client has a expiry time of 1 day or so for the uploaded files. You may need to follow the steps above to re-upload the files.

## Import Libraries

In [None]:
# Cell 1: Imports (Corrected for `google.genai`)
import os
import csv
import json
import logging
import time
import random
import requests
import datetime
import zipfile
import math
import sys
import asyncio
from typing import Dict, List, Optional, Set, Tuple, Any
from collections import defaultdict
from pathlib import Path
import shutil
import subprocess
import tempfile
import fractions

# Google Cloud & AI Libraries (Unified SDK)
try:
    import google.genai as genai
    from google.genai import types
    from google.genai import errors as genai_errors
    from google.api_core import exceptions as api_core_exceptions
    # GCS Client (Optional, for Vertex Mode)
    try:
        from google.cloud import storage
        GCS_AVAILABLE = True
    except ImportError:
        print("INFO: google-cloud-storage not found. Vertex AI GCS operations unavailable.")
        storage = None
        GCS_AVAILABLE = False
    print("`google.genai` SDK and helpers imported successfully.")
except ImportError as e:
     print(f"ERROR: Failed to import Google libraries: {e}. Install: pip install google-genai google-api-core google-cloud-storage")
     genai = None; types = None; genai_errors = None; api_core_exceptions = None
     storage = None; GCS_AVAILABLE = False
     raise ImportError("FATAL: `google.genai` or `google-api-core` SDK not found.")

# Data Handling & Progress
from datasets import load_dataset
import pandas as pd
from tqdm.notebook import tqdm

# UI Elements
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML, clear_output

# Async in Notebook
import nest_asyncio
nest_asyncio.apply()

## Config Settings

In [None]:
# --- GCP Configuration ---

PROJECT_ID = "YOUR_PROJECT_ID_HERE" # Your Google Cloud Project ID (Needed for GCS and Vertex AI mode)
LOCATION = "us-central1"      # Your Google Cloud Region (Needed for Vertex AI mode)
GCS_BUCKET = "YOUR_GCS_BUCKET_HERE" # Your GCS bucket name (Needed for video storage)

# --- Choose Backend Mode ---
# Set USE_VERTEX to True to use the Vertex AI backend (requires ADC or service account auth).
# Set USE_VERTEX to False to use the Gemini API backend (requires GEMINI_API_KEY).
USE_VERTEX = False  # <-- CHANGE THIS TO True TO USE VERTEX AI

# --- Gemini API Key (Only required if USE_VERTEX is False) ---
# IMPORTANT: Replace with your actual Gemini API Key if USE_VERTEX is False.
# Consider loading from environment variables (GOOGLE_API_KEY) or a secure secrets manager.
GEMINI_API_KEY = ""  # Replace with your actual Gemini API Key



# --- File Paths ---
DATASET_CSV = "dataset.csv"               # Input dataset metadata from HuggingFace
METADATA_FILE = "video_metadata_vertex.csv" if USE_VERTEX else "video_metadata_non_vertex.csv"      # Stores video info: video_id, local_path, gcs_uri (if Vertex), question data
RESULTS_FILE = "results_noncot_full_inference.csv"              # Output file for inference predictions
DOWNLOADS_DIR = "downloads"               # Directory for downloaded zip file
EXTRACTED_VIDEOS_DIR = "extracted_videos" # Directory storing extracted .mp4 files locally
SPEED_VIDEOS_DIR = "speed_videos"         # Stores sped up/slowed down videos
HF_CACHE_DIR = "./hf_cache"               # Cache directory for HuggingFace datasets

# --- Step 1: Fetch Dataset Configuration ---
HF_DATASET_NAME = "lmms-lab/AISG_Challenge" # HuggingFace dataset identifier
HF_DATASET_SPLIT = "test"                 # Dataset split to use
SKIP_FETCH = False                        # Set True to skip fetching if DATASET_CSV exists

# --- Step 2: Download & Prepare Videos Configuration ---
VIDEO_ZIP_URL = "https://huggingface.co/datasets/lmms-lab/AISG_Challenge/resolve/main/Benchmark-AllVideos-HQ-Encoded-challenge.zip?download=true"
ZIP_FILE_NAME = "all_videos.zip"
SKIP_DOWNLOAD_ZIP = True                 # Set True to skip downloading if zip exists
SKIP_EXTRACT = True                      # Set True to skip extraction if videos exist locally
SKIP_PREPARE = False                      # Set True to skip video preparation (GCS upload for Vertex, metadata update)
MAX_VIDEOS_TO_PROCESS = None              # Limit videos for testing (e.g., 5), None for all
UPLOAD_BATCH_SIZE_GCS = 10                # Batch size for GCS uploads (Vertex mode only)

# --- Inference Configuration ---
# Choose a model name compatible with your selected method (Vertex AI or Gemini API)

# Examples:

# Vertex AI: gemini-2.0-flash, gemini-2.0-flash-lite, gemini-2.0-pro-exp-02-05, gemini-2.0-flash-thinking-exp-01-21
# Rate limits: https://cloud.google.com/vertex-ai/generative-ai/docs/quotas#gemini-2.0-flash
# Basically 500 requests per minute for 2.0-flash and 2.0-flash-lite (unlimited), 10 requests per minute for 2.0-pro-exp-02-05, gemini-2.5-flash-preview-04-17

# Gemini API: gemini-2.0-flash, gemini-2.0-flash-lite, gemini-2.0-flash-thinking-exp-01-21, gemini-2.5-pro-exp-03-25
# Rate limits: https://ai.google.dev/gemini-api/docs/rate-limits#tier-1
# For free tier: 30 requests per minute for 2.0-flash and 2.0-flash-lite, 10 requests per minute for 2.0-pro-exp-02-05
# For tier-1: 2000 requests per minute for 2.0-flash and 2.0-flash-lite (have to pay), 10 requests per minute for 2.0-pro-exp-02-05 and gemini-2.0-flash-thinking-exp-01-21, gemini-2.5-flash-preview-04-17

# 1.0=normal speed, 0.5=half speed, etc.
VIDEO_SPEED_FACTOR = 0.5

# --- Setup Derived Paths & Directories ---
zip_file_path = Path(DOWNLOADS_DIR) / ZIP_FILE_NAME
extracted_videos_path = Path(EXTRACTED_VIDEOS_DIR)
speed_videos_path = Path(SPEED_VIDEOS_DIR) / str(VIDEO_SPEED_FACTOR)
Path(DOWNLOADS_DIR).mkdir(parents=True, exist_ok=True)
extracted_videos_path.mkdir(parents=True, exist_ok=True)
speed_videos_path.mkdir(parents=True, exist_ok=True)
Path(HF_CACHE_DIR).mkdir(parents=True, exist_ok=True)

# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger(__name__)

# --- Configuration Validation & Display --- #
warnings_found = False
if USE_VERTEX:
    if not PROJECT_ID or PROJECT_ID == "your-gcp-project-id":
        logger.error("Vertex AI mode requires PROJECT_ID to be set.")
        warnings_found = True
    if not LOCATION:
        logger.error("Vertex AI mode requires LOCATION to be set.")
        warnings_found = True
    if not GCS_BUCKET or GCS_BUCKET == "your-gcs-bucket-name":
        logger.error("Vertex AI mode requires GCS_BUCKET for video uploads.")
        warnings_found = True
    if not GCS_AVAILABLE:
        logger.error("Vertex AI mode requires 'google-cloud-storage', but it's not installed.")
        warnings_found = True
else: # Gemini API Mode
    # Check API Key (explicit or env var)
    effective_api_key = GEMINI_API_KEY if GEMINI_API_KEY != "YOUR_API_KEY_HERE" else os.environ.get("GOOGLE_API_KEY")
    if not effective_api_key:
        logger.error("Gemini API mode requires GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
        warnings_found = True
    else:
        # Don't store the key in the config display if loaded from env
        if GEMINI_API_KEY == "YOUR_API_KEY_HERE" and os.environ.get("GOOGLE_API_KEY"):
            GEMINI_API_KEY = "(Loaded from GOOGLE_API_KEY env var)"
        logger.info("Gemini API mode configured. Videos will be uploaded via File API.")

if warnings_found:
     print("\n\n************************* WARNING *************************")
     print("Configuration errors detected above. Execution might fail.")
     print("***********************************************************\n")

# Main Model Selection (Innovation 1)
Agentic CoT - Chain of Thought + Summary

## 1. COT and NonCOT Model for Bulk Inference

### NonCoT output models for Bulk Inferences

In [None]:
# To see all available models, go into the NonCoT_output_models.py file
# NonCoT models are used for Bulk Inference
from models.NonCoT_output_models import get_non_cot_model
non_cot_model_list = ["gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-exp-03-25"]

MODEL_NAME, SYSTEM_PROMPT, PROMPT_TEMPLATES, CONFIG, REQUESTS_PER_MINUTE, MAX_RETRIES, MAX_ASYNC_WORKERS  = get_non_cot_model(non_cot_model_list[0])

### COT output models for Bulk Inferences

In [None]:
# To See All Available Models, Go into the CoT_ouput_models.py file
# CoT models are used for Bulk Inference
from models.CoT_ouput_models import get_cot_model
CoT_model_list = ["gemini-2.0-flash","gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-03-25"]
MODEL_NAME, SYSTEM_PROMPT, PROMPT_TEMPLATES, CONFIG, REQUESTS_PER_MINUTE, MAX_RETRIES, MAX_ASYNC_WORKERS = get_cot_model(CoT_model_list[1])

### Summary Output Model

In [None]:
# To see all available models, go into the Summary_models.py file
from models.Summary_models import get_summary_model
summary_model_list = ["gemini-2.0-flash-ver1", "gemini-2.0-flash-ver2", "gemini-2.0-flash-ver3"]

QUESTION_MODEL_NAME, QUESTION_SYSTEM_PROMPT, QUESTION_CONFIG = get_summary_model(summary_model_list[0])

### Generated Questions Model

In [None]:
QUESTIONS_MODEL_NAME = "gemini-2.0-flash"
QUESTIONS_DIR = os.path.join(f"generated_questions/{QUESTIONS_MODEL_NAME}", "questions.csv")

# Basic Initialization

## Initialize Google Cloud Clients

In [None]:
storage_client = None
ai_client = None

# --- Initialize Generative AI Client (`google.genai`) --- #
display(Markdown("### Initializing Generative AI Client (`google.genai`)"))
try:
    if USE_VERTEX:
        display(Markdown(f"Vertex AI backend (Project: {PROJECT_ID}, Loc: {LOCATION})..."))
        if not PROJECT_ID or not LOCATION or PROJECT_ID == "your-gcp-project-id":
             raise ValueError("PROJECT_ID/LOCATION invalid for Vertex AI.")
        # Initialize Client for Vertex
        ai_client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
        display(Markdown(f"✅ Vertex AI Client Initialized."))
    else: # Gemini API Mode
        display(Markdown("Gemini API backend (using API Key)..."))
        effective_api_key = GEMINI_API_KEY if GEMINI_API_KEY != "YOUR_API_KEY_HERE" else os.environ.get("GOOGLE_API_KEY")
        if not effective_api_key:
             if os.environ.get("GOOGLE_API_KEY"): effective_api_key = None # Client uses env var
             else: raise ValueError("Gemini API Key required but not found.")
        # Initialize Client for Gemini API
        ai_client = genai.Client(api_key=effective_api_key, vertexai=False)
        display(Markdown(f"✅ Gemini API Client Initialized."))

except ValueError as ve: display(Markdown(f"❌ **Config Error:** {ve}")); ai_client = None
except Exception as e: display(Markdown(f"❌ **AI Client Error:** {e}.")); logger.error("AI Client Init Failed", exc_info=True); ai_client = None

# --- Initialize Storage Client (ONLY for Vertex AI mode) --- #
if USE_VERTEX:
    display(Markdown("### Initializing GCS Client (Vertex Mode Only)"))
    if not GCS_AVAILABLE: display(Markdown("❌ GCS lib missing.")); raise RuntimeError("Missing GCS lib.")
    if not GCS_BUCKET or GCS_BUCKET == "your-gcs-bucket-name": display(Markdown("❌ GCS_BUCKET needed.")); raise ValueError("GCS_BUCKET required.")
    try:
        storage_client = storage.Client(project=PROJECT_ID)
        if not storage_client.bucket(GCS_BUCKET).exists(): display(Markdown(f"⚠️ GCS Bucket `{GCS_BUCKET}` inaccessible."))
        else: display(Markdown(f"✅ GCS Client Initialized (Bucket: '{GCS_BUCKET}')."))
    except Exception as e:
        display(Markdown(f"❌ **GCS Client Error:** {e}.")); logger.error("GCS Client Init Failed", exc_info=True)
        if not SKIP_PREPARE: raise RuntimeError("GCS client failed.")
        else: display(Markdown("⚠️ GCS client failed, but skipping prep."))
else:
    display(Markdown("### Initializing Gemini API Client (File API)"))
    try:
        storage_client = ai_client.files
    except Exception as e:
        display(Markdown(f"❌ **Gemini File API Client Error:** {e}.")); logger.error("Gemini API Client Init Failed", exc_info=True)
    display(Markdown(f"✅ Gemini File API Client Initialized."))

# --- Final Checks --- #
if ai_client is None: raise RuntimeError("AI client failed.")
if USE_VERTEX and storage_client is None and not SKIP_PREPARE: raise RuntimeError("GCS client failed for Vertex prep.")
display(Markdown("✅ Client initialization complete."))


## Utility Functions

In [None]:
# --- File/Data Handling ---
def load_processed_qids(filename: str) -> Set[str]:
    processed_qids = set()
    if Path(filename).is_file():
        try:
            df = pd.read_csv(filename, usecols=['qid'], dtype={'qid': str}, on_bad_lines='warn')
            processed_qids = set(df['qid'].dropna().unique())
            logger.info(f"Loaded {len(processed_qids)} processed QIDs from {filename}")
        except Exception as e:
            logger.warning(f"Could not read QIDs from {filename}: {e}. Assuming zero processed.")
    return processed_qids

def download_file_with_progress(url: str, destination: Path):
    logger.info(f"Downloading {url} to {destination}...")
    try:
        response = requests.get(url, stream=True, timeout=600)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024 * 1024
        with open(destination, 'wb') as f, tqdm(
            desc=f"Downloading {destination.name}", total=total_size, unit='iB', unit_scale=True, unit_divisor=1024
        ) as bar:
            for data in response.iter_content(block_size):
                size = f.write(data)
                bar.update(size)
        if total_size != 0 and bar.n != total_size:
            destination.unlink(missing_ok=True)
            raise RuntimeError(f"Download size mismatch for {destination.name}.")
        logger.info(f"Successfully downloaded {destination}")
    except Exception as e:
        destination.unlink(missing_ok=True)
        logger.error(f"Download failed for {url}: {e}")
        raise

def extract_zip(zip_path: Path, extract_to: Path):
    logger.info(f"Extracting {zip_path.name} to {extract_to}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            members = [m for m in zip_ref.namelist() if not m.startswith('__MACOSX/') and not m.endswith('.DS_Store')]
            with tqdm(total=len(members), desc=f"Extracting {zip_path.name}") as pbar:
                for member in members:
                    zip_ref.extract(member=member, path=extract_to)
                    pbar.update(1)
        logger.info(f"Successfully extracted {zip_path} to {extract_to}")
    except Exception as e:
        logger.error(f"Extraction error: {e}")
        raise
    
def move_videos_to_main_directory(base_path):
    """Find all MP4 files in subdirectories and move them to the main directory."""
    logger.info(f"Moving all videos to main directory: {base_path}")
    moved_count = 0
    failed_count = 0
    
    # Find all MP4 files in subdirectories (but not in the main directory)
    for file_path in list(base_path.glob('**/*.mp4')):
        # Skip files already in the main directory or hidden Mac files
        if file_path.parent == base_path or file_path.name.startswith('._'):
            continue
            
        # Destination in the main directory
        dest_path = base_path / file_path.name
        
        try:
            # Move the file
            shutil.move(str(file_path), str(dest_path))
            moved_count += 1
            if moved_count % 50 == 0:
                logger.info(f"Moved {moved_count} videos so far...")
        except Exception as e:
            logger.error(f"Error moving {file_path}: {e}")
            failed_count += 1
    
    logger.info(f"Moved {moved_count} videos to main directory. Failed: {failed_count}")
    

def create_or_update_metadata(metadata_path: str, dataset_df: pd.DataFrame, video_updates: Dict[str, Dict]):
    try:
        required_cols = ['video_id', 'qid']
        update_cols = ['local_path', 'gcs_uri', 'file_api_name', 'status']
        dtype_map = {'video_id': str, 'qid': str} # Ensure IDs are strings

        if not Path(metadata_path).is_file():
            logger.info(f"Creating metadata file: {metadata_path}")
            meta_df = dataset_df.copy()
            for col in update_cols: meta_df[col] = pd.NA
            meta_df['status'] = 'pending'
        else:
            logger.debug(f"Loading existing metadata: {metadata_path}")
            meta_df = pd.read_csv(metadata_path, dtype=dtype_map)
            for col in update_cols: # Add missing update columns if needed
                 if col not in meta_df.columns: meta_df[col] = pd.NA

        if not all(col in meta_df.columns for col in required_cols):
            raise ValueError(f"Metadata missing required columns ({required_cols}).")

        updates_df = pd.DataFrame.from_dict(video_updates, orient='index')
        updates_df.index.name = 'video_id'
        updates_df.reset_index(inplace=True)
        updates_df['video_id'] = updates_df['video_id'].astype(str)

        # Use merge for robust updating across potentially multiple rows per video_id
        # First, prepare updates DF with only the necessary columns (video_id + update_cols)
        merge_cols = ['video_id'] + [col for col in update_cols if col in updates_df.columns]
        updates_to_merge = updates_df[merge_cols].drop_duplicates(subset=['video_id'], keep='last')

        # Merge, prioritizing updates
        # Suffixes help identify original vs update cols if needed, but update will overwrite
        merged_df = pd.merge(meta_df, updates_to_merge, on='video_id', how='left', suffixes=('', '_update'))

        # Apply the updates
        for col in update_cols:
            update_col_name = col + '_update'
            if update_col_name in merged_df.columns:
                # Fill NAs in original col with update col, then drop update col
                meta_df[col] = merged_df[update_col_name].fillna(merged_df[col])
                # Alternative: Directly update where update is not NA
                # meta_df[col] = np.where(merged_df[update_col_name].notna(), merged_df[update_col_name], merged_df[col])

        meta_df.to_csv(metadata_path, index=False, encoding='utf-8')
        logger.info(f"Metadata file '{metadata_path}' updated with {len(video_updates)} video records.")

    except Exception as e:
        logger.error(f"Error updating metadata {metadata_path}: {e}", exc_info=True)
        raise

def load_metadata_for_inference(metadata_file: str = METADATA_FILE) -> Dict[str, List[Dict]]:
    if not Path(metadata_file).is_file(): return {}
    video_questions = defaultdict(list)
    required_col = 'gcs_uri' if USE_VERTEX else 'file_api_name'
    try:
        df = pd.read_csv(metadata_file, dtype=str).fillna('')
        if 'video_id' not in df.columns or required_col not in df.columns:
            logger.error(f"Metadata missing 'video_id' or '{required_col}'.")
            return {}
        valid_df = df[df['video_id'].astype(bool) & df[required_col].astype(bool)]
        if len(valid_df) == 0:
             logger.warning(f"No videos found with '{required_col}' in {metadata_file}. Check Step 4.")
             return {}
        for video_id, group in valid_df.groupby('video_id'):
             video_questions[video_id] = group.to_dict('records')
        logger.info(f"Loaded {len(video_questions)} videos ({len(valid_df)} questions) with valid IDs for inference.")
        return dict(video_questions)
    except Exception as e:
        logger.error(f"Error loading metadata for inference: {e}", exc_info=True)
        return {}

# --- Upload/Verification Helpers ---
def upload_to_gcs(storage_client, bucket_name: str, source_file_path: Path, destination_blob_name: str) -> Optional[str]:
    if not GCS_AVAILABLE or storage_client is None or not source_file_path.is_file(): return None
    try:
        blob = storage_client.bucket(bucket_name).blob(destination_blob_name)
        blob.upload_from_filename(str(source_file_path))
        gcs_uri = f"gs://{bucket_name}/{destination_blob_name}"
        logger.debug(f"GCS OK: {source_file_path} -> {gcs_uri}")
        return gcs_uri
    except Exception as e:
        logger.error(f"GCS Fail: {source_file_path}. Error: {e}")
        return None

def upload_via_file_api(storage_client, local_path: Path, display_name: str) -> Optional[str]:
    if storage_client is None or not local_path.is_file(): return None
    try:
        logger.debug(f"Uploading {local_path} via File API...")
        uploaded_file = storage_client.upload(file=local_path)
        logger.info(f"File API OK: {local_path} -> {uploaded_file.name}")
        return uploaded_file.name
    except Exception as e:
        logger.error(f"File API Fail: {local_path}. Error: {e}", exc_info=True)
        return None

def verify_gcs_file_exists(storage_client, gcs_uri: str) -> bool:
    if not GCS_AVAILABLE or storage_client is None or not gcs_uri: return False
    try:
        exists = storage.Blob.from_string(gcs_uri, client=storage_client).exists()
        if not exists: logger.warning(f"GCS verify failed: {gcs_uri}")
        return exists
    except Exception as e:
        logger.error(f"Error verifying GCS {gcs_uri}: {e}")
        return False

def verify_file_api_resource_exists(storage_client, file_api_name: str) -> bool:
    if not storage_client or not file_api_name: return False
    try:
        _ = storage_client.get(name=file_api_name) # Sync get for verification
        return True
    except Exception as e:
        logger.error(f"Error verifying File API {file_api_name}: {e}")
        return False

def verify_local_file_exists(local_path: str) -> bool:
    exists = Path(local_path).is_file() if local_path else False
    if not exists: logger.warning(f"Local verify failed: {local_path}")
    return exists

# --- Prompt Building ---
def build_prompt(question_info: dict) -> str:
    question = question_info.get("question", "")
    q_type = question_info.get("question_type", "default")
    template = PROMPT_TEMPLATES.get(q_type, PROMPT_TEMPLATES["default"])
    # if q_type is MCQ
    if q_type == "Multiple-choice Question with a Single Correct Answer":
        return template.format(question=question).strip() + "\n" + "E. None of the above"
    return template.format(question=question).strip() + "\n" + question_info.get("question_prompt").strip()

# --- Rate Limiter ---
class AsyncRateLimiter:
    """
    An asyncio-compatible token bucket rate limiter.

    Args:
        rate (int): The maximum number of requests allowed per period.
        period (float): The time period in seconds (default: 60 for RPM).
        capacity (int, optional): The maximum burst capacity. Defaults to `rate`.
    """
    def __init__(self, rate: int, period: float = 60.0, capacity: Optional[int] = None):
        if rate <= 0:
            raise ValueError("Rate must be positive")
        if period <= 0:
            raise ValueError("Period must be positive")

        self.rate = rate
        self.period = float(period)
        self.capacity = float(capacity if capacity is not None else rate)
        self._tokens = self.capacity # Start full
        self._last_refill_time = time.monotonic()
        self._lock = asyncio.Lock()

    def _get_tokens_per_second(self) -> float:
        return self.rate / self.period

    async def _refill(self):
        """Replenishes tokens based on elapsed time. Must be called under lock."""
        now = time.monotonic()
        elapsed = now - self._last_refill_time
        if elapsed > 0:
            tokens_to_add = elapsed * self._get_tokens_per_second()
            self._tokens = min(self.capacity, self._tokens + tokens_to_add)
            self._last_refill_time = now

    async def acquire(self):
        """
        Acquires a token, waiting if necessary.
        """
        async with self._lock:
            await self._refill() # Refill based on time since last acquire/refill

            while self._tokens < 1:
                # Calculate how long to wait for 1 token
                tokens_needed = 1.0 - self._tokens
                wait_time = tokens_needed / self._get_tokens_per_second()

                # Release the lock before sleeping
                lock_released = True
                try:
                    self._lock.release()
                    logger.debug(f"Rate limit hit. Waiting for {wait_time:.3f}s for next token.")
                    await asyncio.sleep(wait_time)
                finally:
                    # Re-acquire the lock if it was released
                    if lock_released:
                        await self._lock.acquire()

                # Refill again after waiting, as more time has passed
                await self._refill()

            # Consume a token
            self._tokens -= 1.0


# Testing UI

## Single Prompt Single Question Testing UI

In [None]:
def perform_inference_single_sync(question_info: Dict, client: Any) -> Dict[str, Any]:
    qid = question_info.get("qid", "?")
    prompt_text = build_prompt(question_info)
    gcs_uri = question_info.get("gcs_uri")
    file_api_name = question_info.get("file_api_name")
    start_time = time.time()
    result = None

    try:
        video_part = None
        if USE_VERTEX:
            if not gcs_uri: raise ValueError("Missing GCS URI.")
            video_part = types.Part.from_uri(mime_type="video/mp4", file_uri=gcs_uri)
        else:
            if not file_api_name: raise ValueError("Missing File API name.")
            try:
                # Fetch File object sync
                file_object = client.files.get(name=file_api_name)
                video_part = file_object
            except genai_errors.NotFoundError: raise FileNotFoundError(f"File API '{file_api_name}' not found.")
            except Exception as e: raise RuntimeError(f"Failed get File API obj: {e}")

        question_content = types.Content(
            role="user", 
            parts=[types.Part.from_text(text=prompt_text)]
        )
        contents = [
            question_content,
            video_part
        ]
            

    except (ValueError, FileNotFoundError, RuntimeError) as e:
        logger.error(f"QID {qid} (Sync): Input Error - {e}")
        return {"qid": qid, "pred": f"ERROR: Input Fail - {e}", "duration": 0, "status": "Failed (Input)"}

    # Inference with Retries (Sync)
    for attempt in range(MAX_RETRIES + 1):
        try:
            api_start = time.time()
            # Use sync client.models
            response = client.models.generate_content(
                model=MODEL_NAME,
                contents=contents,
                config=CONFIG,
            )
            answer, reason, status, err_detail = "ERROR", "UNKNOWN", "Success", ""
            try: # Process Response
                answer = response
                if response.candidates: reason = response.candidates[0].finish_reason.name
            except ValueError as ve:
                status, err_detail = "Blocked/Empty", f"ValueError: {ve}. "
                answer = f"ERROR: {status}. {err_detail}"
            result = {"qid": qid, "pred": answer, "duration": time.time()-start_time, "finish_reason": reason, "status": status}
            return result
        except (api_core_exceptions.ResourceExhausted) as e:
             if attempt < MAX_RETRIES: time.sleep(INITIAL_BACKOFF_SECONDS * (2**attempt))
             else: result = {"qid": qid, "pred": f"ERROR: Max Retries ({type(e).__name__}) - {e}", "duration": time.time()-start_time, "status": "Failed (Retries)"}; return result
        except genai_errors.APIError as e:
             result = {"qid": qid, "pred": f"ERROR: GenAI APIError - {e}", "duration": time.time()-start_time, "status": "Failed (API Error)"}; return result
        except Exception as e:
             result = {"qid": qid, "pred": f"ERROR: Unexpected - {e}", "duration": time.time()-start_time, "status": "Failed (Unexpected)"}; return result
    # Fallback
    result = {"qid": qid, "pred": "ERROR: Unknown after retries", "duration": time.time()-start_time, "status": "Failed (Unknown)"}
    return result

# --- UI Setup (using `google.genai` types where needed) --- #
ui_video_questions = {}
try: ui_video_questions = load_metadata_for_inference(METADATA_FILE)
except Exception as e: display(Markdown(f"❌ UI Load Error: {e}"))

# (Widgets setup remains the same)
video_options = [("Select video...", None)]
if ui_video_questions: video_options.extend(sorted([(f"{vid} ({len(qs)}q)", vid) for vid, qs in ui_video_questions.items()]))
video_selector = widgets.Dropdown(options=video_options, description='Video ID:', disabled=not ui_video_questions, style={'description_width': 'initial'})
question_selector = widgets.Dropdown(options=[("Select question...", None)], description='Question (QID):', disabled=True, layout=widgets.Layout(width='95%'), style={'description_width': 'initial'})
run_button = widgets.Button(description='Run Inference', disabled=True, button_style='primary', icon='play')
output_area = widgets.Output()

# --- Widget Interaction Logic (Remains mostly the same, calls updated sync function) --- #
def on_video_selected(change):
    # (Same logic)
    selected_video_id = change['new']
    question_selector.options = [("Select question...", None)]
    question_selector.value = None
    question_selector.disabled = True
    run_button.disabled = True
    output_area.clear_output()
    if selected_video_id and selected_video_id in ui_video_questions:
        questions = ui_video_questions[selected_video_id]
        question_options = sorted([(f"{q.get('qid', 'N/A')}: {q.get('question', '')[:80]}...", q) for q in questions if q.get('qid')])
        if question_options:
            question_selector.options = [("Select question...", None)] + question_options
            question_selector.disabled = False

def on_question_selected(change):
    # (Same logic)
    selected_question_info = change['new']
    run_button.disabled = selected_question_info is None
    output_area.clear_output()
    if selected_question_info:
        with output_area:
            display(Markdown("### Selected Info"))
            id_col = 'gcs_uri' if USE_VERTEX else 'file_api_name'
            display(pd.Series({
                 'qid': selected_question_info.get('qid'),
                 'question': selected_question_info.get('question'),
                 f'Resource ({id_col})': selected_question_info.get(id_col)
            }).to_frame('Value'))

def on_run_button_clicked(b):
    run_button.disabled = True
    output_area.clear_output()
    with output_area:
        if not video_selector.value or not question_selector.value: display(Markdown("❌ Select video & question.")); run_button.disabled = False; return
        if ai_client is None: display(Markdown("❌ AI Client not ready.")); run_button.disabled = False; return

        q_info = question_selector.value
        qid = q_info.get('qid')
        resource_col = 'gcs_uri' if USE_VERTEX else 'file_api_name'
        resource_id = q_info.get(resource_col)

        display(Markdown(f"### Running QID: {qid}"))
        display(Markdown(f"--- Verifying Resource --- ({'GCS' if USE_VERTEX else 'File API'}) ---"))
        verified = False
        if not resource_id: display(Markdown(f"❌ Error: Missing '{resource_col}' ID."))
        elif USE_VERTEX: verified = verify_gcs_file_exists(storage_client, resource_id)
        else: verified = verify_file_api_resource_exists(storage_client, resource_id)

        if not verified: display(Markdown("❌ Verification failed.")); run_button.disabled = False; return
        display(Markdown(f"✅ Resource Verified: {resource_id}"))

        display(Markdown("Video Preview:"))
        video_path = Path(extracted_videos_path) / f"{video_selector.value}.mp4"
        if video_path.is_file():
            video_widget = widgets.Video.from_file(video_path, width=400, height=300)
            display(video_widget)

        display(Markdown("### Inference Details"))
        display(Markdown(f"**Video ID:** {video_selector.value} | **QID:** {qid}"))
        display(Markdown(f"**Resource ID ({resource_col}):** {resource_id}"))

        prompt = build_prompt(q_info)
        display(Markdown("### Prompt"))
        display(Markdown(f"**Question:** {q_info.get('question', '')}"))
        display(Markdown(f"**Question Type:** {q_info.get('question_type', 'default')}"))
        display(Markdown(f"**Prompt Template:** {PROMPT_TEMPLATES.get(q_info.get('question_type', 'default'), PROMPT_TEMPLATES['default'])}"))
        display(Markdown(f"**System Prompt:** {SYSTEM_PROMPT}"))
        display(HTML(f"<pre style='white-space: pre-wrap; border: 1px solid #000; padding: 10px;'>{prompt}</pre>"))
        display(Markdown("--- Performing Inference (Sync) ---"))

        # CALL THE CORRECTED SYNC FUNCTION
        inference_result = perform_inference_single_sync(q_info, ai_client)

        display(Markdown("--- Result ---"))
        if inference_result and isinstance(inference_result, dict):
             # (Result display + Save button logic remains the same)
            status, duration, answer, reason = (
                inference_result.get("status", "?"), inference_result.get('duration', -1),
                inference_result.get('pred', ''), inference_result.get('finish_reason', 'N/A')
            )
            display(Markdown(f"**Status:** {status} | **Duration:** {duration:.2f}s | **Reason:** {reason}"))
            display(Markdown("**Response:**"))
            display(HTML(f"<pre style='white-space: pre-wrap; border: 1px solid #000; padding: 10px;'>{answer}</pre>"))
        else:
            display(Markdown("❌ Invalid result."))

    run_button.disabled = False

# Register & Display
video_selector.observe(on_video_selected, names='value')
question_selector.observe(on_question_selected, names='value')
run_button.on_click(on_run_button_clicked)
display(Markdown("### Select Video & Question")); 
display(video_selector)
display(question_selector)
display(run_button)
display(output_area)

## Generated Questions Prompt chaining Testing UI - turn by turn Format

In [None]:
# --- Step 1: Fetch Generated Questions for each videos ---
import ast

def get_questions_for_video(questions_file: str = QUESTIONS_DIR) -> Dict[str, List[str]]:
    if not Path(questions_file).is_file(): return {}
    video_questions = defaultdict(list)
    try:
        df = pd.read_csv(questions_file, dtype=str).fillna('')
        if 'video_id' not in df.columns or 'questions' not in df.columns:
            logger.error(f"Questions file missing 'video_id' or 'question'.")
            return {}
        valid_df = df[df['video_id'].astype(bool) & df['questions'].astype(bool)]
        if len(valid_df) == 0:
             logger.warning(f"No videos found with questions in {questions_file}. Check Step 4.")
             return {}
        for video_id in valid_df['video_id'].unique():
            question_str = valid_df[valid_df['video_id'] == video_id]['questions'].iloc[0]
            questions_list = ast.literal_eval(question_str)
            video_questions[video_id] = questions_list
        logger.info(f"Loaded {len(video_questions)} videos ({len(valid_df)} questions) with valid IDs for inference.")
        return dict(video_questions)
    except Exception as e:
        logger.error(f"Error loading questions: {e}", exc_info=True)
        return {}

generated_questions_dict = get_questions_for_video(QUESTIONS_DIR)

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 1⃣  Request renderer – handles list + File objects safely
# ──────────────────────────────────────────────────────────────────────────────
def _render_request(contents: list, turn_idx: int):
    """
    Pretty‑print the payload sent to the model.

    Handles:
    • `types.Part` (binary video or textual description)
    • `types.Content` (normal chat messages)
    • plain File objects (Vertex File API)
    • nested lists of Parts (video_part list in Vertex branch)
    """
    # flatten in case the first element itself is a list of Parts
    flat: list[Any] = []
    for item in contents:
        if isinstance(item, list):
            flat.extend(item)
        else:
            flat.append(item)

    lines = []
    for c in flat:
        if isinstance(c, types.Part):
            # text Part vs. binary video Part
            if getattr(c, "text", None):
                lines.append(f"video_desc: {c.text}")
            else:
                lines.append("[video/mp4]")
        elif hasattr(c, "parts"):                           # types.Content
            role = getattr(c, "role", "?")
            txt = getattr(c.parts[0], "text", "")
            lines.append(f"{role}: {txt}")
        else:                                               # plain File
            lines.append("[file]")

    joined_lines = '\n'.join(lines)
    display(HTML(
        f"<div style='border:1px dashed #999;padding:8px;margin:8px 0;'>"
        f"<strong>Turn {turn_idx} – Request sent to model:</strong>"
        f"<pre style='white-space:pre-wrap;margin:4px 0 0;'>{joined_lines}</pre>"
        f"</div>"
    ))


In [None]:
def perform_inference_single_sync(question_info: Dict, client: Any) -> Dict[str, Any]:
    qid = question_info.get("qid", "?")
    video_id = question_info.get("video_id", "?")
    prompt_text = build_prompt(question_info)
    gcs_uri = question_info.get("gcs_uri")
    file_api_name = question_info.get("file_api_name")
    start_time = time.time()
    result = None
    

    questions_list = generated_questions_dict.get(video_id, [])[:3]        
        
    questions_list.append(prompt_text)
            
    try:
        if USE_VERTEX:
            if not gcs_uri: raise ValueError("Missing GCS URI.")
            video_part = types.Part.from_uri(mime_type="video/mp4", file_uri=gcs_uri)
        else:
            if not file_api_name: raise ValueError("Missing File API name.")
            try:
                # Fetch File object sync
                file_object = client.files.get(name=file_api_name)
                video_part = file_object
            except genai_errors.NotFoundError: raise FileNotFoundError(f"File API '{file_api_name}' not found.")
            except Exception as e: raise RuntimeError(f"Failed get File API obj: {e}")
    except Exception as e:
        display(Markdown(f"❌ Video resource error: {e}")); return

    chat: list = []

        
    
    
    for idx, q in enumerate(questions_list, 1):
        user_msg = types.Content(role="user",
                                 parts=[types.Part.from_text(text=q)])
        # always include video_part
        contents = [video_part] + chat + [user_msg]

        _render_request(contents, idx)

        try:
            rsp = client.models.generate_content(model=MODEL_NAME,
                                                 contents=contents,
                                                 config=CONFIG)
            answer = rsp.text.strip()
            if answer:
              summary_content = types.Content(
              role="user",
              parts=[
                  types.Part.from_text(text=q),
                  types.Part.from_text(text=answer)
              ])
              summary_rsp = client.models.generate_content(model=QUESTION_MODEL_NAME,
                                    contents=summary_content,
                                    config=QUESTION_CONFIG)
              sumamry_answer = summary_rsp.text.strip()
            finish_reason = (rsp.candidates[0].finish_reason.name
                             if rsp.candidates and rsp.candidates[0].finish_reason else "UNKNOWN")
        except Exception as e:
            answer, finish_reason = f"ERROR: {e}", "Failed (API)"

        display(HTML(
            f"<div style='border:1px solid #000;margin:8px 0;padding:10px;'>"
            f"<b>Q{idx}:</b> {q}</div>"
        ))
        display(HTML(
            f"<div style='border:1px solid #000;margin:0 0 12px;"
            f"padding:10px;background:#f9f9f9;'><b>CoT:</b> {answer}</div>"
        ))
        display(HTML(
            f"<div style='border:1px solid #000;margin:0 0 12px;"
            f"padding:10px;background:#f9f9f9;'><b>Summary Answer:</b> {sumamry_answer}</div>"
        ))

        chat.extend([
            user_msg,
            types.Content(role="model",
                          parts=[types.Part.from_text(text=answer or "…")])
        ])
    
    final_answer = chat[-1].parts[0].text if chat else "No answer"
    
    # def serialize_chat(chat):
    #     return [
    #         {
    #             "role": msg.role,
    #             "parts": [part.text for part in msg.parts]
    #         } for msg in chat
    #     ]

    # with open("chat_history.json", "w") as f:
    #     json.dump(serialize_chat(chat), f, indent=2)
    
    result = {"qid": qid, "pred": final_answer, "duration": time.time()-start_time, "status": "Successful"}
    return result


# --- UI Setup (using `google.genai` types where needed) --- #
ui_video_questions = {}
try: ui_video_questions = load_metadata_for_inference(METADATA_FILE)
except Exception as e: display(Markdown(f"❌ UI Load Error: {e}"))

# (Widgets setup remains the same)
video_options = [("Select video...", None)]
if ui_video_questions: video_options.extend(sorted([(f"{vid} ({len(qs)}q)", vid) for vid, qs in ui_video_questions.items()]))
video_selector = widgets.Dropdown(options=video_options, description='Video ID:', disabled=not ui_video_questions, style={'description_width': 'initial'})
question_selector = widgets.Dropdown(options=[("Select question...", None)], description='Question (QID):', disabled=True, layout=widgets.Layout(width='95%'), style={'description_width': 'initial'})
run_button = widgets.Button(description='Run Inference', disabled=True, button_style='primary', icon='play')
output_area = widgets.Output()

# --- Widget Interaction Logic (Remains mostly the same, calls updated sync function) --- #
def on_video_selected(change):
    # (Same logic)
    selected_video_id = change['new']
    question_selector.options = [("Select question...", None)]
    question_selector.value = None
    question_selector.disabled = True
    run_button.disabled = True
    output_area.clear_output()
    if selected_video_id and selected_video_id in ui_video_questions:
        questions = ui_video_questions[selected_video_id]
        question_options = sorted([(f"{q.get('qid', 'N/A')}: {q.get('question', '')[:80]}...", q) for q in questions if q.get('qid')])
        if question_options:
            question_selector.options = [("Select question...", None)] + question_options
            question_selector.disabled = False

def on_question_selected(change):
    # (Same logic)
    selected_question_info = change['new']
    run_button.disabled = selected_question_info is None
    output_area.clear_output()
    if selected_question_info:
        with output_area:
            display(Markdown("### Selected Info"))
            id_col = 'gcs_uri' if USE_VERTEX else 'file_api_name'
            display(pd.Series({
                 'qid': selected_question_info.get('qid'),
                 'question': selected_question_info.get('question'),
                 f'Resource ({id_col})': selected_question_info.get(id_col)
            }).to_frame('Value'))

def on_run_button_clicked(b):
    run_button.disabled = True
    output_area.clear_output()
    with output_area:
        if not video_selector.value or not question_selector.value: display(Markdown("❌ Select video & question.")); run_button.disabled = False; return
        if ai_client is None: display(Markdown("❌ AI Client not ready.")); run_button.disabled = False; return

        q_info = question_selector.value
        qid = q_info.get('qid')
        resource_col = 'gcs_uri' if USE_VERTEX else 'file_api_name'
        resource_id = q_info.get(resource_col)

        display(Markdown(f"### Running QID: {qid}"))
        display(Markdown(f"--- Verifying Resource --- ({'GCS' if USE_VERTEX else 'File API'}) ---"))
        verified = False
        if not resource_id: display(Markdown(f"❌ Error: Missing '{resource_col}' ID."))
        elif USE_VERTEX: verified = verify_gcs_file_exists(storage_client, resource_id)
        else: verified = verify_file_api_resource_exists(storage_client, resource_id)

        if not verified: display(Markdown("❌ Verification failed.")); run_button.disabled = False; return
        display(Markdown(f"✅ Resource Verified: {resource_id}"))

        display(Markdown("Video Preview:"))
        video_path = Path(extracted_videos_path) / f"{video_selector.value}.mp4"
        if video_path.is_file():
            video_widget = widgets.Video.from_file(video_path, width=400, height=300)
            display(video_widget)

        display(Markdown("### Inference Details"))
        display(Markdown(f"**Video ID:** {video_selector.value} | **QID:** {qid}"))
        display(Markdown(f"**Resource ID ({resource_col}):** {resource_id}"))

        prompt = build_prompt(q_info)
        display(Markdown("### Prompt"))
        display(Markdown(f"**Question:** {q_info.get('question', '')}"))
        display(Markdown(f"**Question Type:** {q_info.get('question_type', 'default')}"))
        display(Markdown(f"**Prompt Template:** {PROMPT_TEMPLATES.get(q_info.get('question_type', 'default'), PROMPT_TEMPLATES['default'])}"))
        display(Markdown(f"**System Prompt:** {SYSTEM_PROMPT}"))
        display(HTML(f"<pre style='white-space: pre-wrap; border: 1px solid #000; padding: 10px;'>{prompt}</pre>"))
        display(Markdown("--- Performing Inference (Sync) ---"))

        # CALL THE CORRECTED SYNC FUNCTION
        inference_result = perform_inference_single_sync(q_info, ai_client)

        display(Markdown("###--- Result ---"))
        if inference_result and isinstance(inference_result, dict):
             # (Result display + Save button logic remains the same)
            status, duration, answer, reason = (
                inference_result.get("status", "?"), inference_result.get('duration', -1),
                inference_result.get('pred', ''), inference_result.get('finish_reason', 'N/A')
            )
            display(Markdown(f"**Status:** {status} | **Duration:** {duration:.2f}s | **Reason:** {reason}"))
            display(Markdown("**Response:**"))
            display(HTML(f"<pre style='white-space: pre-wrap; border: 1px solid #000; padding: 10px;'>{answer}</pre>"))
        else:
            display(Markdown("❌ Invalid result."))

    run_button.disabled = False

# Register & Display
video_selector.observe(on_video_selected, names='value')
question_selector.observe(on_question_selected, names='value')
run_button.on_click(on_run_button_clicked)
display(Markdown("### Select Video & Question")); 
display(video_selector)
display(question_selector)
display(run_button)
display(output_area)