In [3]:
# %% [markdown]
# # FinTabNet Training with GTE-Inspired Joint Detection Model
# ## Objective: Train a single Detectron2 model to detect both tables and cells jointly,
# ## incorporating a GTE-inspired cell containment loss.

# %% [markdown]
# ## 1. Imports and Setup

# %%
import json
import os
import datetime
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from PIL import Image, ImageDraw
from glob import glob
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
import copy
import sys
import cv2 # Needed for cropping and visualization
import random
import logging # Import standard logging
import fitz # Import PyMuPDF

# Setup logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Detectron2 Imports ---
from detectron2.structures import BoxMode, Instances, Boxes, ImageList, pairwise_iou
from detectron2.engine import DefaultTrainer, DefaultPredictor, hooks
from detectron2.config import get_cfg, CfgNode, configurable
from detectron2 import model_zoo
from detectron2.data import (
    DatasetCatalog,
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader,
    DatasetMapper,
    detection_utils as utils,
    transforms as T
)
from detectron2.data.detection_utils import SizeMismatchError # Import the specific error
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.modeling import (
    build_model,
    build_backbone,
    build_proposal_generator,
    build_roi_heads,
    META_ARCH_REGISTRY,
    ROI_HEADS_REGISTRY,
    StandardROIHeads # Use standard ROI heads initially
)
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.checkpoint import DetectionCheckpointer
import detectron2.utils.comm as comm
from detectron2.utils.events import EventStorage, get_event_storage
from detectron2.utils.visualizer import Visualizer, ColorMode # For drawing predictions
from detectron2.layers import cat

  from IPython.core.display import display, HTML


In [5]:
# %% [markdown]
# ## 2. Constants and Configuration Setup

# %%
# --- Constants ---
# Define categories for the *joint* model
# IMPORTANT: Ensure these IDs match the ones assigned during data conversion
categories_joint = ["table", "cell"]
TABLE_CAT_ID = 0
CELL_CAT_ID = 1
NUM_CLASSES = len(categories_joint)

# Colors for visualization (BGR) - Ensure order matches categories_joint
colors = [(0, 0, 255), (0, 255, 0)] # Red for table, Green for cell

# --- Configuration ---
# Recommend using FinTabNet.c if available, adjust paths accordingly
# For now, using original paths provided
BASE_DIR = 'fintabnet'
#!! IMPORTANT: Ensure this points to the folder with ORIGINAL PDFs!!
PDF_FOLDER = os.path.join(BASE_DIR, 'pdf')
#!! Consider switching to FinTabNet.c JSONL files if available!!
TRAIN_JSONL = os.path.join(BASE_DIR, 'FinTabNet_1.0.0_table_train.jsonl')
VAL_JSONL = os.path.join(BASE_DIR, 'FinTabNet_1.0.0_table_val.jsonl')
TEST_JSONL = os.path.join(BASE_DIR, 'FinTabNet_1.0.0_table_test.jsonl')

# Cache directory for processed data (increment version if logic changes)
CACHE_DIR = os.path.join(BASE_DIR, "converted_cache_gte_v1")
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(PDF_FOLDER, exist_ok=True) # Ensure PDF folder exists

# Output directory for model checkpoints and logs
OUTPUT_DIR = "./output_fintabnet_gte_model_v1"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Base Model Config ---
cfg = get_cfg()
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Use Faster R-CNN with ResNet-50 FPN as a baseline
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")

# --- Dataset Configuration ---
# Use FinTabNet.c names if switching dataset source
cfg.DATASETS.TRAIN = ("fintabnet_gte_train",)
cfg.DATASETS.TEST = ("fintabnet_gte_val",) # Use validation set for evaluation during training

# --- Dataloader Config ---
cfg.DATALOADER.NUM_WORKERS = 4 # Adjust based on system cores/memory
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True # Keep default

# --- Model Configuration ---
cfg.MODEL.META_ARCHITECTURE = "GTE_MetaArch" # Register custom MetaArch
cfg.MODEL.ROI_HEADS.NUM_CLASSES = NUM_CLASSES # Set to 2 (table, cell)
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64 # Adjust based on GPU memory (default 512)

# Anchor generator settings (adjust if needed, especially for small cells)
# Consider adding smaller anchors or different aspect ratios if cell detection is poor
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[8], [16], [32], [64], [128]] # Added smaller size
cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.25, 0.5, 1.0, 2.0, 4.0]] # Added wider ratio

# --- Custom GTE Configuration ---
cfg.MODEL.GTE = CfgNode()
#!! CRITICAL HYPERPARAMETER: Tune this value!!
cfg.MODEL.GTE.CONTAINMENT_LOSS_WEIGHT = 1.0 # Start with 1.0, experiment with 0.1, 0.5, 2.0, 5.0 etc.
# Optional: Threshold for IoU to consider a cell 'contained' in loss calculation
cfg.MODEL.GTE.CONTAINMENT_IOU_THRESH = 0.5

# --- Solver Configuration ---
cfg.SOLVER.IMS_PER_BATCH = 2 # Adjust based on GPU memory (try 2, 4, 8, 16...)
# Adjust LR based on batch size (Linear Scaling Rule: new_lr = base_lr * new_batch_size / base_batch_size)
# Base LR for IMS_PER_BATCH=16 is often 0.02 for SGD
cfg.SOLVER.BASE_LR = 0.005 # Example for IMS_PER_BATCH=4 (0.02 * 4 / 16 = 0.005) - TUNE THIS
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
# Increase iterations significantly for large datasets like FinTabNet
cfg.SOLVER.MAX_ITER = 90000 # Example: ~10 epochs if dataset size ~70k and batch size 4
# Adjust steps based on MAX_ITER (e.g., decay at 66% and 88% of total iterations)
cfg.SOLVER.STEPS = (120000, 160000)
cfg.SOLVER.GAMMA = 0.1 # LR decay factor
cfg.SOLVER.WARMUP_ITERS = 1000
cfg.SOLVER.WARMUP_FACTOR = 1.0 / cfg.SOLVER.WARMUP_ITERS # Standard linear warmup
cfg.SOLVER.WEIGHT_DECAY = 0.0001
cfg.SOLVER.CHECKPOINT_PERIOD = 5000 # Save checkpoints less frequently for long runs
cfg.TEST.EVAL_PERIOD = 200 # Evaluate on validation set periodically

