
- liens qui ne sont pas des pdfs -> pointent vers pages ou autres pdfs (https://www.ahv-iv.ch/fr/M%C3%A9mentos/Prestations-de-lAI)

- better table parsing with (careful with headers !!!)
- envoyer structured_content à flash pour l'injection de liens
- 2e passe avec flash pour traduire les descriptions d'images ou alt-text pas dans la bonne langue
  
### data augmentation
- extract links and create KG triplets
- hyq/declarative hyq
- extract topic/subtopics/etc.
- maybe do summary here? !!!!!

In [1]:
import os
import re
import json
from typing import List, Dict, Tuple
from dotenv import load_dotenv
from google import genai
from google.genai import types
import pathlib
from pydantic import BaseModel
import glob
import fitz  # PyMuPDF
from itertools import groupby
from operator import itemgetter

# Env variables

In [2]:
load_dotenv()

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)

In [3]:
lang = "f"
PDF_PATH = os.path.join("../pdfs", lang)
OUTPUT_PATH = os.path.join("../parsed_pdfs", lang)

# VLM parsing

In [4]:
with open("prompts/parse_pdf_to_text.txt", "r") as f:
    prompt = f.read()

class SectionContent(BaseModel):
    page_number: List[int]
    header: str
    content: str

class StructuredContent(BaseModel):
    section_header: str
    section_content: List[SectionContent]
    
class PageContent(BaseModel):
    page_number: int
    page_content: str

class ParsedDocument(BaseModel):
    content: List[PageContent]
    structured_content: List[StructuredContent]
    summary: str

In [5]:
pdf_files = [f for f in os.listdir(PDF_PATH) if f.endswith(".pdf")]

In [6]:
pdf_files = ["2_01_f.pdf"]

In [7]:
for fn in pdf_files:
    print(f"Processing: {fn}")
    
    pdf_path = os.path.join(PDF_PATH, fn)
    pdf_bytes = pathlib.Path(pdf_path).read_bytes()

    # Create output filename using the same name as the PDF but with .txt extension
    output_file = os.path.join(OUTPUT_PATH, fn.replace(".pdf", ".txt"))

    # VLM call
    response = client.models.generate_content(
      model="gemini-2.5-pro-exp-03-25",
      contents=[
          types.Part.from_bytes(
            data=pdf_bytes,
            mime_type='application/pdf',
          ),
          prompt],
      config={
            'response_mime_type': 'application/json',
            'response_schema': ParsedDocument,
        }
      )

Processing: 2_01_f.pdf


# Post-processing

In [8]:
def extract_pages_and_links_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    link_data = []
    n_pages = 0
    for page_n, page in enumerate(doc):
        n_pages += 1
        links = page.get_links()
        for link in links:
            if 'uri' in link:
                rect = link['from']
                text = page.get_textbox(rect).strip()
                link_data.append((page_n+1, text, link['uri']))
    return n_pages, link_data

def replace_special_chars(text: str) -> str:
    replacements = {
        "ß": "ss",
    }
    for orig, repl in replacements.items():
        text = text.replace(orig, repl)
    return text

In [9]:
with open("prompts/link_injection.txt", "r") as f:
    prompt = f.read()
    
class MarkdownContentWithUrl(BaseModel):
        current_page_markdown_content: str
 

In [10]:
pdf_path = os.path.join(PDF_PATH, fn)

n_pages, links_to_inject = extract_pages_and_links_from_pdf(pdf_path)

links_to_inject

