<a href="https://colab.research.google.com/github/ElCotox/Project/blob/main/ESG_Rating_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prep / Install

In [None]:
# ==============================================================================
# CELLULE 1: INSTALLATION DES DÉPENDANCES
# Exécutez cette cellule, puis redémarrez l'environnement d'exécution.
# ==============================================================================
!pip install -U torch torchvision --index-url https://download.pytorch.org/whl/cu121
!python -m spacy download en_core_web_sm
!pip install -U \
  transformers timm sentencepiece accelerate bitsandbytes \
  sentence-transformers rank-bm25 faiss-gpu-cu12 \
  spacy pymupdf pillow pandas

In [None]:
import os

# A) Vérifier le GPU
print("--- Vérification du GPU ---")
!nvidia-smi

# B) Monter Google Drive
print("\n--- Montage de Google Drive ---")
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

PROJECT_ROOT = "/content/drive/MyDrive/esg_rating_project"

# C) Créer l'arborescence
print("\n--- Création de l'arborescence du projet ---")
PROJECT_ROOT = "/content/drive/MyDrive/esg_rating_project"
!mkdir -p "{PROJECT_ROOT}/src"
!mkdir -p "{PROJECT_ROOT}/data/reports_to_analyze"
!mkdir -p "{PROJECT_ROOT}/rating_module"

print("\n\n✅ --- ENVIRONNEMENT DE PROJET PRÊT --- ✅")
PROJECT_ROOT = "/content/drive/MyDrive/esg_rating_project"


# KPI_Config.py

In [None]:
# --- Fichier setup.py ---
%%writefile {PROJECT_ROOT}/setup.py
from setuptools import setup, find_packages
setup(
    name='esg_rating_engine',
    version='0.9.0',
    packages=find_packages(),
)

In [None]:
%%writefile /content/drive/MyDrive/esg_rating_project/src/kpi_config.py

# ==============================================================================
#  MATRICE DE PONDÉRATION SECTORIELLE (%)
# ==============================================================================
# Inspiré de votre image. Chaque clé est un secteur d'activité.
# Les valeurs correspondent aux poids des piliers E, S, G pour ce secteur.
# Le moteur essaiera de faire correspondre le secteur détecté à l'une de ces clés.
# ==============================================================================
#  SECTOR LIST (no weights)
# ==============================================================================
SECTOR_LIST = [
    "automobile",
    "capital goods",
    "materials",
    "real estate",
    "construction and engineering",
    "food, beverages and agriculture",
    "consumer goods",
    "leisure",
    "healthcare",
    "retail",
    "professional and commercial services",
    "transport and logistics",
    "media and telecommunications",
    "energy and utilities",
    "software",
    "hardware",
    "municipalities",
    "financial services"
]





# ==============================================================================
#  CADRE DES INDICATEURS CLÉS DE PERFORMANCE (KPIs)
# ==============================================================================
# Ne contient plus les pondérations, seulement les questions et la direction.
# La pondération sera appliquée à la fin, après le scoring.
KPI_FRAMEWORK = {
    'E': {
        'scope_1_emissions': {
            'question': (
                "Provide the 'Group total Scope 1 GHG emissions' for the most recent reporting year ONLY, unit must be tco2e or ktco2e "
                "and note the unit and the latest year in the reasoning"
            ),
            "search_query": "Group total Scope 1 GHG emissions tCO2e ktCO2e latest year",
            'direction': 'lower_is_better',
            'keywords': [
                'Scope 1 emissions', 'direct emissions', 'GHG', 'carbon footprint',
                'climate change', 'environmental impact', 'CO2', 'sustainability reporting'
            ]
        },
        'scope_2_market': {
            'question': (
                "Provide the 'Group total Scope 2 GHG emissions (location based)' for the most recent reporting year ONLY, unit must be tco2e or ktco2e "
                "and note the unit and the latest year in the reasoning"
            ),
            "search_query": "Group total Scope 2 GHG emissions (location based) tCO2e ktCO2e latest year",
            'direction': 'lower_is_better',
            'keywords': [
                'Scope 2 emissions', 'indirect emissions', 'energy consumption',
                'market-based emissions', 'GHG', 'carbon accounting', 'sustainability'
            ]
        },
        'scope_3_emissions': {
            'question': (
                "Provide the 'Group total Scope 3 GHG emissions' for the most recent reporting year ONLY, unit must be tco2e or ktco2e "
                "and note the unit and the latest year in the reasoning"
            ),
            "search_query": "Group total Scope 3 GHG emissions tCO2e ktCO2e latest year",
            'direction': 'lower_is_better',
            'keywords': [
                'Scope 3 emissions', 'value chain emissions', 'indirect GHG',
                'supply chain', 'carbon footprint', 'upstream emissions', 'downstream emissions'
            ]
        },
        'renewable_energy_pct': {
            'question': (
                "Provide the 'Renewable energy consumption percentage' for the most recent reporting year ONLY, unit must be kWh or kMWh or GWh"
                "And note the unit in the reasoning."
            ),
            "search_query": "renewable energy consumption share % latest year renewable electricity",
            'direction': 'higher_is_better',
            'keywords': [
                'renewable energy', 'clean energy', 'energy production',
                'solar', 'wind', 'hydro', 'green electricity', 'sustainable energy'
            ]
        },
        'hazardous_waste': {
            'question': (
                "Provide the 'TOTAL HAZARDOUS WASTE' for the most recent reporting year ONLY, in tonnes (t). "
                "If—and only if—the latest-year figure is explicitly reported in kilograms (kg) or kilotonnes (kt), "
                "safely convert to tonnes ( t = kg ÷ 1,000; t = kt × 1,000). If reported in 'tons', convert only if the "
                "document explicitly states 'metric tons' (i.e., tonnes). If units are ambiguous, do not convert; "
                "return the number as-is and note the unit in the reasoning."
            ),
            "search_query": "total hazardous waste generated tonnes t latest year",
            'direction': 'lower_is_better',
            'keywords': [
                'waste generation', 'solid waste', 'hazardous waste',
                'waste management', 'recycling', 'landfill', 'environmental footprint'
            ]
        },
        'water_fresh_consumption': {
            'question': (
                "Provide the total water consumed or abstracted for the most recent reporting year ONLY, in cubic "
                "meters (m3). If—and only if—the latest-year value is explicitly reported in thousand m3 or million m3, "
                "safely convert to m3 (m3 = thousand m3 × 1,000; m3 = million m3 × 1,000,000). If units are ambiguous, do not convert; "
                "return the number as-is and note the unit in the reasoning."
            ),
            "search_query": "total water withdrawn or consumed m3 latest year thousand m3 million m3",
            'direction': 'lower_is_better',
            'keywords': [
                'water usage', 'water abstraction', 'resource consumption',
                'water footprint', 'sustainable water use', 'environmental impact'
            ]
        }
    },
    'S': {
        'employee_turnover_rate': {
            'question': (
                "Provide the 'Employee turnover rate' for the most recent reporting year ONLY, in percent (%)."
            ),
            "search_query": "employee turnover rate % latest year attrition staff turnover",
            'direction': 'lower_is_better',
            'keywords': [
                'employee turnover', 'HR metrics', 'workforce stability',
                'retention rate', 'human capital', 'employee engagement'
            ]
        },
        'trir': {
            'question': (
                "Provide the 'Accident Rate' for the most recent reporting year ONLY, "
                "per 1,000,000 hours worked, it can also be called loss frequency, accident frequency or incident rate extract the rate as stated and note the unit and the latest year in the reasoning"
            ),
            "search_query": "accident frequency rate per 1,000,000 hours latest year TRIR AFR loss frequency",
            'direction': 'lower_is_better',
            'keywords': [
                'TRIR', 'occupational safety', 'incident rate',
                'workplace injuries', 'health and safety', 'employee wellbeing'
            ]
        },
        'health_safety_fatalities': {
            'question': (
                "Provide the number of work-related 'Fatalities' for the most recent reporting year ONLY."
            ),
            "search_query":"work-related fatalities employees contractors total latest year",
            'direction': 'lower_is_better',
            'keywords': [
                'workplace fatalities', 'occupational hazards', 'employee safety',
                'fatal incidents', 'health and safety', 'risk management'
            ]
        },
        'women_exec_mgmt': {
            'question': (
                "Provide the percentage (%) of 'female in the executiv management' for the most recent reporting year ONLY."
            ),
            "search_query":"women in executiv management % female employees share latest year",
            'direction': 'higher_is_better',
            'keywords': [
                'gender diversity', 'female workforce', 'inclusion',
                'equality', 'diversity metrics', 'HR reporting'
            ]
        }
    },
    'G': {
        'female_directors_pct': {
            'question': (
                "Provide the percentage (%) of 'Women in the Board of Directors.' for the most recent reporting year ONLY."
            ),
            "search_query":"women in executive committee % latest year female representation ExCom",
            'direction': 'higher_is_better',
            'keywords': [
                'board diversity', 'female representation', 'gender equality',
                'governance', 'leadership diversity', 'corporate board'
            ]
        },
        'board_independence_pct': {
            'question': (
                "Provide the percentage of independent directors on the board for the most recent reporting year ONLY."
            ),
            "search_query": "independent directors % of board latest year board independence",
            'direction': 'higher_is_better',
            'keywords': [
                'board independence', 'corporate governance', 'transparency',
                'accountability', 'board structure', 'non-executive directors'
            ]
        },
        'training_hours_per_emp': {
            'question': (
                "Provide the average number of training hours provided per employee for the "
                "most recent reporting year ONLY."
            ),
            "search_query":"average training hours per employee",
            'direction': 'higher_is_better',
            'keywords': [
                'ethics training', 'code of conduct', 'compliance',
                'employee integrity', 'corporate ethics', 'training programs'
            ]
        }
    },
    'C': {
        'env_controversy_score': {
            'question': "Report the company's ENVIRONMENT CONTROVERSIES score at the section level",
            "search_query":"ENVIRONMENT CONTROVERSIES score",
            'direction': 'higher_is_better',
            'keywords': ['ENVIRONMENT CONTROVERSIES'
            ]
        },
        'customers_controversy_score': {
            'question': "Report the company's CUSTOMERS CONTROVERSIES score, a sub-section of the social section",
            "search_query":"CUSTOMERS CONTROVERSIES score social",
            'direction': 'higher_is_better',
            'keywords': ['CUSTOMERS CONTROVERSIES'
            ]
        },
        'human_rights_community_controversy_score': {
            'question': "Report the company's HUMAN RIGHTS and COMMUNITY CONTROVERSIES score, a sub-section of the social section",
            "search_query":"HUMAN RIGHTS & COMMUNITY CONTROVERSIES",
            'direction': 'higher_is_better',
            'keywords': ['HUMAN RIGHTS', 'COMMUNITY CONTROVERSIES'
            ]
        },
        'labor_rights_supply_chain_controversy_score': {
            'question': "Report the company's LABOR RIGHTS and SUPPLY CHAIN CONTROVERSIES score, a sub-section of the social section",
            "search_query":"LABOR RIGHTS & SUPPLY CHAIN CONTROVERSIES",
            'direction': 'higher_is_better',
            'keywords': ['LABOR RIGHTS', 'SUPPLY CHAIN CONTROVERSIES'
            ]
        },
        'governance_controversy_score': {
            'question': "Report the company's GOVERNANCE CONTROVERSIES score at the section level",
            "search_query":"GOVERNANCE CONTROVERSIES score",
            'direction': 'higher_is_better',
            'keywords': ['GOVERNANCE CONTROVERSIES'
            ]
        }
      }
    }

# Utils

In [None]:
%%writefile {PROJECT_ROOT}/src/utils.py
# src/utils.py
import re
import pandas as pd

# spaCy est optionnel (fallback regex si indisponible)
try:
    import spacy
    try:
        NLP = spacy.load("en_core_web_sm", disable=["parser", "ner"])
    except Exception:
        NLP = None
except Exception:
    NLP = None

def clean_text_spacy(text: str, remove_stopwords: bool = False) -> str:
    """
    Nettoie le texte. Utilise spaCy si dispo; sinon fallback regex.
    Optimisé pour la préparation des embeddings.
    """
    if not text:
        return ""

    # Nettoyage grossier
    t = re.sub(r'https?://\\S+|www\\.\\S+', '', text)              # URLs
    t = re.sub(r'\\S+@\\S+', '', t)                                # Emails
    t = re.sub(r'\\b(Page\\s\\d+\\s*(of\\s\\d+)?)\\b', '', t, flags=re.IGNORECASE)

    if NLP is None:
        # Fallback léger (sans spaCy)
        t = re.sub(r'[^A-Za-z0-9%.,:;()\\-\\s]', ' ', t)
        t = re.sub(r'\\s+', ' ', t).strip()
        return t.lower()

    # spaCy path
    doc = NLP(t)
    toks = []
    for tok in doc:
        if tok.is_space or tok.is_punct:
            continue
        if len(tok.text) == 1 and not tok.text.isdigit():
            continue
        if remove_stopwords and tok.is_stop:
            continue
        toks.append(tok.text.lower())
    cleaned = " ".join(toks)
    cleaned = re.sub(r'\\s+', ' ', cleaned).strip()
    return cleaned

def linearize_table(table: pd.DataFrame, source_page: int) -> list[str]:
    """
    Transforme un DataFrame en phrases sémantiques (pour l'index texte).
    """
    sentences = []
    if table is None or table.empty or table.columns.empty:
        return sentences

    headers = [str(h).replace('\\n', ' ').strip() for h in table.columns]
    for _, row in table.iterrows():
        try:
            row_data = [str(item).replace('\\n', ' ').strip() for item in row]
            subject = row_data[0] if row_data else ""
            if not subject:
                continue
            sentence = f"From a table on page {source_page} regarding '{subject}':"
            has_data = False
            for i in range(1, len(row_data)):
                if i < len(headers) and row_data[i] and headers[i]:
                    sentence += f" the value for '{headers[i]}' is '{row_data[i]}';"
                    has_data = True
            if has_data:
                sentences.append(sentence)
        except Exception:
            continue
    return sentences

# PerfectTables

In [None]:
%%writefile {PROJECT_ROOT}/src/perfect_tables.py
import os
import io
import math
import hashlib
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Tuple, Optional

import fitz  # PyMuPDF
import numpy as np
import pandas as pd
from PIL import Image

import torch
from transformers import AutoImageProcessor, TableTransformerForObjectDetection

# -------------------------
# Config par défaut (Kaggle)
# -------------------------
DEFAULT_IMG_DPI = 288  # ~4x @72dpi, bon compromis qualité/mémoire
DET_MODEL = "microsoft/table-transformer-detection"
STR_MODEL = "microsoft/table-transformer-structure-recognition"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class CellUnit:
    row: int
    col: int
    text: str
    bbox: List[float]  # [x0,y0,x1,y1] en pixels (coords du crop d'image)


@dataclass
class TableUnit:
    page: int
    bbox: List[float]            # bbox table en pixels (coord. image page)
    bbox_page_pts: List[float]   # bbox table en points PDF (rect page)
    fingerprint: str
    df: Any                      # pandas.DataFrame
    cells: List[CellUnit]
    image_crop_path: str
    doc_type: str
    text_repr: str               # headers | first_col preview


def _render_page_to_image(doc, page_idx: int, dpi: int = DEFAULT_IMG_DPI) -> Image.Image:
    page = doc.load_page(page_idx)
    zoom = dpi / 72.0
    mat = fitz.Matrix(zoom, zoom)
    pix = page.get_pixmap(matrix=mat, alpha=False)
    img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
    return img


def _words_as_pixels(page, img_w: int, img_h: int, dpi: int) -> List[Tuple[float,float,float,float,str]]:
    # PyMuPDF words: (x0,y0,x1,y1,"text", block_no, line_no, word_no) in POINTS (72dpi)
    words = page.get_text("words")
    page_rect = page.rect
    sx = (dpi / 72.0)
    sy = (dpi / 72.0)
    shift_x, shift_y = -page_rect.x0, -page_rect.y0

    px_words = []
    for (x0, y0, x1, y1, txt, *_rest) in words:
        X0 = (x0 + shift_x) * sx
        Y0 = (y0 + shift_y) * sy
        X1 = (x1 + shift_x) * sx
        Y1 = (y1 + shift_y) * sy
        # clamp
        X0 = max(0, min(img_w - 1, X0)); X1 = max(0, min(img_w - 1, X1))
        Y0 = max(0, min(img_h - 1, Y0)); Y1 = max(0, min(img_h - 1, Y1))
        px_words.append((X0, Y0, X1, Y1, txt))
    return px_words


def _detect_tables(image: Image.Image, processor, model, score_thresh=0.5):
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)
    with torch.inference_mode():
        out = model(**inputs)
    target_sizes = torch.tensor([image.size[::-1]]).to(DEVICE)  # (h, w)
    results = processor.post_process_object_detection(out, target_sizes=target_sizes)[0]
    bboxes, scores, labels = results["boxes"].cpu().numpy(), results["scores"].cpu().numpy(), results["labels"].cpu().numpy()
    keep = [i for i, (s, l) in enumerate(zip(scores, labels)) if s >= score_thresh and int(l) == 0]  # label 1 == "table"
    return bboxes[keep], scores[keep]


