# Retrieve Batch Results from Gemini API

This notebook retrieves completed batch results from the Google Gemini API.

## Prerequisites
- Set the `GEMINI_API_KEY` environment variable
- Have a completed batch job from `create_batch_gemini.ipynb`
- Know the path to your batch log file (created during batch submission)

## Features
- Downloads results from Gemini File API
- Parses JSONL-formatted results
- Tracks token usage statistics
- Merges with existing results for incremental collection
- Reports errors and missing results

## Output Format
```json
{
  "meta_data": {"file_name": "...", "inference_model": "..."},
  "token_stats": {"total_input_tokens": ..., "total_output_tokens": ...},
  "results": {"case_001": "response text", ...}
}
```

In [None]:
# ============================================================================
# CONFIGURATION - Edit these variables before running
# ============================================================================
from pathlib import Path

# Base directory (should match the one used in create_batch_gemini.ipynb)
BASE_DIR = Path(".")  # Change to your data directory

# Settings from batch creation (must match what you used)
INPUT_DIR = "data"  # Directory that contained input JSON
INPUT_FILE = "your_input_file"  # Name without .json extension
MODEL_NAME = "gemini-2.5-pro"  # Model used for batch

# Output directory for results
OUTPUT_DIR = BASE_DIR / "output"
LOGS_DIR = BASE_DIR / "logs"

# ============================================================================

In [None]:
import json
import os
from google import genai

# Build paths
log_path = LOGS_DIR / INPUT_DIR / MODEL_NAME / f"{INPUT_FILE}.json"
output_path = OUTPUT_DIR / INPUT_DIR / MODEL_NAME / f"{INPUT_FILE}.json"

# Load log file to get batch_job_name and key_dict
with open(log_path, 'r') as f:
    data = json.load(f)

key_dict = data["key_dict"]
batch_job_name = data['batch_job']['name']

print(f"Log file: {log_path}")
print(f"Batch job name: {batch_job_name}")
print(f"Model: {MODEL_NAME}")
print(f"Expected results: {len(key_dict)}")
print(f"Output path: {output_path}")

# Load existing results if any
existing_dict = {}
if output_path.exists():
    with open(output_path, 'r') as f:
        existing_dict = json.load(f)
    print(f"Found existing results: {len(existing_dict.get('results', {}))}")

# Initialize Gemini client
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))

In [None]:
def merge_nested_dicts(dict1, dict2):
    """Merge two dictionaries, keeping existing values on conflict."""
    result = dict1.copy()
    for key, value in dict2.items():
        if key not in result:
            result[key] = value
        elif isinstance(value, dict) and isinstance(result[key], dict):
            result[key] = merge_nested_dicts(result[key], value)
    return result

In [None]:
# Check if we already have all results
if existing_dict:
    existing_keys = set(existing_dict.get("results", {}).keys())
    expected_keys = set(key_dict.values())
    
    print(f"Expected results: {len(expected_keys)}")
    print(f"Existing results: {len(existing_keys)}")
    
    if existing_keys == expected_keys:
        print(f"\nAll results already collected!")
        print(f"Results file: {output_path}")
        raise SystemExit("All results already collected.")
    elif len(existing_keys) > 0:
        missing = expected_keys - existing_keys
        print(f"Missing {len(missing)} results from existing file")

In [None]:
# Check batch status
batch_job = client.batches.get(name=batch_job_name)
print(f"Batch job: {batch_job.name}")
print(f"State: {batch_job.state.name}")

if batch_job.state.name != "JOB_STATE_SUCCEEDED":
    print(f"\nBatch not yet completed. Status: {batch_job.state.name}")
    if batch_job.state.name == "JOB_STATE_FAILED" and hasattr(batch_job, 'error'):
        print(f"Error: {batch_job.error}")
    raise SystemExit("Batch not completed. Try again later.")

print("\nBatch completed! Retrieving results...")

In [None]:
# Download and parse results
if not (batch_job.dest and batch_job.dest.file_name):
    raise ValueError("No result file found in batch job")