cfg.OUTPUT_DIR = OUTPUT_DIR

# Freeze config to prevent accidental changes
cfg.freeze()

In [None]:
# %% [markdown]
# ## 3. Data Loading and Preparation (with DPI Scaling)

# %%
# Keep the parse_fintabnet_jsonl function as is (from previous code)
def parse_fintabnet_jsonl(jsonl_path, data_folder):
    """Parses FinTabNet JSONL file. Expects PNGs in data_folder."""
    images_data = {}
    if not os.path.exists(jsonl_path):
        logging.error(f"JSONL file not found: {jsonl_path}")
        return {}
    logging.info(f"Parsing {os.path.basename(jsonl_path)}...")
    line_count, parsed_count, skipped_missing_file = 0, 0, 0
    processed_files = set() # Track processed PDF filenames to count records correctly

    with open(jsonl_path, 'r') as fp:
        for line in fp:
            line_count += 1
            try:
                sample = json.loads(line)
                filename_pdf = sample['filename']
                png_filename = filename_pdf.replace(".pdf", ".png")
                png_filepath = os.path.join(data_folder, png_filename)

                if not os.path.exists(png_filepath):
                    if skipped_missing_file % 1000 == 0:
                         logging.warning(f"PNG file not found (logged once per 1000): {png_filepath}")
                    skipped_missing_file += 1
                    continue

                # Use filename_pdf as the primary key
                if filename_pdf not in images_data:
                    try:
                        with Image.open(png_filepath) as img:
                            width, height = img.size
                    except Exception as img_e:
                        logging.error(f"Failed to open/read dimensions for {png_filepath}: {img_e}")
                        skipped_missing_file += 1
                        continue
                    images_data[filename_pdf] = {'filepath_png': png_filepath, 'width': width, 'height': height, 'annotations':[]}

                annotations = images_data[filename_pdf]['annotations']
                added_annotation = False

                # Check if it's a table annotation entry
                # Heuristic: if 'html' key exists OR it's a simple entry (filename, bbox, split)
                is_table_entry = 'html' in sample or len(sample.keys()) <= 3

                if "bbox" in sample and is_table_entry:
                    annotations.append({"category_id": TABLE_CAT_ID, "bbox": sample["bbox"], "is_table": True}) # Use TABLE_CAT_ID
                    added_annotation = True

                # Check for cell annotations within the html structure
                if "html" in sample and "cells" in sample["html"]:
                    for token in sample["html"]["cells"]:
                        if "bbox" in token:
                            annotations.append({"category_id": CELL_CAT_ID, "bbox": token["bbox"], "is_table": False}) # Use CELL_CAT_ID
                            added_annotation = True

                if added_annotation:
                    processed_files.add(filename_pdf)

            except json.JSONDecodeError as e:
                logging.warning(f"Warning: Invalid JSON in line {line_count} in {jsonl_path}: {e}")
            except Exception as e:
                logging.warning(f"Warning: Error processing line {line_count} in {jsonl_path}: {e}")

    final_images_data = {k: v for k, v in images_data.items() if v.get('annotations')}
    parsed_count = len(final_images_data)

    logging.info(f"Finished parsing {os.path.basename(jsonl_path)}. Lines: {line_count}, Parsed Records with Annotations: {parsed_count}, Skipped (Missing/Unreadable PNG): {skipped_missing_file}")
    return final_images_data

