In [1]:
import sys
import os

# Ruta absoluta a la carpeta src
sys.path.append(os.path.abspath("../src"))

from utils_TATR import outputs_to_objects, build_grid_with_spans, fill_grid_with_ocr, fill_grid_from_global_ocr_centered, draw_tatr_overlays_multi


In [2]:
import torch
from transformers import TableTransformerForObjectDetection
from torchvision import transforms  # 👈 aquí está 'transforms'

#from tatr_ocr.transforms_tatr import make_structure_transform, to_model_batch



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from PIL import Image

cropped_table = Image.open("tabla.jpg").convert("RGB")


In [5]:
from typing import Sequence
from dataclasses import dataclass

@dataclass
class MaxResize:
    """
    Redimensiona una imagen manteniendo aspecto para que el lado mayor sea `max_size`.
    A diferencia de la versión "segura", esta implementación **también hace upscaling**
    si la imagen es más chica que `max_size`.

    Args:
        max_size (int): Tamaño máximo del lado mayor (la imagen resultante siempre tendrá
            su lado mayor igual a este valor).
        resample (int): Filtro de remuestreo de PIL (por defecto Image.BILINEAR).

    Returns:
        Image.Image: Imagen redimensionada con el nuevo tamaño.

    Raises:
        TypeError: Si la entrada no es una instancia de PIL.Image.Image.
        ValueError: Si la imagen tiene dimensiones inválidas (<= 0).

    Examples:
        >>> img = Image.open("ejemplo.jpg")
        >>> transform = MaxResize(max_size=800)
        >>> out = transform(img)
        >>> out.size
        (800, 533)  # si la original era 1200x800
    """
    max_size: int = 800
    resample: int = Image.BILINEAR

    def __call__(self, image: Image.Image) -> Image.Image:
        if not isinstance(image, Image.Image):
            raise TypeError(f"Se esperaba PIL.Image.Image, recibido: {type(image)}")

        width, height = image.size
        if width <= 0 or height <= 0:
            raise ValueError(f"Tamaño de imagen inválido: {image.size}")

        current_max = max(width, height)
        scale = self.max_size / float(current_max)
        new_w = max(1, int(round(scale * width)))
        new_h = max(1, int(round(scale * height)))

        return image.resize((new_w, new_h), resample=self.resample)


