In [1]:
import logging
import os
from datetime import datetime

import pandas as pd
from dotenv import load_dotenv
from pydantic import BaseModel
from tqdm import tqdm


logging.basicConfig(level=logging.WARNING)
load_dotenv()
df = pd.read_excel("Wedding Values.xlsx")
df.columns

Index(['City', 'Zip Code ', 'State', 'Country ', 'Email', 'Phone Number',
       'Price', 'Price Breakdown', 'Menu Breakdown', 'Bar Breakdown',
       'Groom and Bridal Set-Up', 'Ceremony Cost ', 'Guest Capacity ',
       'Outside Food', 'Outside Alcohol', 'Outside Dessert ',
       'Outside Wedding Coordinator', 'Outside Photographer ',
       'Package Approach', 'Pricing Transparency ', 'Reception or Ceremony',
       'Style', 'Indoor/Outdoor', 'Deposit and Payment Plans ', 'Privacy',
       'Accommodations ', 'Photography Score ', 'Environmental ',
       'What Time Does the Party Need to Stop', 'Late Night Eats ',
       'General Vibe', 'Top Choices ', 'Menu Choices '],
      dtype='object')

In [2]:
from wedding_venue_models import *

In [3]:
models = [
    WeddingContactInfo,
    FoodBreakdown,
    WeddingFoodInfo,
    BarBreakdown,
    WeddingVenuePricingSummary,
    WeddingVenueStyle,
    WeddingVenueOther,
]

In [4]:
from itertools import chain
from typing import get_args

import numpy as np
from openpyxl.styles import Font, PatternFill


def assert_keys_in_readable_columns(
    models: list[type[BaseModel]], readable_columns: dict[str, str]
) -> None:
    keys = list(
        chain.from_iterable(
            [
                [
                    model.__name__ + "_" + key
                    for key in np.array(list(model.model_fields.keys()))
                ]
                for model in models
            ]
        )
    )
    keys = [
        str(key).replace("_tiers", "_summary").replace("_options", "_summary")
        for key in keys
    ]
    keys
    assert set(keys) - set(readable_columns.keys()) == set(), (
        f"missing keys in readable_columns: {set(keys) - set(readable_columns.keys())}"
    )


readable_columns = {
    "venue": "wedding venue",
    "WeddingVenuePricingSummary_price": "price per guest",
    "WeddingVenuePricingSummary_base_prices": "price breakdown",
    "WeddingVenuePricingSummary_taxes_and_fees": "price breakdown taxes and fees",
    "WeddingVenuePricingSummary_flexibility": "venue customization flexibility",
    # "WeddingPriceInfo_option": "options",
    "WeddingContactInfo_city": "city",
    "WeddingContactInfo_state": "state",
    "WeddingContactInfo_country": "country",
    "WeddingContactInfo_zip_code": "zip code",
    "WeddingContactInfo_email": "email",
    "WeddingContactInfo_website": "website",
    "WeddingContactInfo_phone": "phone",
    "WeddingContactInfo_facebook": "facebook",
    "WeddingContactInfo_instagram": "instagram",
    # "WeddingVenuePricingSummary_summary": "venue pricing summary",
    "FoodBreakdown_summary": "food menu breakdown",
    "FoodBreakdown_flexibility": "food menu flexibility",
    "BarBreakdown_summary": "bar menu breakdown",
    "BarBreakdown_flexibility": "bar menu flexibility",
    "WeddingVenuePricingSummary_pricing_transparency": "pricing transparency",
    "WeddingVenuePricingSummary_deposit_and_payment_plans": "deposit and payment plans",
    "WeddingVenueStyle_style": "style",
    "WeddingVenueStyle_indoor_outdoor": "indoor/outdoor seating",
    "WeddingVenueStyle_privacy": "privacy",
    "WeddingVenueStyle_accommodations": "accommodations",
    "WeddingVenueStyle_environmental": "environmental",
    "WeddingVenueStyle_general_vibe": "general vibe",
    "WeddingFoodInfo_east_asian_food": "serves east asian food",
    "WeddingFoodInfo_gluten_free_food": "serves gluten free food",
    "WeddingFoodInfo_halal_food": "serves halal food",
    "WeddingFoodInfo_indian_food": "serves indian food",
    "WeddingFoodInfo_kosher_food": "serves kosher food",
    "WeddingFoodInfo_late_night_food": "serves late night food",
    "WeddingFoodInfo_other_ethnic_food_style": "serves other ethnic food",
    "WeddingFoodInfo_outside_alcohol_allowed": "allows outside alcohol",
    "WeddingFoodInfo_outside_dessert_allowed": "allows outside dessert",
    "WeddingFoodInfo_outside_food_allowed": "allows outside food",
    "WeddingVenueOther_guest_capacity": "guest capacity",
    "WeddingVenueOther_what_time_does_the_party_need_to_stop": "what time does the party need to stop",
    "WeddingVenueOther_outside_photographer": "allows outside photographer",
    "WeddingVenueOther_package_approach": "package approach",
    "WeddingVenueOther_outside_wedding_coordinator": "allows outside wedding coordinator",
    "WeddingVenueOther_reception_or_ceremony": "reception or ceremony",
    "WeddingVenueOther_top_choices": "top choices",
}