# %%
# Modified conversion function to produce a single dataset dict list
# with correctly scaled coordinates for both tables and cells.
def convert_to_detectron_gte(images_dict, pdf_base_folder):
    """
    Converts parsed data to Detectron2 format for the GTE joint model,
    correctly scaling PDF coordinates to rendered image pixel coordinates.

    Args:
        images_dict (dict): Dictionary mapping PDF filenames to parsed data.
        pdf_base_folder (str): Base directory containing original PDF files.

    Returns:
        list[dict]: A list of dataset dictionaries for Detectron2.
    """
    dataset_dicts = []
    file_count = len(images_dict)
    processed_count, skipped_img_error, skipped_pdf_error = 0, 0, 0
    logging.info("Converting data to Detectron2 format (Joint GTE) with DPI scaling...")

    sorted_items = sorted(images_dict.items())

    for idx, (filename_pdf, data) in enumerate(sorted_items):
        processed_count += 1
        png_path = data.get('filepath_png')
        rendered_width = data.get('width')
        rendered_height = data.get('height')
        pdf_path = os.path.join(pdf_base_folder, filename_pdf)

        if not png_path or not rendered_width or not rendered_height:
            logging.warning(f"Missing essential PNG data for {filename_pdf}. Skipping.")
            skipped_img_error += 1
            continue

        if not os.path.exists(pdf_path):
            logging.warning(f"Original PDF file not found at {pdf_path}. Skipping.")
            skipped_pdf_error += 1
            continue

        # --- Get PDF Page Dimensions and Calculate Scaling ---
        try:
            with fitz.open(pdf_path) as doc:
                if not doc or len(doc) == 0: raise ValueError("PDF empty.")
                page = doc.load_page(0) # Assuming first page
                pdf_rect = page.rect
                pdf_page_width = pdf_rect.width
                pdf_page_height = pdf_rect.height
                if pdf_page_width <= 0 or pdf_page_height <= 0: raise ValueError("Invalid PDF dims.")
                scale_x = rendered_width / pdf_page_width
                scale_y = rendered_height / pdf_page_height
        except Exception as e:
            logging.error(f"Error processing PDF {pdf_path}: {e}. Skipping.")
            skipped_pdf_error += 1
            continue
        # --- ---
 
        record = {
            "file_name": png_path,
            "image_id": idx,
            "height": rendered_height,
            "width": rendered_width,
            "annotations":[]
        }

        # --- Process ALL Annotations (Scale and Flip) ---
        for ann in data.get("annotations",): # Added default empty list
            if "bbox" not in ann or len(ann["bbox"])!= 4: continue
            category_id = ann.get("category_id") # Should be TABLE_CAT_ID or CELL_CAT_ID
            VALID_CATEGORY_IDS = [0,1,"0","1"]
            # *** CORRECTED CHECK ***
            if category_id not in VALID_CATEGORY_IDS:
                 logging.warning(f"Skipping annotation with unexpected category_id {category_id} in {filename_pdf}")
                 continue

            pdf_x1, pdf_y1, pdf_x2, pdf_y2 = ann["bbox"]
            if not all(isinstance(coord, (int, float)) for coord in [pdf_x1, pdf_y1, pdf_x2, pdf_y2]): continue

            # Apply scaling and vertical flip
            pixel_x1 = pdf_x1 * scale_x
            pixel_x2 = pdf_x2 * scale_x
            scaled_pdf_y1 = pdf_y1 * scale_y
            scaled_pdf_y2 = pdf_y2 * scale_y
            pixel_y1 = rendered_height - scaled_pdf_y2
            pixel_y2 = rendered_height - scaled_pdf_y1

            # Clip coordinates and ensure valid box
            pixel_x1_c = max(0.0, min(pixel_x1, pixel_x2))
            pixel_x2_c = min(float(rendered_width), max(pixel_x1, pixel_x2))
            pixel_y1_c = max(0.0, min(pixel_y1, pixel_y2))
            pixel_y2_c = min(float(rendered_height), max(pixel_y1, pixel_y2))

            if pixel_x2_c <= pixel_x1_c or pixel_y2_c <= pixel_y1_c: continue # Skip invalid boxes

            record["annotations"].append({
                "bbox": [pixel_x1_c, pixel_y1_c, pixel_x2_c, pixel_y2_c],
                "bbox_mode": BoxMode.XYXY_ABS,
                "category_id": category_id, # Use the original category ID (0 or 1)
            })

        # Only add record if it has valid annotations
        if record["annotations"]:
            dataset_dicts.append(record)

        print(f"Converting: {processed_count}/{file_count} | Records: {len(dataset_dicts)} | Skipped (Img Err): {skipped_img_error} | Skipped (PDF Err): {skipped_pdf_error}   ", end="\r")

    print() # Newline after progress indicator
    logging.info(f"Finished Detectron2 conversion (Joint GTE). Created {len(dataset_dicts)} records. Skipped images: {skipped_img_error}, Skipped PDFs: {skipped_pdf_error}")
    return dataset_dicts

# %%
# Modified caching function
def load_or_convert_gte_dataset(dataset_name, jsonl_path, image_folder, pdf_folder, cache_dir):
    """Loads GTE joint dataset from cache or parses/converts/caches."""
    # Use a distinct cache filename for the GTE joint data
    cache_path = os.path.join(cache_dir, f"converted_{dataset_name}_gte_joint.pkl")

    if os.path.exists(cache_path):
        logging.info(f"Loading cached GTE joint dataset '{dataset_name}' from {cache_dir}...")
        try:
            with open(cache_path, 'rb') as f:
                data = pickle.load(f)
            if not isinstance(data, list): raise TypeError("Cached data not list.")
            logging.info(f"Loaded {len(data)} records.")
            if data and not isinstance(data, dict): raise TypeError("Cache item not dict.")
            return data
        except Exception as e:
            logging.error(f"Error loading cache from {cache_path}: {e}. Re-processing.")

    logging.info(f"Parsing and converting GTE joint dataset '{dataset_name}' from {jsonl_path}...")
    # Step 1: Parse the JSONL to get raw data including PNG paths and PDF coords
    raw_data = parse_fintabnet_jsonl(jsonl_path, image_folder) # image_folder has PNGs
    if not raw_data:
        logging.error(f"Error: No data parsed from {jsonl_path}. Cannot proceed.")
        return

    # Step 2: Convert using the function that handles PDF scaling
    converted_data = convert_to_detectron_gte(raw_data, pdf_folder) # pdf_folder has PDFs

    # Step 3: Cache the result
    if converted_data:
        try:
            logging.info(f"Saving converted GTE joint dataset '{dataset_name}' to cache: {cache_path}")
            os.makedirs(os.path.dirname(cache_path), exist_ok=True)
            with open(cache_path, 'wb') as f:
                pickle.dump(converted_data, f)
        except Exception as e:
            logging.error(f"Error saving cache to {cache_path}: {e}")
    else:
        logging.warning(f"No data converted for {dataset_name}, skipping cache saving.")

    return converted_data

# %%
# --- Load and Register Datasets ---
#!! IMPORTANT: Clear cache directory (CACHE_DIR) if conversion logic changed!!
logging.info(f"Using PDF folder: {PDF_FOLDER}")
logging.info(f"Using Cache folder: {CACHE_DIR}")

# Load data using the new caching function
train_dataset = load_or_convert_gte_dataset("train", TRAIN_JSONL, PDF_FOLDER, PDF_FOLDER, CACHE_DIR)
val_dataset = load_or_convert_gte_dataset("val", VAL_JSONL, PDF_FOLDER, PDF_FOLDER, CACHE_DIR)
test_dataset = load_or_convert_gte_dataset("test", TEST_JSONL, PDF_FOLDER, PDF_FOLDER, CACHE_DIR)

# Register the combined datasets
def register_fintabnet_gte_datasets(train_data, val_data, test_data):
    datasets_to_register = {
        "fintabnet_gte_train": train_data,
        "fintabnet_gte_val": val_data,
        "fintabnet_gte_test": test_data,
    }
    print("\nRegistering GTE joint datasets...")
    for name, data in datasets_to_register.items():
        if name in DatasetCatalog.list():
            logging.warning(f"Dataset '{name}' already registered. Removing and re-registering.")
            DatasetCatalog.remove(name)
            if name in MetadataCatalog.list():
                MetadataCatalog.remove(name)

        if data and isinstance(data, list) and len(data) > 0:
            logging.info(f"Registering {name} with {len(data)} records.")
            # Lambda function ensures data is loaded only when accessed
            DatasetCatalog.register(name, lambda d=data: d)
            # Set metadata (classes) for this dataset
            MetadataCatalog.get(name).set(thing_classes=categories_joint)
            MetadataCatalog.get(name).set(thing_colors=colors) # Optional: for visualization consistency
        else:
            logging.warning(f"Skipping registration for {name} (dataset empty or invalid).")