[(3, '2.03 -', 'https://www.ahv-iv.ch/p/2.03.f'),
 (3,
  'Cotisations des personnes sans activité lucrative à l’AVS, à l’AI et aux APG)',
  'https://www.ahv-iv.ch/p/2.03.f'),
 (3,
  '2.08 - Cotisations à l’assurance-chômage',
  'https://www.ahv-iv.ch/p/2.08.f'),
 (5, '2', 'https://www.ahv-iv.ch/p/2.07.f'),
 (5,
  '2.07 - Procédures de décompte simplifiées pour les employeurs)',
  'https://www.ahv-iv.ch/p/2.07.f'),
 (8,
  'e, si elles ne sont pas exceptées du salaire déterminant \n2.05 - Rémunérations versées lors de la cessation des \nvail) ;',
  'https://www.ahv-iv.ch/p/2.05.f'),
 (8,
  '(voir mémento 2.0\nrapports de travail',
  'https://www.ahv-iv.ch/p/2.05.f'),
 (8,
  'de travail pour cause d’intempéries au sens de l’AC (voir\n2.11 - Obligation de cotiser sur les indemnités en cas de\nde l’horaire de travail ou d’intempéries) ;',
  'https://www.ahv-iv.ch/p/2.11.f'),
 (8,
  'mémento 2.11 - Obligation de cotiser sur les in\nréduction de l’horaire de travail ou d’intempéries',
  'http

In [11]:
# Sort the data by page_number (first element of tuple)
links_to_inject.sort(key=itemgetter(0))

# Group by page_number
grouped_links = {
    key: list(group)
    for key, group in groupby(links_to_inject, key=itemgetter(0))
}

In [None]:
for page_number, links in grouped_links.items():

    print(page_number)
    print("-------------------------")

    current_page = next((p for p in response.parsed.content if p.page_number == page_number), None)
    if not current_page: 
        break

    current_page.page_content = replace_special_chars(current_page.page_content)
    print(current_page.page_content)
    print("-------------------------")

    links_to_inject = [(x[1], x[2]) for x in links]
    print(links_to_inject)
    print("-------------------------")
    
    res = client.models.generate_content(
      model="gemini-2.5-flash-preview-04-17",
      contents=[
          prompt.format(
              links_to_inject=links_to_inject,
              current_page_markdown_content=current_page.page_content,
          )],
      config={
            'response_mime_type': 'application/json',
            'response_schema': MarkdownContentWithUrl,
        }
      )

    print(res.parsed.current_page_markdown_content)
    print("******************************************")
    print("******************************************")
    print("******************************************")


# Calculate Metadata

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# NORMALISED PRICE TABLE
# ──────────────────────────────────────────────────────────────────────────────
# • Top-level key  = model name
# • Second level   = *billing group* (input / output / cached / cached_hourly)
# • Third level    = one of
#     • {"flat":  rate}                         → same rate for all tokens
#     • {"tiers": {low: rate, high: rate}, "threshold": N}   → tiered pricing
#     • {"by_key": {sub_key: rate, …}}         → keyed pricing (modalities,  
#       thinking/non-thinking, etc.)

pricing: dict[str, dict] = {
    "gemini-2.5-pro-exp-03-25": {
        "input":   {"tiers": {"low": 1.25e-6,  "high": 2.50e-6},  "threshold": 200_000},
        "output":  {"tiers": {"low": 10.00e-6, "high": 15.00e-6}, "threshold": 200_000},
        "cached":  {"tiers": {"low": 0.31e-6,  "high": 0.625e-6}, "threshold": 200_000},
        "cached_hourly": {"flat": 4.50e-6},          # separate if you bill by hour
    },

    "gemini-2.5-flash-preview-04-17": {
        "input":   {"flat": 0.15e-6},                      # same for every token
        "output":  {"by_key": {                           # keyed by token *type*
            "non_thinking": 0.60e-6,                      # (= candidates)
            "thinking":     3.50e-6,                      # (= thoughts)
        }},
        "cached":  {"by_key": {"context": 0.0, "storage": 0.0}},
    },
}

# ──────────────────────────────────────────────────────────────────────────────
# TOKEN EXTRACTION  (UNCHANGED – SHOWN HERE FOR COMPLETENESS)
# ──────────────────────────────────────────────────────────────────────────────
def get_tokens(response) -> Dict:
    tokens = {
        "input": {
            "prompt": response.usage_metadata.prompt_token_count,
            "prompt_details": {p.modality.value:p.token_count for p in response.usage_metadata.prompt_tokens_details},
        },
        "output": {
            "candidates": response.usage_metadata.candidates_token_count,
            "thoughts": response.usage_metadata.thoughts_token_count,
        },
        "cached": response.usage_metadata.cached_content_token_count or 0,
        "total": response.usage_metadata.total_token_count,
    }
    return tokens

# ──────────────────────────────────────────────────────────────────────────────
# GENERIC PRICE CALCULATOR
# ──────────────────────────────────────────────────────────────────────────────
def _tier(prompt_tokens: int, cfg: dict) -> str:
    """Return 'low' or 'high' if tiered, else ''."""
    return (
        "low"
        if prompt_tokens <= cfg["threshold"]
        else "high"
    )

def _cost_for_group(
    token_count: int | dict[str, int],
    cfg: dict,
    prompt_tokens: int | None = None,
) -> float:
    # ─── 1. Flat rate ────────────────────────────────────────────────────────
    if "flat" in cfg:
        rate = cfg["flat"]
        return token_count * rate if isinstance(token_count, int) else sum(
            c * rate for c in token_count.values()
        )

    # ─── 2. Tiered rate ──────────────────────────────────────────────────────
    if "tiers" in cfg:
        tier  = "low" if prompt_tokens <= cfg["threshold"] else "high"
        rate  = cfg["tiers"][tier]
        return token_count * rate if isinstance(token_count, int) else sum(
            c * rate for c in token_count.values()
        )

    # ─── 3. Keyed rate (by_key) ──────────────────────────────────────────────
    if "by_key" in cfg:
        rates = cfg["by_key"]

        # 3a. Token count is single int → all tokens billed with *one* rate
        if isinstance(token_count, int):
            unique_rates = set(rates.values())
            if len(unique_rates) != 1:
                raise ValueError(
                    "Token total given, but different rates per key: "
                    f"{rates}. Pass a breakdown or merge the rates."
                )
            rate = unique_rates.pop()     # they’re all the same
            return token_count * rate

        # 3b. Token count is already per-key dict
        return sum(token_count.get(k, 0) * r for k, r in rates.items())

    raise ValueError(f"Unsupported pricing config: {cfg!r}")

# ---------------------------------------------------------------------------

def compute_price(
    model_name: str,
    tokens: Dict,
    price_table: dict[str, dict] = pricing,
) -> Dict[str, float]:
    """
    Generic price calculator that supports flat, tiered, and keyed pricing.
    """
    cfg = price_table[model_name]
    prompt_tokens = tokens["input"]["prompt"]

    # INPUT ──────────────────────────────────────────────────────────────────
    input_tokens = (
        tokens["input"]["prompt_details"]          # per-modality
        if "by_key" in cfg["input"]
        else prompt_tokens                         # single number
    )
    input_cost = _cost_for_group(input_tokens, cfg["input"], prompt_tokens)

    # OUTPUT ─────────────────────────────────────────────────────────────────
    if "by_key" in cfg["output"]:                 # thinking / non-thinking
        output_tokens = {
            "non_thinking": tokens["output"]["candidates"],
            "thinking":     tokens["output"]["thoughts"],
        }
    else:                                         # tiered or flat
        output_tokens = (
            tokens["output"]["candidates"]
            + tokens["output"]["thoughts"]
        )
    output_cost = _cost_for_group(output_tokens, cfg["output"], prompt_tokens)

    # CACHED ────────────────────────────────────────────────────────────────
    cached_tokens = tokens["cached"]
    cached_cost = _cost_for_group(cached_tokens, cfg["cached"], prompt_tokens)

    # TOTAL ─────────────────────────────────────────────────────────────────
    total_cost = input_cost + output_cost + cached_cost
    return {
        "input_cost_usd":  round(input_cost,  6),
        "output_cost_usd": round(output_cost, 6),
        "cached_cost_usd": round(cached_cost, 6),
        "total_cost_usd":  round(total_cost, 6),
    }

In [None]:
tokens_parsing = get_tokens(response)
price_parsing  = compute_price("gemini-2.5-pro-exp-03-25", tokens_parsing)
print(price_parsing)

In [None]:
tokens_pp = get_tokens(res)
price_pp  = compute_price("gemini-2.5-flash-preview-04-17", tokens_pp)
print(price_pp)

# Save output

In [None]:
fn

In [13]:
dict_obj = {
    "raw_content": [pc.model_dump() for pc in response.parsed.content],
    "summary": response.parsed.summary,
    "structured_content": [sc.model_dump() for sc in response.parsed.structured_content],
}

# Save to JSON
with open(os.path.join(OUTPUT_PATH, fn.replace(".pdf", ".json")), "w", encoding="utf-8") as f:
    json.dump(dict_obj, f, ensure_ascii=False, indent=4)

In [None]:
METADATA_PATH = os.path.join(OUTPUT_PATH, fn.replace(".pdf", "_metadata.json"))

metadata = {
    "n_pages": n_pages,
    "parsing": {
        "tokens": tokens_parsing,
        "price": price_parsing,
    },
    "post_processing": {
        "tokens": tokens_pp,
        "price": price_pp,
    },
}

with open(METADATA_PATH, "w", encoding="utf-8") as f:
    json.dump(metadata, f, ensure_ascii=False, indent=4)