In [None]:
# super7_vertex.py

import json
import re
from typing import Dict, Any, List, Optional

import pandas as pd
from google import genai
from google.genai import types


# ========= CONFIGURATION =========

PROJECT_ID = "YOUR_GCP_PROJECT_ID"
LOCATION = "us-east4"
MODEL_NAME = "gemini-2.0-flash"   # adjust if needed

# Columns you EXPECT to see in the CSV (Super-7 style). Others will be ignored but can be added.
SUPER7_INPUT_COLS = [
    "company_name",
    "country",
    "state_province",
    "city",
    "street_address",
    "postal_code",
    "phone_number",
    "additional_info",   # optional catch-all column
]

SUPER7_FIELDS = [
    "company_name",
    "trade_style_name",
    "country",
    "street_address",
    "postal_code",
    "city",
    "state_province",
    "website",
    "phone_number",
]

SOURCE_TYPE_WEIGHTS = {
    "government_registry": 1.0,
    "official_website": 0.9,
    "business_directory": 0.75,
    "other": 0.6,
}


# ========= CLIENT SETUP =========

def init_gemini_client() -> genai.Client:
    """
    Initialize the Gemini client for Vertex AI with Google Search grounding support.
    """
    client = genai.Client(
        vertexai=True,
        project=PROJECT_ID,
        location=LOCATION,
    )
    return client


# ========= PROMPT BUILDING =========

def build_company_context(row: Dict[str, Any]) -> str:
    """
    Build a textual context block from whatever Super-7 style fields exist in the row.
    Only non-empty cells are included.
    """
    lines = []
    for col in SUPER7_INPUT_COLS:
        if col in row and pd.notna(row[col]) and str(row[col]).strip():
            lines.append(f"{col}: {row[col]}")
    if not lines and "company_name" in row:
        lines.append(f"company_name: {row['company_name']}")
    return "\n".join(lines)


def build_super7_prompt(row: Dict[str, Any]) -> str:
    """
    Build the main extraction prompt using your Super-7 style architecture,
    but phrased for Gemini + Google Search grounding.
    """
    company_context = build_company_context(row)

    prompt = f"""
You are a highly reliable company information extractor that uses the Google Search tool
to open the most relevant and trustworthy pages.

You are given structured input fields for a company. Use them ONLY as hints.
You MUST verify everything using grounded search results.

# Input company context (from CSV)
{company_context}

# Task
1. Use the Google Search tool to:
   - Identify the official website of the company (if it exists).
   - Identify any government or business-registry pages about this entity.
   - Optionally, look at high-quality business directories if they help confirm details.

2. From those sources, extract the following fields whenever possible:
   - company_name             : Official legal or primary operating name
   - trade_style_name         : Trading name / doing-business-as name (if any)
   - country
   - street_address
   - postal_code
   - city
   - state_province
   - website                  : Full URL of the official website
   - phone_number             : Primary business phone
   - line_of_business         : 1–3 sentence description of what the company does
   - reference_url            : Up to 5 key URLs you used

3. For EACH extracted field above, you MUST also provide:
   - source_url   : The single best page where you found or confirmed this field.
   - source_type  : One of:
                    "government_registry",
                    "official_website",
                    "business_directory",
                    "other"
   - comment      : Short explanation of how or where you inferred the value
                    (e.g., "From contact page of official website", "From national business registry", etc.)

4. If you cannot find a field even after checking multiple sources:
   - Set the value for that field to null.
   - Still include metadata with a comment "not found" and source_type "other".

5. Confidence:
   - base_model_confidence: a numeric value between 0 and 1 representing how confident you are overall.

# Output format (JSON ONLY)
Return the result STRICTLY as a single JSON array with ONE object, like:

[
  {{
    "company_name": "...",
    "trade_style_name": "...",
    "country": "...",
    "street_address": "...",
    "postal_code": "...",
    "city": "...",
    "state_province": "...",
    "website": "...",
    "phone_number": "...",
    "line_of_business": "...",
    "reference_url": ["...", "..."],
    "field_metadata": {{
      "company_name": {{
        "source_url": "...",
        "source_type": "government_registry | official_website | business_directory | other",
        "comment": "..."
      }},
      "trade_style_name": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "country": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "street_address": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "postal_code": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "city": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "state_province": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "website": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }},
      "phone_number": {{
        "source_url": "...",
        "source_type": "...",
        "comment": "..."
      }}
    }},
    "base_model_confidence": 0.0
  }}
]

Rules:
- JSON only. Do NOT include explanations outside of the JSON.
- Do NOT wrap the JSON in markdown (no ```json```).
- If a field is missing or not found, set its value to null but still include metadata.
"""
    return prompt.strip()


# ========= MODEL CALL + PARSING =========

def clean_json_from_text(raw_text: str) -> str:
    """
    Remove surrounding markdown fences or stray text and try to isolate the JSON.
    """
    text = raw_text.strip()

    # Strip common ```json fences
    text = re.sub(r"^```json", "", text, flags=re.IGNORECASE).strip()
    text = re.sub(r"^```", "", text).strip()
    text = re.sub(r"```$", "", text).strip()

    # Heuristic: JSON should start with [ or {.
    first_bracket = min(
        [i for i in [text.find("["), text.find("{")] if i != -1] or [0]
    )
    text = text[first_bracket:]
    return text