assert_keys_in_readable_columns(models, readable_columns)


def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict:
    """Flatten a nested dictionary by concatenating nested keys with a separator.

    Parameters
    ----------
    d : dict
        The dictionary to flatten
    parent_key : str, optional
        The parent key for nested dictionaries, by default ""
    sep : str, optional
        The separator to use between nested keys, by default "_"

    Returns
    -------
    dict
        A flattened dictionary with concatenated keys

    Examples
    --------
    >>> d = {"a": 1, "b": {"c": 2, "d": {"e": 3}}}
    >>> flatten_dict(d)
    {'a': 1, 'b_c': 2, 'b_d_e': 3}
    """
    items: list = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


class WeddingVenue:
    def __init__(self, venue_name: str, raw: list[BaseModel]):
        item_dict = {"venue": venue_name}
        for item in raw:
            obj_dict = item.model_dump()
            if "tiers" in obj_dict:
                obj_dict.pop("tiers")
                obj_dict["summary"] = item.to_string()

            if "options" in obj_dict:
                obj_dict.pop("options")
                obj_dict["summary"] = item.to_string()

            item_dict[item.__class__.__name__] = obj_dict

        self.df = pd.DataFrame()
        self.update(item_dict)

    def add_price_breakdown(self) -> None:
        self.df["price breakdown"] = self.df[
            [
                "price breakdown",
                "price breakdown taxes and fees",
            ]
        ].apply(
            lambda x: f"""
                base prices: {x.iloc[0]}
                taxes and fees: {x.iloc[1]}
                """,
            axis=1,
        )
        del self.df["price breakdown taxes and fees"]

    def add_bar_flexibility(self) -> None:
        self.df["bar menu flexibility info"] = self.df["bar menu flexibility"]
        args = get_args(BarBreakdown.model_fields["flexibility"].annotation)
        self.df["bar menu flexibility"] = self.df["bar menu flexibility"].map(
            lambda x: len(args) - args.index(x)
        )

    def add_indoor_outdoor_seating(self) -> None:
        self.df["indoor/outdoor seating info"] = self.df["indoor/outdoor seating"]
        args = get_args(WeddingVenueStyle.model_fields["indoor_outdoor"].annotation)
        self.df["indoor/outdoor seating"] = self.df["indoor/outdoor seating"].map(
            lambda x: args.index(x) if args.index(x) != len(args) else "X"
        )

    def update(self, d: dict) -> None:
        self.df = pd.DataFrame(flatten_dict(d), index=[0])
        self.rename_columns()
        self.add_price_breakdown()
        self.add_bar_flexibility()
        self.add_indoor_outdoor_seating()

    def _repr_html_(self) -> str:
        return self.df._repr_html_()

    def rename_columns(self) -> None:
        """Rename and reorder columns based on readable_columns dictionary."""
        self.df.rename(columns=readable_columns, inplace=True)
        self.df.set_index("wedding venue", inplace=True)
        ordered_columns = [
            col for col in readable_columns.values() if col != "wedding venue"
        ]
        self.df = self.df.reindex(columns=ordered_columns)

    def to_excel(self, name: str = "wedding_venue.xlsx"):
        if not name.endswith(".xlsx"):
            name = f"{name}.xlsx"
        if os.path.exists(name):
            name = f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
        print("saving to: ", name)
        with pd.ExcelWriter(name, engine="openpyxl") as writer:
            self.df.to_excel(writer, sheet_name="Venue Options")

            worksheet = writer.sheets["Venue Options"]

            header_fill = PatternFill(
                start_color="B3E5FC", end_color="B3E5FC", fill_type="solid"
            )
            header_font = Font(bold=True)

            for col in range(1, len(df.columns) + 1):
                cell = worksheet.cell(row=1, column=col)
                cell.fill = header_fill
                cell.font = header_font

            for col in worksheet.columns:
                max_length = 0
                column = col[0].column_letter
                for cell in col:
                    if cell.value:
                        max_length = max(max_length, len(str(cell.value)))
                adjusted_width = max_length + 2
                worksheet.column_dimensions[column].width = min(adjusted_width, 50)

            worksheet.auto_filter.ref = worksheet.dimensions

        self.df.to_excel(writer, sheet_name="Venue Options")
        return self

    def __add__(self, other: "WeddingVenue") -> "WeddingVenue":
        self.df = pd.concat([self.df, other.df])
        return self