register_fintabnet_gte_datasets(train_dataset, val_dataset, test_dataset)

# Verify registration (optional)
print("Registered datasets:", DatasetCatalog.list())
if "fintabnet_gte_train" in MetadataCatalog.list():
    metadata = MetadataCatalog.get("fintabnet_gte_train")
    print("Train metadata:", metadata)
else:
    metadata = None # Needed for visualization later



2025-04-29 00:17:56,136 - INFO - Using PDF folder: fintabnet/pdf
2025-04-29 00:17:56,137 - INFO - Using Cache folder: fintabnet/converted_cache_gte_v1
2025-04-29 00:17:56,139 - INFO - Loading cached GTE joint dataset 'train' from fintabnet/converted_cache_gte_v1...
2025-04-29 00:18:03,222 - INFO - Loaded 47985 records.
2025-04-29 00:18:03,222 - ERROR - Error loading cache from fintabnet/converted_cache_gte_v1/converted_train_gte_joint.pkl: Cache item not dict.. Re-processing.
2025-04-29 00:18:03,223 - INFO - Parsing and converting GTE joint dataset 'train' from fintabnet/FinTabNet_1.0.0_table_train.jsonl...
2025-04-29 00:18:03,223 - INFO - Parsing FinTabNet_1.0.0_table_train.jsonl...
2025-04-29 00:18:18,943 - INFO - Finished parsing FinTabNet_1.0.0_table_train.jsonl. Lines: 61801, Parsed Records with Annotations: 48001, Skipped (Missing/Unreadable PNG): 0
2025-04-29 00:18:18,947 - INFO - Converting data to Detectron2 format (Joint GTE) with DPI scaling...


Converting: 47993/48001 | Records: 47977 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   

2025-04-29 00:18:54,247 - INFO - Finished Detectron2 conversion (Joint GTE). Created 47985 records. Skipped images: 0, Skipped PDFs: 0
2025-04-29 00:18:54,330 - INFO - Saving converted GTE joint dataset 'train' to cache: fintabnet/converted_cache_gte_v1/converted_train_gte_joint.pkl


Converting: 48001/48001 | Records: 47985 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   


2025-04-29 00:19:01,081 - INFO - Loading cached GTE joint dataset 'val' from fintabnet/converted_cache_gte_v1...
2025-04-29 00:19:01,307 - INFO - Loaded 5943 records.
2025-04-29 00:19:01,308 - ERROR - Error loading cache from fintabnet/converted_cache_gte_v1/converted_val_gte_joint.pkl: Cache item not dict.. Re-processing.
2025-04-29 00:19:01,309 - INFO - Parsing and converting GTE joint dataset 'val' from fintabnet/FinTabNet_1.0.0_table_val.jsonl...
2025-04-29 00:19:01,310 - INFO - Parsing FinTabNet_1.0.0_table_val.jsonl...
2025-04-29 00:19:02,564 - INFO - Finished parsing FinTabNet_1.0.0_table_val.jsonl. Lines: 7191, Parsed Records with Annotations: 5943, Skipped (Missing/Unreadable PNG): 0
2025-04-29 00:19:02,565 - INFO - Converting data to Detectron2 format (Joint GTE) with DPI scaling...


Converting: 5755/5943 | Records: 5755 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   

2025-04-29 00:19:05,086 - INFO - Finished Detectron2 conversion (Joint GTE). Created 5943 records. Skipped images: 0, Skipped PDFs: 0
2025-04-29 00:19:05,087 - INFO - Saving converted GTE joint dataset 'val' to cache: fintabnet/converted_cache_gte_v1/converted_val_gte_joint.pkl


Converting: 5943/5943 | Records: 5943 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   


2025-04-29 00:19:05,354 - INFO - Loading cached GTE joint dataset 'test' from fintabnet/converted_cache_gte_v1...
2025-04-29 00:19:05,513 - INFO - Loaded 5903 records.
2025-04-29 00:19:05,514 - ERROR - Error loading cache from fintabnet/converted_cache_gte_v1/converted_test_gte_joint.pkl: Cache item not dict.. Re-processing.
2025-04-29 00:19:05,514 - INFO - Parsing and converting GTE joint dataset 'test' from fintabnet/FinTabNet_1.0.0_table_test.jsonl...
2025-04-29 00:19:05,515 - INFO - Parsing FinTabNet_1.0.0_table_test.jsonl...
2025-04-29 00:19:07,912 - INFO - Finished parsing FinTabNet_1.0.0_table_test.jsonl. Lines: 7085, Parsed Records with Annotations: 5903, Skipped (Missing/Unreadable PNG): 0
2025-04-29 00:19:07,913 - INFO - Converting data to Detectron2 format (Joint GTE) with DPI scaling...


Converting: 5633/5903 | Records: 5633 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   

2025-04-29 00:19:11,146 - INFO - Finished Detectron2 conversion (Joint GTE). Created 5903 records. Skipped images: 0, Skipped PDFs: 0
2025-04-29 00:19:11,149 - INFO - Saving converted GTE joint dataset 'test' to cache: fintabnet/converted_cache_gte_v1/converted_test_gte_joint.pkl


Converting: 5903/5903 | Records: 5903 | Skipped (Img Err): 0 | Skipped (PDF Err): 0   


2025-04-29 00:19:11,484 - INFO - Registering fintabnet_gte_train with 47985 records.
2025-04-29 00:19:11,485 - INFO - Registering fintabnet_gte_val with 5943 records.
2025-04-29 00:19:11,485 - INFO - Registering fintabnet_gte_test with 5903 records.