def call_gemini_super7(
    client: genai.Client,
    prompt: str
) -> Optional[Dict[str, Any]]:
    """
    Call Gemini with Google Search grounding and parse the Super-7 JSON.
    Returns the single object from the JSON array, or None on error.
    """
    grounding_tool = types.Tool(
        google_search=types.GoogleSearch()
    )

    config = types.GenerateContentConfig(
        temperature=0.0,
        tools=[grounding_tool],
    )

    try:
        response = client.models.generate_content(
            model=MODEL_NAME,
            contents=prompt,
            config=config,
        )
        raw_text = response.text or ""
        cleaned = clean_json_from_text(raw_text)
        data = json.loads(cleaned)

        if isinstance(data, list) and data:
            return data[0]
        elif isinstance(data, dict):
            # If model returned single object instead of array
            return data
        else:
            print("Unexpected JSON format from model:", data)
            return None

    except json.JSONDecodeError as e:
        print("Failed to decode JSON from model:", e)
        print("Raw text was:\n", raw_text)
        return None
    except Exception as e:
        print("Error calling Gemini:", e)
        return None


# ========= SCORING =========

def compute_field_score(
    field_name: str,
    obj: Dict[str, Any]
) -> float:
    """
    Compute a numeric score for a field using:
      - source_type weight (gov > official > directory > other)
      - base_model_confidence (0–1)
    """
    meta = (obj.get("field_metadata") or {}).get(field_name) or {}
    source_type = (meta.get("source_type") or "other").strip()
    base_conf = float(obj.get("base_model_confidence") or 0.0)

    weight = SOURCE_TYPE_WEIGHTS.get(source_type, SOURCE_TYPE_WEIGHTS["other"])
    # Simple multiplicative scoring; you can make this more complex later
    return round(base_conf * weight, 4)


def attach_scores(obj: Dict[str, Any]) -> Dict[str, Any]:
    """
    Attach per-field scores and an overall score to the result object.
    """
    field_scores = {}
    for f in SUPER7_FIELDS:
        field_scores[f] = compute_field_score(f, obj)

    # Overall score: average of all field scores
    if field_scores:
        overall = sum(field_scores.values()) / len(field_scores)
    else:
        overall = 0.0

    obj["field_scores"] = field_scores
    obj["overall_confidence_score"] = round(overall, 4)
    return obj


# ========= CSV PIPELINE =========

def process_company_row(
    client: genai.Client,
    row: pd.Series
) -> Dict[str, Any]:
    """
    Process a single CSV row: build prompt, call Gemini, attach scores,
    and include original input for traceability.
    """
    row_dict = row.to_dict()
    prompt = build_super7_prompt(row_dict)
    result = call_gemini_super7(client, prompt)

    if result is None:
        # In case of hard failure, return minimal object with error info.
        return {
            "input": row_dict,
            "error": "model_failed_or_invalid_json"
        }

    result = attach_scores(result)
    result["input"] = row_dict
    return result


def process_csv(
    input_csv_path: str,
    output_jsonl_path: Optional[str] = None,
    output_csv_path: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """
    Read a CSV (with Super-7 style columns), run the Vertex-AI+Grounding
    Super-7 extractor for each row, and optionally write JSONL/CSV outputs.

    Returns a list of result dicts (one per row).
    """
    client = init_gemini_client()
    df = pd.read_csv(input_csv_path)

    results: List[Dict[str, Any]] = []

    for idx, row in df.iterrows():
        print(f"Processing row {idx+1}/{len(df)}: {row.get('company_name', '')}")
        res = process_company_row(client, row)
        res["row_index"] = int(idx)
        results.append(res)

    # Optional: write JSONL
    if output_jsonl_path:
        with open(output_jsonl_path, "w", encoding="utf-8") as f:
            for r in results:
                f.write(json.dumps(r, ensure_ascii=False) + "\n")

    # Optional: write flat CSV (Super-7 + scores + comments)
    if output_csv_path:
        flat_rows = []
        for r in results:
            base = {f"input_{k}": v for k, v in (r.get("input") or {}).items()}
            if "error" in r:
                base["error"] = r["error"]
                flat_rows.append(base)
                continue

            # core fields
            for f in SUPER7_FIELDS + ["line_of_business"]:
                base[f] = r.get(f)

                meta = (r.get("field_metadata") or {}).get(f) or {}
                base[f"{f}_source_url"] = meta.get("source_url")
                base[f"{f}_source_type"] = meta.get("source_type")
                base[f"{f}_comment"] = meta.get("comment")

                score_dict = r.get("field_scores") or {}
                base[f"{f}_score"] = score_dict.get(f)

            base["reference_url"] = ";".join(r.get("reference_url") or [])
            base["base_model_confidence"] = r.get("base_model_confidence")
            base["overall_confidence_score"] = r.get("overall_confidence_score")

            flat_rows.append(base)

        out_df = pd.DataFrame(flat_rows)
        out_df.to_csv(output_csv_path, index=False)

    return results


# ========= EXAMPLE MAIN (for CLI testing) =========

if __name__ == "__main__":
    # Example usage from terminal; in Streamlit, call process_csv / process_company_row directly.
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--input_csv", required=True)
    parser.add_argument("--out_jsonl", default="super7_output.jsonl")
    parser.add_argument("--out_csv", default="super7_output_flat.csv")
    args = parser.parse_args()

    process_csv(
        input_csv_path=args.input_csv,
        output_jsonl_path=args.out_jsonl,
        output_csv_path=args.out_csv,
    )

    print("Done.")