In [5]:
# Setup
from pathlib import Path

import openai
from openai import OpenAI
from google import genai
from typing import Literal


class Response:
    def __init__(self, ai: Literal["openai", "google"]):
        self.ai = ai
        if ai == "openai":
            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            self.response = self._response_openai
        if ai == "google":
            self.client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
            self.response = self._response_google

    def _response_openai(
        self, model, system_prompt, user_prompt, response_format, temperature
    ):
        client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        completion = client.beta.chat.completions.parse(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {
                    "role": "user",
                    "content": user_prompt,
                },
            ],
            response_format=response_format,
            temperature=temperature,
        )
        return completion.choices[0].message.parsed

    def _response_google(
        self, model, system_prompt, user_prompt, response_format, temperature
    ):
        client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
        response = client.models.generate_content(
            model=model,
            contents=f"{system_prompt}\n{user_prompt}",
            config={
                "response_mime_type": "application/json",
                "response_schema": response_format,
            },
        )
        return response.parsed


load_dotenv()

venue_data = []
md_path = Path("test_md")
if not md_path.exists():
    raise FileNotFoundError("Directory 'test_md' not found")

md_files = list(md_path.glob("*.md"))[-5:]

# openAI
# response = Response(ai="openai")
# gemini
response = Response(ai="google")

venue_data = None
for file in tqdm(md_files, desc="Processing venues", unit="file"):
    tqdm.write(f"Processing: {file.name}")
    with open(file, "r", encoding="utf-8") as f:
        md_content = f.read()

    raw = []
    venue_name = file.stem
    venue_dict = {"name": venue_name}

    for model_class in models:
        system_prompt = create_system_prompt(model_class)
        if response.ai == "openai":
            if model_class == WeddingVenuePricingSummary:
                ai_model = "o3-mini"
                temperature = openai.NOT_GIVEN
            else:
                ai_model = "gpt-4o-mini"
                temperature = 0
            obj = response.response(
                model=ai_model,
                system_prompt=system_prompt,
                user_prompt=f"Extract venue information from this text about '{venue_name}':\n\n{md_content}",
                response_format=model_class,
                temperature=temperature,
            )
        elif response.ai == "google":
            temperature = 0
            ai_model = "gemini-2.0-flash-001"
            obj = response.response(
                model=ai_model,
                system_prompt=system_prompt,
                user_prompt=f"Extract venue information from this text about '{venue_name}':\n\n{md_content}",
                response_format=model_class,
                temperature=temperature,
            )
        raw.append(obj)
        if hasattr(obj, "to_string"):
            string_summary = obj.to_string()
            venue_dict[f"{model_class.__name__}_summary"] = string_summary
        else:
            venue_dict[f"{model_class.__name__}_summary"] = obj.model_dump()

        tqdm.write(f"✓ Processed {model_class.__name__} for: {venue_name}")

    if venue_data is None:
        venue_data = WeddingVenue(venue_name, raw)
    else:
        try:
            venue_data += WeddingVenue(venue_name, raw)
        except Exception as e:
            print(f"✗ Error adding {venue_name}: {e}")
            # venue_data = None