Registering GTE joint datasets...
Registered datasets: ['coco_2014_train', 'coco_2014_val', 'coco_2014_minival', 'coco_2014_valminusminival', 'coco_2017_train', 'coco_2017_val', 'coco_2017_test', 'coco_2017_test-dev', 'coco_2017_val_100', 'keypoints_coco_2014_train', 'keypoints_coco_2014_val', 'keypoints_coco_2014_minival', 'keypoints_coco_2014_valminusminival', 'keypoints_coco_2017_train', 'keypoints_coco_2017_val', 'keypoints_coco_2017_val_100', 'coco_2017_train_panoptic_separated', 'coco_2017_train_panoptic_stuffonly', 'coco_2017_train_panoptic', 'coco_2017_val_panoptic_separated', 'coco_2017_val_panoptic_stuffonly', 'coco_2017_val_panoptic', 'coco_2017_val_100_panoptic_separated', 'coco_2017_val_100_panoptic_stuffonly', 'coco_2017_val_100_panoptic', 'lvis_v1_train', 'lvis_v1_val', 'lvis_v1_test_dev', 'lvis_v1_test_challenge', 'lvis_v0.5_train', 'lvis_v0.5_val', 'lvis_v0.5_val_rand_100', 'lvis_v0.5_test', 'lvis_v0.5_train_cocofied', 'lvis_v0.5_val_cocofied', 'cityscapes_fine_instan

In [None]:
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY

if "GTE_MetaArch" in META_ARCH_REGISTRY._obj_map:
    del META_ARCH_REGISTRY._obj_map["GTE_MetaArch"]
    print("Successfully unregistered GTE_MetaArch!")
else:
    print("GTE_MetaArch not found in registry.")

Successfully unregistered GTE_MetaArch!


In [None]:

# --- Custom Meta Architecture ---
@META_ARCH_REGISTRY.register()
class GTE_MetaArch(nn.Module):
    """
    Meta architecture inspired by GTE, using standard Detectron2 components
    but adding a custom containment loss during training.
    Assumes a single ROI head predicts both table and cell classes.
    """
    @configurable
    def __init__(
        self,
        *,
        backbone,
        proposal_generator,
        roi_heads,
        pixel_mean,
        pixel_std,
        input_format=None,
        vis_period=0,
        containment_loss_weight=1.0,
        containment_iou_thresh=0.5
    ):
        super().__init__()

        self.backbone = backbone
        self.proposal_generator = proposal_generator
        self.roi_heads = roi_heads # Single head predicting both classes

        self.input_format = input_format
        self.vis_period = vis_period
        if vis_period > 0:
            assert input_format is not None, "input_format is required for visualization!"

        self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)

        # GTE specific parameters
        self.containment_loss_weight = containment_loss_weight
        self.containment_iou_thresh = containment_iou_thresh
        logging.info(f" Containment Loss Weight: {self.containment_loss_weight}")
        logging.info(f" Containment IoU Threshold: {self.containment_iou_thresh}")


    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        return {
            "backbone": backbone,
            "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()),
            "roi_heads": build_roi_heads(cfg, backbone.output_shape()), # Build the single ROI head
            "input_format": cfg.INPUT.FORMAT,
            "vis_period": cfg.VIS_PERIOD,
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            "containment_loss_weight": cfg.MODEL.GTE.CONTAINMENT_LOSS_WEIGHT,
            "containment_iou_thresh": cfg.MODEL.GTE.CONTAINMENT_IOU_THRESH,
        }

    @property
    def device(self):
        return self.pixel_mean.device

    def preprocess_image(self, batched_inputs):
        """Normalize, pad and batch the input images."""
        images = [x["image"].to(self.device) for x in batched_inputs]
        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility)
        return images
    
    def forward(self, batched_inputs):
        """
        Forward pass for training and inference.
        """
        if not self.training:
            return self.inference(batched_inputs)

        # --- Check for Ground Truth ---
        for x in batched_inputs:
            if "instances" not in x:
                 raise ValueError("Ground truth instances are required for training! Missing 'instances' key.")

        images = self.preprocess_image(batched_inputs)
        gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        features = self.backbone(images.tensor)

        # --- RPN ---
        if self.proposal_generator is not None:
            if not gt_instances:
                 raise ValueError("gt_instances are required for proposal_generator during training.")
            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
        else:
            assert "proposals" in batched_inputs, "Proposals required if proposal_generator is None"
            proposals = [x["proposals"].to(self.device) for x in batched_inputs]
            proposal_losses = {}

        # --- ROI Heads (Loss Calculation) ---
        # Standard call to calculate losses
        # The first returned value (often None or {}) is ignored here for loss calculation.
        _, detector_losses = self.roi_heads(
            images=images,
            features=features,
            proposals=proposals,
            targets=gt_instances
        )

        # --- Combine Standard Losses ---
        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        # --- Custom Containment Loss ---
        # Requires running ROI heads in inference mode to get predictions
        if self.containment_loss_weight > 0:
            # Store current training mode
            original_training_mode = self.roi_heads.training
            # Set to eval mode for inference pass
            self.roi_heads.eval()
            with torch.no_grad(): # No gradients needed for this inference pass
                # Run inference pass using the same features and proposals
                # targets=None indicates inference mode for ROI heads
                pred_instances_list, _ = self.roi_heads(
                    images=images,
                    features=features,
                    proposals=proposals,
                    targets=None # Ensures inference path is taken
                )
            # Restore original training mode
            self.roi_heads.train(original_training_mode)

            # Now calculate containment loss using the predictions
            if pred_instances_list is not None:
                 # Check if the output is a list of Instances
                 if isinstance(pred_instances_list, list) and all(isinstance(inst, Instances) for inst in pred_instances_list):
                     containment_loss = self.compute_containment_loss(pred_instances_list)
                     losses["loss_containment"] = containment_loss * self.containment_loss_weight
                 else:
                     logging.warning("ROI heads inference output format unexpected for containment loss. Skipping loss.")
                     losses["loss_containment"] = torch.tensor(0.0, device=self.device)
            else:
                 losses["loss_containment"] = torch.tensor(0.0, device=self.device) # No predictions
        else:
            losses["loss_containment"] = torch.tensor(0.0, device=self.device) # Loss weight is 0

        # --- Visualization (Training) ---
        # (Optional: visualize proposals or predictions if needed)
        # if self.vis_period > 0:
        #     storage = get_event_storage()
        #     if storage.iter % self.vis_period == 0:
        #         self.visualize_training(batched_inputs, proposals) # Assuming this method exists

        return losses

    def inference(self, batched_inputs, detected_instances=None, do_postprocess=True):
        """
        Run inference on the component models.
        """
        assert not self.training

        images = self.preprocess_image(batched_inputs) # Get the images object
        features = self.backbone(images.tensor)

        if detected_instances is None:
            if self.proposal_generator is not None:
                proposals, _ = self.proposal_generator(images, features, None)
            else:
                assert "proposals" in batched_inputs
                proposals = [x["proposals"].to(self.device) for x in batched_inputs]

            # *** CORRECTED CALL: Pass 'images' as the first argument ***
            results, _ = self.roi_heads(images, features, proposals, None)
            # *** END CORRECTION ***
        else:
            # NOTE: If you use the forward_with_given_boxes path,
            # ensure it also receives the correct arguments (likely including 'images').
            # For simplicity, let's assume the standard path is used for now.
            detected_instances = [x.to(self.device) for x in detected_instances]
            # Using standard path even if instances are provided (might need adjustment if specific logic is needed)
            if self.proposal_generator is not None:
                 proposals, _ = self.proposal_generator(images, features, None)
            else:
                 # Handle case where proposals are needed but not generated/provided
                 raise ValueError("Proposal generator is None, but proposals not found in input for inference with detected_instances.")

            results, _ = self.roi_heads(images, features, proposals, targets=None) # Pass images here too


        if do_postprocess:
            assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
            # Assuming GTE_MetaArch has _postprocess defined correctly
            return GTE_MetaArch._postprocess(results, batched_inputs, images.image_sizes)
        else:
            return results

    @staticmethod
    def _postprocess(instances, batched_inputs, image_sizes):
        """Rescale predictions to original image size."""
        processed_results = []
        for results_per_image, input_per_image, image_size in zip(
            instances, batched_inputs, image_sizes
        ):
            height = input_per_image.get("height", image_size)
            width = input_per_image.get("width", image_size[1])
            r = detector_postprocess(results_per_image, height, width)
            processed_results.append({"instances": r})
        return processed_results

    def compute_containment_loss(self, pred_instances_list):
        """
        Computes the GTE-inspired containment loss.

        Args:
            pred_instances_list (list[Instances]): Predicted Instances per image.
                Instances contain 'pred_boxes' and 'pred_classes'.

        Returns:
            torch.Tensor: Scalar containment loss for the batch.
        """
        total_loss = 0.0
        num_images_processed = 0

        for instances in pred_instances_list:
            if len(instances) == 0:
                continue

            pred_boxes = instances.pred_boxes # Boxes object
            pred_classes = instances.pred_classes # Tensor of class IDs

            # Separate predicted tables and cells
            table_indices = (pred_classes == TABLE_CAT_ID).nonzero().squeeze(1)
            cell_indices = (pred_classes == CELL_CAT_ID).nonzero().squeeze(1)

            if len(table_indices) == 0 or len(cell_indices) == 0:
                continue # Need both tables and cells predicted in an image for this loss

            pred_table_boxes = pred_boxes[table_indices] # Boxes object for tables
            pred_cell_boxes = pred_boxes[cell_indices]   # Boxes object for cells

            # Calculate pairwise IoU between predicted cells and predicted tables
            # Shape: (num_pred_cells, num_pred_tables)
            ious = pairwise_iou(pred_cell_boxes, pred_table_boxes)

            # --- Loss Component 1: Penalize cells "outside" tables ---
            # For each cell, find the max IoU with any table
            max_iou_per_cell, _ = torch.max(ious, dim=1)
            # Cells with low max IoU are considered "outside"
            # Penalty could be 1 - max_iou, or based on a threshold
            # Simple penalty: average (1 - max_iou) for cells below threshold
            outside_penalty = 1.0 - max_iou_per_cell
            # Apply penalty only if max IoU is below threshold (optional, can make loss less noisy)
            # outside_penalty = outside_penalty[max_iou_per_cell < self.containment_iou_thresh]
            loss_outside = outside_penalty.mean() if len(outside_penalty) > 0 else torch.tensor(0.0, device=self.device)


            # --- Loss Component 2: Penalize tables "not containing" cells ---
            # For each table, count how many cells have IoU > threshold with it
            contained_mask = ious > self.containment_iou_thresh
            cells_contained_per_table = torch.sum(contained_mask, dim=0).float() # Shape: (num_pred_tables,)

            # Simple penalty: encourage tables to contain at least one cell
            # Penalize tables with zero contained cells (using 1 / (count + eps) encourages higher counts)
            # Avoid division by zero with a small epsilon
            epsilon = 1e-6
            table_penalty = 1.0 / (cells_contained_per_table + epsilon)
            loss_table_containment = table_penalty.mean() if len(table_penalty) > 0 else torch.tensor(0.0, device=self.device)

            # Combine penalties (simple sum for now, could be weighted)
            image_loss = loss_outside + loss_table_containment
            total_loss += image_loss
            num_images_processed += 1

        # Average loss over images that contributed
        return total_loss / num_images_processed if num_images_processed > 0 else torch.tensor(0.0, device=self.device)