def _structure_on_crop(crop_img: Image.Image, str_processor, str_model, score_thresh=0.4):
    inputs = str_processor(images=crop_img, return_tensors="pt").to(DEVICE)
    with torch.inference_mode():
        out = str_model(**inputs)
    target_sizes = torch.tensor([crop_img.size[::-1]]).to(DEVICE)
    results = str_processor.post_process_object_detection(out, target_sizes=target_sizes)[0]
    boxes, scores, labels = results["boxes"].cpu().numpy(), results["scores"].cpu().numpy(), results["labels"].cpu().numpy()
    # labels mapping (from model card)
    # 0: table, 1: table column, 2: table row, 3: table column header, 4: table projected row header,
    # 5: table spanning cell, 6: no object
    rows = [boxes[i] for i in range(len(labels)) if scores[i] >= score_thresh and int(labels[i]) == 2]
    cols = [boxes[i] for i in range(len(labels)) if scores[i] >= score_thresh and int(labels[i]) == 1]
    return rows, cols


def _intersect(a, b):
    x0 = max(a[0], b[0]); y0 = max(a[1], b[1])
    x1 = min(a[2], b[2]); y1 = min(a[3], b[3])
    if x1 <= x0 or y1 <= y0:
        return None
    return [x0, y0, x1, y1]


def _assign_text_to_grid(crop_words, row_boxes, col_boxes) -> Tuple[pd.DataFrame, List[CellUnit]]:
    # trier lignes/colonnes (centre géometrique)
    rows = sorted(row_boxes, key=lambda b: (b[1]+b[3])/2.0)
    cols = sorted(col_boxes, key=lambda b: (b[0]+b[2])/2.0)
    if not rows or not cols:
        # fallback: tout le texte en une cellule
        text = " ".join([w[4] for w in sorted(crop_words, key=lambda t: (t[1], t[0]))])
        df = pd.DataFrame([[text]], columns=["Value"])
        cells = [CellUnit(0,0,text,[0,0,1,1])]
        return df, cells

    # grille via intersections
    H, W = len(rows), len(cols)
    grid: List[List[str]] = [[ "" for _ in range(W)] for __ in range(H)]
    cell_boxes: List[List[List[float]]] = [[ None for _ in range(W)] for __ in range(H)]

    for i, rb in enumerate(rows):
        for j, cb in enumerate(cols):
            inter = _intersect(rb, cb)
            if inter is None:
                # petite extension tolérante
                eps = 2.0
                inter = _intersect([rb[0], rb[1]-eps, rb[2], rb[3]+eps], [cb[0]-eps, cb[1], cb[2]+eps, cb[3]])
            if inter is None:
                inter = [ (rb[0]+cb[0])/2, (rb[1]+cb[1])/2, (rb[2]+cb[2])/2, (rb[3]+cb[3])/2 ]
            cell_boxes[i][j] = inter

    # affecter tokens à la cellule la plus proche par centre
    for (x0,y0,x1,y1,txt) in crop_words:
        cx, cy = (x0+x1)/2.0, (y0+y1)/2.0
        best = None
        for i in range(H):
            for j in range(W):
                b = cell_boxes[i][j]
                if b[0] <= cx <= b[2] and b[1] <= cy <= b[3]:
                    best = (i,j); break
            if best: break
        if best is None:
            # chercher cellule la plus proche (distance centre->box)
            dmin, pos = 1e18, (0,0)
            for i in range(H):
                for j in range(W):
                    b = cell_boxes[i][j]
                    bx, by = (b[0]+b[2])/2.0, (b[1]+b[3])/2.0
                    d = (bx-cx)**2 + (by-cy)**2
                    if d < dmin: dmin, pos = d, (i,j)
            best = pos
        i, j = best
        if grid[i][j]:
            grid[i][j] += " " + txt
        else:
            grid[i][j] = txt

    # construire DataFrame (première ligne = header si non numérique)
    headers = []
    first_row = grid[0]
    if any([any(ch.isalpha() for ch in (c or "")) for c in first_row]):
        headers = [c if c else f"Col{j+1}" for j,c in enumerate(first_row)]
        data_rows = grid[1:] if len(grid) > 1 else []
    else:
        headers = [f"Col{j+1}" for j in range(len(first_row))]
        data_rows = grid

    df = pd.DataFrame(data_rows, columns=headers)
    cells_out: List[CellUnit] = []
    for i in range(H):
        for j in range(W):
            t = grid[i][j] if i < len(grid) and j < len(grid[i]) else ""
            cells_out.append(CellUnit(i, j, t, [float(x) for x in cell_boxes[i][j]]))
    return df, cells_out


def _fingerprint_df(df: pd.DataFrame) -> str:
    try:
        header = " | ".join([str(c) for c in df.columns])
        first_col = ""
        if df.shape[1] > 0:
            first_col = " | ".join([str(x) for x in df.iloc[:10, 0].astype(str).tolist()])
        shape = f"{df.shape[0]}x{df.shape[1]}"
        raw = f"{header} || {first_col} || {shape}".lower()
        raw = "".join(ch if ch.isalnum() or ch in " %|" else " " for ch in raw)
        raw = " ".join(raw.split())
        return hashlib.sha1(raw.encode("utf-8")).hexdigest()
    except Exception:
        return hashlib.sha1(str(id(df)).encode("utf-8")).hexdigest()


def extract_tables_from_pdf(
    pdf_path: str,
    out_img_dir: str,
    doc_type: str,
    dpi: int = DEFAULT_IMG_DPI,
    det_thresh: float = 0.8,
    str_thresh: float = 0.7
) -> List[Dict[str, Any]]:
    """
    Renvoie une liste de TableUnit (dict) pour alimenter LayoutLMv3 et l'index.
    """
    processor_det = AutoImageProcessor.from_pretrained(DET_MODEL)
    model_det = TableTransformerForObjectDetection.from_pretrained(DET_MODEL).to(DEVICE)
    processor_str = AutoImageProcessor.from_pretrained(STR_MODEL)
    model_str = TableTransformerForObjectDetection.from_pretrained(STR_MODEL).to(DEVICE)

    out: List[Dict[str, Any]] = []

    with fitz.open(pdf_path) as doc:
        for p in range(len(doc)):
            page = doc.load_page(p)
            img = _render_page_to_image(doc, p, dpi=dpi)
            W, H = img.size
            px_words = _words_as_pixels(page, W, H, dpi=dpi)

            # détection table(s) -> bboxes image pixels
            tb_boxes, tb_scores = _detect_tables(img, processor_det, model_det, score_thresh=det_thresh)
            if len(tb_boxes) == 0:
                continue

            for t_idx, box in enumerate(tb_boxes):
                x0, y0, x1, y1 = [float(v) for v in box]
                x0c, y0c = max(0, x0), max(0, y0)
                x1c, y1c = min(W-1, x1), min(H-1, y1)

                # crop + mots dans la zone
                crop = img.crop((x0c, y0c, x1c, y1c))
                cW, cH = crop.size
                crop_words = []
                for (wx0, wy0, wx1, wy1, wtxt) in px_words:
                    if wx1 < x0c or wx0 > x1c or wy1 < y0c or wy0 > y1c:
                        continue
                    # clamp au crop
                    cx0 = max(0, wx0 - x0c); cy0 = max(0, wy0 - y0c)
                    cx1 = min(cW-1, wx1 - x0c); cy1 = min(cH-1, wy1 - y0c)
                    if cx1 > cx0 and cy1 > cy0:
                        crop_words.append((cx0, cy0, cx1, cy1, wtxt))

                # structure (rows/cols)
                row_boxes, col_boxes = _structure_on_crop(crop, processor_str, model_str, score_thresh=str_thresh)
                df, cells = _assign_text_to_grid(crop_words, row_boxes, col_boxes)

                # sauvegarder image pour explicabilité / LayoutLM
                os.makedirs(out_img_dir, exist_ok=True)
                img_name = os.path.join(out_img_dir, f"{os.path.basename(pdf_path)}_p{p+1}_t{t_idx+1}.png")
                crop.save(img_name)

                # bbox page (points PDF) pour traçabilité
                page_rect = page.rect
                sx = (dpi / 72.0); sy = (dpi / 72.0)
                page_bbox = [
                    (x0c / sx) + page_rect.x0,
                    (y0c / sy) + page_rect.y0,
                    (x1c / sx) + page_rect.x0,
                    (y1c / sy) + page_rect.y0,
                ]

                fp = _fingerprint_df(df)
                headers = " | ".join(map(str, list(df.columns)))
                fcol = " | ".join(map(str, list(df.iloc[:,0].astype(str).values[:10]))) if df.shape[1] > 0 else ""
                text_repr = f"Table: {headers} || first_col: {fcol}"

                out.append({
                    "page": int(p+1),
                    "bbox": [x0c, y0c, x1c, y1c],
                    "bbox_page_pts": [float(v) for v in page_bbox],
                    "fingerprint": fp,
                    "df": df,
                    "cells": [asdict(c) for c in cells],
                    "image_crop_path": img_name,
                    "doc_type": doc_type,
                    "text": text_repr,
                })
    return out


# Extraction Engine

In [None]:
%%writefile {PROJECT_ROOT}/src/esg_engine.py

#==============================================================================
#FILE: src/esg_engine.py (PyMuPDF + PerfectTables, Donut DocVQA + RAG, BGE rerank, Judge, cache+CSV)
#==============================================================================

import os
import gc
import re
import csv
import json
import math
import torch
import faiss
import spacy
import fitz
import hashlib
import numpy as np
import pandas as pd
import concurrent.futures as cf

from PIL import Image
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder

from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
DonutProcessor,
AutoModelForVision2Seq,
VisionEncoderDecoderModel
)

#--- Projet local ---

from src.kpi_config import KPI_FRAMEWORK, SECTOR_LIST
from src.utils import clean_text_spacy, linearize_table
from src.perfect_tables import extract_tables_from_pdf

CSV_DELIM = ";"
CAND_SCHEMA_VERSION = "v2"

#--------------------------------------------------------------------------
#Helpers de nettoyage si spaCy indispo
#--------------------------------------------------------------------------

def _clean_fallback(txt: str) -> str:
    txt = (txt or "").replace("\u00A0"," ").replace("\u2009"," ").replace("\u202F"," ")
    txt = re.sub(r"[ \t]+", " ", txt)
    txt = re.sub(r"\s{2,}", " ", txt)
    return txt.strip()

#--------------------------------------------------------------------------
#Classe principale
#--------------------------------------------------------------------------