now = datetime.now().strftime("%Y%m%d%")
if venue_data is not None:
    venue_data.to_excel(f"all_info_{now}.xlsx")
else:
    print("⚠️ No venue data processed.")

Processing venues:   0%|          | 0/5 [00:00<?, ?file/s]

Processing: Aliso Viejo Wedgewood.md


Processing venues:   0%|          | 0/5 [00:02<?, ?file/s]

✓ Processed WeddingContactInfo for: Aliso Viejo Wedgewood


Processing venues:   0%|          | 0/5 [00:05<?, ?file/s]

✓ Processed FoodBreakdown for: Aliso Viejo Wedgewood


Processing venues:   0%|          | 0/5 [00:07<?, ?file/s]

✓ Processed WeddingFoodInfo for: Aliso Viejo Wedgewood


Processing venues:   0%|          | 0/5 [00:09<?, ?file/s]

✓ Processed BarBreakdown for: Aliso Viejo Wedgewood


Processing venues:   0%|          | 0/5 [00:11<?, ?file/s]

✓ Processed WeddingVenuePricingSummary for: Aliso Viejo Wedgewood


Processing venues:   0%|          | 0/5 [00:13<?, ?file/s]

✓ Processed WeddingVenueStyle for: Aliso Viejo Wedgewood


Processing venues:  20%|██        | 1/5 [00:14<00:59, 14.80s/file]

✓ Processed WeddingVenueOther for: Aliso Viejo Wedgewood
Processing: Alcazar Palm Springs.md


Processing venues:  20%|██        | 1/5 [00:15<00:59, 14.80s/file]

✓ Processed WeddingContactInfo for: Alcazar Palm Springs


Processing venues:  20%|██        | 1/5 [00:17<00:59, 14.80s/file]

✓ Processed FoodBreakdown for: Alcazar Palm Springs


Processing venues:  20%|██        | 1/5 [00:18<00:59, 14.80s/file]

✓ Processed WeddingFoodInfo for: Alcazar Palm Springs


Processing venues:  20%|██        | 1/5 [00:20<00:59, 14.80s/file]

✓ Processed BarBreakdown for: Alcazar Palm Springs


Processing venues:  20%|██        | 1/5 [00:23<00:59, 14.80s/file]

✓ Processed WeddingVenuePricingSummary for: Alcazar Palm Springs


Processing venues:  20%|██        | 1/5 [00:25<00:59, 14.80s/file]

✓ Processed WeddingVenueStyle for: Alcazar Palm Springs


Processing venues:  40%|████      | 2/5 [00:26<00:39, 13.10s/file]

✓ Processed WeddingVenueOther for: Alcazar Palm Springs
Processing: Almansor Court.md


Processing venues:  40%|████      | 2/5 [00:28<00:39, 13.10s/file]

✓ Processed WeddingContactInfo for: Almansor Court


Processing venues:  40%|████      | 2/5 [00:28<00:42, 14.25s/file]


ClientError: 429 RESOURCE_EXHAUSTED. {'error': {'code': 429, 'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits.', 'status': 'RESOURCE_EXHAUSTED', 'details': [{'@type': 'type.googleapis.com/google.rpc.QuotaFailure', 'violations': [{'quotaMetric': 'generativelanguage.googleapis.com/generate_content_free_tier_requests', 'quotaId': 'GenerateRequestsPerMinutePerProjectPerModel-FreeTier', 'quotaDimensions': {'model': 'gemini-2.0-flash', 'location': 'global'}, 'quotaValue': '15'}]}, {'@type': 'type.googleapis.com/google.rpc.Help', 'links': [{'description': 'Learn more about Gemini API quotas', 'url': 'https://ai.google.dev/gemini-api/docs/rate-limits'}]}, {'@type': 'type.googleapis.com/google.rpc.RetryInfo', 'retryDelay': '32s'}]}}

In [8]:
model_class

wedding_venue_models.WeddingVenueOther