In [None]:
# %% [markdown]
# ## 4. Custom Model (GTE_MetaArch) and Trainer

# --- Custom Trainer ---
class GTETrainer(DefaultTrainer):
    """
    Custom trainer that uses the GTE_MetaArch and potentially custom data loaders/evaluators.
    """
    @classmethod
    def build_model(cls, cfg):
        """Builds the GTE_MetaArch."""
        model = META_ARCH_REGISTRY.get(cfg.MODEL.META_ARCHITECTURE)(cfg)
        model.to(cfg.MODEL.DEVICE)
        logging.info(f"Model:\n{model}")
        return model

    @classmethod
    def build_train_loader(cls, cfg):
        """Builds the training data loader."""
        # Use standard DatasetMapper for now, assuming combined GT is handled
        mapper = DatasetMapper(cfg, is_train=True)
        return build_detection_train_loader(cfg, mapper=mapper)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """Builds the evaluator for the validation set."""
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        # Use standard COCOEvaluator
        # Ensure the inference output format matches what COCOEvaluator expects
        return COCOEvaluator(dataset_name, output_dir=output_folder)

# %% [markdown]
# ## 5. Training

# %%
# --- Start Training ---
do_train = True
if do_train:
    print(f"\n--- Starting GTE Joint Model Training ---")
    print(f"Output directory: {cfg.OUTPUT_DIR}")
    print(f"Configured MAX_ITER: {cfg.SOLVER.MAX_ITER}")
    print(f"Containment Loss Weight: {cfg.MODEL.GTE.CONTAINMENT_LOSS_WEIGHT}")

    # Check if the configured training dataset exists and is registered
    if not cfg.DATASETS.TRAIN: # Check if the tuple is empty first
        logging.error("No training dataset specified in cfg.DATASETS.TRAIN. Aborting training.")
        do_train = False # Prevent proceeding
    else:
        train_dataset_name_str = cfg.DATASETS.TRAIN[0] # Get the first element (the string name)

        if train_dataset_name_str not in DatasetCatalog.list():
            logging.error(f"Training dataset '{train_dataset_name_str}' is not registered in DatasetCatalog. Aborting training.")
            do_train = False # Prevent proceeding
        # Check if the actual data getter function returns something (is not empty)
        # This ensures the data loading/conversion didn't fail silently earlier
        elif not DatasetCatalog.get(train_dataset_name_str):
            logging.error(f"Training dataset '{train_dataset_name_str}' is registered but appears empty (data loading/conversion might have failed). Aborting training.")
            do_train = False # Prevent proceeding

    # Proceed only if the checks passed
    if do_train:
        trainer = GTETrainer(cfg)
        # Load last checkpoint if available, otherwise start from MODEL.WEIGHTS
        trainer.resume_or_load(resume=True) # Set resume=True to continue training if checkpoint exists
        print(f"GTE Model device: {next(trainer.model.parameters()).device}")
        try:
            trainer.train()
            print("--- GTE Joint Model Training Finished ---")
        except Exception as e:
            print(f"\nAn error occurred during GTE training: {e}")
            import traceback
            traceback.print_exc()