class ESGRatingEngine:
    def __init__(self, use_fine_tuned_model: bool = False):
        """
        Qwen2-7B (4-bit quand CUDA), Qwen3-Embedding (CPU), BGE CrossEncoder (CPU),
        RAG Hybride + Donut DocVQA + Judge.
        """
        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
        os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

        self.device = "cuda" if (hasattr(torch, "cuda") and torch.cuda.is_available()) else "cpu"
        if self.device == "cuda":
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
        print("--- Initializing ESG Rating Engine (VRAM-safe) ---")

        # --- Model names ---
        model_name = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
        embedding_model_name = "mixedbread-ai/mxbai-embed-large-v1"

        # --- LLM (Qwen2) ---
        print(f"Loading LLM: {model_name}")
        use_4bit = torch.cuda.is_available()
        if use_4bit:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
            )
            max_memory = {"cpu": "48GiB"}
            try:
                total_gb = int(torch.cuda.get_device_properties(0).total_memory / (1024**3))
                cap = max(10, total_gb - 3)  # garde ~3 GiB de marge
                max_memory[0] = f"{cap}GiB"
            except Exception:
                max_memory[0] = "19GiB"
            llm_kwargs = dict(
                device_map="auto",
                quantization_config=bnb_config,
                torch_dtype=torch.float16,
                max_memory=max_memory,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
        else:
            llm_kwargs = dict(
                device_map="cpu",
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

        # ✅ Forcer ChatML si absent/incorrect
        if (self.tokenizer.chat_template is None) or ("<|im_start|>" not in str(self.tokenizer.chat_template)):
            self.tokenizer.chat_template = (
                "{% for message in messages %}"
                "{% if message['role'] == 'system' %}"
                "<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
                "{% elif message['role'] == 'user' %}"
                "<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
                "{% elif message['role'] == 'assistant' %}"
                "<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
                "{% endif %}"
                "{% endfor %}"
                "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
            )

        # ✅ Aligner EOS/PAD sur le token ChatML de fin
        im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
        if im_end_id is not None and im_end_id != -1:
            self.tokenizer.eos_token_id = im_end_id
            self.tokenizer.eos_token    = "<|im_end|>"
            if self.tokenizer.pad_token_id is None:
                self.tokenizer.pad_token_id = im_end_id
                self.tokenizer.pad_token    = "<|im_end|>"



        self.model = AutoModelForCausalLM.from_pretrained(model_name, **llm_kwargs)

        # --- Embeddings (CPU) ---
        print(f"\nLoading embedding model on CPU: {embedding_model_name}")
        self.embedding_model = SentenceTransformer(
            embedding_model_name,
            device="cpu",
            trust_remote_code=True
        )
        self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.embedding_batch_size = 16  # batch réduit (RAM)

        # --- Cross-encoder reranker (CPU, base) ---
        self.reranker_name = "BAAI/bge-reranker-v2-m3"
        try:
            print(f"\nLoading cross-encoder reranker on CPU: {self.reranker_name}")
            self.reranker = CrossEncoder(self.reranker_name, device="cpu")
        except Exception as e:
            print(f"  -> WARNING: {self.reranker_name} unavailable ({e}). Falling back to MiniLM.")
            self.reranker_name = "cross-encoder/ms-marco-MiniLM-L-2-v2"
            self.reranker = CrossEncoder(self.reranker_name, device="cpu")

        # --- Donut DocVQA ---
        self.enable_donut = os.environ.get("ESG_ENABLE_DONUT", "1") == "1"
        self.donut_model_name = os.environ.get(
            "ESG_DONUT_MODEL",
            "naver-clova-ix/donut-base-finetuned-docvqa"
        )
        self.donut_device = "cuda" if torch.cuda.is_available() else "cpu"
        self.max_donut_new_tokens = int(os.environ.get("ESG_DONUT_MAX_NEW_TOKENS", "48"))
        self.donut_proc = None
        self.donut_model = None
        # Nombre de tables testées par VQA (augmente la couverture)
        self.table_llm_limit = int(os.environ.get("ESG_TABLE_LLM_LIMIT", "5"))

        # --- NLP + state ---
        print("\nLoading spaCy...")
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except Exception:
            self.nlp = None  # fallback regex si besoin

        # Stores / state
        self.chunks = []
        self.chunk_doc_types = []
        self.tables = []
        self.all_embeddings = None
        self.faiss_index = None
        self.bm25_index = None
        self.analysis_state = {}
        self.page_texts = []

        # RAG knobs
        self.serialize_rows = 12
        self.serialize_cols = 8
        self.max_new_tokens_json = 512
        self.rag_top_candidates = 24
        self.rag_top_k_context = 3
        self._last_rag_context = ""

        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        except Exception:
            pass

        self.debug = False
        self.debug_trace = {}
        self.debug_tables = True

        # --- Paths / cache ---
        self.project_root = os.environ.get("ESG_PROJECT_ROOT", os.path.abspath(os.getcwd()))
        self.cache_dir     = os.environ.get("ESG_CACHE_DIR", os.path.join(self.project_root, "cache"))
        self.image_dir     = os.path.join(self.cache_dir, "images")
        self.candidates_log_csv = os.environ.get(
            "ESG_CANDIDATES_CSV",
            os.path.join(self.project_root, "data", "esg_kpi_candidates_log.csv"),
        )

        # Crée les dossiers au démarrage
        os.makedirs(self.cache_dir, exist_ok=True)
        os.makedirs(self.image_dir, exist_ok=True)
        os.makedirs(os.path.dirname(self.candidates_log_csv), exist_ok=True)


        try:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        except Exception:
            pass

    # --------------------------------------------------------------------------
    # Utils génériques
    # --------------------------------------------------------------------------
    def _ensure_donut_ready(self) -> bool:
        """
        Charge Donut DocVQA (processor+model) en lazy, correctement configuré.
        - Conditioning sur image + prompt (géré au moment de generate).
        """
        if not self.enable_donut:
            return False

        # Valeurs par défaut si pas déjà définies
        if not getattr(self, "donut_model_name", None):
            self.donut_model_name = os.environ.get(
                "ESG_DONUT_MODEL",
                "naver-clova-ix/donut-base-finetuned-docvqa"
            )
        if not getattr(self, "donut_device", None):
            self.donut_device = "cuda" if (hasattr(torch, "cuda") and torch.cuda.is_available()) else "cpu"
        if not getattr(self, "max_donut_new_tokens", None):
            self.max_donut_new_tokens = int(os.environ.get("ESG_DONUT_MAX_NEW_TOKENS", "48"))

        if self.donut_proc is not None and self.donut_model is not None:
            return True

        try:
            print(f"Loading Donut DocVQA: {self.donut_model_name} on {self.donut_device}")
            self.donut_proc = DonutProcessor.from_pretrained(self.donut_model_name)
            # VisionEncoderDecoderModel attendu pour Donut
            self.donut_model = VisionEncoderDecoderModel.from_pretrained(self.donut_model_name)
            self.donut_model = self.donut_model.to(self.donut_device)
            self.donut_model.eval()
            print(f"Donut ready on {self.donut_device}: {self.donut_model_name}")
            return True
        except Exception as e:
            print(f"WARNING: cannot load Donut ({e}). Disabling it.")
            self.enable_donut = False
            self.donut_proc = None
            self.donut_model = None
            return False

    def _find_df_context_phrase(self, df: pd.DataFrame, raw_answer: str | None, value: float | None, unit_hint: str | None) -> str:
        """
        Recompose une phrase contextuelle à partir du DF du tableau :
        - Tente d'identifier la cellule contenant la valeur extraite (en utilisant 'raw_answer' si dispo).
        - Construit: 'row_label — col_header: cell_text'.
        - Pas de fallback sur le nom du KPI (si rien trouvé -> retourne "").
        """
        try:
            if df is None or df.empty:
                return ""

            # Normaliser DF en str
            sdf = df.astype(str)

            # Construit une liste de motifs plausibles à retrouver dans la cellule
            import re
            patterns: list[str] = []
            if raw_answer:
                # récupère tous les tokens numériques (avec % éventuel) du texte de Donut
                toks = re.findall(r"[0-9][0-9\.,\s]*%?", raw_answer)
                # nettoyer
                toks = [t.strip() for t in toks if t and any(ch.isdigit() for ch in t)]
                patterns.extend(toks)

            # si on a une value numérique, générer quelques variantes textuelles possibles
            if isinstance(value, (int, float)):
                v = float(value)
                # formes sans pourcentage
                patterns.extend([
                    f"{int(v)}" if abs(v - int(v)) < 1e-6 else f"{v}",
                    f"{int(v)}",
                    f"{v:.1f}",
                    f"{v:.2f}",
                    f"{v:.3f}",
                ])
                # si % suspecté, générer variantes
                if unit_hint and "%" in unit_hint:
                    base = f"{int(v)}" if abs(v - int(v)) < 1e-6 else f"{v}"
                    patterns.extend([
                        base + "%",
                        base + " %",
                        f"{v:.1f}%",
                        f"{v:.1f} %",
                        f"{v:.2f}%",
                        f"{v:.2f} %",
                    ])

            # déduire aussi formes avec virgule
            extended = []
            for p in patterns:
                if "." in p:
                    extended.append(p.replace(".", ","))
                if "," in p:
                    extended.append(p.replace(",", "."))
            patterns.extend(extended)

            # cherche la première occurrence dans le DF
            best = None
            for i in range(sdf.shape[0]):
                for j in range(sdf.shape[1]):
                    cell = sdf.iat[i, j]
                    c_norm = str(cell)
                    hit = False
                    for pat in patterns:
                        # recherche tolérante (ignore spaces fines)
                        pat_norm = str(pat).replace("\u00A0", " ").strip()
                        if pat_norm and pat_norm in c_norm:
                            hit = True
                            break
                    if hit:
                        row_label = ""
                        try:
                            if j != 0 and sdf.shape[1] > 0:
                                row_label = sdf.iat[i, 0]
                        except Exception:
                            row_label = ""
                        col_header = str(df.columns[j]) if j < len(df.columns) else f"Col{j+1}"
                        # phrase compacte
                        phrase = f"{row_label} — {col_header}: {cell}".strip()
                        best = phrase
                        break
                if best:
                    break

            return best or ""
        except Exception:
            return ""

    def _donut_answer(self, img: Image.Image, question: str) -> str:
        """
        Appelle Donut et renvoie le texte de réponse.
        """
        # Prépare le prompt Donut DocVQA
        prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
        pixel_values = self.donut_proc(img, return_tensors="pt").pixel_values.to(self.donut_device)
        decoder_input_ids = self.donut_proc.tokenizer(
            prompt, add_special_tokens=False, return_tensors="pt"
        ).input_ids.to(self.donut_device)

        with torch.inference_mode():
            outputs = self.donut_model.generate(
                pixel_values=pixel_values,
                decoder_input_ids=decoder_input_ids,
                max_length=decoder_input_ids.shape[-1] + self.max_donut_new_tokens,
                early_stopping=True,
                pad_token_id=self.donut_proc.tokenizer.pad_token_id,
                eos_token_id=self.donut_proc.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[self.donut_proc.tokenizer.unk_token_id]],
            )

        seq = self.donut_proc.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        # Extraire ce qui suit <s_answer> si présent
        if "</s_answer>" in seq:
            ans = seq.split("<s_answer>")[-1].split("</s_answer>")[0].strip()
        else:
            # Donut renvoie parfois directement le texte après le prompt
            ans = seq.replace(prompt, "").strip()
        return ans


    def _donut_vqa_extract(self, kpi_name: str, kpi_config: dict, allowed_docs: set | None = None) -> list[dict]:
        """
        Donut DocVQA multi-candidats : renvoie une LISTE de candidats (<= table_llm_limit),
        un par table candidate, si la valeur extraite est plausible.
        """
        out = []
        if not self._ensure_donut_ready():
            return out

        cand_tables = self._select_candidate_tables(
            kpi_name, kpi_config, limit=self.table_llm_limit, allowed_docs=allowed_docs
        )
        if not cand_tables:
            return out

        question = kpi_config.get("question", kpi_name)

        for T in cand_tables:
            img_path = T.get("image_crop_path")
            if not img_path:
                continue

            # image
            try:
                img = Image.open(img_path).convert("RGB")
                mx = max(img.size)
                if mx > 2200:
                    s = 2200.0 / mx
                    img = img.resize((int(img.width * s), int(img.height * s)), Image.BICUBIC)
            except Exception:
                continue

            # prompt + conditioning
            task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
            try:
                proc_in = self.donut_proc(images=img, text=task_prompt, return_tensors="pt")
                pixel_values = proc_in.pixel_values.to(self.donut_model.device)
                input_ids    = proc_in.input_ids.to(self.donut_model.device)
            except Exception as e:
                print(f"Donut preprocessing failed: {e}")
                continue

            # génération bornée GPU-safe
            def _generate(pixel_values, input_ids):
                input_len = int(input_ids.shape[-1])
                max_pos = int(getattr(self.donut_model.config.decoder, "max_position_embeddings", 512))
                if input_len >= max_pos - 4:
                    keep = max_pos - 8
                    input_ids = input_ids[:, -keep:]
                    input_len = int(input_ids.shape[-1])
                allowed_new = max(1, min(int(getattr(self, "max_donut_new_tokens", 48)), max_pos - input_len - 2))
                with torch.inference_mode():
                    out = self.donut_model.generate(
                        pixel_values=pixel_values,
                        input_ids=input_ids,
                        max_new_tokens=allowed_new,
                        num_beams=1,
                        do_sample=False,
                        use_cache=True,
                        eos_token_id=self.donut_proc.tokenizer.eos_token_id,
                        pad_token_id=self.donut_proc.tokenizer.pad_token_id,
                    )
                seqs = out.sequences if hasattr(out, "sequences") else out
                raw  = self.donut_proc.batch_decode(seqs, skip_special_tokens=True)[0].strip()
                return raw

            try:
                raw = _generate(pixel_values, input_ids)
            except Exception as e:
                msg = str(e).lower()
                if "device-side assert" in msg or "cuda error" in msg:
                    print("Donut CUDA assert — retry on CPU with safe generation.")
                    try:
                        self.donut_model = self.donut_model.to("cpu")
                        pixel_values = pixel_values.to("cpu")
                        input_ids    = input_ids.to("cpu")
                        raw = _generate(pixel_values, input_ids)
                    except Exception as e2:
                        print(f"Donut retry on CPU failed: {e2}")
                        continue
                else:
                    print(f"Donut failed on table (page {T.get('page')}): {e}")
                    continue

            # réponse
            ans = raw
            if "<s_answer>" in raw:
                ans = raw.split("<s_answer>", 1)[-1].split("</s_answer>")[0].strip()
            else:
                q_norm = question.strip()
                pos = ans.find(q_norm)
                if pos != -1:
                    ans = ans[pos + len(q_norm):].strip()


            m_all = re.findall(r"([0-9][0-9\.,]+\b)", ans) # Cherche des nombres purs
            if not m_all: continue # Si pas de nombre, on ignore
            val_str = m_all[-1]
            val = self._coerce_float(val_str)
            if val is None or val < 0: continue


            # 2. Chercher l'année et l'unité dans un "voisinage" de la valeur trouvée
            pos = ans.rfind(val_str)
            window = 50 # On regarde 50 caractères avant et après
            neighborhood = ans[max(0, pos - window): pos + len(val_str) + window]

            # Recherche de l'année D'ABORD dans ce voisinage proche
            year_matches = re.findall(r"\b(20\d{2})\b", neighborhood)
            year = max(map(int, year_matches)) if year_matches else None

            # Si pas d'année proche, on cherche dans tout le contexte (texte du tableau + réponse)
            if year is None:
                search_text = (ans + " " + T.get("text", "")).lower()
                year_matches = re.findall(r"\b(20\d{2})\b", search_text)
                year = max(map(int, year_matches)) if year_matches else None

            # Recherche de l'unité (logique inchangée, elle fonctionne bien)
            unit_search_text = (ans + " " + T.get("text", "")).lower()
            unit = None
            if "ktco2e" in unit_search_text or "kt co2e" in unit_search_text: unit = "ktCO2e"
            elif "tco2e" in unit_search_text or "tons co2eq" in unit_search_text: unit = "tCO2e"
            elif "twh" in unit_search_text: unit = "TWh"
            elif "gwh" in unit_search_text: unit = "GWh"
            elif "%" in unit_search_text: unit = "%"

            if not self._validate_value_generic(unit, val):
                continue

            df = T.get("df")
            ctx = self._find_df_context_phrase(df, raw_answer=ans, value=val, unit_hint=unit)

            if isinstance(val, (int, float)) and float(val).is_integer():
                val_canon = str(int(val))
            else:
                val_canon = str(float(val)).rstrip("0").rstrip(".")

            answer_text_parts = []
            if ctx:
                answer_text_parts.append(ctx.strip())

            at = " — ".join(answer_text_parts) if answer_text_parts else ""
            if val_canon not in at:
                answer_text_parts.append(val_canon)

            if unit:
                unit_norm = unit.strip()
                if unit_norm and unit_norm not in " — ".join(answer_text_parts):
                    answer_text_parts.append(unit_norm)

            if year is not None:
                y_str = str(int(year))
                if y_str not in " — ".join(answer_text_parts):
                    answer_text_parts.append(y_str)

            answer_text = " — ".join(answer_text_parts).strip()

            out.append({
                "source": "Donut-DocVQA",
                "value": float(val),
                "unit": unit,
                "year": year,
                "reason": f"Donut on table fp={T.get('fingerprint','')}",
                "evidence": f"table on page {T.get('page')}",
                "answer_text": answer_text,
                "doc_type": T.get("doc_type",""),
                "page": T.get("page"),
                "table_fp": T.get("fingerprint"),
                "image_crop_path": T.get("image_crop_path"),
                "bbox": T.get("bbox"),
                "bbox_page_pts": T.get("bbox_page_pts"),
                "judge_reason": None,
                "judge_confidence": None,
            })
        return out

    def _ce_score(self, query: str, text: str) -> float:
        try:
            s = float(self.reranker.predict([(query, text)])[0])
            return s
        except Exception:
            return 0.0

    def _coerce_float(self, x):
        if x is None:
            return None
        if isinstance(x, (int, float)):
            return float(x)
        raw = str(x)
        raw = raw.replace("\u00A0"," ").replace("\u2009"," ").replace("\u202F"," ")
        s = raw.strip()

        # Accepter le groupage FR "1 234 567" -> "1234567"
        if re.fullmatch(r"\d{1,3}(?:\s\d{3})+(?:[.,]\d+)?", s):
            s = s.replace(" ", "")
        # Si espaces entre nombres mais que ce n'est PAS un groupage de milliers, on rejette (ex: "81 0 25")
        elif re.search(r"\d+\s+\d+", s):
            return None
        t = s
        t = (t.replace("\u00A0","").replace("\u2009","").replace("\u202F","").replace(" ","").replace("_",""))
        t = t.replace("%","")
        if t.count(",") > 0 and t.count(".") == 0:
            t = t.replace(",", ".")
        else:
            t = t.replace(",", "")
        try:
            return float(t)
        except Exception:
            return None

    def _validate_value_generic(self, unit: str | None, value: float | None) -> bool:
        if value is None or not isinstance(value, (int, float)):
            return False
        v = float(value)
        if not np.isfinite(v):
            return False
        if v < 0:
            return False

        unit_l = (unit or "").lower()
        # Si l'unité contient %, alors la valeur doit être dans [0,100]
        if "%" in unit_l and not (0 <= v <= 100):
            return False

        # ✅ On NE rejette plus les valeurs sans unité (elles seront triées au Judge)
        return True



    def _as_json_str(self, obj):
        try:
            return json.dumps(obj, ensure_ascii=False)
        except Exception:
            return str(obj)

    def _flatten_candidate_for_csv(self, pillar: str, kpi_name: str, question: str, cand: dict) -> dict:
        """
        Aplati un candidat (toutes sources) vers un dictionnaire CSV stable (schéma v3).
        Cette version n'a plus besoin du score CE en entrée.
        """
        # Calcul du score CE directement ici pour le logging
        ev = " ".join([str(cand.get(k) or "") for k in ["answer_text", "reason", "evidence"]]).strip()
        try:
            ce_ev = float(self._ce_score(question, ev))
        except Exception:
            ce_ev = 0.0

        # Format de base unifié (schéma "v3" implicite)
        base = {
            "schema_version": "v3",
            "pillar": pillar,
            "kpi_name": kpi_name,
            "kpi_key": f"{pillar}_{kpi_name}",
            "question": question,
            "kpi_source": cand.get("source", ""),
            "value": cand.get("value"),
            "unit": cand.get("unit"),
            "year": cand.get("year"), # ✅ Ajout du champ année
            "answer_text": cand.get("answer_text") or str(cand.get("value", "")),
            "ce_evidence": ce_ev, # On le garde pour l'analyse
            "evidence_preview": (cand.get("evidence", "") or cand.get("reason", ""))[:300],
            # Champs communs
            "doc_type": cand.get("doc_type", ""),
            "page": cand.get("page"),
            "bbox_json": self._as_json_str(cand.get("bbox")),
            "table_fp": cand.get("table_fp", ""),
            "image_crop_path": cand.get("image_crop_path", ""),
            # ✅ Ajout des champs du Judge (seront vides pour les candidats bruts)
            "judge_reason": cand.get("judge_reason"),
            "judge_confidence": cand.get("judge_confidence"),
        }
        return base

    def _clean_table_df(self, df: pd.DataFrame) -> pd.DataFrame | None:
        """
        Nettoie/normalise un DataFrame issu du parseur (PyMuPDF light tables / Camelot).
        Version sans applymap (évite le FutureWarning).
        """
        if df is None:
            return None
        try:
            # Aplatit les entêtes si MultiIndex
            if isinstance(df.columns, pd.MultiIndex):
                df.columns = [
                    " ".join([str(x).strip() for x in tup if str(x).strip()])
                    for tup in df.columns.values
                ]
            else:
                df.columns = [str(c).strip() for c in df.columns]

            # Normalisations de base
            df = df.replace(r"\s+", " ", regex=True)
            df = df.replace({"—": "", "–": "", "N/A": "", "NA": "", "n/a": ""})

            # ⚠️ Remplace applymap par map (pandas ≥2.1)
            df = df.map(lambda x: str(x).strip())

            # Drop vide
            df = df.replace("", np.nan)
            df = df.dropna(how="all", axis=0)
            df = df.dropna(how="all", axis=1)
            if df.empty:
                return None

            # Élimine éventuelles lignes dupliquant exactement l'entête
            header_tuple = tuple(df.columns)
            to_drop = []
            for i, row in df.iterrows():
                if tuple(row.values) == header_tuple:
                    to_drop.append(i)
            if to_drop:
                df = df.drop(index=to_drop)

            # Garde-fous
            if df.shape[1] > 40:
                df = df.iloc[:, :40]
            if df.shape[0] > 1000:
                df = df.iloc[:1000, :]

            return df if not df.empty else None
        except Exception:
            return None
    # --------------------------------------------------------------------------
    # Chargement & preprocessing (PyMuPDF + PerfectTables)
    # --------------------------------------------------------------------------
    def _load_and_process_documents(self, document_paths: dict):
        """
        Extraction du TEXTE par PyMuPDF (pour RAG) + TABLES via PerfectTables UNIQUEMENT.
        """
        print("\n--- Loading & Processing Documents (PerfectTables + PDF text) ---")
        self.page_texts = []
        self.chunks = []
        self.tables = []
        self.chunk_doc_types = []
        self.analysis_state = {
            "sources_provided": [],
            "total_pages": 0,
            "detected_sector": "unknown",
        }

        total_tables_found = 0

        for doc_type, path in document_paths.items():
            if not path:
                continue
            if not os.path.exists(path):
                print(f"  >> WARNING: missing file for '{doc_type}': {path}")
                continue

            print(f"  -> Processing {doc_type} ({os.path.basename(path)})")
            self.analysis_state["sources_provided"].append(doc_type)

            if path.lower().endswith(".pdf"):
                # 1) TEXTE (pour RAG / sector detection)
                try:
                    doc = fitz.open(path)
                    pages = len(doc)
                    for pidx in range(pages):
                        try:
                            txt = doc[pidx].get_text("text") or ""
                        except Exception:
                            txt = ""
                        if txt and len(txt.split()) > 5:
                            chunk = clean_text_spacy(txt)
                            if chunk:
                                self.chunks.append(chunk)
                                self.chunk_doc_types.append(doc_type)
                                self.page_texts.append((pidx + 1, doc_type, chunk))
                    doc.close()
                    self.analysis_state["total_pages"] += pages
                except Exception as e:
                    print(f"  ERROR: cannot open PDF for text: {e}")

                # 2) TABLES (PerfectTables only)
                try:
                    pt_tables = extract_tables_from_pdf(path, self.image_dir, doc_type)
                except Exception as e:
                    print(f"  ERROR: PerfectTables failed on {os.path.basename(path)}: {e}")
                    pt_tables = []

                for t in pt_tables:
                    df = self._clean_table_df(t.get("df"))
                    if df is None or df.empty:
                        continue

                    # largeur/hauteur du crop
                    try:
                        bx = t.get("bbox") or []
                        W = float(bx[2] - bx[0]); H = float(bx[3] - bx[1])
                        if not (W > 0 and H > 0):
                            raise ValueError("bad bbox")
                    except Exception:
                        try:
                            with Image.open(t.get("image_crop_path")) as im:
                                W, H = im.size
                        except Exception:
                            W, H = 1000.0, 1000.0  # garde-fou

                    bbox_rel = [0.0, 0.0, float(W), float(H)]

                    # aperçu texte
                    headers = " | ".join(map(str, list(df.columns)))
                    first_col = " | ".join(map(str, list(df.iloc[:, 0].astype(str).values[:10]))) if df.shape[1] > 0 else ""
                    table_text_representation = t.get("text") or f"Table: {headers} || first_col: {first_col}"

                    record = {
                        "page": int(t.get("page") or 0),
                        "df": df,
                        "text": table_text_representation,
                        "doc_type": doc_type,
                        "bbox": bbox_rel,                              # (0,0,W,H) ≡ coords du crop
                        "cells": t.get("cells", []),                   # pas requis par Donut, gardé pour debug
                        "image_crop_path": t.get("image_crop_path"),
                        "fingerprint": t.get("fingerprint") or self._fingerprint_table(df),
                        "bbox_page_pts": t.get("bbox_page_pts"),
                    }
                    self.tables.append(record)

                    # linearisation pour l’index texte
                    lin = linearize_table(df, source_page=int(t.get("page") or 0))

                    header_text = " ".join(map(str, list(df.columns)))
                    semantic_bait = f"What are the values for {header_text}? "
                    prefix = f"[DATA TABLE from page {int(t.get('page') or 0)}]"
                    prefixed_lin = [prefix + l for l in lin]
                    self.chunks.extend(prefixed_lin)
                    self.chunk_doc_types.extend([doc_type] * len(prefixed_lin))

                    total_tables_found += 1

            elif path.lower().endswith(".txt"):
                with open(path, "r", encoding="utf-8") as f:
                    txt = clean_text_spacy(f.read())
                    if txt:
                        self.chunks.append(txt)
                        self.chunk_doc_types.append(doc_type)

        print(f"--- Done. {len(self.chunks)} knowledge chunks and {len(self.tables)} tables collected (PerfectTables={total_tables_found}). ---")
        self._table_sanity_report()

    def _table_sanity_report(self, sample: int = 5):
        n = len(self.tables)
        if n == 0:
            print(">>> TABLE SANITY: No tables parsed.")
            return
        numeric_tables = 0
        for t in self.tables[:sample]:
            df = t["df"]
            txt = " | ".join(map(str, df.columns))
            sub = df.head(30).astype(str).values.ravel().tolist()
            nums = sum(1 for s in sub if re.search(r"\d", s))
            if nums > 0: numeric_tables += 1
            print(f"    • Page {t['page']}: shape={df.shape}, header≈ [{txt[:80]}...] nums_in_head30={nums}")
        print(f">>> TABLE SANITY: {n} tables total, {numeric_tables}/{min(n, sample)} with numeric content in preview.")

    # --------------------------------------------------------------------------
    # Index hybride (FAISS + BM25)
    # --------------------------------------------------------------------------
    def _build_index(self):
        """
        Construit l'index hybride en limitant la mémoire:
        - embeddings encodés en batch (CPU) -> float32
        - normalisation L2 + FAISS IP
        """
        if not self.chunks and not self.tables:
            print("ERROR: Cannot build index (no content).")
            return

        print("\n--- Building Hybrid Index (Semantic FAISS + Lexical BM25) ---")
        if self.chunks:
            print(f"  -> Indexing {len(self.chunks)} text chunks (batch={self.embedding_batch_size})...")
            embs = self.embedding_model.encode(
                self.chunks,
                batch_size=self.embedding_batch_size,
                show_progress_bar=True,
            )
            embs = np.ascontiguousarray(embs, dtype=np.float32)
            faiss.normalize_L2(embs)
            self.all_embeddings = embs
            self.faiss_index = faiss.IndexFlatIP(self.embedding_dim)
            self.faiss_index.add(self.all_embeddings)

            tokenized_corpus = [doc.lower().split(" ") for doc in self.chunks]
            self.bm25_index = BM25Okapi(tokenized_corpus)

        print(f"--- Hybrid Index ready. ---")

    # --------------------------------------------------------------------------
    # Détection secteur (LLM), limité aux pages 1–5
    # --------------------------------------------------------------------------

    def _detect_industry_sector(self):
        if not self.page_texts:
            print("No content loaded; sector detection skipped.")
            self.analysis_state["detected_sector"] = "unknown"
            return

        print("\n--- Industry Sector Detection (pages 1–5) ---")
        first_pages = [t for t in self.page_texts if t[0] <= 5]
        context = " ".join(t[2] for t in first_pages[:10])
        sector_choices = ", ".join(SECTOR_LIST)
        question = f"Based on the text, what is the primary industry sector? Choose from: {sector_choices}."

        # Prompt de "guidage"
        format_json = '{"sector":"..."}'
        system_message = (
            "You are an expert financial analyst. First, think step-by-step in a <think> block to determine the sector. "
            "Then, provide the answer as a single, compact JSON object and nothing else."
        )
        user_message = (
            f"SECTOR LIST: {sector_choices}\n\n"
            f"CONTEXT:\n{context}\n\n"
            f"QUESTION: {question}\n\n"
            f"FINAL ANSWER FORMAT:\n{format_json}"
        )

        # Logique de génération standardisée
        raw_output = self._generate_text(system_message, user_message, max_new_tokens=500)
        final_answer = self._strip_think(raw_output)
        js = self._extract_first_json_object(final_answer)

        detected = "unknown"
        if js:
            try:
                obj = self._json_loads_lenient(js)
                if obj and isinstance(obj, dict) and "sector" in obj:
                    # On s'assure que le secteur détecté est bien dans la liste autorisée
                    cand = str(obj["sector"]).strip()
                    for s in SECTOR_LIST:
                        if s.lower() == cand.lower():
                            detected = s
                            break
            except Exception:
                pass

        # Fallback (inchangé)
        if detected == "unknown" and final_answer:
            txt = final_answer.lower()
            for sector in SECTOR_LIST:
                if sector.lower() in txt:
                    detected = sector
                    break

        self.analysis_state["detected_sector"] = detected
        print(f"  -> Detected sector: '{detected}'")
    # --------------------------------------------------------------------------
    # Sélection de tables candidates (PerfectTables embeddings)
    # --------------------------------------------------------------------------
    def _select_candidate_tables(self, kpi_name: str, kpi_config: dict, limit: int = 3, allowed_docs: set | None = None):
        if not self.tables:
            return []

        kpi_question = kpi_config.get("question", kpi_name)

        # Embedding de la question
        q_vec = self.embedding_model.encode(kpi_question, convert_to_numpy=True)
        q_vec = np.asarray(q_vec, dtype=np.float32)
        faiss.normalize_L2(q_vec.reshape(1, -1))

        # ⚠️ normalise le filtre et le doc_type comme dans ton debug
        allow_norm = {d.strip().lower() for d in (allowed_docs or set())}

        scored = []
        for t in self.tables:
            t_doc = str(t.get("doc_type","")).strip().lower()
            if allow_norm and t_doc not in allow_norm:
                continue

            df = t.get("df")
            if df is None or getattr(df, "empty", False):
                continue

            # même fallback que dans ton calcul manuel
            text_repr = t.get("text") or (" | ".join(map(str, df.columns)))

            # embedding du tableau (cache sur l’objet table)
            vec = t.get("embedding")
            if vec is None:
                vec = self.embedding_model.encode(text_repr, convert_to_numpy=True)
                vec = np.asarray(vec, dtype=np.float32)
                faiss.normalize_L2(vec.reshape(1, -1))
                t["embedding"] = vec

            sem = float(np.dot(q_vec.ravel(), vec.ravel()))
            header = " | ".join(map(str, df.columns))
            first_col = " | ".join(map(str, df.iloc[:,0].astype(str).values[:20])) if df.shape[1] > 0 else ""
            ce = self._ce_score(kpi_question, (t.get("text") or (header + " | " + first_col)))
            dens = self._numeric_density(df)
            latest_bonus = 0.6 if (self._find_latest_year_col(df) is not None) else 0.0

            score = sem * 10.0 + ce * 2.0 + dens * 1.0 + latest_bonus
            scored.append((score, t))

        ranked = sorted(scored, key=lambda x: x[0], reverse=True)
        return [t for _, t in ranked[:max(1, limit)]]


    def _numeric_density(self, df: pd.DataFrame) -> float:
        vals = df.astype(str).values.ravel().tolist()
        if not vals: return 0.0
        nums = sum(1 for v in vals if re.search(r"\d", v or ""))
        return nums / float(len(vals))

    def _find_latest_year_col(self, df: pd.DataFrame):
        candidates = []
        for j, col in enumerate(df.columns):
            years = re.findall(r"(20\d{2})", str(col))
            if years: candidates.append((j, max(map(int, years))))
        if not candidates: return None
        return max(candidates, key=lambda x: x[1])


    def _log_kpi_candidates(self, pillar: str, kpi_name: str, question: str, candidates: list[dict]):
        """
        Log des candidats (schéma v3). Gère l'écriture et la mise à jour de l'en-tête du CSV.
        """
        if not candidates:
            return

        try:
            # Prépare les lignes pour le CSV en utilisant la nouvelle fonction flatten
            rows = [self._flatten_candidate_for_csv(pillar, kpi_name, question, c) for c in candidates]

            # S'assure que le dossier existe
            os.makedirs(os.path.dirname(self.candidates_log_csv), exist_ok=True)

            # Définit l'ordre des colonnes une fois pour toutes pour la cohérence
            # On utilise les clés de la première ligne comme référence
            header = list(rows[0].keys())

            # Vérifie si le fichier a besoin d'une en-tête
            write_header = not os.path.exists(self.candidates_log_csv)

            with open(self.candidates_log_csv, "a", encoding="utf-8", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=header, delimiter=CSV_DELIM, extrasaction='ignore')

                if write_header:
                    writer.writeheader()

                writer.writerows(rows)

        except Exception as e:
            print(f"[WARN] CSV logging failed for {pillar}_{kpi_name}: {e}")
    # --------------------------------------------------------------------------
    # RAG (Hybride + BGE) → JSON
    # --------------------------------------------------------------------------
    def _rerank_doc_ids(self, question: str, doc_ids: list[int], top_k: int) -> list[int]:
        if not doc_ids:
            return []
        texts = [self.chunks[i] for i in doc_ids]
        try:
            scores = self.reranker.predict([(question, t) for t in texts])
            order = list(np.argsort(scores)[::-1])
            return [doc_ids[i] for i in order[:top_k]]
        except Exception:
            return doc_ids[:top_k]

    def _query_rag_hybrid(self, kpi_name: str, kpi_config: dict, k: int = 3, allowed_docs: set | None = None) -> list[dict]:
        """
        RAG Hybride multi-candidats (STRICT). Génère ≤ k candidats (1 par contexte reranké),
        avec règles strictes + sanity-check "valeur présente dans le contexte".
        """
        out = []
        try:
            question_long = kpi_config.get("question", kpi_name)
            search_query  = kpi_config.get("search_query", question_long)

            # ---- 1) Retrieve BM25 + FAISS
            indices = []
            if self.bm25_index is not None and self.chunks:
                toks = search_query.lower().split(" ")
                bm25_scores = self.bm25_index.get_scores(toks)
                bm25_top = np.argsort(bm25_scores)[::-1][:max(self.rag_top_candidates//2, k)]
                indices.extend(bm25_top.tolist())

            if self.faiss_index is not None and self.chunks:
                q_emb = self.embedding_model.encode([search_query], convert_to_numpy=True)
                q_emb = np.ascontiguousarray(q_emb, dtype=np.float32)
                faiss.normalize_L2(q_emb)
                _, faiss_top = self.faiss_index.search(q_emb, max(self.rag_top_candidates//2, k))
                indices.extend(faiss_top[0].tolist())

            # dédoublonne + filtre types
            indices = list(dict.fromkeys(indices))
            if allowed_docs:
                indices = [i for i in indices if self.chunk_doc_types[i] in allowed_docs]
            if not indices:
                return []

            # ---- 2) Rerank (BGE) → top ≤ k contextes
            top_ids = self._rerank_doc_ids(search_query, indices, top_k=max(k, self.rag_top_k_context))
            top_ids = top_ids[:max(1, k)]

            # ---- 3) Prompt strict (with-think) + génération
            JSON_FORMAT = (
                '{"reasoning":"<short evidence>", '
                '"current_value": 1234.5, '
                '"unit": "tCO2e|ktCO2e|null", '
                '"year": 2024, '
                '"is_present": true}'
            )
            STRICT_RULES = (
                "Rules:\n"
                "1) If you cannot find an explicit number for the QUESTION in CONTEXT, return exactly: "
                '{"reasoning":"not found", "current_value": null, "unit": null, "year": null, "is_present": false}\n'
                "2) Copy digits verbatim from CONTEXT. Do not reorder, do not guess. "
                "If you write 15,571 in reasoning, current_value must be 15571 (commas/spaces removed).\n"
                "3) Year must be a 4-digit year seen in CONTEXT.\n"
                "4) unit ∈ {tCO2e, ktCO2e} or null.\n"
                "5) Output ONE JSON object only. First char must be '{'. No extra text."
            )
            system_message = (
                "You are a precise data extraction expert.\n"
                "First, think step-by-step inside a <think> block (≤60 tokens) to locate the number in the CONTEXT.\n"
                "Then output exactly ONE JSON object and nothing else.\n" + STRICT_RULES
            )

            num_pat  = re.compile(r"\b\d{1,3}(?:[ ,.\u00A0]\d{3})*(?:[.,]\d+)?\b")
            year_pat = re.compile(r"\b20\d{2}\b")

            def _numbers_in_ctx(ctx: str) -> set[float]:
                spans = num_pat.findall(ctx or "")
                vals = set()
                for s in spans:
                    v = self._coerce_float(s)
                    if v is not None:
                        vals.add(float(v))
                return vals

            for idx in top_ids:
                context = self.chunks[idx]
                doc_type = self.chunk_doc_types[idx]

                user_message = f"CONTEXT:\n{context}\n\nQUESTION: {search_query}\n\nFINAL ANSWER FORMAT:\n{JSON_FORMAT}"

                raw_output = self._generate_text(system_message, user_message, max_new_tokens=512)
                final_answer = self._strip_think(raw_output)
                js_str = self._extract_first_json_object(final_answer)
                if not js_str:
                    continue

                data = self._json_loads_lenient(js_str)
                if not isinstance(data, dict):
                    continue

                # respect du fallback "not found"
                isp = str(data.get("is_present", "")).strip().lower()
                if isp in {"false", "0", "no"}:
                    continue

                # parse strict du JSON (sans rescue)
                val = self._coerce_float(data.get("current_value"))
                if val is None:
                    continue

                unit = data.get("unit")
                if isinstance(unit, str):
                    u = unit.lower().replace(" ", "")
                    if "ktco2" in u:
                        unit = "ktCO2e"
                    elif "tco2e" in u or "ton" in u or "tons" in u:
                        unit = "tCO2e"
                    else:
                        unit = None
                else:
                    unit = None

                year = None
                y = data.get("year")
                if isinstance(y, (int, float)):
                    year = int(y)
                elif isinstance(y, str):
                    m = re.search(r"\b20\D?(\d{2})\b", y)
                    if m:
                        year = int("20" + m.group(1))

                # ---- 4) Sanity: valeur & année doivent être dans le CONTEXT
                present_nums = _numbers_in_ctx(context)
                val_ok = (val in present_nums)
                # tolérer kt ↔ t
                if not val_ok and unit == "ktCO2e" and (val * 1000.0) in present_nums:
                    val_ok = True
                if not val_ok and unit == "tCO2e" and (val / 1000.0) in present_nums:
                    val_ok = True
                if not val_ok:
                    continue

                if unit not in (None, "tCO2e", "ktCO2e"):
                    continue
                if year is not None and not year_pat.search(context or ""):
                    # si l'année proposée n'apparaît pas dans le contexte, on l'ignore (mais on garde le candidat)
                    year = None

                if not self._validate_value_generic(unit, val):
                    continue

                # ---- 5) Candidat
                out.append({
                    "source": "RAG+LLM",
                    "value": float(val),
                    "unit": unit,
                    "year": year,
                    "reason": data.get("reasoning", "Generated by RAG."),
                    "evidence": context[:800],
                    "answer_text": data.get("reasoning", str(val)),
                    "doc_type": doc_type,
                    "page": None,
                    "table_fp": None,
                    "image_crop_path": None,
                    "bbox": None,
                    "bbox_page_pts": None,
                    "judge_reason": None,
                    "judge_confidence": None,
                })

        except Exception as e:
            print(f"CRITICAL ERROR in _query_rag_hybrid for '{kpi_name}': {e}")

        return out


    # --------------------------------------------------------------------------
    # Résumé global / confiance
    # --------------------------------------------------------------------------
    def _build_summary_context(self, extracted_kpis: dict, max_chunks: int = 12, allowed_docs: set | None = None) -> str:
        active_questions = []
        for pillar, kpis in KPI_FRAMEWORK.items():
            for k, cfg in kpis.items():
                if extracted_kpis.get(pillar, {}).get(k, {}).get("values", {}).get("is_present"):
                    active_questions.append(cfg.get("question", k))
        if not active_questions:
            active_questions = [cfg.get("question", k) for p, kpis in KPI_FRAMEWORK.items() for k, cfg in kpis.items()]
            active_questions = active_questions[:12]

        combined_query = " ".join(active_questions)[:2000]
        idxs = [i for i in range(len(self.chunks)) if (not allowed_docs or self.chunk_doc_types[i] in allowed_docs)]
        if not idxs:
            return ""

        cand = []
        try:
            if self.bm25_index is not None and self.chunks:
                tokenized = combined_query.lower().split(" ")
                bm_scores = self.bm25_index.get_scores(tokenized)
                bm_top = np.argsort(bm_scores)[::-1][:min(30, len(bm_scores))]
                cand.extend(bm_top.tolist())
        except Exception:
            pass

        try:
            if self.faiss_index is not None and self.chunks:
                q_emb = self.embedding_model.encode([combined_query], convert_to_numpy=True)
                q_emb = np.ascontiguousarray(q_emb, dtype=np.float32); faiss.normalize_L2(q_emb)
                _, fa_top = self.faiss_index.search(q_emb, min(30, len(idxs)))
                cand.extend(fa_top[0].tolist())
        except Exception:
            pass

        # dédoublonne et garde seulement ceux autorisés
        cand = [i for i in dict.fromkeys(cand) if i in idxs]
        if not cand:
            cand = idxs[:min(120, len(idxs))]

        top_ids = self._rerank_doc_ids(combined_query, cand, top_k=max(4, min(max_chunks, len(cand))))
        return "\n---\n".join(self.chunks[i] for i in top_ids)


    def _calculate_confidence_score(self, extracted_kpis: dict) -> float:
        source_weights = {'sustainability': 0.6, 'financial': 0.3, 'controversies': 0.1}
        provided = set(self.analysis_state.get('sources_provided', []))
        base_confidence = sum(source_weights.get(s, 0.0) for s in provided)
        base_confidence = min(1.0, base_confidence)
        total_kpis = sum(len(kpis) for kpis in KPI_FRAMEWORK.values()) or 1
        found_kpis = sum(1 for pillar_data in extracted_kpis.values() for kpi_data in pillar_data.values()
                        if kpi_data.get('values', {}).get('current') is not None)
        data_density = found_kpis / total_kpis
        return float(max(0.0, min(1.0, 0.5 * base_confidence + 0.5 * data_density)))


    def _generate_global_opinion(self, main_context: str, contro_context: str = "") -> dict:
        format_json = '{"global_opinion":"...", "key_risks":["...", "..."], "controversy_comment": ""}'

        # Prompt de "guidage"
        system_message = (
            "You are an equity/ESG analyst. First, think step-by-step in a <think> block to structure your thoughts. "
            "Base 'global_opinion' and 'key_risks' ONLY on MAIN_CONTEXT. "
            "If CONTRO_CONTEXT is not empty, add a 'controversy_comment'. "
            "Then, provide your final analysis as a single, compact JSON object and nothing else."
        )
        user_message = f"MAIN_CONTEXT:\n{main_context}\n\nCONTRO_CONTEXT:\n{contro_context}\n\nFINAL ANSWER FORMAT:\n{format_json}"

        # Logique de génération standardisée
        raw_output = self._generate_text(system_message, user_message, max_new_tokens=550)
        final_answer = self._strip_think(raw_output)

        # Cherche le JSON dans la réponse nettoyée
        js_str = self._extract_first_json_object(final_answer)
        if js_str:
            try:
                obj = self._json_loads_lenient(js_str)
                if obj and isinstance(obj, dict):
                    return {
                        "global_opinion": obj.get("global_opinion", ""),
                        "key_risks": obj.get("key_risks", []),
                        "controversy_comment": obj.get("controversy_comment", "")
                    }
            except Exception:
                pass

        # Fallback si pas de JSON valide
        return {"global_opinion": "", "key_risks": [], "controversy_comment": ""}

    # --------------------------------------------------------------------------
    # Cache (préprocess) : save/load
    # --------------------------------------------------------------------------
    def _preprocess_and_save_to_cache(self, document_paths: dict):
        self._load_and_process_documents(document_paths)
        if not self.chunks and not self.tables:
            print("ERROR: Preprocessing failed, no content extracted.")
            return
        self._build_index()
        print("\n--- Saving preprocessed data to cache... ---")
        with open(self.cache_files["chunks"], "w", encoding="utf-8") as f:
            json.dump({"chunks": self.chunks, "doc_types": self.chunk_doc_types, "page_texts": self.page_texts}, f)
        pd.to_pickle(self.tables, self.cache_files["tables"])
        np.save(self.cache_files["embeddings"], self.all_embeddings)
        pd.to_pickle(self.bm25_index, self.cache_files["bm25"])
        print("--- Caching complete. ---")

    def _load_from_cache(self):
        with open(self.cache_files["chunks"], "r", encoding="utf-8") as f:
            data = json.load(f)
            self.chunks = data["chunks"]; self.chunk_doc_types = data["doc_types"]; self.page_texts = data.get("page_texts",[])
        self.tables = pd.read_pickle(self.cache_files["tables"])
        self.all_embeddings = np.load(self.cache_files["embeddings"])
        self.bm25_index = pd.read_pickle(self.cache_files["bm25"])
        print("  -> Rebuilding FAISS index from cached embeddings...")
        self.faiss_index = faiss.IndexFlatIP(self.embedding_dim); self.faiss_index.add(self.all_embeddings)
        print(f"--- Loaded {len(self.chunks)} chunks and {len(self.tables)} tables from cache. ---")

    def _purge_cache_artifacts(self, cache_prefix: str, force: bool, phase: str = "pre"):
        """
        Purge/refresh des artefacts de cache et des crops PNG pour un PDF donné (cache_prefix).
        - phase="pre":  si force=True, supprime les fichiers de cache et le dossier images/<cache_prefix>, puis recrée le dossier.
        - phase="post": supprime les PNG orphelins non référencés par self.tables (sécurité).
        """
        import shutil

        img_dir = os.path.join(self.cache_dir, "images", cache_prefix)

        if phase == "pre":
            if force:
                try:
                    for p in (self.cache_files or {}).values():
                        if p and os.path.exists(p):
                            os.remove(p)
                except Exception:
                    pass
                shutil.rmtree(img_dir, ignore_errors=True)

            os.makedirs(img_dir, exist_ok=True)
            self.image_dir = img_dir
            return

        if phase == "post":
            try:
                if not os.path.isdir(img_dir):
                    return
                used = set()
                for t in (self.tables or []):
                    p = t.get("image_crop_path")
                    if p:
                        used.add(os.path.abspath(p))
                for fname in os.listdir(img_dir):
                    if not fname.lower().endswith(".png"):
                        continue
                    fp = os.path.abspath(os.path.join(img_dir, fname))
                    if fp not in used:
                        try:
                            os.remove(fp)
                        except Exception:
                            pass
            except Exception:
                pass

    # --------------------------------------------------------------------------
    # Features → CSV (pour futur ranker)
    # --------------------------------------------------------------------------
    def _candidate_features(self, question: str, candidates: list[dict]) -> list[dict]:
        ce_evidences, evidences = [], []
        ce_answers = []

        for c in candidates:
            ev = f"{c.get('reason','')} {c.get('evidence','')}".strip()
            evidences.append(ev)
            try:
                ce_evidences.append(float(self._ce_score(question, ev)))
            except Exception:
                ce_evidences.append(0.0)

            answer_text = (c.get("answer_text") or str(c.get("value") or "")) + " " + (c.get("unit") or "")
            try:
                ce_answers.append(float(self._ce_score(question, answer_text.strip())))
            except Exception:
                ce_answers.append(0.0)

        order = list(np.argsort(ce_evidences)[::-1]) if ce_evidences else []
        top1 = ce_evidences[order[0]] if order else 0.0
        top2 = ce_evidences[order[1]] if len(order) > 1 else 0.0
        margin = float(top1 - top2)
        std_ce = float(np.std(ce_evidences)) if len(ce_evidences) > 1 else 0.0

        def close_count(idx, tol=0.5):
            v = candidates[idx].get("value")
            if v is None: return 0
            v = float(v); cnt = 0
            for j, cj in enumerate(candidates):
                if j == idx or cj.get("value") is None: continue
                if abs(float(cj["value"]) - v) <= tol: cnt += 1
            return cnt

        def unit_onehots(u: str):
            u = (u or "").lower()
            return {
                "unit_%": 1.0 if "%" in u else 0.0,
                "unit_ktco2e": 1.0 if "ktco2e" in u else 0.0,
                "unit_tco2e": 1.0 if "tco2e" in u else 0.0,
                "unit_gwh": 1.0 if "gwh" in u else 0.0,
                "unit_mwh": 1.0 if "mwh" in u else 0.0,
                "unit_m3": 1.0 if re.search(r"\bm3\b", u) else 0.0,
                "unit_tonne": 1.0 if re.search(r"\b(tonnes?|t)\b", u) else 0.0,
            }

        feats = []
        for i, c in enumerate(candidates):
            unit = c.get("unit")
            val = c.get("value")
            value_log10 = (math.log10(float(val)) if isinstance(val,(int,float)) and float(val)>0 else -10.0)
            plausible = 1.0 if self._validate_value_generic(unit, val if isinstance(val,(int,float)) else None) else 0.0

            yrs = [int(y) for y in re.findall(r"\b(20\d{2})\b", evidences[i])]
            year_in_evidence = max(yrs) if yrs else 0
            now_year = 2025
            recency_norm = max(0.0, min(1.0, (year_in_evidence - 2017) / (now_year - 2017 + 1))) if year_in_evidence else 0.0

            row = {
                "kpi_source": c.get("source",""),
                "value": val,
                "unit": unit,
                "ce_evidence": float(ce_evidences[i]),
                "ce_answer": float(ce_answers[i]),
                "ce_is_top": 1.0 if (order and i == order[0]) else 0.0,
                "ce_margin_top1_top2": margin,
                "std_ce_over_cands": std_ce,
                "consensus_close_count": close_count(i),
                "year_in_evidence": year_in_evidence,
                "recency_norm": recency_norm,
                "value_log10": value_log10,
                "is_plausible": plausible,
                "doc_type": c.get("doc_type",""),
                "page": c.get("page"),
                "bbox": c.get("bbox"),
                "evidence_preview": (c.get("evidence","") or c.get("reason",""))[:200],
                "y": ""
            }
            row.update(unit_onehots(unit))
            feats.append(row)
        return feats

    def _winner_meta(self, kpi_entry: dict) -> dict:
        w = (kpi_entry or {}).get("winner") or {}
        return {
            "source": w.get("source"),
            "unit": w.get("unit"),
            "page": w.get("page"),
            "doc_type": w.get("doc_type"),
            "answer_text": w.get("answer_text"),
            "evidence": w.get("evidence"),
            "table_fp": w.get("table_fp") or w.get("fingerprint"),
            "image_crop_path": w.get("image_crop_path"),
            "judge_confidence": w.get("judge_confidence"),
            "judge_reason": w.get("judge_reason"),
        }


    def _append_csv_rows(self, rows: list[dict], out_csv: str):
        if not rows: return
        header = []
        seen = set()
        for r in rows:
            for k in r.keys():
                if k not in seen:
                    seen.add(k); header.append(k)
        need_header = not os.path.exists(out_csv)
        with open(out_csv, "a", newline="", encoding="utf-8") as f:
            w = csv.DictWriter(f, fieldnames=header)
            if need_header: w.writeheader()
            for r in rows: w.writerow(r)


    def _generate_text(self, system_msg: str, user_msg: str, max_new_tokens: int | None = None) -> str:
        """
        Génère via ChatML. Laisse le modèle raisonner (<think>), puis on le strippera.
        """
        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user",   "content": user_msg},
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(self.device)

        gen_kwargs = dict(
            max_new_tokens=int(max_new_tokens or self.max_new_tokens_json),
            no_repeat_ngram_size=0,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            do_sample=False,
        )

        with torch.inference_mode():
            out = self.model.generate(**inputs, **gen_kwargs)

        gen_ids = out[:, inputs["input_ids"].shape[1]:]
        raw = self.tokenizer.decode(gen_ids[0], skip_special_tokens=False).strip()
        # on garde le raw (avec <think>) pour pouvoir le nettoyer après
        return raw


    def _strip_think(self, text: str) -> str:
        if not text:
            return ""
        s = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL|re.IGNORECASE)
        # si une balise ouvrante subsiste sans fermeture, on coupe tout ce qui suit
        low = s.lower()
        if "<think>" in low and "</think>" not in low:
            s = s[:low.rfind("<think>")].strip()
        # enlève les fences ```json ... ```
        s = re.sub(r"```[\w-]*\n?|```", "", s).strip()
        return s


    def _summarize_controversies(self, contro_context: str) -> str:
        """
        Résume les controverses en 1–2 phrases en utilisant la méthode de guidage standard.
        """
        if not contro_context or not contro_context.strip():
            return ""

        format_json = '{"controversy_comment": "<1–2 concise sentences>"}'
        # Prompt de "guidage"
        system_msg = (
            "You are an ESG analyst. First, think in a <think> block to identify the key controversy. "
            "Then, provide the summary as a single, compact JSON object and nothing else."
        )
        user_msg = f"CONTRO_CONTEXT:\n{contro_context}\n\nFINAL ANSWER FORMAT:\n{format_json}"

        # Logique de génération standardisée
        raw_output = self._generate_text(system_msg, user_msg, max_new_tokens=500)
        final_answer = self._strip_think(raw_output)

        # Essaie d'extraire le commentaire du JSON
        js = self._extract_first_json_object(final_answer)
        if js:
            try:
                obj = self._json_loads_lenient(js)
                if obj and isinstance(obj, dict):
                    comment = obj.get("controversy_comment", "").strip()
                    if comment:
                        return comment
            except Exception:
                pass

        # Fallback: si pas de JSON valide, on retourne la réponse nettoyée comme meilleur effort
        return final_answer

    # --------------------------------------------------------------------------
    # Judge (Donut vs RAG)
    # --------------------------------------------------------------------------

    def _llm_judge_candidates(self, question: str, unit_hint: str | None, candidates: list[dict]) -> dict | None:
        """
        LLM-Judge (DeepSeek) : choisit le meilleur candidat parmi VQA (Donut) et RAG+LLM.
        - Réduit la liste pour éviter la verbosité.
        - Rend un JSON: {"winner_index": int, "confidence": float, "reason": "..."}.
        - Fallback: retourne None si parsing impossible (le caller gère le repli heuristique).
        """
        if not candidates:
            return None

        if len(candidates) == 1 and self._validate_value_generic(candidates[0].get('unit'), candidates[0].get('value')):
            return candidates[0]

        # 1) Réduction (max 4) : garder au plus 1 par source + top CE sur le reste
        #    -> priorité: Donut + RAG ; sinon top-CE.
        prio = []
        seen_src = set()
        # essaie d'abord de garder 1 Donut et 1 RAG
        for src in ("Donut-DocVQA", "RAG+LLM"):
            for i, c in enumerate(candidates):
                if c.get("source") == src and src not in seen_src:
                    prio.append((i, c)); seen_src.add(src); break
        # complète avec top-CE evidence sur le reste
        rest = [(i, c) for i, c in enumerate(candidates) if c.get("source") not in seen_src]
        # score CE evidence rapide
        scored_rest = []
        for i, c in rest:
            ev = " ".join([(c.get("answer_text") or ""), (c.get("reason") or ""), (c.get("evidence") or "")]).strip()
            try:
                s = float(self._ce_score(question, ev))
            except Exception:
                s = 0.0
            scored_rest.append((s, i, c))
        scored_rest.sort(key=lambda x: x[0], reverse=True)
        for s, i, c in scored_rest[:2]:
            prio.append((i, c))
        # borne finale
        prio = prio[:4]

        # 2) Vue compacte envoyée au LLM
        def _trim(x, n):
            x = (x or "").replace("\n", " ").replace("\r", " ").strip()
            return (x[:n] + "…") if len(x) > n else x

        unit_hint_norm = (unit_hint or "").strip()
        cand_view = []
        for local_idx, (original_idx, c) in enumerate(prio):
                cand_view.append({ "idx": local_idx,
                                  "src": c.get("source"),
                                   "value": c.get("value"),
                                   "unit": c.get("unit"),
                                   "year": c.get("year"),
                                   "answer": _trim(c.get("answer_text"), 250) })

        # 3) Prompt LLM — consignes strictes + JSON unique
        format_json = '{"winner_index": 0, "confidence": 0.9, "reason": "Reason for choice."}'
        system_msg = (
            "You are a meticulous ESG data auditor. Your task is to analyze several data candidates and select the single most accurate one. "
            "First, think step-by-step inside a <think> block to evaluate the candidates based on the rules. "
            "Then, after the </think> block, provide your final answer as a single, compact JSON object and nothing else."
        )
        rules = [
            "Prefer the most recent year.",
            "Evidence from a table (VQA) with clear context is more reliable than text snippets (RAG).",
            "Ensure the evidence semantically matches the question.",
            f"A unit hint of '{unit_hint}' was provided; this is a strong indicator." if unit_hint else "No unit hint was provided."
        ]
        user_msg = (
            f"QUESTION: \"{question}\"\n\n"
            f"RULES:\n- " + "\n- ".join(rules) + "\n\n"
            f"CANDIDATES (use 'idx' for your choice):\n{json.dumps(cand_view, indent=2, ensure_ascii=False)}\n\n"
            f"Based on your analysis, provide your final choice in the following JSON format ONLY:\n{format_json}"
        )

        # 3. Générer, nettoyer, et parser
        raw_output = self._generate_text(system_msg, user_msg, max_new_tokens=300)
        final_answer = self._strip_think(raw_output)
        js_str = self._extract_first_json_object(final_answer)
        if not js_str: return None

        try:
            obj = self._json_loads_lenient(js_str)
            if not isinstance(obj, dict) or "winner_index" not in obj: return None
            wi = int(obj["winner_index"])
            if not (0 <= wi < len(prio)): return None
        except (ValueError, TypeError):
            return None

        # 4. Retourner le candidat gagnant au format standard, enrichi par le jugement
        _original_idx, winner = prio[wi]
        winner["judge_reason"] = obj.get("reason", "")
        winner["judge_confidence"] = float(obj.get("confidence", 0.0))
        return winner

    # --------------------------------------------------------------------------
    # MAIN
    # --------------------------------------------------------------------------
    def run_full_analysis(self, document_paths: dict, force_preprocessing: bool = True):
        import hashlib, os, gc, concurrent.futures as cf, torch

        # 1) Clé de cache par PDF principal
        main_report_path = document_paths.get("sustainability") or document_paths.get("controversies") or "default"
        cache_prefix = hashlib.sha1(os.path.basename(main_report_path).encode()).hexdigest()[:10]

        # 2) Fichiers de cache
        self.cache_files = {
            "chunks": os.path.join(self.cache_dir, f"{cache_prefix}_chunks.json"),
            "tables": os.path.join(self.cache_dir, f"{cache_prefix}_tables.pkl"),
            "embeddings": os.path.join(self.cache_dir, f"{cache_prefix}_embeddings.npy"),
            "bm25": os.path.join(self.cache_dir, f"{cache_prefix}_bm25.pkl")
        }

        # 3) Purge ciblée (cache + crops images/<cache_prefix>) si forcing
        self._purge_cache_artifacts(cache_prefix, force=force_preprocessing, phase="pre")

        # 4) Charger du cache ou refaire tout le préprocess
        if (not force_preprocessing) and all(os.path.exists(p) for p in self.cache_files.values()):
            print("\n--- Found cached data. Loading from disk... ---")
            self._load_from_cache()
        else:
            print("\n--- No valid cache found or preprocessing forced. Running full preprocessing... ---")
            self._preprocess_and_save_to_cache(document_paths)

        # Prune PNG orphelins
        self._purge_cache_artifacts(cache_prefix, force=False, phase="post")



    # --------------------------------------------------------------------------
    # Divers utilitaires JSON
    # --------------------------------------------------------------------------
    def _parse_llm_output(self, llm_output_str: str) -> dict:
        llm_output_str = self._strip_think(llm_output_str or "")
        json_candidate = self._extract_first_json_object(llm_output_str)
        data = self._json_loads_lenient(json_candidate) if json_candidate else None

        current_val = None
        reasoning = "JSON parsing failed or no JSON found."
        is_present = False
        unit = None
        year = None

        if isinstance(data, dict):
            reasoning = data.get("reasoning", reasoning)
            current_val = self._coerce_float(data.get("current_value"))
            unit = data.get("unit")
            y = data.get("year")
            if isinstance(y, (int, float)): year = int(y)
            elif isinstance(y, str):
                m = re.search(r"\b20\D?(\d{2})\b", y);  # capte "20,24" etc.
                if m: year = int("20" + m.group(1))
            is_present = bool(data.get("is_present", current_val is not None))

        # Fallback si pas de JSON exploitable: tenter de repêcher la valeur dans le texte
        if current_val is None:
            rescued = self._guess_current_from_text(llm_output_str)
            if rescued is not None:
                current_val = rescued
                is_present = True

        return {"current": current_val, "reasoning": reasoning, "is_present": is_present, "unit": unit, "year": year}


    def _extract_first_json_object(self, text: str) -> str | None:
        start = text.find("{")
        if start == -1: return None
        depth = 0; in_str = False; esc = False
        for i in range(start, len(text)):
            ch = text[i]
            if ch == '"' and not esc:
                in_str = not in_str
            esc = (ch == "\\") and not esc if in_str else False
            if in_str: continue
            if ch == "{": depth += 1
            elif ch == "}":
                depth -= 1
                if depth == 0: return text[start : i + 1]
        return None

    def _json_loads_lenient(self, s: str) -> dict | None:
        if not s:
            return None
        s1 = s.strip()
        # 1) virer les fences
        s1 = re.sub(r"```[\w-]*\n?|```", "", s1)
        # 2) guillemets typographiques → ASCII
        s1 = s1.replace("“", '"').replace("”", '"').replace("’", "'")
        # 3) quoter les clés non-quotées: key: val -> "key": val
        #    (heuristique raisonnable hors chaînes)
        s1 = re.sub(r'(?<!")\b([A-Za-z_][A-Za-z0-9_]*)\b\s*:', r'"\1":', s1)
        # 4) supprimer virgules avant } ou ]
        s1 = re.sub(r",\s*(?=[}\]])", "", s1)
        # 5) Python→JSON
        s1 = s1.replace("None", "null").replace("True", "true").replace("False", "false")
        try:
            return json.loads(s1)
        except Exception:
            return None


    def _guess_current_from_text(self, text: str) -> float | None:
        m = re.search(r"current[_\s-]?value\s*[:=]\s*([0-9][0-9\.,_ ]*)", text, flags=re.IGNORECASE)
        if m: return self._coerce_float(m.group(1))
        m = re.search(r"(current year|current|latest|most recent)\D{0,40}([0-9][0-9\.,_ ]*)", text, flags=re.IGNORECASE)
        if m: return self._coerce_float(m.group(2))
        fy = re.findall(r"FY\s*([12][0-9]{3})", text, flags=re.IGNORECASE)
        if fy:
            fy_latest = max(map(int, fy))
            window = 120
            for m in re.finditer(r"FY\s*([12][0-9]{3})", text, flags=re.IGNORECASE):
                if int(m.group(1)) == fy_latest:
                    start = max(0, m.start() - window); end = min(len(text), m.end() + window)
                    near = text[start:end]; n = re.search(r"([0-9][0-9\.,_ ]*)", near)
                    if n:
                        val = self._coerce_float(n.group(1))
                        if val is not None: return val
        m = re.search(r"([0-9][0-9\.,_ ]*)\s*(ktco2e|twh|%|m3|tonnes|tons|t)\b", text, flags=re.IGNORECASE)
        if m: return self._coerce_float(m.group(1))
        m = re.search(r"\b([0-9][0-9\.,_]{2,})\b", text)
        if m: return self._coerce_float(m.group(1))
        return None




# IMPORT .py

In [None]:
import os

PROJECT_ROOT = "/content/drive/MyDrive/esg_rating_project"
# D) Se placer dans le répertoire du projet et installer votre code
print(f"\n--- Installation du package local 'esg_rating_engine' depuis {PROJECT_ROOT} ---")
os.chdir(PROJECT_ROOT)
# L'option '-q' rend l'installation silencieuse
!pip install -q -e .

print("\n\n✅ --- ENVIRONNEMENT DE PROJET PRÊT --- ✅")

# RUN DOC PREPROCESSING

In [None]:
# ==============================================================================
# CELLULE 4: LANCEMENT DE L'ANALYSE
# ==============================================================================
import os
import json
import sys
import traceback

# S'assurer que le répertoire de travail est bien celui du projet
PROJECT_ROOT = "/content/drive/MyDrive/esg_rating_project"
os.chdir(PROJECT_ROOT)

# Importer la classe principale maintenant que le package est installé
from src.esg_engine import ESGRatingEngine

print("=============================================")
print("=           PREPROCESSING DOCUMENT          =")
print("=============================================")

# Définir les chemins des rapports à analyser
reports_directory = "data/reports_to_analyze"
company_name = "Strauss Group"
sustainability_report_path = os.path.join(reports_directory, "PPF.pdf") # Assurez-vous que le fichier est sur votre Drive
controversies_report_path = os.path.join(reports_directory, "PPF controversy.pdf")

# Logique de lancement...
document_paths = {'sustainability': sustainability_report_path, 'financial': None, 'controversies': controversies_report_path}


try:
    engine = ESGRatingEngine()
    final_rating = engine.run_full_analysis(document_paths)
    print("\n\n--- RAPPORT FINAL ---")
    print(json.dumps(final_rating, indent=2))

except Exception as e:
    print("\nUNE ERREUR CRITIQUE EST SURVENUE DURANT L'ANALYSE.")
    traceback.print_exc()

# KPI Extraction

In [None]:
# [1/3] === RAG + Extraction → enregistrement structuré des résultats (results_df) ===
import os, re, json, hashlib, numpy as np, faiss
import pandas as pd
import numpy as np
import sys
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

from src.esg_engine import ESGRatingEngine
from src.kpi_config import KPI_FRAMEWORK
from datetime import datetime

# -------------------- PARAMS (reprise) --------------------
ALL_KPIS = [
    (pillar, kpi_name)
    for pillar, kpis in KPI_FRAMEWORK.items()
    for kpi_name in kpis.keys()
]
TOP_K_RAG_CONTEXTS = 3
GEN_BUDGET = 512
SHOW_RAW_FULL = False
# ----------------------------------------------------------

def _auto_paths():
    root = os.environ.get("ESG_PROJECT_ROOT", "/content/drive/MyDrive/esg_rating_project")
    reports_dir = os.path.join(root, "data", "reports_to_analyze")
    sust = os.path.join(reports_dir, "PPF.pdf")
    contro = os.path.join(reports_dir, "PPF controversy.pdf")
    return {"sustainability": sust if os.path.exists(sust) else None,
            "financial": None, "controversies": contro if os.path.exists(contro) else None}

print("CHEMINS DES DOCUMENTS:", _auto_paths())
eng = ESGRatingEngine()

# ---- Cache ----
main_report_path = _auto_paths().get("sustainability") or "default"
cache_prefix = hashlib.sha1(os.path.basename(main_report_path).encode()).hexdigest()[:10]
eng.cache_files = {
    "chunks": os.path.join(eng.cache_dir, f"{cache_prefix}_chunks.json"),
    "tables": os.path.join(eng.cache_dir, f"{cache_prefix}_tables.pkl"),
    "embeddings": os.path.join(eng.cache_dir, f"{cache_prefix}_embeddings.npy"),
    "bm25": os.path.join(eng.cache_dir, f"{cache_prefix}_bm25.pkl"),
}
if all(os.path.exists(p) for p in eng.cache_files.values()):
    print("[DEBUG] Chargement depuis le cache…"); eng._load_from_cache()
else:
    raise RuntimeError("Cache manquant. Lancez preprocessing(..., force_preprocessing=True) d'abord.")

# ---- Infos LLM (lecture seule)
tok, model = eng.tokenizer, eng.model
tpl = (tok.chat_template or "")
print("\n[LLM/Tokenizer INFO]")
print("  • Model:", type(model).__name__)
print("  • Device:", eng.device)
print("  • engine.max_new_tokens_json:", eng.max_new_tokens_json)
print("  • eos_token_id:", tok.eos_token_id, "| pad_token_id:", tok.pad_token_id)
print("  • ChatML présent ?",
      ("<|im_start|>" in str(tpl) and "<|im_end|>" in str(tpl)))

# ============ Format de sortie LLM (inchangé) ============
JSON_FORMAT = '{"reasoning":"<short evidence>", "current_value": 1234.5, "unit": "tCO2e|tons|m3|null", "year": 2024, "is_present": true}'
STRICT_RULES = (
    "Rules:\n"
    "1) If you cannot find an explicit number for the QUESTION in CONTEXT, return exactly: "
    '{"reasoning":"not found", "current_value": null, "unit": null, "year": null, "is_present": false}\n'
    "2) Copy digits **verbatim** from CONTEXT. Do not reorder, do not guess. If you write 15,571 in reasoning, current_value must be 15571 (commas/spaces removed).\n"
    "3) Year must be a 4-digit year **seen in CONTEXT near the value**.\n"
    "4) Output ONE JSON object only. First char must be '{'. No extra text."
)
SYS_THINK = (
    "You are a precise data extraction expert.\n"
    "First, think step-by-step inside a <think> block (≤60 tokens) to locate the number *in the CONTEXT*.\n"
    "Then output exactly ONE JSON object and nothing else.\n" + STRICT_RULES
)
def _strip_think(s: str) -> str:
    return re.sub(r"<think>.*?</think>", "", s, flags=re.DOTALL | re.IGNORECASE).strip()

_num_pat  = re.compile(r"\b\d{1,3}(?:[ ,.\u00A0]\d{3})*(?:[.,]\d+)?\b")
_year_pat = re.compile(r"\b20\d{2}\b")
def _numbers_in_ctx(ctx):
    spans = _num_pat.findall(ctx)
    def norm(x):
        return float(re.sub(r"[,\s\u00A0]", "", x).replace(" ", ""))
    values = []
    for s in spans:
        try:
            v = norm(s)
            values.append((s, v))
        except:
            pass
    return values
def _sanity_check(ctx, parsed):
    ok = True
    reasons = []
    cv = parsed.get("current")
    yr = parsed.get("year")
    unit = parsed.get("unit")
    if cv is None:
        return False, ["no current_value"]
    vals = _numbers_in_ctx(ctx)
    present_nums = {v for _, v in vals}
    if cv not in present_nums:
        if unit == "ktCO2e" and (cv*1000) in present_nums:
            pass
        elif unit == "tCO2e" and (cv/1000) in present_nums:
            pass
        else:
            ok = False; reasons.append("value_not_in_context")
    if unit not in (None, "tCO2e", "ktCO2e"):
        ok = False; reasons.append("bad_unit")
    if yr is not None and not _year_pat.search(ctx):
        ok = False; reasons.append("year_not_in_context")
    return ok, reasons

def _run(engine, ctx_text, question, sys_msg, label):
    user_msg = f"CONTEXT:\n{ctx_text}\n\nQUESTION: {question}\n\nFINAL ANSWER FORMAT:\n{JSON_FORMAT}"
    raw = engine._generate_text(sys_msg, user_msg, max_new_tokens=GEN_BUDGET)
    final = _strip_think(raw)
    js = engine._extract_first_json_object(final)
    parsed = engine._parse_llm_output(js or final)

    preview = raw if SHOW_RAW_FULL else (raw[:240] + "…") if len(raw) > 240 else raw
    print(f"\n[{label}] RAW:", preview or "(empty)")
    print(f"[{label}] found_json={bool(js)}")
    if js: print(f"[{label}] JSON:", js)
    print(f"[{label}] PARSED:", parsed)

    ok_digits, why = _sanity_check(ctx_text, parsed)
    if not ok_digits:
        print(f"[{label}] SANITY: FAIL ->", why)
    else:
        print(f"[{label}] SANITY: OK")
    val = parsed.get("current")
    ok_schema = (val is not None) and eng._validate_value_generic(None, val)
    return (ok_digits and ok_schema), parsed

# -------------------- RUN & CAPTURE --------------------
rows = []
docs_paths = _auto_paths()
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")

for pillar, kpi_name in ALL_KPIS:
    kpi_cfg = KPI_FRAMEWORK[pillar][kpi_name]
    question_long = kpi_cfg.get("question", kpi_name)
    search_query  = kpi_cfg.get("search_query", question_long)

    print(f"\n===== KPI: {pillar} - {kpi_name} =====")
    print("Question (extraction):", question_long)
    print("Search Query (retrieval):", search_query)

    allowed_docs = {"controversies"} if pillar == "C" else {"sustainability"}

    indices = []
    if eng.bm25_index is not None and eng.chunks:
        toks = [t for t in search_query.lower().split() if t]
        bm_scores = eng.bm25_index.get_scores(toks)
        bm_top = np.argsort(bm_scores)[::-1][:eng.rag_top_candidates]
        indices.extend(bm_top.tolist())

    if eng.faiss_index is not None and eng.all_embeddings is not None:
        import faiss
        q = eng.embedding_model.encode([search_query], convert_to_numpy=True)
        q = np.ascontiguousarray(q, dtype=np.float32)
        faiss.normalize_L2(q)
        _, fa_top = eng.faiss_index.search(q, eng.rag_top_candidates)
        indices.extend(fa_top[0].tolist())

    # dédup & filtre type doc
    indices = list(dict.fromkeys(indices))
    if allowed_docs:
        indices = [i for i in indices if eng.chunk_doc_types[i] in allowed_docs]

    if not indices:
        print("→ Aucun contexte admissible après filtrage (docs autorisés).")
        # On logge malgré tout un "candidat vide" pour traçabilité
        rows.append({
            "pillar": pillar, "kpi": kpi_name, "question": question_long, "search_query": search_query,
            "variant": "with-think / search_query", "context_rank": None,
            "chunk_id": None, "doc_type": ",".join(sorted(allowed_docs)), "doc_path": None,
            "ok": False, "sanity_ok": False, "sanity_why": "no_context",
            "value": None, "unit": None, "year": None, "is_present": False,
            "reasoning": "no admissible context after filtering", "ctx_preview": ""
        })
        continue

    top_ids = eng._rerank_doc_ids(search_query, indices, top_k=TOP_K_RAG_CONTEXTS)
    print("\n--- [DIAGNOSTIC RAG] Contexte sélectionné ---")
    print(f"-> {len(top_ids)} chunk(s) : {top_ids}")
    for i, cid in enumerate(top_ids, 1):
        doc_t = eng.chunk_doc_types[cid]
        preview = eng.chunks[cid][:500].replace("\n", " ")
        print(f"  [Chunk {i} | id={cid} | doc={doc_t}] {preview}…")

    # ---- Lancement (variante with-think / search_query) ----
    for rank, idx in enumerate(top_ids, 1):
        ctx = eng.chunks[idx]
        print(f"\n\n===== CONTEXTE #{rank} (id={idx}, doc={eng.chunk_doc_types[idx]}) =====")
        print("PREVIEW:", (ctx[:300] + "…") if len(ctx) > 300 else ctx)

        ok, parsed = _run(eng, ctx, search_query, SYS_THINK, "with-think / search_query")
        sanity_ok, why = _sanity_check(ctx, parsed)

        rows.append({
            "pillar": pillar,
            "kpi": kpi_name,
            "question": question_long,
            "search_query": search_query,
            "variant": "with-think / search_query",
            "context_rank": rank,
            "chunk_id": idx,
            "doc_type": eng.chunk_doc_types[idx],
            "doc_path": docs_paths.get(eng.chunk_doc_types[idx]),
            "ok": bool(ok),
            "sanity_ok": bool(sanity_ok),
            "sanity_why": ";".join(why) if not sanity_ok else "",
            "value": parsed.get("current"),
            "unit": parsed.get("unit"),
            "year": parsed.get("year"),
            "is_present": parsed.get("is_present"),
            "reasoning": parsed.get("reasoning"),
            "ctx_preview": ctx[:600].replace("\n", " "),
        })

# ---- DataFrame & sauvegarde ----
results_df = pd.DataFrame(rows)
RES_DIR = os.path.join(eng.cache_dir, "human_in_the_loop")
os.makedirs(RES_DIR, exist_ok=True)
base_fname = f"extractions_{cache_prefix}_{RUN_ID}"
csv_path = os.path.join(RES_DIR, base_fname + ".csv")
parq_path = os.path.join(RES_DIR, base_fname + ".parquet")
results_df.to_csv(csv_path, index=False)
try:
    results_df.to_parquet(parq_path, index=False)
except Exception as e:
    print(f"[WARN] Parquet non écrit ({e}). CSV disponible.")

print(f"\n[OK] {len(results_df)} candidats enregistrés →")
print("     ", csv_path)
print("     ", parq_path if os.path.exists(parq_path) else "(parquet non écrit)")

# Garder en mémoire pour la cellule UI
print("\nresults_df.head():")
display(results_df.head())


# Human in the loop

In [None]:
#  [2/3] === UI de revue & sélection par KPI → DataFrame features prêt pour le modèle ===
import os, json
import numpy as np
import pandas as pd
import ipywidgets as w
from IPython.display import display, HTML, clear_output

# --------- Charger les résultats (si non présents en mémoire) ----------
if "results_df" not in globals() or results_df is None or results_df.empty:
    # On tente de reprendre le dernier fichier d'extractions
    import glob
    RES_DIR = os.path.join(eng.cache_dir, "human_in_the_loop")
    candidates = sorted(glob.glob(os.path.join(RES_DIR, "extractions_*.csv")))
    assert candidates, "Aucun fichier d'extractions trouvé. Exécutez d'abord la cellule [1/3]."
    latest_csv = candidates[-1]
    results_df = pd.read_csv(latest_csv)
    print(f"[INFO] results_df chargé depuis {latest_csv} ({len(results_df)} lignes)")

# --------- Colonnes features attendues par le modèle (à adapter si besoin) ----------
FEATURE_COLS_TARGET = [
    "ISSUER_CNTRY_DOMICILE",
    "CARBON_EMISSIONS_SCOPE_1",
    "CARBON_EMISSIONS_SCOPE_2",
    "CARBON_EMISSIONS_SCOPE_3",
    "PCT_NONRENEW_CONSUMP_PROD",
    "HAZARD_WASTE_METRIC_TON",
    "WATER_FRESH_CON",
    "EMP_TURNOVER_ANNUAL_PCT_RECENT",
    "TRIR",
    "HLTH_SAFETY_FATALITIES_YEAR_RECENT",
    "WOMEN_EXEC_MGMT_RECENT",
    "FEMALE_DIRECTORS_PCT",
    "BOARD_INDEP_PCT",
    "PROF_DEV_TRAIN_HOURS_PER_EMP_RECENT",
    "ENVIRONMENT_CONTROVERSY_SCORE",
    "CUSTOMER_CONTROVERSY_SCORE",
    "HUMAN_RIGHTS_CONTROVERSY_SCORE",
    "LABOR_RIGHTS_CONTROVERSY_SCORE",
    "GOVERNANCE_CONTROVERSY_SCORE",
    "IVA_INDUSTRY",
]

# --------- Mapping par défaut (à ajuster à vos clés KPI_FRAMEWORK) ----------
# Astuce : la UI permet de re-mapper KPI→feature au besoin
KPI_TO_FEATURE_COL = {
    "scope_1_emissions": "CARBON_EMISSIONS_SCOPE_1",
    "scope_2_emissions": "CARBON_EMISSIONS_SCOPE_2",
    "scope_3_emissions": "CARBON_EMISSIONS_SCOPE_3",
    "hazardous_waste": "HAZARD_WASTE_METRIC_TON",
    "renewable_energy_pct":"PCT_NONRENEW_CONSUMP_PROD",
    "hazardous_waste": "HAZARD_WASTE_METRIC_TON",
    "water_fresh_consumption": "WATER_FRESH_CON",
    "employee_turnover_rate": "EMP_TURNOVER_ANNUAL_PCT_RECENT",
    "trir": "TRIR",
    "health_safety_fatalities": "HLTH_SAFETY_FATALITIES_YEAR_RECENT",
    "women_exec_mgmt": "WOMEN_EXEC_MGMT_RECENT",
    "female_directors_pct": "FEMALE_DIRECTORS_PCT",
    "board_independence_pct": "BOARD_INDEP_PCT",
    "training_hours_per_emp": "PROF_DEV_TRAIN_HOURS_PER_EMP_RECENT",
    "env_controversy_score": "ENVIRONMENT_CONTROVERSY_SCORE",
    "customers_controversy_score": "CUSTOMER_CONTROVERSY_SCORE",
    "human_rights_community_controversy_score": "HUMAN_RIGHTS_CONTROVERSY_SCORE",
    "labor_rights_supply_chain_controversy_score": "LABOR_RIGHTS_CONTROVERSY_SCORE",
    "governance_controversy_score": "GOVERNANCE_CONTROVERSY_SCORE",
}

# Fonction heuristique simple si une clé n'est pas mappée
def guess_feature_col_from_kpi(k):
    k_low = str(k).lower()
    if "scope_1" in k_low: return "CARBON_EMISSIONS_SCOPE_1"
    if "scope_2" in k_low: return "CARBON_EMISSIONS_SCOPE_2"
    if "scope_3" in k_low: return "CARBON_EMISSIONS_SCOPE_3"
    if "renewable" in k_low: return "PCT_NONRENEW_CONSUMP_PROD"
    if "waste" in k_low: return "HAZARD_WASTE_METRIC_TON"
    if "water" in k_low: return "WATER_FRESH_CON"
    if "turnover" in k_low: return "EMP_TURNOVER_ANNUAL_PCT_RECENT"
    if "trir" in k_low: return "TRIR"
    if "fatal" in k_low: return "HLTH_SAFETY_FATALITIES_YEAR_RECENT"
    if "women" in k_low and ("exec" in k_low or "management" in k_low): return "WOMEN_EXEC_MGMT_RECENT"
    if "female" in k_low and "director" in k_low: return "FEMALE_DIRECTORS_PCT"
    if "indep" in k_low: return "BOARD_INDEP_PCT"
    if "train" in k_low: return "PROF_DEV_TRAIN_HOURS_PER_EMP_RECENT"
    if "env" in k_low and "controvers" in k_low: return "ENVIRONMENT_CONTROVERSY_SCORE"
    if "customer" in k_low and "controvers" in k_low: return "CUSTOMER_CONTROVERSY_SCORE"
    if "human" in k_low and "right" in k_low: return "HUMAN_RIGHTS_CONTROVERSY_SCORE"
    if "labor" in k_low or "labour" in k_low: return "LABOR_RIGHTS_CONTROVERSY_SCORE"
    if "governance" in k_low and "controvers" in k_low: return "GOVERNANCE_CONTROVERSY_SCORE"
    return None

# --------- Préparation des groupes KPI ---------
results_df["kpi_key"] = results_df["kpi"].astype(str)
kpi_list = (
    results_df[["pillar","kpi_key"]]
    .drop_duplicates()
    .sort_values(["pillar","kpi_key"])
    .values.tolist()
)
kpi_options = [f"{p} • {k}" for p, k in kpi_list]
kpi_index_by_label = {f"{p} • {k}": (p,k) for p,k in kpi_list}

# --------- Widgets header (métadonnées) ---------
HTML("""
<style>
.hitl-card { border:1px solid #e5e7eb; border-radius:10px; padding:12px; margin:8px 0; background:#fafafa; }
.hitl-title { font-weight:600; font-size:16px; }
.small { color:#6b7280; font-size:12px; }
.ok-badge { background:#10b981; color:white; padding:2px 8px; border-radius:999px; font-size:11px; }
.warn-badge { background:#f59e0b; color:white; padding:2px 8px; border-radius:999px; font-size:11px; }
.err-badge { background:#ef4444; color:white; padding:2px 8px; border-radius:999px; font-size:11px; }
</style>
""")

issuer_input = w.Text(description="Entreprise", placeholder="Nom de l'émetteur", layout=w.Layout(width="40%"))
country_input = w.Text(description="ISSUER_CNTRY_DOMICILE", placeholder="FR, DE, US…", layout=w.Layout(width="40%"))
sector_input  = w.Text(description="IVA_INDUSTRY *", placeholder="Secteur (obligatoire)", layout=w.Layout(width="40%"))

header_box = w.HBox([issuer_input, country_input, sector_input])

# --------- Widgets KPI selector ---------
kpi_select = w.Dropdown(options=kpi_options, description="KPI", layout=w.Layout(width="60%"))
missing_value = w.Checkbox(value=False, description="Valeur manquante (NaN)", disabled=True)
# --------- Zone table candidats ---------
table_out = w.Output(layout={"border":"1px solid #e5e7eb"})

# --------- Widgets sélection / saisie ---------
radio_candidates = w.RadioButtons(description="Candidats", options=[], layout=w.Layout(width="100%"))
use_custom = w.Checkbox(value=False, description="Saisir une valeur manuellement")
val_input  = w.FloatText(description="Valeur", disabled=True)
unit_input = w.Text(description="Unité", disabled=True)
year_input = w.IntText(description="Année", disabled=True)
isp_input  = w.Checkbox(description="Présence (is_present)", value=True, disabled=True)
reas_input = w.Textarea(description="Reasoning", layout=w.Layout(width="100%", height="70px"), disabled=True)

# mapping KPI -> colonne feature
feature_dropdown = w.Dropdown(
    options=FEATURE_COLS_TARGET,
    description="Colonne feature",
    layout=w.Layout(width="50%")
)

save_btn = w.Button(description="Valider ce KPI", button_style="success", icon="check")
status_out = w.Output()

# --------- Résumé & export ---------
summary_out = w.Output()
export_btn = w.Button(description="Exporter ➜ df_features_ready", button_style="primary", icon="save")
audit_out = w.Output()

# --------- État sélectionné en mémoire ---------
selected_map = {}  # (pillar,kpi_key) -> dict(value, unit, year, is_present, reasoning, feature_col, source_row)

def _rank_candidates(dfk):
    # Tri : ok DESC, sanity_ok DESC, is_present DESC, context_rank ASC
    cols_present = {c for c in ["ok","sanity_ok","is_present","context_rank"] if c in dfk.columns}
    if {"ok","sanity_ok","is_present","context_rank"} <= cols_present:
        return dfk.sort_values(["ok","sanity_ok","is_present","context_rank"],
                               ascending=[False, False, False, True]).reset_index(drop=True)
    return dfk.reset_index(drop=True)

def refresh_candidates(_=None):
    status_out.clear_output()
    table_out.clear_output()
    radio_candidates.options = []
    val_input.disabled = unit_input.disabled = year_input.disabled = isp_input.disabled = reas_input.disabled = not use_custom.value
    missing_value.disabled = (not use_custom.value)  # AJOUT

    p,k = kpi_index_by_label[kpi_select.value]
    dfk = results_df[(results_df["pillar"]==p) & (results_df["kpi_key"]==k)].copy()
    if dfk.empty:
        with table_out:
            display(HTML(f"<div class='hitl-card'><span class='err-badge'>VIDE</span> Aucun candidat pour {p} • {k}</div>"))
        return

    dfk = _rank_candidates(dfk)
    # Affichage compact
    with table_out:
        display(HTML(f"<div class='hitl-card'><div class='hitl-title'>Candidats pour <b>{p} • {k}</b></div>"
                     f"<div class='small'>Triés par qualité puis rang de contexte</div></div>"))
        display(dfk[["context_rank","value","unit","year","is_present","ok","sanity_ok","doc_type","chunk_id","reasoning","ctx_preview"]])

    # Options radio
    opts = []
    for idx, row in dfk.iterrows():
        tag = "OK" if row.get("ok") else ("WARN" if row.get("sanity_ok") else "NOK")
        label = f"[{tag}] rank={row.get('context_rank')} | {row.get('value')} {row.get('unit')} ({row.get('year')}) — {row.get('doc_type')}# {row.get('chunk_id')}"
        opts.append((label, idx))
    radio_candidates.options = opts
    if opts:
        radio_candidates.value = opts[0][1]

    # Feature mapping par défaut
    default_feature = KPI_TO_FEATURE_COL.get(k) or guess_feature_col_from_kpi(k) or FEATURE_COLS_TARGET[0]
    feature_dropdown.value = default_feature

def on_use_custom_change(change):
    enabled = change["new"]
    val_input.disabled  = unit_input.disabled = year_input.disabled = isp_input.disabled = reas_input.disabled = (not enabled)
    missing_value.disabled = (not enabled)  # AJOUT
    if enabled:
        missing_value.value = False  # AJOUT (reset propre)
use_custom.observe(on_use_custom_change, names="value")


def on_missing_change(change):
    is_missing = change["new"]
    val_input.disabled = is_missing or (not use_custom.value)

missing_value.observe(on_missing_change, names="value")


def on_save_clicked(_):
    p,k = kpi_index_by_label[kpi_select.value]
    dfk = results_df[(results_df["pillar"]==p) & (results_df["kpi_key"]==k)].copy()
    if dfk.empty:
        return
    dfk = _rank_candidates(dfk)

    if not use_custom.value:
        sel_idx = radio_candidates.value
        row = dfk.iloc[sel_idx]
        record = dict(
            value = None if pd.isna(row.get("value")) else float(row.get("value")),
            unit  = row.get("unit"),
            year  = None if pd.isna(row.get("year")) else int(row.get("year")),
            is_present = bool(row.get("is_present")),
            reasoning  = row.get("reasoning"),
            feature_col = feature_dropdown.value,
            source = {
                "context_rank": int(row.get("context_rank")) if not pd.isna(row.get("context_rank")) else None,
                "doc_type": row.get("doc_type"),
                "chunk_id": int(row.get("chunk_id")) if not pd.isna(row.get("chunk_id")) else None,
                "ok": bool(row.get("ok")),
                "sanity_ok": bool(row.get("sanity_ok")),
                "sanity_why": row.get("sanity_why"),
            }
        )
    else:
        # Saisie manuelle
        if missing_value.value:
            val = None  # se transformera en NaN dans pandas pour une colonne float
        else:
            v = val_input.value
            val = None if (v is None) else float(v)

        record = dict(
            value = None if (val_input.value is None) else float(val_input.value),
            unit  = (unit_input.value or None),
            year  = None if (year_input.value is None or year_input.value==0) else int(year_input.value),
            is_present = bool(isp_input.value),
            reasoning  = (reas_input.value or "manual input"),
            feature_col = feature_dropdown.value,
            source = {"context_rank": None, "doc_type": "manual", "chunk_id": None, "ok": True, "sanity_ok": True, "sanity_why": ""}
        )

    selected_map[(p,k)] = record

    # Feedback
    with status_out:
        clear_output()
        display(HTML(f"<div class='hitl-card'><span class='ok-badge'>ENREGISTRÉ</span> {p} • <b>{k}</b> → "
                     f"<b>{record['value']}</b> {record['unit'] or ''} ({record['year'] or '—'}) "
                     f"→ feature <b>{record['feature_col']}</b></div>"))
    refresh_summary()

def refresh_summary():
    summary_out.clear_output()
    if not selected_map:
        return
    rows = []
    for (p,k), r in selected_map.items():
        rows.append({
            "pillar": p, "kpi": k, "value": r["value"], "unit": r["unit"], "year": r["year"],
            "is_present": r["is_present"], "feature_col": r["feature_col"], "reasoning": r["reasoning"],
            "source_doc": r["source"].get("doc_type"), "ctx_rank": r["source"].get("context_rank")
        })
    df_sel = pd.DataFrame(rows).sort_values(["pillar","kpi"])
    with summary_out:
        display(HTML("<div class='hitl-title'>Sélections en cours</div>"))
        display(df_sel)

def on_export_clicked(_):
    audit_out.clear_output()
    # Contrôles de base
    if not sector_input.value.strip():
        with audit_out:
            display(HTML("<div class='hitl-card'><span class='err-badge'>ERREUR</span> Le champ <b>IVA_INDUSTRY</b> est obligatoire.</div>"))
        return

    # Construire df_kpi_selected (audit) et df_features_ready (features modèle)
    audit_rows = []
    feat = {c: np.nan for c in FEATURE_COLS_TARGET}

    # Métadonnées obligatoires & utiles
    feat["ISSUER_CNTRY_DOMICILE"] = (country_input.value or "").strip() or np.nan
    feat["IVA_INDUSTRY"] = sector_input.value.strip()

    for (p,k), r in selected_map.items():
        audit_rows.append({
            "pillar": p, "kpi": k, "value": r["value"], "unit": r["unit"], "year": r["year"],
            "is_present": r["is_present"], "feature_col": r["feature_col"],
            "reasoning": r["reasoning"], "source": json.dumps(r["source"], ensure_ascii=False)
        })
        # Injection dans la colonne feature (si présente)
        col = r["feature_col"]
        if col in feat and r["value"] is not None:
            try:
                feat[col] = float(r["value"])
            except:
                feat[col] = r["value"]

    df_kpi_selected = pd.DataFrame(audit_rows).sort_values(["pillar","kpi"]).reset_index(drop=True)
    df_features_ready = pd.DataFrame([feat], columns=FEATURE_COLS_TARGET)

    # Sauvegarde fichiers
    RES_DIR = os.path.join(eng.cache_dir, "human_in_the_loop")
    os.makedirs(RES_DIR, exist_ok=True)
    rid = datetime.now().strftime("%Y%m%d_%H%M%S")
    p_audit = os.path.join(RES_DIR, f"selected_kpis_{cache_prefix}_{rid}.csv")
    p_feat  = os.path.join(RES_DIR, f"features_{cache_prefix}_{rid}.csv")
    df_kpi_selected.to_csv(p_audit, index=False)
    df_features_ready.to_csv(p_feat, index=False)

    # Exposer dans le namespace global pour la cellule 3
    globals()["df_kpi_selected"] = df_kpi_selected
    globals()["df_features_ready"] = df_features_ready

    with audit_out:
        display(HTML("<div class='hitl-title'>Export terminé</div>"))
        display(HTML(f"<div class='hitl-card'><span class='ok-badge'>OK</span> Audit → {p_audit}</div>"))
        display(HTML(f"<div class='hitl-card'><span class='ok-badge'>OK</span> Features → {p_feat}</div>"))
        display(HTML("<div class='hitl-title'>Aperçu df_features_ready</div>"))
        display(df_features_ready)

# --------- Liaisons ---------
kpi_select.observe(refresh_candidates, names="value")
save_btn.on_click(on_save_clicked)
export_btn.on_click(on_export_clicked)

# --------- Mise en page ---------
left_col  = w.VBox([
    w.HTML("<h3>🧭 Revue par KPI</h3>"),
    kpi_select,
    table_out,
    w.HTML("<hr>"),
    w.HTML("<b>Choix du candidat</b>"),
    radio_candidates,
    use_custom,
    w.VBox([
        w.HBox([val_input, unit_input, year_input]),
        missing_value
    ]),
    isp_input,
    reas_input,
    w.HTML("<hr>"),
    feature_dropdown,
    save_btn,
    status_out
], layout=w.Layout(width="65%"))

right_col = w.VBox([
    w.HTML("<h3>🏷️ Métadonnées</h3>"),
    header_box,
    w.HTML("<hr>"),
    w.HTML("<h3>✅ Sélections</h3>"),
    summary_out,
    w.HTML("<hr>"),
    export_btn,
    audit_out
], layout=w.Layout(width="35%"))

display(w.HBox([left_col, right_col]))

# Initialiser la vue
refresh_candidates()
refresh_summary()


# Rating

In [None]:
#  [3/3] === Inférence XGBoost + correction secteur × bin(WA, pas=0.5) ===
import json, os, numpy as np, pandas as pd, xgboost as xgb
from packaging import version

# 0) Préconditions
assert "df_features_ready" in globals(), "df_features_ready manquant. Utilisez la cellule [2/3] pour exporter les features."
df_input = df_features_ready.copy()
df_input = df_input.replace(-1, np.nan)
# 1) Charger artefacts (depuis Google Drive)
OUT_DIR = "/content/drive/MyDrive/esg_rating_project/rating_module"

mdl = xgb.Booster()
mdl.load_model(f"{OUT_DIR}/xgb_model.json")

feat_meta = json.load(open(f"{OUT_DIR}/feature_columns.json", "r", encoding="utf-8"))
feature_cols_inf = feat_meta["feature_cols"]
cat_cols_inf     = set(feat_meta.get("cat_cols", []))

# 1bis) Tête secteur × bin(WA) — nouveau format (colonnes: sector, wa_bin, bias[, n_sb])
path_s = f"{OUT_DIR}/sector_bias.csv"
assert os.path.exists(path_s), "sector_bias.csv introuvable. Exécutez la création de la tête binned d'abord."
_bias_df = pd.read_csv(path_s)

# Normaliser le type du bin (pas=0.5 → une décimale)
if "wa_bin" not in _bias_df.columns or "bias" not in _bias_df.columns or "sector" not in _bias_df.columns:
    raise ValueError("sector_bias.csv doit contenir les colonnes: 'sector', 'wa_bin', 'bias'.")

_bias_df["wa_bin"] = pd.to_numeric(_bias_df["wa_bin"], errors="coerce")
_bias_df["bias"]   = pd.to_numeric(_bias_df["bias"], errors="coerce")

# Tables de lookup
_df_exact = _bias_df[_bias_df["wa_bin"] >= 0][["sector","wa_bin","bias"]].copy()
_df_fb_s  = _bias_df[_bias_df["wa_bin"] == -1].copy()

_global_row = _df_fb_s[_df_fb_s["sector"] == "__GLOBAL__"]
global_bias = float(_global_row["bias"].iloc[0]) if len(_global_row) else 0.0
_df_fb_s = _df_fb_s[_df_fb_s["sector"] != "__GLOBAL__"][["sector","bias"]].rename(columns={"bias":"bias_sector"})

# Préparer structures pour nearest-bin (dans le même secteur)
_dict_exact = {(r["sector"], float(r["wa_bin"])): float(r["bias"]) for _, r in _df_exact.iterrows()}
_bins_by_sector = {
    s: np.sort(_df_exact.loc[_df_exact["sector"]==s, "wa_bin"].dropna().astype(float).unique())
    for s in _df_exact["sector"].dropna().astype(str).unique()
}

# 2) Harmoniser colonnes d'inférence
for c in feature_cols_inf:
    if c not in df_input.columns:
        df_input[c] = np.nan

# Colonnes catégorielles
use_native_cat = version.parse(xgb.__version__) >= version.parse("1.7.0")
for c in cat_cols_inf:
    if c in df_input.columns:
        df_input[c] = df_input[c].astype("category")

X_new = df_input[feature_cols_inf].copy()
dnew = xgb.DMatrix(X_new, enable_categorical=use_native_cat)

# 3) Prédiction brute (WA prédite)
pred = mdl.predict(dnew)  # WA_pred

# 4) Correction additive SECTEUR × BIN(WA, pas=0.5), avec bin tiré de ESG_SCORE_PRED si dispo
sec_col = "IVA_INDUSTRY"
sector_new = (
    df_input[sec_col].astype(str).str.strip().replace({"": "__MISSING__"}).fillna("__MISSING__")
    if sec_col in df_input.columns else pd.Series(["__MISSING__"] * len(df_input), index=df_input.index)
)

# Source du bin : ESG_SCORE_PRED si déjà présent, sinon pred courant
if "ESG_SCORE_PRED" in df_input.columns:
    wa_for_binning = pd.to_numeric(df_input["ESG_SCORE_PRED"], errors="coerce").fillna(pred).values
else:
    wa_for_binning = pred

# Bin pas=0.5 → {0.0, 0.5, 1.0, ..., 9.5}
WA_clip = np.clip(wa_for_binning, 0.0, 9.999)
wa_bin  = np.floor(WA_clip * 2) / 2.0
wa_bin  = np.round(wa_bin, 1)

lk = pd.DataFrame({"sector": sector_new.values, "wa_bin": wa_bin}, index=df_input.index)

# 4.1) Jointure exacte (sector, wa_bin)
lk = lk.merge(_df_exact, on=["sector","wa_bin"], how="left")  # ajoute 'bias' (exact)

# 4.2) Fallback nearest-bin (dans le même secteur)
def _nearest_bias(row):
    if not pd.isna(row.get("bias")):
        return row["bias"], "sector+bin"
    s = str(row["sector"])
    b = float(row["wa_bin"])
    bins = _bins_by_sector.get(s)
    if bins is not None and len(bins) > 0:
        nb = float(bins[np.argmin(np.abs(bins - b))])
        val = _dict_exact.get((s, nb))
        if val is not None:
            return val, "nearest-bin"
    return np.nan, None

tmp = lk.apply(_nearest_bias, axis=1, result_type="expand")
lk["bias_nearest"]     = pd.to_numeric(tmp[0], errors="coerce")
lk["bias_nearest_src"] = tmp[1]

# 4.3) Fallback sector-only (-1)
lk = lk.merge(_df_fb_s, on="sector", how="left")  # ajoute 'bias_sector'

# 4.4) Construire le vecteur final de biais (priorité: exact > nearest-bin > sector-only > global > 0)
bias_vec = lk["bias"].copy()
bias_vec = bias_vec.fillna(lk["bias_nearest"])
bias_vec = bias_vec.fillna(lk["bias_sector"])
bias_vec = bias_vec.fillna(global_bias)
bias_vec = bias_vec.fillna(0.0).astype(float).values

# 4.5) Tracer la source utilisée
bias_source = np.where(~pd.isna(lk["bias"]), "sector+bin",
                np.where(~pd.isna(lk["bias_nearest"]), "nearest-bin",
                np.where(~pd.isna(lk["bias_sector"]), "sector-only",
                "__GLOBAL__" if (global_bias != 0.0) else "zero")))

# 5) Résultat
pred_corrected = pred + bias_vec

df_scored = df_input.copy()
df_scored["SECTOR_NORM"]          = sector_new.values
df_scored["WA_BIN"]               = wa_bin
df_scored["BIAS_USED"]            = bias_vec
df_scored["BIAS_SOURCE"]          = bias_source
df_scored["ESG_SCORE_PRED"]       = pred                # WA_pred
df_scored["ESG_SCORE_PRED_FINAL"] = pred_corrected      # IA_hat

print("[OK] Scoring terminé. Aperçu :")
display(df_scored[["SECTOR_NORM","WA_BIN","BIAS_USED","BIAS_SOURCE","ESG_SCORE_PRED","ESG_SCORE_PRED_FINAL"]].head(3))


# CUDA FLUSH


In [None]:
# %% === PATCH & FLUSH CUDA GPU CACHE (à exécuter avant (re)load des modèles) ===
import os, gc, time
try:
    import torch
    has_cuda = torch.cuda.is_available()
except Exception:
    has_cuda = False

# 👉 Patch allocator PyTorch (limite la fragmentation) — effectif pour les prochains allocs
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,garbage_collection_threshold:0.6,max_split_size_mb:128"

def gpu_mem_report(tag=""):
    if not has_cuda:
        print(f"[{tag}] CUDA indisponible."); return
    free, total = torch.cuda.mem_get_info()
    print(f"[{tag}] GPU: free={free/1e9:.2f}GB / total={total/1e9:.2f}GB | "
          f"reserved={torch.cuda.memory_reserved()/1e9:.2f}GB | "
          f"allocated={torch.cuda.memory_allocated()/1e9:.2f}GB | "
          f"max_alloc={torch.cuda.max_memory_allocated()/1e9:.2f}GB")

gpu_mem_report("AVANT")

# 👉 Optionnel: dé-référencer quelques objets connus si présents en global (modifie selon ton notebook)
for name in ("eng","model","embedding_model","donut_model","reranker","tokenizer"):
    if name in globals():
        try:
            del globals()[name]
            print(f"del {name}")
        except Exception:
            pass

# 👉 GC Python + flush CUDA
gc.collect()
time.sleep(0.1)
if has_cuda:
    try:
        torch.cuda.empty_cache()        # libère le cache du caching allocator
        torch.cuda.ipc_collect()        # collecte les segments partagés
        torch.cuda.reset_peak_memory_stats()
        # (facultatif) relance TF32 pour NVIDIA modernes
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        except Exception:
            pass
    except Exception as e:
        print("CUDA flush warning:", e)

time.sleep(0.1)
gpu_mem_report("APRÈS")

print("\n✅ Patch OK. Relance maintenant le chargement de tes modèles (HF/LLM/Donut).")