def make_structure_transform(
    max_size: int = 1000,
    mean: Sequence[float] = (0.485, 0.456, 0.406),
    std: Sequence[float] = (0.229, 0.224, 0.225),
):
    """
    Construye un pipeline `transforms.Compose` para preparar imágenes de estructura/tablas.

    El pipeline incluye:
    - MaxResize: asegura que el lado mayor quede exactamente en `max_size`.
    - ToTensor: convierte a tensor (C,H,W) en [0,1].
    - Normalize: normaliza con `mean` y `std` (por defecto, ImageNet).

    Args:
        max_size (int): Tamaño máximo del lado mayor tras redimensionar.
        mean (Sequence[float]): Medias para normalización.
        std (Sequence[float]): Desvíos estándar para normalización.

    Returns:
        transforms.Compose: Transformación compuesta lista para usar.

    Raises:
        RuntimeError: Si torchvision no está disponible en el entorno.
    """
    if transforms is None:
        raise RuntimeError("torchvision no está disponible en el entorno.")

    return transforms.Compose([
        MaxResize(max_size=max_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# cargar modelo en el notebook
structure_model = TableTransformerForObjectDetection.from_pretrained(
    "microsoft/table-structure-recognition-v1.1-all"
).to(device)

structure_transform = transforms.Compose([
    MaxResize(1000),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [7]:
pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)
print(pixel_values.shape)

torch.Size([1, 3, 690, 1000])


In [8]:
# forward pass
with torch.no_grad():
  outputs = structure_model(pixel_values)

id2label me tira con las opciones que me puede dar la tabla que son:

* table
* table header
* etc

In [9]:
# armás el id2label según tu proyecto
structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"
# si en tu caso no hace falta "no object", no lo agregás

cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)

print(cells)

[{'label': 'table column', 'score': 0.9999914169311523, 'bbox': [71.0396499633789, 18.333629608154297, 372.3055419921875, 692.390380859375]}, {'label': 'table row', 'score': 0.9998055100440979, 'bbox': [8.559820175170898, 167.4796600341797, 1030.1776123046875, 271.1231994628906]}, {'label': 'table column', 'score': 0.9999634027481079, 'bbox': [835.1494140625, 19.05427360534668, 938.7198486328125, 692.7470092773438]}, {'label': 'table column', 'score': 0.9999054670333862, 'bbox': [492.745361328125, 18.53667449951172, 746.7467041015625, 692.6864013671875]}, {'label': 'table column', 'score': 0.9998608827590942, 'bbox': [747.0014038085938, 18.84082794189453, 833.9036865234375, 692.5086059570312]}, {'label': 'table row', 'score': 0.9994695782661438, 'bbox': [8.504619598388672, 75.78985595703125, 1029.8843994140625, 113.44821166992188]}, {'label': 'table column', 'score': 0.9999246597290039, 'bbox': [938.3115234375, 19.059001922607422, 1030.29296875, 692.8472290039062]}, {'label': 'table co

In [10]:
import numpy as np
# 3) Construir grilla con merges de spans
pack = build_grid_with_spans(cells)

# 4) Completar OCR por celda (sobre la imagen original)
pack = fill_grid_with_ocr(
    grid_pack=pack,
    image_path="tabla.jpg",           # ruta a la misma imagen usada para el modelo
    tess_cfg="--oem 3 --psm 6",       # ajustá PSM si hace falta
    skip_headers=False,               # poné True si querés saltear headers
)

# 4) Acceder a resultados
print("Rows:", pack["n_rows"])
print("Cols:", pack["n_cols"])
print("Cells (counted):", pack["cells_counted"])
print("Header rows:", pack["header_rows"])
print("Row header cols:", pack["row_header_cols"])

# 5) Ejemplo: recorrer la grilla
grid = pack["grid"]
for r in range(pack["n_rows"]):
    row_texts = [cell["text"] for cell in grid[r] if not cell["covered"]]
    print(f"Row {r}: {row_texts}")

Rows: 11
Cols: 7
Cells (counted): 77
Header rows: [0]
Row header cols: []
Row 0: ['Rank', 'Protein domains', 'Relevant pathway', 'Cellular component', 'Gene count', 'Quality”', 'P-value']
Row 1: ['1', 'Phospholipase A2', 'hsa0592', 'Extracellular region (63%)', '20', 'O878', '434E-08']
Row 2: ['>', 'THF dehydrogenase; formy] transferase; SHMT;', 'hsa00670', 'Mitochondrion (44%)', '', '0891', '1.79E-07']
Row 3: ['3', 'Sialyltransferase; GTF, family 31; GTF, family 11; GH, family 20', 'hsa00603', 'Golgi apparatus (76%)', '17', '0.756', '2.7TE-06']
Row 4: ['4', 'Four-helical cytokine, core: IL-4; IL-17, TNF 2; Peroxidases heam-ligand binding site: Toll-ILR MHC class I, a chain, al and a2', 'hsa05310', 'Extracellular space (52%)', '27', '0598', '3.72E-06']
Row 5: ['5', 'Immunoglobulin C-Type', 'hsa05310', 'MHC class II protein complex (10046)', '9', '0.793', 'S.98E-04']
Row 6: ['6', 'GPCR, rhodopsin-like superfamily', 'NA', 'Integral to plasma membrane (75%)', '12', 'OS5T8', '7 86E-04']
Ro

In [11]:
# 3) Armar grilla (con spans/headers)
pack = build_grid_with_spans(cells, iou_th=0.6, overlap_th=0.5)

# 4) OCR global → asignación por centro
pack = fill_grid_from_global_ocr_centered(
    grid_pack=pack,
    image_path="tabla.jpg",
    tess_cfg="--oem 3 --psm 6",
    min_conf=0,          # subí si hay ruido (por ej., 40-60)
    joiner=" ",
    skip_headers=False,
)

# 5) Visualizar en Colab/Jupyter
# 5) Ejemplo: recorrer la grilla
grid = pack["grid"]
for r in range(pack["n_rows"]):
    row_texts = [cell["text"] for cell in grid[r] if not cell["covered"]]
    print(f"Row {r}: {row_texts}")

Row 0: ['Rank', 'Protein domains', 'Relevant pathway', 'Cellular component', 'Gene count', 'Quality”', 'P-value']
Row 1: ['1', 'Phospholipase A2', 'hsa0592', 'Extracellular region (63%)', '20', 'O878', '34E-08 4']
Row 2: ['2', 'THF dehydrogenase; formy] transferase; SHMT;', 'hsal0670—-', 'Mitochondrion (44%)', '18', '0891', '1.79B-07']
Row 3: ['3', 'Sialyltransferase; GTF, family 31; GTF, family 11; GH, family 20', 'hsa00603', 'Golgi (76%) apparatus', '17', '0.756', '2.77TE-06']
Row 4: ['4', 'Four-helical cytokine, core: IL-4: IL-17, TNF 2; Peroxidases heam-ligand binding site: Toll-ILR a MHC class I, chain, and «2 acl', 'hsa05310', 'Extracellular (52%) space', '27', '0598', '3.72E-06']
Row 5: ['3', 'Immunoglobulin C-Type', 'hsa05310', 'MHC class II protein complex (100%)', '9', '0.793', 'S.98E-04']
Row 6: ['6', 'GPCR, rhodopsin-like superfamily', 'nan', 'Integral plasma to membrane (75%)', '12', 'O578', '7R6E-O4']
Row 7: ['7', 'nan', 'nan', 'Synaptic vesicle (43%)', '8', '0.701', '0.0

## Visualización de imagenes por tipo (opcional)

In [None]:
from TATR_OCR import draw_tatr_overlays_multi

# detections es la lista de dicts que obtuviste con outputs_to_objects(...)
detections = cells  

# ruta de la imagen original que procesaste con el modelo
image_path = "table.jpg"

# carpeta de salida (ej: "../temp" si está fuera de notebooks/)
out_dir = "../temp"

# generar las imágenes con overlays
paths = draw_tatr_overlays_multi(
    image_path=image_path,
    detections=detections,
    out_dir=out_dir,
    alpha=0.3,       # transparencia (0=sin relleno, 1=totalmente opaco)
    thickness=2,     # grosor de los bordes
    put_labels=True  # si mostrar etiquetas de clase y score
)

print("Imágenes generadas:")
for k, v in paths.items():
    print(f"{k}: {v}")


## Entrenamiento

Se van a sacar las siguientes estadísticas para el estudio de 100 imágenes random del dataset de pubtabnet:

* Tamaño de la imagen
* Cantidad de filas reales
* Cantidad de filas predichas
* Cantidad de columnas reales
* Cantidad de columnas predichas
* Cantidad de celdas reales
* Cantidad de celdas predichas
* Row Precision
* Column Precision
* Cell Precision
* WRC Global
* WCC Global
* CER Promedio
* CER Global

$\text{precision} = 1 - \dfrac{|\text{pred} - \text{gt}|}{\text{gt}}$


In [12]:
# Funciones para tratar con el ground truth en formato json
from typing import Dict, Literal, Any, List
from jiwer import cer

def count_structure_from_pubtabnet(gt: Dict) -> Dict[str, int]:
    """Cuenta rows, cols y cells usando html.structure.tokens de PubTabNet-like.

    Reglas:
    - rows: cantidad de <tr>
    - cols: máximo nº de celdas por fila (cuenta <td> y <th> entre <tr> ... </tr>)
    - cells: total de celdas (<td> + <th>) en toda la tabla

    Args:
        gt: dict con clave 'html' -> {'structure': {'tokens': [...]}}.

    Returns:
        {'rows': int, 'cols': int, 'cells': int}
    """
    tokens: List[str] = gt["html"]["structure"]["tokens"]
    rows = 0
    max_cols = 0
    total_cells = 0

    in_row = False
    cells_in_current_row = 0

    i = 0
    while i < len(tokens):
        tok = tokens[i]

        if tok == "<tr>":
            # cerrar fila previa si quedó abierta (por robustez)
            if in_row:
                rows += 1
                max_cols = max(max_cols, cells_in_current_row)
                cells_in_current_row = 0
            in_row = True

        elif tok in ("</tr>",):
            if in_row:
                rows += 1
                max_cols = max(max_cols, cells_in_current_row)
                cells_in_current_row = 0
                in_row = False

        elif tok in ("<td>", "<th>"):
            total_cells += 1
            if in_row:
                cells_in_current_row += 1

        # ignoramos otros tokens (<thead>, </thead>, <tbody>, </tbody>, </td>, </th>, etc.)
        i += 1

    # si terminó el stream con una fila abierta (sin </tr>)
    if in_row:
        rows += 1
        max_cols = max(max_cols, cells_in_current_row)

    return {"rows": rows, "cols": max_cols, "cells": total_cells}

def structure_precision_counts(
    pred_stats: Dict[str, int],
    gt_stats: Dict[str, int],
) -> Dict[str, Dict[str, float]]:
    """Compara conteos de estructura (rows/cols/cells) entre pred y GT y calcula precisión.

    La métrica se define como:
        precision = 1 - |pred - gt| / gt

    Args:
        pred_stats: {'rows': int, 'cols': int, 'cells': int} predichos.
        gt_stats:   {'rows': int, 'cols': int, 'cells': int} ground truth.

    Returns:
        Dict con claves 'rows', 'cols', 'cells'. Cada una contiene:
            - gt: int
            - pred: int
            - delta: int (pred - gt)
            - precision: float entre 0 y 1
    """
    out: Dict[str, Dict[str, float]] = {}
    for key in ("rows", "cols", "cells"):
        gt_val = int(gt_stats.get(key, 0))
        pred_val = int(pred_stats.get(key, 0))
        delta = pred_val - gt_val

        if gt_val == 0:
            # si el GT no tiene valor, definimos precision = 1 si pred==0, si no 0
            precision = 1.0 if pred_val == 0 else 0.0
        else:
            precision = 1 - abs(delta) / gt_val
            precision = max(0.0, min(1.0, precision))  # clamp a [0,1]

        out[key] = {
            "gt": gt_val,
            "pred": pred_val,
            "delta": delta,
            "precision": precision,
        }
    return out


In [13]:

def count_words_from_pubtabnet(gt: Dict[str, Any]) -> int:
    """Reconstruye texto de celdas y cuenta palabras en todo el GT.

    Args:
        gt: dict con clave 'html' -> {'cells': [{'tokens': [...]}]}.

    Returns:
        Número total de palabras.
    """
    total_words = 0
    for cell in gt["html"]["cells"]:
        tokens = cell.get("tokens", [])
        # unir tokens en un string completo
        text = "".join(
            t for t in tokens 
            if not (t.startswith("<") and t.endswith(">"))  # descartar tags
        )
        # dividir por espacios
        words = [w for w in text.split(" ") if w.strip()]
        total_words += len(words)
    return total_words

def reconstruct_gt_cells(gt: Dict[str, Any]) -> List[str]:
    """Reconstruye el texto de cada celda en GT PubTabNet."""
    cells = []
    for cell in gt["html"]["cells"]:
        tokens = cell.get("tokens", [])
        text = "".join(t for t in tokens if not (t.startswith("<") and t.endswith(">")))
        text = " ".join(text.split())  # normalizar espacios
        cells.append(text)
    return cells

def flatten_pred_rows(pred_rows: List[List[str]]) -> List[str]:
    """Convierte predicción en lista plana de celdas."""
    return [c.strip() for row in pred_rows for c in row]

def wrc_single(n_gt: int, n_pred: int) -> float:
    """Word Rate Count para un par de conteos."""
    if n_gt == 0:
        return 1.0 if n_pred == 0 else 0.0
    wrc = 1 - abs(n_pred - n_gt) / n_gt
    return max(0.0, min(1.0, wrc))

def wrc_global(gt_cells: List[str], pred_cells: List[str]) -> float:
    """WRC considerando todas las palabras de la tabla."""
    n_gt = sum(len(c.split()) for c in gt_cells)
    n_pred = sum(len(c.split()) for c in pred_cells)
    return wrc_single(n_gt, n_pred)

def wrc_cellwise(gt_cells: List[str], pred_cells: List[str]) -> float:
    """WRC promedio celda a celda."""
    n = len(gt_cells)
    if n == 0:
        return 1.0
    scores = []
    for i in range(n):
        gt_text = gt_cells[i] if i < len(gt_cells) else ""
        pred_text = pred_cells[i] if i < len(pred_cells) else ""
        scores.append(wrc_single(len(gt_text.split()), len(pred_text.split())))
    return sum(scores) / len(scores)

def cer_pair(gt: str, pred: str) -> float:
    """CER entre dos textos normalizados."""
    gt_norm = " ".join(gt.split())
    pred_norm = " ".join(pred.split())
    if not gt_norm:
        return 0.0 if not pred_norm else 1.0
    return float(cer(gt_norm, pred_norm))

def cer_global(gt_cells: List[str], pred_cells: List[str]) -> float:
    """CER considerando toda la tabla concatenada."""
    gt_text = " ".join(" ".join(c.split()) for c in gt_cells).strip()
    pred_text = " ".join(" ".join(c.split()) for c in pred_cells).strip()
    return cer_pair(gt_text, pred_text)

def cer_cellwise(gt_cells: List[str], pred_cells: List[str]) -> float:
    """CER promedio celda a celda."""
    n = len(gt_cells)
    if n == 0:
        return 0.0
    scores = []
    for i in range(n):
        gt_text = gt_cells[i] if i < len(gt_cells) else ""
        pred_text = pred_cells[i] if i < len(pred_cells) else ""
        scores.append(cer_pair(gt_text, pred_text))
    return sum(scores) / len(scores)

In [None]:
import json
from pathlib import Path
from PIL import Image
import pandas as pd

# --- Asumimos que ya tenés estas funciones definidas ---
# count_structure_from_pubtabnet
# structure_precision_counts
# reconstruct_gt_cells
# flatten_pred_rows
# wrc_global, wrc_cellwise
# cer_global, cer_cellwise

# --- Paths ---
image_dir = Path("..\\data\\regions\\table")               # carpeta con imágenes
gt_path = Path("..\\data\\annotations\\ocr_table_labels.json")      # ground truth PubTabNet-style

# --- Cargar GT en un diccionario: filename -> objeto completo ---
gt_map = {}
with gt_path.open("r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        obj = json.loads(line)
        filename = obj.get("filename")
        if filename:
            gt_map[filename] = obj

# --- Iterar imágenes ---
rows_out = []
exts = {".png", ".jpg", ".jpeg"}

for img_path in sorted(image_dir.iterdir()):
    if img_path.suffix.lower() not in exts:
        continue

    filename = img_path.name
    gt = gt_map.get(filename)
    if gt is None:
        print(f"⚠️ No GT para {filename}, se saltea")
        continue

    # --- Tamaño de la imagen ---
    img = Image.open(img_path).convert("RGB")
    img_w, img_h = img.size

    # --- Conteos reales (GT) ---
    gt_counts = count_structure_from_pubtabnet(gt)
    gt_cells = reconstruct_gt_cells(gt)

    # --- Conteos predichos y textos ---
    # acá usás tu pipeline para obtener pack["grid"]
    # 1) Aplico el modelo de TATR a la imagen
    structure_model = TableTransformerForObjectDetection.from_pretrained(
        "microsoft/table-structure-recognition-v1.1-all"
    ).to(device)

    structure_transform = transforms.Compose([
        MaxResize(1000),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    pixel_values = structure_transform(img).unsqueeze(0)
    pixel_values = pixel_values.to(device)
    
    with torch.no_grad():
        outputs = structure_model(pixel_values)
    
    # 2) Paso los datos crudos del tatr a lo legible
    structure_id2label = structure_model.config.id2label
    structure_id2label[len(structure_id2label)] = "no object"
    cells = outputs_to_objects(outputs, img.size, structure_id2label)
    
    # 3) Construir grilla con merges de spans
    pack = build_grid_with_spans(cells)

    # 4) Completar OCR por celda (sobre la imagen original)
    pack = fill_grid_with_ocr(
        grid_pack=pack,
        image_path="tabla.jpg",           # ruta a la misma imagen usada para el modelo
        tess_cfg="--oem 3 --psm 6",       # ajustá PSM si hace falta
        skip_headers=False,               # poné True si querés saltear headers
    )
    grid = pack["grid"]

    pred_rows = []
    for r in range(pack["n_rows"]):
        row_texts = [cell["text"] for cell in grid[r] if not cell["covered"]]
        pred_rows.append(row_texts)

    pred_cells = flatten_pred_rows(pred_rows)
    pred_counts = {
        "rows": pack["n_rows"],
        "cols": pack["n_cols"],
        "cells": pack["cells_counted"],  # asumimos que tu pack trae esto
    }

    # --- Precisiones estructurales ---
    comp = structure_precision_counts(pred_counts, gt_counts)

    # --- Métricas WRC y CER ---
    wrc_g = wrc_global(gt_cells, pred_cells)
    wrc_c = wrc_cellwise(gt_cells, pred_cells)
    cer_g = cer_global(gt_cells, pred_cells)
    cer_c = cer_cellwise(gt_cells, pred_cells)

    # --- Registrar resultados ---
    rows_out.append({
        "filename": filename,
        "img_w": img_w,
        "img_h": img_h,
        "gt_rows": gt_counts["rows"],
        "pred_rows": pred_counts["rows"],
        "gt_cols": gt_counts["cols"],
        "pred_cols": pred_counts["cols"],
        "gt_cells": gt_counts["cells"],
        "pred_cells": pred_counts["cells"],
        "row_precision": comp["rows"]["precision"],
        "col_precision": comp["cols"]["precision"],
        "cell_precision": comp["cells"]["precision"],
        "wrc_global": wrc_g,
        "wrc_avg": wrc_c,
        "cer_global": cer_g,
        "cer_avg": cer_c,
    })

# --- DataFrame final ---
df_results_1 = pd.DataFrame(rows_out)
df_results_1.to_csv("metrics_results_1.csv", index=False)
print(df_results.head())


⚠️ No GT para PMC1180437_003_01.png, se saltea
⚠️ No GT para PMC1215488_007_00.png, se saltea
⚠️ No GT para PMC1796903_013_00.png, se saltea
⚠️ No GT para PMC2174470_003_00.png, se saltea
⚠️ No GT para PMC2654114_006_00.png, se saltea
⚠️ No GT para PMC2679760_005_00.png, se saltea
⚠️ No GT para PMC2688351_005_00.png, se saltea
⚠️ No GT para PMC2741432_008_00.png, se saltea
⚠️ No GT para PMC2781010_008_00.png, se saltea
⚠️ No GT para PMC2837001_002_01.png, se saltea
⚠️ No GT para PMC3042936_004_00.png, se saltea
⚠️ No GT para PMC3103444_003_00.png, se saltea
⚠️ No GT para PMC3213066_006_00.png, se saltea
⚠️ No GT para PMC3284427_006_01.png, se saltea
⚠️ No GT para PMC3296643_002_00.png, se saltea
⚠️ No GT para PMC3349608_002_00.png, se saltea
⚠️ No GT para PMC3399222_010_00.png, se saltea
⚠️ No GT para PMC3414058_003_00.png, se saltea
⚠️ No GT para PMC3426469_004_00.png, se saltea
⚠️ No GT para PMC3442972_003_00.png, se saltea
⚠️ No GT para PMC3448500_004_00.png, se saltea
⚠️ No GT para

In [15]:
df_results

Unnamed: 0,filename,img_w,img_h,gt_rows,pred_rows,gt_cols,pred_cols,gt_cells,pred_cells,row_precision,col_precision,cell_precision,wrc_global,wrc_avg,cer_global,cer_avg
0,PMC1232864_004_01.png,503,157,13,13,4,4,44,51,1.000000,1.0,0.840909,0.170732,0.176087,0.932292,1.012587
1,PMC1310919_004_00.png,344,103,10,10,6,6,60,60,1.000000,1.0,1.000000,0.065789,0.062500,0.965035,0.994815
2,PMC1534032_004_01.png,503,373,34,38,2,2,68,76,0.882353,1.0,0.882353,0.408451,0.044118,0.883234,2.230949
3,PMC1559691_002_00.png,503,557,43,50,6,6,258,300,0.837209,1.0,0.837209,0.308793,0.250111,0.880734,0.932266
4,PMC1570152_006_00.png,246,137,11,11,7,7,77,77,1.000000,1.0,1.000000,0.093023,0.071429,0.961864,1.010823
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,PMC5876629_005_00.png,317,110,7,7,3,3,17,19,1.000000,1.0,0.882353,0.291667,0.131579,0.864662,1.391604
96,PMC5928814_004_01.png,243,212,19,18,3,3,54,54,0.947368,1.0,1.000000,0.189542,0.100649,0.908405,1.043463
97,PMC5977059_004_00.png,502,194,17,17,9,9,141,153,1.000000,1.0,0.914894,0.165957,0.244681,0.933740,0.897559
98,PMC6025574_003_00.png,404,337,27,27,7,7,183,184,1.000000,1.0,0.994536,0.482270,0.300901,0.903651,0.795572


In [16]:
# mean de todas las columnas numéricas
df_mean = df_results.mean(numeric_only=True)

print(df_mean)



img_w             390.360000
img_h             221.510000
gt_rows            15.480000
pred_rows          15.840000
gt_cols             4.830000
pred_cols           4.890000
gt_cells           65.870000
pred_cells         72.870000
row_precision       0.968008
col_precision       0.980726
cell_precision      0.906645
wrc_global          0.286189
wrc_avg             0.188910
cer_global          0.898016
cer_avg             1.105488
dtype: float64


Mismo ejercicio pero con un scale a la tabla antes de aplicar OCR.

In [None]:
# --- Paths ---
image_dir = Path("..\\data\\regions\\table")               # carpeta con imágenes
gt_path = Path("..\\data\\annotations\\ocr_table_labels.json")      # ground truth PubTabNet-style

# --- Cargar GT en un diccionario: filename -> objeto completo ---
gt_map = {}
with gt_path.open("r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        obj = json.loads(line)
        filename = obj.get("filename")
        if filename:
            gt_map[filename] = obj

# --- Iterar imágenes ---
rows_out = []
exts = {".png", ".jpg", ".jpeg"}

for img_path in sorted(image_dir.iterdir()):
    if img_path.suffix.lower() not in exts:
        continue

    filename = img_path.name
    gt = gt_map.get(filename)
    if gt is None:
        print(f"⚠️ No GT para {filename}, se saltea")
        continue

    # --- Tamaño de la imagen ---
    # --- Tamaño original con PIL ---
    img_pil = Image.open(img_path).convert("RGB")
    orig_w, orig_h = img_pil.size

    # --- Convertir a numpy (para usar cv2) ---
    img_np = np.array(img_pil)

    # --- Escalar con cv2 ---
    img_np = cv2.resize(
        img_np,
        (orig_w * 2, orig_h * 2),
        interpolation=cv2.INTER_CUBIC   # o cv2.INTER_LANCZOS4 para más calidad
    )

    # --- Volver a PIL para usar con TATR ---
    img = Image.fromarray(img_np)

    # --- Nuevo tamaño ---
    img_w, img_h = img.size

    # --- Conteos reales (GT) ---
    gt_counts = count_structure_from_pubtabnet(gt)
    gt_cells = reconstruct_gt_cells(gt)

    # --- Conteos predichos y textos ---
    # acá usás tu pipeline para obtener pack["grid"]
    # 1) Aplico el modelo de TATR a la imagen
    structure_model = TableTransformerForObjectDetection.from_pretrained(
        "microsoft/table-structure-recognition-v1.1-all"
    ).to(device)

    structure_transform = transforms.Compose([
        MaxResize(1000),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    pixel_values = structure_transform(img).unsqueeze(0)
    pixel_values = pixel_values.to(device)
    
    with torch.no_grad():
        outputs = structure_model(pixel_values)
    
    # 2) Paso los datos crudos del tatr a lo legible
    structure_id2label = structure_model.config.id2label
    structure_id2label[len(structure_id2label)] = "no object"
    cells = outputs_to_objects(outputs, img.size, structure_id2label)
    
    # 3) Construir grilla con merges de spans
    pack = build_grid_with_spans(cells)

    # 4) Completar OCR por celda (sobre la imagen original)
    pack = fill_grid_with_ocr(
        grid_pack=pack,
        image_path="tabla.jpg",           # ruta a la misma imagen usada para el modelo
        tess_cfg="--oem 3 --psm 6",       # ajustá PSM si hace falta
        skip_headers=False,               # poné True si querés saltear headers
    )
    grid = pack["grid"]

    pred_rows = []
    for r in range(pack["n_rows"]):
        row_texts = [cell["text"] for cell in grid[r] if not cell["covered"]]
        pred_rows.append(row_texts)

    pred_cells = flatten_pred_rows(pred_rows)
    pred_counts = {
        "rows": pack["n_rows"],
        "cols": pack["n_cols"],
        "cells": pack["cells_counted"],  # asumimos que tu pack trae esto
    }

    # --- Precisiones estructurales ---
    comp = structure_precision_counts(pred_counts, gt_counts)

    # --- Métricas WRC y CER ---
    wrc_g = wrc_global(gt_cells, pred_cells)
    wrc_c = wrc_cellwise(gt_cells, pred_cells)
    cer_g = cer_global(gt_cells, pred_cells)
    cer_c = cer_cellwise(gt_cells, pred_cells)

    # --- Registrar resultados ---
    rows_out.append({
        "filename": filename,
        "img_w": img_w,
        "img_h": img_h,
        "gt_rows": gt_counts["rows"],
        "pred_rows": pred_counts["rows"],
        "gt_cols": gt_counts["cols"],
        "pred_cols": pred_counts["cols"],
        "gt_cells": gt_counts["cells"],
        "pred_cells": pred_counts["cells"],
        "row_precision": comp["rows"]["precision"],
        "col_precision": comp["cols"]["precision"],
        "cell_precision": comp["cells"]["precision"],
        "wrc_global": wrc_g,
        "wrc_avg": wrc_c,
        "cer_global": cer_g,
        "cer_avg": cer_c,
    })

# --- DataFrame final ---
df_results_1 = pd.DataFrame(rows_out)
df_results_1.to_csv("metrics_results_1.csv", index=False)
print(df_results_1.head())


⚠️ No GT para PMC1180437_003_01.png, se saltea
⚠️ No GT para PMC1215488_007_00.png, se saltea
⚠️ No GT para PMC1796903_013_00.png, se saltea
⚠️ No GT para PMC2174470_003_00.png, se saltea
⚠️ No GT para PMC2654114_006_00.png, se saltea
⚠️ No GT para PMC2679760_005_00.png, se saltea
⚠️ No GT para PMC2688351_005_00.png, se saltea
⚠️ No GT para PMC2741432_008_00.png, se saltea
⚠️ No GT para PMC2781010_008_00.png, se saltea
⚠️ No GT para PMC2837001_002_01.png, se saltea
⚠️ No GT para PMC3042936_004_00.png, se saltea
⚠️ No GT para PMC3103444_003_00.png, se saltea
⚠️ No GT para PMC3213066_006_00.png, se saltea
⚠️ No GT para PMC3284427_006_01.png, se saltea
⚠️ No GT para PMC3296643_002_00.png, se saltea
⚠️ No GT para PMC3349608_002_00.png, se saltea
⚠️ No GT para PMC3399222_010_00.png, se saltea
⚠️ No GT para PMC3414058_003_00.png, se saltea
⚠️ No GT para PMC3426469_004_00.png, se saltea
⚠️ No GT para PMC3442972_003_00.png, se saltea
⚠️ No GT para PMC3448500_004_00.png, se saltea
⚠️ No GT para

In [18]:
# mean de todas las columnas numéricas
df_mean = df_results_1.mean(numeric_only=True)

print(df_mean)

img_w             780.720000
img_h             443.020000
gt_rows            15.480000
pred_rows          15.820000
gt_cols             4.830000
pred_cols           4.890000
gt_cells           65.870000
pred_cells         72.980000
row_precision       0.965003
col_precision       0.980726
cell_precision      0.902942
wrc_global          0.655685
wrc_avg             0.293468
cer_global          0.851446
cer_avg             1.597833
dtype: float64
