# Smelt AI — Live Demo

Interactive walkthrough of smelt-ai using **Google Gemini**.

1. Test run — quick single-row validation
2. Basic classification
3. Sentiment analysis with score validation
4. Support ticket triage (complex schema)
5. Parameter tuning (temperature)
6. Batch configuration (batch_size, concurrency)
7. Error handling (stop_on_exhaustion)
8. Company summaries

## Setup

In [None]:
import os
import csv
from pathlib import Path
from typing import Literal

from dotenv import load_dotenv
from pydantic import BaseModel, Field

from smelt import Model, Job, SmeltResult
from smelt.errors import SmeltExhaustionError

load_dotenv()

GEMINI_KEY = os.getenv("GEMINI_API_KEY")
print(f"Gemini key: {'set' if GEMINI_KEY else 'MISSING'}")

model = Model(provider="google_genai", name="gemini-3-flash-preview", api_key=GEMINI_KEY)

## Load Test Data

In [None]:
DATA_DIR = Path("../tests/data")


def load_csv(filename: str) -> list[dict[str, str]]:
    """Load CSV from tests/data directory."""
    with open(DATA_DIR / filename, newline="", encoding="utf-8") as f:
        return list(csv.DictReader(f))


companies = load_csv("companies.csv")
products = load_csv("products.csv")
tickets = load_csv("support_tickets.csv")

print(f"Companies: {len(companies)} rows")
print(f"Products:  {len(products)} rows")
print(f"Tickets:   {len(tickets)} rows")
print()
print("Sample company:", companies[0])

## Define Output Models

In [None]:
class IndustryClassification(BaseModel):
    sector: str = Field(description="Primary industry sector")
    sub_sector: str = Field(description="More specific sub-sector")
    is_public: bool = Field(description="Whether the company is publicly traded")


class SentimentAnalysis(BaseModel):
    sentiment: Literal["positive", "negative", "mixed"] = Field(description="Overall sentiment")
    score: float = Field(description="Score from 0.0 (negative) to 1.0 (positive)")
    key_themes: list[str] = Field(description="Main themes in the review (1-3 items)")


class TicketTriage(BaseModel):
    category: str = Field(description="Category: billing, technical, shipping, account, or general")
    priority: Literal["low", "medium", "high", "urgent"] = Field(description="Priority level")
    requires_human: bool = Field(description="Whether human escalation is needed")
    suggested_response: str = Field(description="Brief suggested response to the customer")


class CompanySummary(BaseModel):
    one_liner: str = Field(description="One sentence description")
    industry: str = Field(description="Primary industry")
    company_size: Literal["startup", "small", "medium", "large", "enterprise"] = Field(
        description="Size classification based on employee count"
    )
    age_years: int = Field(description="Approximate age in years")


print("Output models defined.")

## Helper: Pretty-Print Results

In [4]:
def show_result(label: str, result: SmeltResult) -> None:
    """Pretty-print a SmeltResult."""
    status = "SUCCESS" if result.success else "FAILED"
    m = result.metrics
    print(f"\n{'='*70}")
    print(f"  {label}")
    print(f"  Status: {status}")
    print(f"  Rows: {m.successful_rows}/{m.total_rows} successful")
    print(f"  Batches: {m.successful_batches}/{m.total_batches} successful")
    print(f"  Tokens: {m.input_tokens:,} in / {m.output_tokens:,} out")
    print(f"  Retries: {m.total_retries} | Time: {m.wall_time_seconds:.2f}s")
    if result.errors:
        print(f"  Errors: {len(result.errors)}")
        for e in result.errors:
            print(f"    - Batch {e.batch_index}: {e.error_type} ({e.attempts} attempts)")
    print(f"{'='*70}")
    print()
    for i, row in enumerate(result.data):
        print(f"  [{i}] {row}")
    if len(result.data) > 3:
        print(f"  ... and {len(result.data) - 3} more rows")

---
## 1. Test Run — Quick Single-Row Validation

Use `job.atest()` to validate your setup before a full run. Sends only the first row — ignores batch_size, concurrency, and shuffle settings.

In [None]:
job = Job(
    prompt="Classify each company by its primary industry sector and sub-sector. "
    "Determine if the company is publicly traded.",
    output_model=IndustryClassification,
    batch_size=20,
    concurrency=5,
    shuffle=True,
)

# Quick test — only the first row is sent, batch_size/concurrency/shuffle are ignored
result = await job.atest(model, data=companies)
show_result("Test Run — Single Row Classification", result)

---
## 2. Basic Classification

Classify all 10 companies by industry sector.

In [None]:
job = Job(
    prompt="Classify each company by its primary industry sector and sub-sector. "
    "Determine if the company is publicly traded.",
    output_model=IndustryClassification,
    batch_size=10,
    stop_on_exhaustion=False,
)

result = await job.arun(model, data=companies)
show_result("Company Classification", result)

---
## 3. Sentiment Analysis — Score Validation

Analyze product reviews and verify scores are in [0, 1] range.

In [None]:
job = Job(
    prompt="Analyze the sentiment of each product's customer_review. "
    "Identify the overall sentiment, assign a score between 0.0 and 1.0, "
    "and extract 1-3 key themes.",
    output_model=SentimentAnalysis,
    batch_size=5,
    concurrency=2,
    stop_on_exhaustion=False,
)

result = await job.arun(model, data=products)
show_result("Sentiment Analysis", result)