result_file_name = batch_job.dest.file_name
print(f"Downloading results from: {result_file_name}")

file_content = client.files.download(file=result_file_name)
result_text = file_content.decode('utf-8')

# Parse JSONL results
result_dict = {}
processed_keys = set()
errors_dict = {}

# Token counters
total_input_tokens = 0
total_output_tokens = 0
num_prompts = 0

for line in result_text.strip().split('\n'):
    if not line.strip():
        continue
    
    try:
        result_obj = json.loads(line)
    except json.JSONDecodeError as e:
        print(f"Failed to parse line: {e}")
        continue
    
    key = result_obj.get("key")
    if key is None:
        continue
    
    processed_keys.add(key)
    
    # Check for errors
    if "error" in result_obj:
        result_key = key_dict.get(key)
        if result_key:
            errors_dict[result_key] = result_obj["error"]
        continue
    
    if "response" not in result_obj:
        continue
    
    response = result_obj["response"]
    result_key = key_dict.get(key)
    
    if not result_key:
        continue
    
    # Extract text from candidates
    text_content = ""
    if "candidates" in response and len(response["candidates"]) > 0:
        candidate = response["candidates"][0]
        if "content" in candidate and "parts" in candidate["content"]:
            for part in candidate["content"]["parts"]:
                if "text" in part:
                    text_content += part["text"]
    
    if text_content:
        result_dict[result_key] = text_content
        
        # Extract token usage
        if "usageMetadata" in response:
            usage = response["usageMetadata"]
            total_input_tokens += usage.get("promptTokenCount", 0)
            total_output_tokens += usage.get("candidatesTokenCount", 0)
            num_prompts += 1

# Calculate statistics
token_stats = {
    "total_input_tokens": total_input_tokens,
    "total_output_tokens": total_output_tokens,
    "num_prompts": num_prompts,
    "avg_input_tokens": total_input_tokens / num_prompts if num_prompts > 0 else 0,
    "avg_output_tokens": total_output_tokens / num_prompts if num_prompts > 0 else 0
}

# Report results
missing_keys = set(key_dict.keys()) - processed_keys
print(f"\nTotal submitted: {len(key_dict)}")
print(f"Successful: {len(result_dict)}")
print(f"Errored: {len(errors_dict)}")
print(f"Missing: {len(missing_keys)}")

print(f"\n=== Token Statistics ===")
print(f"Total input tokens: {token_stats['total_input_tokens']:,}")
print(f"Total output tokens: {token_stats['total_output_tokens']:,}")
print(f"Average input tokens: {token_stats['avg_input_tokens']:.1f}")
print(f"Average output tokens: {token_stats['avg_output_tokens']:.1f}")

In [None]:
# Merge with existing results and save
if existing_dict:
    result_dict = merge_nested_dicts(existing_dict.get("results", {}), result_dict)
    
    # Merge token stats
    if "token_stats" in existing_dict:
        old_stats = existing_dict["token_stats"]
        token_stats["total_input_tokens"] += old_stats.get("total_input_tokens", 0)
        token_stats["total_output_tokens"] += old_stats.get("total_output_tokens", 0)
        token_stats["num_prompts"] += old_stats.get("num_prompts", 0)
        if token_stats["num_prompts"] > 0:
            token_stats["avg_input_tokens"] = token_stats["total_input_tokens"] / token_stats["num_prompts"]
            token_stats["avg_output_tokens"] = token_stats["total_output_tokens"] / token_stats["num_prompts"]
    print("Merged with existing results")

# Create output directory and save
output_path.parent.mkdir(parents=True, exist_ok=True)

final_dict = {
    "meta_data": {
        "file_name": INPUT_FILE,
        "inference_model": MODEL_NAME
    },
    "token_stats": token_stats,
    "results": result_dict
}

with open(output_path, "w") as f:
    json.dump(final_dict, f, indent=4)

print(f"\nSaved results to: {output_path}")
print(f"Total cases with results: {len(result_dict)}")