else:
    print("\nSkipping GTE Joint Model training (do_train=False).")



--- Starting GTE Joint Model Training ---
Output directory: ./output_fintabnet_gte_model_v1
Configured MAX_ITER: 90000
Containment Loss Weight: 1.0


2025-04-29 00:40:31,123 - INFO -  Containment Loss Weight: 1.0
2025-04-29 00:40:31,123 - INFO -  Containment IoU Threshold: 0.5
2025-04-29 00:40:31,173 - INFO - Model:
GTE_MetaArch(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2)

GTE Model device: cuda:0


2025-04-29 00:41:00,677 - INFO -  eta: 19:39:49  iter: 19  total_loss: 5.163e+05  loss_cls: 0.8325  loss_box_reg: 0.01731  loss_rpn_cls: 0.6661  loss_rpn_loc: 0.4972  loss_containment: 5.163e+05    time: 0.9504  last_time: 0.7525  data_time: 0.0647  last_data_time: 0.0055   lr: 9.9905e-05  max_mem: 6245M
2025-04-29 00:41:20,682 - INFO -  eta: 19:41:39  iter: 39  total_loss: 1.701  loss_cls: 0.4184  loss_box_reg: 0.1037  loss_rpn_cls: 0.5153  loss_rpn_loc: 0.3322  loss_containment: 0    time: 0.9762  last_time: 0.7782  data_time: 0.0042  last_data_time: 0.0040   lr: 0.0001998  max_mem: 6245M
2025-04-29 00:41:39,981 - INFO -  eta: 20:03:15  iter: 59  total_loss: 1.371  loss_cls: 0.2928  loss_box_reg: 0.1381  loss_rpn_cls: 0.4186  loss_rpn_loc: 0.466  loss_containment: 0    time: 0.9715  last_time: 0.7907  data_time: 0.0043  last_data_time: 0.0041   lr: 0.0002997  max_mem: 6245M
2025-04-29 00:41:59,655 - INFO -  eta: 20:11:29  iter: 79  total_loss: 1.153  loss_cls: 0.2746  loss_box_reg: 0

KeyboardInterrupt: 

In [None]:


# %% [markdown]
# ## 6. Inference and Visualization

# %%
# --- Custom Predictor (Optional, DefaultPredictor might work) ---
class GTEPredictor:
    """Simple predictor wrapper for the trained GTE model."""
    def __init__(self, cfg):
        self.cfg = cfg.clone() # Clone cfg to avoid modifying original
        self.cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
        # Set threshold for inference
        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Adjust as needed
        self.cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = GTETrainer.build_model(self.cfg) # Build the custom model
        self.model.eval()
        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(self.cfg.MODEL.WEIGHTS)

        # Get metadata for visualization
        self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST if cfg.DATASETS.TEST else "__unused")
        if not self.metadata.thing_classes:
             logging.warning("Metadata missing thing_classes, setting default for GTE.")
             self.metadata.thing_classes = categories_joint
             self.metadata.thing_colors = colors


    def __call__(self, original_image_bgr):
        """
        Args:
            original_image_bgr (np.ndarray): An image in BGR format.

        Returns:
            predictions (dict): The model's predictions in Detectron2 format.
        """
        with torch.no_grad():
            height, width = original_image_bgr.shape[:2]
            image = torch.as_tensor(original_image_bgr.astype("float32").transpose(2, 0, 1))
            inputs = {"image": image, "height": height, "width": width}
            predictions = self.model([inputs]) # Run model's inference method
            return predictions