print("\nScore validation:")
for i, row in enumerate(result.data):
    in_range = 0.0 <= row.score <= 1.0
    print(f"  [{i}] score={row.score:.2f} sentiment={row.sentiment:8s} valid={in_range} themes={row.key_themes}")

---
## 4. Support Ticket Triage — Complex Schema

Tests Literal types, booleans, and longer text generation.

In [None]:
job = Job(
    prompt="Triage each support ticket. Classify by category (billing, technical, "
    "shipping, account, or general), assign priority, determine if human escalation "
    "is needed, and write a brief suggested response.",
    output_model=TicketTriage,
    batch_size=5,
    concurrency=2,
    stop_on_exhaustion=False,
)

result = await job.arun(model, data=tickets)
show_result("Ticket Triage", result)

print("\nFull triage results:")
for i, row in enumerate(result.data):
    print(f"\n  [{i}] {tickets[i]['ticket_id']}")
    print(f"      Category: {row.category} | Priority: {row.priority} | Human: {row.requires_human}")
    print(f"      Response: {row.suggested_response[:100]}...")

---
## 5. Parameter Tuning — Temperature Comparison

Compare temperature=0 (deterministic) vs temperature=1.0 (creative) on the same task.

In [None]:
data_subset = companies[:3]

for temp in [0, 0.5, 1.0]:
    m = Model(
        provider="google_genai", name="gemini-3-flash-preview", api_key=GEMINI_KEY,
        params={"temperature": temp},
    )
    job = Job(
        prompt="Classify each company by industry sector.",
        output_model=IndustryClassification,
        batch_size=10,
        stop_on_exhaustion=False,
    )
    result = await job.arun(m, data=data_subset)
    show_result(f"temp={temp}", result)

---
## 6. Batch Configuration — Size & Concurrency

Compare different batch_size and concurrency settings on the same dataset.

In [None]:
configs = [
    {"batch_size": 10, "concurrency": 1, "label": "1 batch, serial"},
    {"batch_size": 5, "concurrency": 2, "label": "2 batches, conc=2"},
    {"batch_size": 2, "concurrency": 5, "label": "5 batches, conc=5"},
    {"batch_size": 1, "concurrency": 10, "label": "10 batches, conc=10"},
]

for cfg in configs:
    job = Job(
        prompt="Classify each company by industry sector.",
        output_model=IndustryClassification,
        batch_size=cfg["batch_size"],
        concurrency=cfg["concurrency"],
        stop_on_exhaustion=False,
    )
    result = await job.arun(model, data=companies)
    show_result(f"Config: {cfg['label']} (batch={cfg['batch_size']}, conc={cfg['concurrency']})", result)

    assert len(result.data) == len(companies), f"Row count mismatch"
    print(f"  Row ordering verified: {len(result.data)} rows in correct order")

---
## 7. Error Handling — stop_on_exhaustion

Demonstrate `stop_on_exhaustion=True` (raises on failure) vs `stop_on_exhaustion=False` (collects errors).

In [None]:
# stop_on_exhaustion=False: errors are collected, successful batches still returned
job = Job(
    prompt="Create a concise structured summary for each company. "
    "Calculate age based on founded year (current year is 2026).",
    output_model=CompanySummary,
    batch_size=5,
    concurrency=2,
    max_retries=2,
    stop_on_exhaustion=False,
)

result = await job.arun(model, data=companies)
show_result("Company Summary (stop_on_exhaustion=False)", result)

print(f"\nsuccess property: {result.success}")
print(f"result.data has {len(result.data)} rows")
print(f"result.errors has {len(result.errors)} errors")

In [None]:
# stop_on_exhaustion=True — should succeed without raising
job = Job(
    prompt="Classify each company by industry sector.",
    output_model=IndustryClassification,
    batch_size=10,
    max_retries=3,
    stop_on_exhaustion=True,
)

try:
    result = await job.arun(model, data=companies)
    show_result("Classification (stop_on_exhaustion=True)", result)
    print("No exception raised — all batches succeeded.")
except SmeltExhaustionError as e:
    print(f"SmeltExhaustionError: {e}")
    print(f"Partial results: {len(e.partial_result.data)} rows succeeded")
    print(f"Errors: {len(e.partial_result.errors)} batches failed")

---
## 8. Company Summaries — Full Dataset

Full run with multi-batch concurrency and detailed metrics.

In [ ]:
job = Job(
    prompt="Create a concise structured summary for each company. "
    "Calculate the approximate age based on the founded year (current year is 2026).",
    output_model=CompanySummary,
    batch_size=4,
    concurrency=3,
    stop_on_exhaustion=False,
)

result = await job.arun(model, data=companies)
show_result("Company Summaries (batch=4, conc=3)", result)

print(f"\nMetrics breakdown:")
print(f"  Total batches: {result.metrics.total_batches}")
print(f"  Input tokens:  {result.metrics.input_tokens:,}")
print(f"  Output tokens: {result.metrics.output_tokens:,}")
print(f"  Wall time:     {result.metrics.wall_time_seconds:.2f}s")

---
## Summary

All tests complete. Smelt successfully:
- Validates setup with a quick single-row test (`job.atest()`)
- Transforms structured data through Gemini 3 Flash
- Returns strictly typed Pydantic models
- Handles batching and concurrency
- Provides detailed metrics (tokens, timing, retries)
- Gracefully handles errors with `stop_on_exhaustion`

> **Note:** Jupyter notebooks run inside an event loop, so all cells use `await job.arun()` / `await job.atest()`.
> Use `job.run()` / `job.test()` in regular Python scripts.