# %%
# --- Initialize Predictor ---
predictor_ready = False
gte_predictor = None
final_model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")

if os.path.exists(final_model_path):
    logging.info(f"\nInitializing GTEPredictor with weights from {final_model_path}...")
    try:
        predictor_cfg = cfg.clone() # Use the same config used for training
        gte_predictor = GTEPredictor(predictor_cfg)
        predictor_ready = True
        logging.info("GTEPredictor initialized successfully.")
    except Exception as e:
        logging.error(f"\nError initializing GTEPredictor: {e}")
        import traceback
        traceback.print_exc()
else:
    logging.warning(f"\nCannot initialize predictor: Final model weights not found at {final_model_path}")

# %%
# --- Evaluation ---
# Run evaluation using the custom trainer's method
if predictor_ready and cfg.DATASETS.TEST:
    test_dataset_name = cfg.DATASETS.TEST
    if test_dataset_name in DatasetCatalog.list():
        logging.info(f"\nRunning evaluation on GTE test set ({test_dataset_name})...")
        eval_output_dir = os.path.join(cfg.OUTPUT_DIR, "inference_test_final")
        os.makedirs(eval_output_dir, exist_ok=True)
        evaluator = GTETrainer.build_evaluator(cfg, test_dataset_name, output_folder=eval_output_dir)
        test_loader = GTETrainer.build_test_loader(cfg, test_dataset_name)
        try:
            results = inference_on_dataset(gte_predictor.model, test_loader, evaluator)
            print("\n--- GTE Model Evaluation Results ---")
            print(results)
        except Exception as e:
            print(f"\nError during GTE evaluation: {e}")
            import traceback
            traceback.print_exc()
    else:
        print(f"\nSkipping GTE evaluation: Test dataset '{test_dataset_name}' not found.")
else:
    print("\nSkipping GTE evaluation (predictor or test dataset not ready).")


# %%
# --- Visualization ---
# Keep the visualize_predictions function, but ensure it uses the correct metadata
def visualize_predictions(dataset, predictor, metadata, num_samples=5, title_prefix="Predictions"):
    """Visualizes predictions from a predictor on a dataset."""
    if not dataset: logging.warning(f"Cannot visualize: Dataset for '{title_prefix}' is empty."); return
    if not predictor: logging.warning(f"Cannot visualize: Predictor for '{title_prefix}' is None."); return
    if not metadata: logging.warning(f"Cannot visualize: Metadata for '{title_prefix}' is None."); return
    if not hasattr(metadata, 'thing_classes') or not metadata.thing_classes:
        logging.error(f"Metadata for '{title_prefix}' missing 'thing_classes'. Cannot visualize."); return

    logging.info(f"\nVisualizing {num_samples} {title_prefix}...")
    actual_num_samples = min(num_samples, len(dataset))
    if actual_num_samples <= 0: logging.info("No samples in dataset."); return

    try:
        # Ensure dataset is a list of dicts before sampling
        if isinstance(dataset, list) and all(isinstance(item, dict) for item in dataset):
            indices = random.sample(range(len(dataset)), actual_num_samples)
            samples = [dataset[i] for i in indices]
        else:
             logging.warning("Dataset is not list of dicts, attempting to take first samples.")
             samples = [item for i, item in enumerate(dataset) if i < actual_num_samples and isinstance(item, dict)]
             if not samples: raise ValueError("Could not get valid samples.")
    except Exception as e:
        logging.error(f"Cannot sample from dataset for '{title_prefix}': {e}"); return

    for sample in samples:
        img_path = sample.get("file_name")
        if not img_path or not os.path.exists(img_path):
            logging.warning(f"Image path invalid in sample: {sample.get('image_id', 'Unknown ID')}. Skipping."); continue
        try:
            img_bgr = cv2.imread(img_path)
            if img_bgr is None: logging.warning(f"Cannot read image: {img_path}. Skipping."); continue
        except Exception as read_e:
            logging.warning(f"Exception reading image {img_path}: {read_e}. Skipping."); continue

        logging.info(f"Running predictor for: {os.path.basename(img_path)}")
        try:
            outputs = predictor(img_bgr) # Use the predictor instance directly
        except Exception as e:
            logging.error(f"Error during prediction for {os.path.basename(img_path)}: {e}"); continue

        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        v = Visualizer(img_rgb, metadata, scale=1.0, instance_mode=ColorMode.IMAGE) # Use provided metadata

        out_vis = None
        if isinstance(outputs, dict) and "instances" in outputs:
            instances = outputs["instances"].to("cpu")
            if len(instances) > 0:
                out_vis = v.draw_instance_predictions(instances)
            else:
                logging.info(f"No instances detected for {os.path.basename(img_path)}.")
        else:
            logging.warning(f"Predictor output format not recognized for visualization."); continue

        if out_vis is not None:
            output_image = out_vis.get_image()
            plt.figure(figsize=(15, 15)); plt.imshow(output_image)
            plt.title(f"{title_prefix} - {os.path.basename(img_path)}"); plt.axis('off'); plt.show()
        else: # Show original image if no detections
            plt.figure(figsize=(12, 12)); plt.imshow(img_rgb)
            plt.title(f"{title_prefix} - {os.path.basename(img_path)} (No Detections)"); plt.axis('off'); plt.show()

# %%
# --- Run Visualization ---
if predictor_ready and test_dataset:
    # Get metadata associated with the registered test dataset
    test_metadata = MetadataCatalog.get(cfg.DATASETS.TEST) if cfg.DATASETS.TEST else None
    if test_metadata and test_metadata.thing_classes:
         visualize_predictions(dataset=test_dataset, predictor=gte_predictor, metadata=test_metadata, num_samples=10, title_prefix="GTE Joint Model Predictions")
    else:
         print("\nSkipping visualization: Metadata not found or incomplete for test dataset.")
else:
    print("\nSkipping visualization (predictor or test dataset not ready).")

NameError: name 'os' is not defined