In [42]:
import os
import re
import logging
import warnings
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    EvalPrediction,
)
import sacrebleu

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
ImageFile.LOAD_TRUNCATED_IMAGES = True


class Config:
    ROOT_DIR = Path("./TableBank/Recognition")
    IMAGE_DIR = ROOT_DIR / "Images"
    ANNOTATION_DIR = ROOT_DIR / "Annotations"
    ENCODER_MODEL = "google/vit-base-patch16-224-in21k"
    DECODER_MODEL = "facebook/bart-base"
    NUM_EPOCHS = 8
    BATCH_SIZE = 4
    LEARNING_RATE = 5e-5
    GRAD_ACCUMULATION_STEPS = 2
    TRAIN_SIZE = 8000
    VAL_SIZE = 1024
    TEST_SIZE = 1024
    MODEL_OUTPUT_DIR = "tsr_vit_tablebank_v_hr"
    CHECKPOINT_DIR = "tsr_vit_tablebank_checkpoints_hr"
    MAX_TARGET_LENGTH = 512
    GENERATION_MAX_LENGTH = 512
    DEBUG_MODE = False
    SEED = 42
    IMAGE_SIZE = 224
    SOURCE_TYPE = "all"


def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def normalize_html_label(s: str) -> str:
    if s is None:
        return s
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    s = s.replace("> <", "><")
    return s


def load_split_annotations(
    annotation_dir: Path, source_type: str, split: str
) -> Tuple[List[str], List[str]]:
    src_pattern = f"src-{source_type}_{split}.txt"
    tgt_pattern = f"tgt-{source_type}_{split}.txt"
    
    src_file = annotation_dir / src_pattern
    tgt_file = annotation_dir / tgt_pattern
    
    if not src_file.exists():
        raise FileNotFoundError(f"Source file not found: {src_file}")
    if not tgt_file.exists():
        raise FileNotFoundError(f"Target file not found: {tgt_file}")
    
    with open(src_file, "r", encoding="utf-8") as f:
        image_paths = [line.strip() for line in f.readlines() if line.strip()]
    
    with open(tgt_file, "r", encoding="utf-8") as f:
        texts = [normalize_html_label(line.rstrip("\n")) for line in f.readlines()]
    
    if len(image_paths) != len(texts):
        raise ValueError(
            f"Mismatch in {src_pattern}/{tgt_pattern}: "
            f"{len(image_paths)} images vs {len(texts)} labels"
        )
    
    logging.info(f"Loaded {split} split: {len(image_paths)} samples from {src_pattern}")
    return image_paths, texts


class TableBankDataset(Dataset):
    def __init__(
        self,
        image_paths: List[str],
        texts: List[str],
        config: Config,
        split: str = "train",
    ):
        self.config = config
        self.split = split
        self.image_paths = image_paths
        self.texts = texts
        
        if split == "train" and len(image_paths) > config.TRAIN_SIZE:
            self.image_paths = image_paths[: config.TRAIN_SIZE]
            self.texts = texts[: config.TRAIN_SIZE]
        elif split == "val" and len(image_paths) > config.VAL_SIZE:
            self.image_paths = image_paths[: config.VAL_SIZE]
            self.texts = texts[: config.VAL_SIZE]
        elif split == "test" and len(image_paths) > config.TEST_SIZE:
            self.image_paths = image_paths[: config.TEST_SIZE]
            self.texts = texts[: config.TEST_SIZE]
        
        logging.info(f"{split.upper()} dataset: {len(self.image_paths)} samples")

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        image_name = self.image_paths[idx]
        image_path = self.config.IMAGE_DIR / image_name
        
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            logging.error(f"Failed to load image {image_path}: {e}")
            image = Image.new("RGB", (self.config.IMAGE_SIZE, self.config.IMAGE_SIZE), color=(255, 255, 255))
        
        text = self.texts[idx]
        return {"image": image, "text": text}


class DataCollator:
    def __init__(
        self, image_processor: ViTImageProcessor, tokenizer: AutoTokenizer, cfg: Config
    ):
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.cfg = cfg

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [item["image"] for item in batch]
        texts = [item["text"] for item in batch]
        
        processed = self.image_processor(images, return_tensors="pt")
        pixel_values = processed["pixel_values"]
        
        tokenized = self.tokenizer(
            text_target=texts,
            padding="longest",
            truncation=True,
            max_length=self.cfg.MAX_TARGET_LENGTH,
            return_tensors="pt",
        )
        
        labels = tokenized.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        decoder_input_ids = tokenized.input_ids.clone()
        decoder_input_ids = torch.roll(decoder_input_ids, shifts=1, dims=1)
        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "decoder_input_ids": decoder_input_ids,
        }


def compute_metrics(pred: EvalPrediction, tokenizer: AutoTokenizer) -> Dict[str, float]:
    preds = pred.predictions
    if preds is None:
        return {"bleu": 0.0}
    
    if isinstance(preds, tuple):
        preds = preds[0]
    
    preds = np.where(preds < 0, tokenizer.pad_token_id, preds)
    preds = np.where(preds >= len(tokenizer), tokenizer.pad_token_id, preds)
    
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    
    label_ids = np.array(pred.label_ids)
    label_ids = np.where(label_ids == -100, tokenizer.pad_token_id, label_ids)
    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    try:
        bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
        score = float(bleu.score)
    except Exception as e:
        logging.warning(f"BLEU computation failed: {e}")
        score = 0.0
    
    return {"bleu": score}


class PredictionPrinterCallback(TrainerCallback):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        image_processor: ViTImageProcessor,
        val_dataset: TableBankDataset,
        cfg: Config,
    ):
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.val_dataset = val_dataset
        self.cfg = cfg

    def on_evaluate(self, args, state, control, model=None, **kwargs):
        if model is None or len(self.val_dataset) == 0:
            return
        
        sample = self.val_dataset[0]
        img = sample["image"]
        model.eval()
        device = next(model.parameters()).device
        
        processed = self.image_processor(img, return_tensors="pt")
        inputs = processed["pixel_values"].to(device)
        
        with torch.no_grad():
            pred_ids = model.generate(
                pixel_values=inputs,
                max_length=self.cfg.GENERATION_MAX_LENGTH,
                num_beams=4,
                early_stopping=True,
                decoder_start_token_id=model.config.decoder_start_token_id,
                eos_token_id=model.config.eos_token_id,
                pad_token_id=model.config.pad_token_id,
            )
        
        if pred_ids is None:
            return
        
        pred_text = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]
        gt_text = sample["text"]
        
        logging.info("\n" + "=" * 60)
        logging.info(
            f"SAMPLE PREDICTION AFTER EPOCH {state.epoch if state.epoch is not None else 0.0:.2f}"
        )
        logging.info("=" * 60)
        logging.info(f"MODEL OUTPUT:\n{pred_text}")
        logging.info(f"\nGROUND TRUTH:\n{gt_text}")
        logging.info("=" * 60 + "\n")


def prepare_model_and_tokenizer(cfg: Config):
    image_processor = ViTImageProcessor.from_pretrained(cfg.ENCODER_MODEL)
    tokenizer = AutoTokenizer.from_pretrained(cfg.DECODER_MODEL)
    
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
    if tokenizer.bos_token is None:
        tokenizer.add_special_tokens({"bos_token": "<s>"})
    if tokenizer.eos_token is None:
        tokenizer.add_special_tokens({"eos_token": "</s>"})
    
    html_tokens = [
        "<table>",
        "</table>",
        "<thead>",
        "</thead>",
        "<tbody>",
        "</tbody>",
        "<tr>",
        "</tr>",
        "<td>",
        "</td>",
        "<th>",
        "</th>",
    ]
    tokenizer.add_tokens(html_tokens)
    
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        cfg.ENCODER_MODEL, cfg.DECODER_MODEL
    )
    model.decoder.resize_token_embeddings(len(tokenizer))
    
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.vocab_size = len(tokenizer)
    
    if hasattr(model.decoder, "config"):
        model.decoder.config.is_decoder = True
        model.decoder.config.add_cross_attention = True
    
    return image_processor, tokenizer, model


def train_model():
    cfg = Config()
    set_seed(cfg.SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    
    image_processor, tokenizer, model = prepare_model_and_tokenizer(cfg)
    
    try:
        train_image_paths, train_texts = load_split_annotations(
            cfg.ANNOTATION_DIR, cfg.SOURCE_TYPE, "train"
        )
        val_image_paths, val_texts = load_split_annotations(
            cfg.ANNOTATION_DIR, cfg.SOURCE_TYPE, "val"
        )
    except FileNotFoundError as e:
        logging.error(f"Annotation file error: {e}")
        logging.error(
            f"Please verify that files like 'src_{cfg.SOURCE_TYPE}_train.txt' and "
            f"'tgt_{cfg.SOURCE_TYPE}_train.txt' exist in {cfg.ANNOTATION_DIR}"
        )
        raise
    
    train_dataset = TableBankDataset(train_image_paths, train_texts, cfg, split="train")
    val_dataset = TableBankDataset(val_image_paths, val_texts, cfg, split="val")
    
    data_collator = DataCollator(image_processor, tokenizer, cfg)
    
    if cfg.DEBUG_MODE:
        debug_loader = DataLoader(
            train_dataset,
            batch_size=min(2, len(train_dataset)),
            shuffle=False,
            collate_fn=data_collator,
        )
        try:
            batch = next(iter(debug_loader))
            logging.info(f"pixel_values shape: {batch['pixel_values'].shape}")
            logging.info(f"labels shape: {batch['labels'].shape}")
            logging.info(f"decoder_input_ids shape: {batch['decoder_input_ids'].shape}")
        except Exception as e:
            logging.error(f"Debug batch creation failed: {e}")
    
    training_args = Seq2SeqTrainingArguments(
        output_dir=cfg.CHECKPOINT_DIR,
        num_train_epochs=cfg.NUM_EPOCHS,
        per_device_train_batch_size=cfg.BATCH_SIZE,
        per_device_eval_batch_size=cfg.BATCH_SIZE,
        learning_rate=cfg.LEARNING_RATE,
        gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=0,
        dataloader_pin_memory=torch.cuda.is_available(),
        eval_strategy="epoch",
        save_strategy="epoch",
        predict_with_generate=True,
        generation_max_length=cfg.GENERATION_MAX_LENGTH,
        logging_strategy="steps",
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="bleu",
        greater_is_better=True,
        remove_unused_columns=False,
        save_total_limit=2,
    )
    
    prediction_callback = PredictionPrinterCallback(
        tokenizer, image_processor, val_dataset, cfg
    )
    
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=lambda p: compute_metrics(p, tokenizer),
        callbacks=[prediction_callback],
    )
    
    logging.info("Commencing training")
    try:
        trainer.train()
    except RuntimeError as e:
        logging.error(f"Training failed: {e}")
        if "out of memory" in str(e).lower():
            logging.error(
                "OUT OF MEMORY. Reduce BATCH_SIZE or IMAGE_SIZE or "
                "increase GRAD_ACCUMULATION_STEPS."
            )
        raise
    
    os.makedirs(cfg.MODEL_OUTPUT_DIR, exist_ok=True)
    trainer.save_model(cfg.MODEL_OUTPUT_DIR)
    tokenizer.save_pretrained(cfg.MODEL_OUTPUT_DIR)
    image_processor.save_pretrained(cfg.MODEL_OUTPUT_DIR)
    logging.info(f"Model saved to {cfg.MODEL_OUTPUT_DIR}")


if __name__ == "__main__":
    train_model()

RuntimeError: Error(s) in loading state_dict for CRNNModel:
	Missing key(s) in state_dict: "cnn.3.weight", "cnn.3.bias", "cnn.6.weight", "cnn.6.bias", "cnn.7.weight", "cnn.7.bias", "cnn.7.running_mean", "cnn.7.running_var", "cnn.13.weight", "cnn.13.bias", "cnn.13.running_mean", "cnn.13.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "map_to_rnn.weight", "map_to_rnn.bias", "linear.weight", "linear.bias", "cnn.1.weight", "cnn.1.bias", "cnn.1.running_mean", "cnn.1.running_var", "cnn.1.num_batches_tracked", "cnn.4.weight", "cnn.4.bias", "cnn.5.weight", "cnn.5.bias", "cnn.5.running_mean", "cnn.5.running_var", "cnn.5.num_batches_tracked", "cnn.8.weight", "cnn.8.bias", "cnn.9.running_mean", "cnn.9.running_var", "cnn.9.num_batches_tracked", "cnn.11.weight", "cnn.11.bias", "cnn.12.running_mean", "cnn.12.running_var", "cnn.12.num_batches_tracked", "cnn.16.weight", "cnn.16.bias", "cnn.16.running_mean", "cnn.16.running_var", "cnn.16.num_batches_tracked", "cnn.19.weight", "cnn.19.bias", "cnn.19.running_mean", "cnn.19.running_var", "cnn.19.num_batches_tracked", "rnn.weight_ih_l2", "rnn.weight_hh_l2", "rnn.bias_ih_l2", "rnn.bias_hh_l2", "rnn.weight_ih_l2_reverse", "rnn.weight_hh_l2_reverse", "rnn.bias_ih_l2_reverse", "rnn.bias_hh_l2_reverse". 
	size mismatch for cnn.0.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 1, 3, 3]).
	size mismatch for cnn.9.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for cnn.12.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for cnn.12.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for cnn.15.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for cnn.18.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 2, 2]).
	size mismatch for rnn.weight_ih_l0: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
	size mismatch for rnn.weight_ih_l0_reverse: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([1024, 512]).

In [1]:
import os
import re
import logging
import warnings
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset, DataLoader
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    EvalPrediction,
)
import sacrebleu

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
ImageFile.LOAD_TRUNCATED_IMAGES = True


class Config:
    ROOT_DIR = Path("./TableBank/Recognition")
    IMAGE_DIR = ROOT_DIR / "Images"
    ANNOTATION_DIR = ROOT_DIR / "Annotations"
    MODEL_LOAD_DIR = "tsr_vit_tablebank_v_hr"
    CONTINUE_OUTPUT_DIR = "tsr_vit_tablebank_dos"
    CHECKPOINT_DIR = "tsr_vit_tablebank_dos_ckpt"
    NUM_EPOCHS = 6
    BATCH_SIZE = 4
    LEARNING_RATE = 2e-5
    GRAD_ACCUMULATION_STEPS = 2
    TRAIN_SIZE = 8000
    VAL_SIZE = 1024
    TEST_SIZE = 1024
    MAX_TARGET_LENGTH = 512
    GENERATION_MAX_LENGTH = 512
    SEED = 42
    IMAGE_SIZE = 224
    SOURCE_TYPE = "all"


def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def normalize_html_label(s: str) -> str:
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s.replace("> <", "><")


def load_split_annotations(annotation_dir: Path, source_type: str, split: str) -> Tuple[List[str], List[str]]:
    src_file = annotation_dir / f"src-{source_type}_{split}.txt"
    tgt_file = annotation_dir / f"tgt-{source_type}_{split}.txt"
    with open(src_file, encoding="utf-8") as f:
        images = [x.strip() for x in f if x.strip()]
    with open(tgt_file, encoding="utf-8") as f:
        texts = [normalize_html_label(x.rstrip("\n")) for x in f]
    if len(images) != len(texts):
        raise ValueError("Annotation and image counts mismatch")
    return images, texts


class TableBankDataset(Dataset):
    def __init__(self, paths: List[str], texts: List[str], cfg: Config, split: str):
        self.paths = paths
        self.texts = texts
        self.cfg = cfg
        if split == "train":
            self.paths = self.paths[: cfg.TRAIN_SIZE]
            self.texts = self.texts[: cfg.TRAIN_SIZE]
        elif split == "val":
            self.paths = self.paths[: cfg.VAL_SIZE]
            self.texts = self.texts[: cfg.VAL_SIZE]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        p = self.cfg.IMAGE_DIR / self.paths[idx]
        try:
            img = Image.open(p).convert("RGB")
        except Exception:
            img = Image.new("RGB", (self.cfg.IMAGE_SIZE, self.cfg.IMAGE_SIZE), (255, 255, 255))
        return {"image": img, "text": self.texts[idx]}


class DataCollator:
    def __init__(self, image_processor: ViTImageProcessor, tokenizer: AutoTokenizer, cfg: Config):
        self.processor = image_processor
        self.tokenizer = tokenizer
        self.cfg = cfg

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [x["image"] for x in batch]
        texts = [x["text"] for x in batch]
        pixels = self.processor(images, return_tensors="pt")["pixel_values"]
        tok = self.tokenizer(text_target=texts, padding="longest", truncation=True, max_length=self.cfg.MAX_TARGET_LENGTH, return_tensors="pt")
        labels = tok.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        dec = tok.input_ids.clone()
        dec = torch.roll(dec, 1, 1)
        dec[:, 0] = self.tokenizer.bos_token_id
        return {"pixel_values": pixels, "labels": labels, "decoder_input_ids": dec}


def compute_metrics(pred: EvalPrediction, tokenizer: AutoTokenizer) -> Dict[str, float]:
    pr = pred.predictions
    if isinstance(pr, tuple):
        pr = pr[0]
    pr = np.where(pr < 0, tokenizer.pad_token_id, pr)
    pr = np.where(pr >= len(tokenizer), tokenizer.pad_token_id, pr)
    dp = tokenizer.batch_decode(pr, skip_special_tokens=True)
    gt = np.where(pred.label_ids == -100, tokenizer.pad_token_id, pred.label_ids)
    dl = tokenizer.batch_decode(gt, skip_special_tokens=True)
    try:
        b = sacrebleu.corpus_bleu(dp, [dl]).score
    except Exception:
        b = 0.0
    return {"bleu": float(b)}


class EpochPredictionCallback(TrainerCallback):
    def __init__(self, tokenizer: AutoTokenizer, processor: ViTImageProcessor, dataset: TableBankDataset, cfg: Config):
        self.tokenizer = tokenizer
        self.processor = processor
        self.dataset = dataset
        self.cfg = cfg

    def on_evaluate(self, args, state, control, model=None, **kwargs):
        if model is None or len(self.dataset) == 0:
            return
        s = self.dataset[0]
        img = s["image"]
        device = next(model.parameters()).device
        x = self.processor(img, return_tensors="pt")["pixel_values"].to(device)
        with torch.no_grad():
            y = model.generate(pixel_values=x, max_length=self.cfg.GENERATION_MAX_LENGTH, num_beams=4, early_stopping=True, decoder_start_token_id=model.config.decoder_start_token_id, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id)
        pred = self.tokenizer.batch_decode(y, skip_special_tokens=True)[0]
        gt = s["text"]
        logging.info("=" * 60)
        logging.info(f"Epoch {state.epoch:.2f} Prediction")
        logging.info("Model Output:")
        logging.info(pred)
        logging.info("Ground Truth:")
        logging.info(gt)
        logging.info("=" * 60)


def prepare(cfg: Config):
    processor = ViTImageProcessor.from_pretrained(cfg.MODEL_LOAD_DIR)
    tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL_LOAD_DIR)
    model = VisionEncoderDecoderModel.from_pretrained(cfg.MODEL_LOAD_DIR)
    return processor, tokenizer, model


def train():
    cfg = Config()
    set_seed(cfg.SEED)
    processor, tokenizer, model = prepare(cfg)

    tr_img, tr_txt = load_split_annotations(cfg.ANNOTATION_DIR, cfg.SOURCE_TYPE, "train")
    va_img, va_txt = load_split_annotations(cfg.ANNOTATION_DIR, cfg.SOURCE_TYPE, "val")

    train_ds = TableBankDataset(tr_img, tr_txt, cfg, "train")
    val_ds = TableBankDataset(va_img, va_txt, cfg, "val")

    collator = DataCollator(processor, tokenizer, cfg)

    args = Seq2SeqTrainingArguments(
        output_dir=cfg.CHECKPOINT_DIR,
        num_train_epochs=cfg.NUM_EPOCHS,
        per_device_train_batch_size=cfg.BATCH_SIZE,
        per_device_eval_batch_size=cfg.BATCH_SIZE,
        learning_rate=cfg.LEARNING_RATE,
        gradient_accumulation_steps=cfg.GRAD_ACCUMULATION_STEPS,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=0,
        eval_strategy="epoch",
        save_strategy="epoch",
        predict_with_generate=True,
        generation_max_length=cfg.GENERATION_MAX_LENGTH,
        logging_strategy="steps",
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="bleu",
        greater_is_better=True,
        save_total_limit=2,
        remove_unused_columns=False,
    )

    cb = EpochPredictionCallback(tokenizer, processor, val_ds, cfg)

    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=collator,
        tokenizer=tokenizer,
        compute_metrics=lambda p: compute_metrics(p, tokenizer),
        callbacks=[cb],
    )

    trainer.train()
    os.makedirs(cfg.CONTINUE_OUTPUT_DIR, exist_ok=True)
    trainer.save_model(cfg.CONTINUE_OUTPUT_DIR)
    tokenizer.save_pretrained(cfg.CONTINUE_OUTPUT_DIR)
    processor.save_pretrained(cfg.CONTINUE_OUTPUT_DIR)
    logging.info(f"Model saved to {cfg.CONTINUE_OUTPUT_DIR}")


if __name__ == "__main__":
    train()





`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Epoch,Training Loss,Validation Loss,Bleu
1,0.1043,0.116102,62.036618
2,0.0951,0.102269,57.477146
3,0.0911,0.105682,52.188096
4,0.0756,0.105377,59.171404
5,0.0667,0.10877,59.041188
6,0.0568,0.108921,60.120745


INFO:root:Epoch 1.00 Prediction
INFO:root:Model Output:
INFO:root:<thead><tr><td><td><td></tr><td><tr><td><table><td></tr></td><table><td><td></table><td><tr></tr><td><td><tr></td><td></tr><tbody></td><td><table></tr><td><table>
INFO:root:Ground Truth:
INFO:root:<table><tbody><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr></tbody></table>
INFO:root:Epoch 2.00 Prediction
INFO:root:Model Output:
INFO:root:<thead><tr><td><td><td><table><td><tr></tr><td></table><td></table><table><td></tr></td><table><td><td><tr><td><tr><tr></tr></td><td><tr></td><td><td></tr><tbody><table><td><table></tr><td><tr><table></tr><table><td></thead><tr><td></tr><td><table><table><td></tbody>
INFO:root:Ground Truth:
INFO:root:<table><tbody><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr></tbody></table>
INFO:root:Epoch

In [7]:
import os
import re
import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple
from dataclasses import dataclass

import numpy as np
import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset, Subset
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback,
    EvalPrediction,
)
import sacrebleu

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
ImageFile.LOAD_TRUNCATED_IMAGES = True


@dataclass
class TrainingConfig:
    root_dir: Path = Path("./TableBank/Recognition")
    model_load_dir: str = "tsr_vit_tablebank_v_hr"
    output_dir: str = "tsr_vit_tablebank_structural"
    checkpoint_dir: str = "tsr_vit_tablebank_structural_ckpt"
    
    num_epochs: int = 8
    batch_size: int = 4
    learning_rate: float = 2e-5
    grad_accumulation_steps: int = 2
    train_size: int = 8000
    val_size: int = 1024
    max_target_length: int = 512
    generation_max_length: int = 512
    seed: int = 42
    image_size: int = 224
    source_type: str = "all"
    
    struct_penalty_weight: float = 2.0
    start_penalty: float = 2.0
    end_penalty: float = 2.0
    freeze_encoder_epochs: int = 2
    
    curriculum_stages: List[Tuple[int, int]] = None
    
    def __post_init__(self):
        if self.curriculum_stages is None:
            self.curriculum_stages = [(0, 2), (1, 4), (2, 9999)]
        self.image_dir = self.root_dir / "Images"
        self.annotation_dir = self.root_dir / "Annotations"
    
    @property
    def grammar_tokens(self) -> List[str]:
        return [
            "<TABLE_START>",
            "<TABLE_END>",
            "<ROW_START>",
            "<ROW_END>",
            "<CELL_START>",
            "<CELL_END>",
        ]


class GrammarConverter:
    
    HTML_TO_GRAMMAR_MAP = {
        r"<table>": "<TABLE_START>",
        r"</table>": "<TABLE_END>",
        r"<tbody>": "",
        r"</tbody>": "",
        r"<thead>": "",
        r"</thead>": "",
        r"<tr>": "<ROW_START>",
        r"</tr>": "<ROW_END>",
        r"<td>": "<CELL_START>",
        r"</td>": "<CELL_END>",
        r"<th>": "<CELL_START>",
        r"</th>": "<CELL_END>",
    }
    
    GRAMMAR_TO_HTML_MAP = {
        "<TABLE_START>": "<table><tbody>",
        "<TABLE_END>": "</tbody></table>",
        "<ROW_START>": "<tr>",
        "<ROW_END>": "</tr>",
        "<CELL_START>": "<td>",
        "<CELL_END>": "</td>",
    }
    
    @classmethod
    def html_to_grammar(cls, html_string: str) -> str:
        if html_string is None:
            return ""
        
        result = html_string.strip()
        result = re.sub(r"\s+", " ", result)
        result = result.replace("> <", "><")
        
        for html_tag, grammar_tag in cls.HTML_TO_GRAMMAR_MAP.items():
            result = result.replace(html_tag, grammar_tag)
        
        if not result.startswith("<TABLE_START>"):
            result = "<TABLE_START>" + result
        if not result.endswith("<TABLE_END>"):
            result = result + "<TABLE_END>"
        
        return result
    
    @classmethod
    def grammar_to_html(cls, grammar_string: str) -> str:
        if grammar_string is None:
            return ""
        
        result = grammar_string
        for grammar_tag, html_tag in cls.GRAMMAR_TO_HTML_MAP.items():
            result = result.replace(grammar_tag, html_tag)
        
        return result


class AnnotationLoader:
    
    @staticmethod
    def load_split(annotation_dir: Path, source_type: str, split: str) -> Tuple[List[str], List[str]]:
        src_file = annotation_dir / f"src-{source_type}_{split}.txt"
        tgt_file = annotation_dir / f"tgt-{source_type}_{split}.txt"
        
        if not src_file.exists() or not tgt_file.exists():
            raise FileNotFoundError(
                f"Annotation files not found: {src_file} or {tgt_file}"
            )
        
        with open(src_file, encoding="utf-8") as f:
            images = [line.strip() for line in f if line.strip()]
        
        with open(tgt_file, encoding="utf-8") as f:
            texts = [
                GrammarConverter.html_to_grammar(line.rstrip("\n"))
                for line in f
            ]
        
        if len(images) != len(texts):
            raise ValueError(
                f"Mismatch: {len(images)} images vs {len(texts)} annotations"
            )
        
        return images, texts


class TableStructureDataset(Dataset):
    
    def __init__(
        self,
        image_paths: List[str],
        grammar_texts: List[str],
        config: TrainingConfig,
        split: str
    ):
        self.image_paths = image_paths
        self.grammar_texts = grammar_texts
        self.config = config
        self.split = split
        
        self._apply_size_limit()
        self._precompute_complexities()
    
    def _apply_size_limit(self):
        if self.split == "train" and len(self.image_paths) > self.config.train_size:
            self.image_paths = self.image_paths[:self.config.train_size]
            self.grammar_texts = self.grammar_texts[:self.config.train_size]
        elif self.split == "val" and len(self.image_paths) > self.config.val_size:
            self.image_paths = self.image_paths[:self.config.val_size]
            self.grammar_texts = self.grammar_texts[:self.config.val_size]
    
    def _precompute_complexities(self):
        self.complexities = [
            text.count("<CELL_START>")
            for text in self.grammar_texts
        ]
    
    def __len__(self) -> int:
        return len(self.image_paths)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        image_path = self.config.image_dir / self.image_paths[idx]
        
        try:
            image = Image.open(image_path).convert("RGB")
        except (IOError, OSError) as e:
            logging.warning(f"Failed to load {image_path}: {e}")
            image = Image.new(
                "RGB",
                (self.config.image_size, self.config.image_size),
                (255, 255, 255)
            )
        
        return {
            "image": image,
            "text": self.grammar_texts[idx],
            "complexity": self.complexities[idx]
        }
    
    def get_indices_by_complexity(self, max_cells: int) -> List[int]:
        return [
            idx for idx, complexity in enumerate(self.complexities)
            if complexity <= max_cells
        ]


class TableDataCollator:
    
    def __init__(
        self,
        image_processor: ViTImageProcessor,
        tokenizer: AutoTokenizer,
        config: TrainingConfig
    ):
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.config = config
    
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        images = [item["image"] for item in batch]
        texts = [item["text"] for item in batch]
        
        pixel_values = self.image_processor(
            images,
            return_tensors="pt"
        )["pixel_values"]
        
        tokenized = self.tokenizer(
            text_target=texts,
            padding="longest",
            truncation=True,
            max_length=self.config.max_target_length,
            return_tensors="pt"
        )
        
        labels = tokenized.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        decoder_input_ids = self._create_decoder_inputs(tokenized.input_ids)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "decoder_input_ids": decoder_input_ids
        }
    
    def _create_decoder_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
        decoder_input_ids = torch.full_like(input_ids, self.tokenizer.pad_token_id)
        decoder_input_ids[:, 1:] = input_ids[:, :-1]
        decoder_input_ids[:, 0] = self.tokenizer.convert_tokens_to_ids("<TABLE_START>")
        return decoder_input_ids


class StructuralLossTrainer(Seq2SeqTrainer):
    
    def __init__(
        self,
        start_token_id: int,
        end_token_id: int,
        start_penalty: float,
        end_penalty: float,
        struct_weight: float,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.start_token_id = start_token_id
        self.end_token_id = end_token_id
        self.start_penalty = start_penalty
        self.end_penalty = end_penalty
        self.struct_weight = struct_weight
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(
            pixel_values=inputs.get("pixel_values"),
            decoder_input_ids=inputs.get("decoder_input_ids"),
            labels=labels
        )
        
        base_loss = outputs.loss
        
        if not model.training or self.struct_weight == 0:
            return (base_loss, outputs) if return_outputs else base_loss
        
        structural_penalty = self._compute_structural_penalty(labels, outputs.logits)
        total_loss = base_loss + self.struct_weight * structural_penalty
        
        return (total_loss, outputs) if return_outputs else total_loss
    
    def _compute_structural_penalty(
        self,
        labels: torch.Tensor,
        logits: torch.Tensor
    ) -> torch.Tensor:
        predictions = torch.argmax(logits, dim=-1)
        
        if predictions.size(1) == 0:
            return torch.tensor(0.0, device=predictions.device)
        
        start_penalties = (
            predictions[:, 0] != self.start_token_id
        ).float() * self.start_penalty
        
        valid_lengths = (labels != -100).sum(dim=1)
        last_indices = (valid_lengths - 1).clamp(min=0)
        batch_indices = torch.arange(predictions.size(0), device=predictions.device)
        end_predictions = predictions[batch_indices, last_indices]
        end_penalties = (
            end_predictions != self.end_token_id
        ).float() * self.end_penalty
        
        return (start_penalties + end_penalties).mean()


class CurriculumLearningCallback(TrainerCallback):
    
    def __init__(self, config: TrainingConfig, base_dataset: TableStructureDataset):
        self.config = config
        self.base_dataset = base_dataset
        self.current_subset = None
    
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        current_epoch = int(state.epoch) if state.epoch is not None else 0
        
        stage_idx = self._determine_stage(current_epoch)
        max_cells = self.config.curriculum_stages[stage_idx][1]
        
        if max_cells >= 9999:
            indices = list(range(len(self.base_dataset)))
        else:
            indices = self.base_dataset.get_indices_by_complexity(max_cells)
        
        self.current_subset = Subset(self.base_dataset, indices)
        
        if hasattr(train_dataloader, 'dataset'):
            train_dataloader.dataset = self.current_subset
        
        logging.info(
            f"Epoch {current_epoch}: Curriculum stage {stage_idx}, "
            f"max_cells={max_cells}, samples={len(indices)}"
        )
    
    def _determine_stage(self, current_epoch: int) -> int:
        stage_idx = 0
        for idx, (start_epoch, _) in enumerate(self.config.curriculum_stages):
            if current_epoch >= start_epoch:
                stage_idx = idx
        return stage_idx


class EncoderUnfreezeCallback(TrainerCallback):
    
    def __init__(self, config: TrainingConfig, model: VisionEncoderDecoderModel):
        self.config = config
        self.model = model
        self.unfrozen = False
    
    def on_epoch_end(self, args, state, control, **kwargs):
        current_epoch = int(state.epoch) if state.epoch is not None else 0
        
        if not self.unfrozen and current_epoch >= self.config.freeze_encoder_epochs:
            for param in self.model.encoder.parameters():
                param.requires_grad = True
            self.unfrozen = True
            logging.info(f"Encoder unfrozen at epoch {current_epoch}")


class SamplePredictionCallback(TrainerCallback):
    
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        processor: ViTImageProcessor,
        val_dataset: TableStructureDataset,
        config: TrainingConfig
    ):
        self.tokenizer = tokenizer
        self.processor = processor
        self.val_dataset = val_dataset
        self.config = config
    
    def on_evaluate(self, args, state, control, model=None, **kwargs):
        if model is None or len(self.val_dataset) == 0:
            return
        
        sample = self.val_dataset[0]
        prediction = self._generate_prediction(model, sample["image"])
        
        self._log_comparison(
            state.epoch,
            prediction,
            sample["text"]
        )
    
    def _generate_prediction(self, model, image: Image.Image) -> str:
        device = next(model.parameters()).device
        pixel_values = self.processor(
            image,
            return_tensors="pt"
        )["pixel_values"].to(device)
        
        model.eval()
        with torch.no_grad():
            generated_ids = model.generate(
                pixel_values=pixel_values,
                max_length=self.config.generation_max_length,
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=2.0
            )
        
        prediction = self.tokenizer.batch_decode(
            generated_ids,
            skip_special_tokens=True
        )[0]
        
        return prediction
    
    def _log_comparison(self, epoch: float, prediction: str, ground_truth: str):
        pred_html = GrammarConverter.grammar_to_html(prediction)
        gt_html = GrammarConverter.grammar_to_html(ground_truth)
        
        logging.info("=" * 80)
        logging.info(f"Epoch {epoch:.2f} Sample Prediction")
        logging.info("-" * 80)
        logging.info("Prediction:")
        logging.info(pred_html)
        logging.info("-" * 80)
        logging.info("Ground Truth:")
        logging.info(gt_html)
        logging.info("=" * 80)


class MetricsComputer:
    
    def __init__(self, tokenizer: AutoTokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
        predictions = eval_prediction.predictions
        label_ids = eval_prediction.label_ids
        
        if predictions is None:
            return {"bleu": 0.0}
        
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        
        decoded_predictions = self._decode_predictions(predictions)
        decoded_labels = self._decode_labels(label_ids)
        
        bleu_score = self._compute_bleu(decoded_predictions, decoded_labels)
        
        return {"bleu": float(bleu_score)}
    
    def _decode_predictions(self, predictions: np.ndarray) -> List[str]:
        predictions = np.where(
            predictions < 0,
            self.tokenizer.pad_token_id,
            predictions
        )
        predictions = np.where(
            predictions >= len(self.tokenizer),
            self.tokenizer.pad_token_id,
            predictions
        )
        return self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    def _decode_labels(self, label_ids: np.ndarray) -> List[str]:
        label_ids = np.where(
            label_ids == -100,
            self.tokenizer.pad_token_id,
            label_ids
        )
        return self.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    def _compute_bleu(
        self,
        predictions: List[str],
        references: List[str]
    ) -> float:
        if not predictions or not references:
            return 0.0
        
        try:
            return sacrebleu.corpus_bleu(predictions, [references]).score
        except Exception as e:
            logging.warning(f"BLEU computation failed: {e}")
            return 0.0


class ModelInitializer:
    
    def __init__(self, config: TrainingConfig):
        self.config = config
    
    def initialize(self) -> Tuple[
        ViTImageProcessor,
        AutoTokenizer,
        VisionEncoderDecoderModel,
        int,
        int
    ]:
        image_processor = ViTImageProcessor.from_pretrained(
            self.config.model_load_dir
        )
        tokenizer = AutoTokenizer.from_pretrained(self.config.model_load_dir)
        
        self._add_grammar_tokens(tokenizer)
        
        model = VisionEncoderDecoderModel.from_pretrained(
            self.config.model_load_dir
        )
        model.decoder.resize_token_embeddings(len(tokenizer))
        
        start_token_id = tokenizer.convert_tokens_to_ids("<TABLE_START>")
        end_token_id = tokenizer.convert_tokens_to_ids("<TABLE_END>")
        
        self._configure_model(model, tokenizer, start_token_id, end_token_id)
        
        return image_processor, tokenizer, model, start_token_id, end_token_id
    
    def _add_grammar_tokens(self, tokenizer: AutoTokenizer):
        existing_tokens = set(tokenizer.get_vocab().keys())
        new_tokens = [
            token for token in self.config.grammar_tokens
            if token not in existing_tokens
        ]
        
        if new_tokens:
            num_added = tokenizer.add_tokens(new_tokens)
            logging.info(f"Added {num_added} grammar tokens to tokenizer")
    
    def _configure_model(
        self,
        model: VisionEncoderDecoderModel,
        tokenizer: AutoTokenizer,
        start_token_id: int,
        end_token_id: int
    ):
        model.config.decoder_start_token_id = start_token_id
        model.config.eos_token_id = end_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.vocab_size = len(tokenizer)
        
        if hasattr(model.decoder, "config"):
            model.decoder.config.is_decoder = True
            model.decoder.config.add_cross_attention = True


class TrainingOrchestrator:
    
    def __init__(self, config: TrainingConfig):
        self.config = config
        self._set_seed()
        self._setup_device()
    
    def _set_seed(self):
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.config.seed)
    
    def _setup_device(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info(f"Using device: {self.device}")
    
    def execute(self):
        initializer = ModelInitializer(self.config)
        image_processor, tokenizer, model, start_id, end_id = initializer.initialize()
        
        train_dataset, val_dataset = self._prepare_datasets()
        
        self._freeze_encoder(model)
        
        data_collator = TableDataCollator(image_processor, tokenizer, self.config)
        metrics_computer = MetricsComputer(tokenizer)
        
        training_args = self._create_training_arguments()
        
        trainer = StructuralLossTrainer(
            start_token_id=start_id,
            end_token_id=end_id,
            start_penalty=self.config.start_penalty,
            end_penalty=self.config.end_penalty,
            struct_weight=self.config.struct_penalty_weight,
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=metrics_computer,
        )
        
        self._attach_callbacks(trainer, model, train_dataset, val_dataset, tokenizer, image_processor)
        
        trainer.train()
        
        self._save_artifacts(trainer, model, tokenizer, image_processor)
    
    def _prepare_datasets(self) -> Tuple[TableStructureDataset, TableStructureDataset]:
        train_paths, train_texts = AnnotationLoader.load_split(
            self.config.annotation_dir,
            self.config.source_type,
            "train"
        )
        val_paths, val_texts = AnnotationLoader.load_split(
            self.config.annotation_dir,
            self.config.source_type,
            "val"
        )
        
        train_dataset = TableStructureDataset(
            train_paths,
            train_texts,
            self.config,
            "train"
        )
        val_dataset = TableStructureDataset(
            val_paths,
            val_texts,
            self.config,
            "val"
        )
        
        logging.info(f"Train samples: {len(train_dataset)}")
        logging.info(f"Validation samples: {len(val_dataset)}")
        
        return train_dataset, val_dataset
    
    def _freeze_encoder(self, model: VisionEncoderDecoderModel):
        for param in model.encoder.parameters():
            param.requires_grad = False
        logging.info("Encoder frozen for initial training")
    
    def _create_training_arguments(self) -> Seq2SeqTrainingArguments:
        return Seq2SeqTrainingArguments(
            output_dir=self.config.checkpoint_dir,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            per_device_eval_batch_size=self.config.batch_size,
            learning_rate=self.config.learning_rate,
            gradient_accumulation_steps=self.config.grad_accumulation_steps,
            fp16=torch.cuda.is_available(),
            dataloader_num_workers=0,
            dataloader_pin_memory=torch.cuda.is_available(),
            eval_strategy="epoch",
            save_strategy="epoch",
            predict_with_generate=True,
            generation_max_length=self.config.generation_max_length,
            logging_strategy="steps",
            logging_steps=100,
            load_best_model_at_end=True,
            metric_for_best_model="bleu",
            greater_is_better=True,
            remove_unused_columns=False,
            save_total_limit=2,
            warmup_ratio=0.05,
            no_cuda=not torch.cuda.is_available(),
        )
    
    def _attach_callbacks(
        self,
        trainer: StructuralLossTrainer,
        model: VisionEncoderDecoderModel,
        train_dataset: TableStructureDataset,
        val_dataset: TableStructureDataset,
        tokenizer: AutoTokenizer,
        image_processor: ViTImageProcessor
    ):
        curriculum_callback = CurriculumLearningCallback(self.config, train_dataset)
        unfreeze_callback = EncoderUnfreezeCallback(self.config, model)
        prediction_callback = SamplePredictionCallback(
            tokenizer,
            image_processor,
            val_dataset,
            self.config
        )
        
        trainer.add_callback(curriculum_callback)
        trainer.add_callback(unfreeze_callback)
        trainer.add_callback(prediction_callback)
    
    def _save_artifacts(
        self,
        trainer: StructuralLossTrainer,
        model: VisionEncoderDecoderModel,
        tokenizer: AutoTokenizer,
        image_processor: ViTImageProcessor
    ):
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        trainer.save_model(self.config.output_dir)
        tokenizer.save_pretrained(self.config.output_dir)
        image_processor.save_pretrained(self.config.output_dir)
        
        logging.info(f"Training complete. Model saved to {self.config.output_dir}")


def main():
    config = TrainingConfig()
    orchestrator = TrainingOrchestrator(config)
    orchestrator.execute()


if __name__ == "__main__":
    main()

INFO:root:Using device: cuda
INFO:root:Added 6 grammar tokens to tokenizer
INFO:root:Train samples: 8000
INFO:root:Validation samples: 1024
INFO:root:Encoder frozen for initial training
INFO:root:Epoch 0: Curriculum stage 0, max_cells=2, samples=15


Epoch,Training Loss,Validation Loss,Bleu
1,8.2775,0.137434,35.702048
2,8.2089,0.114808,53.793561
3,8.2152,0.118351,52.746965
4,8.2007,0.115533,53.597275
5,0.0,,0.0
6,0.0,,0.0
7,0.0,,0.0
8,0.0,,0.0


INFO:root:Epoch 1.00 Sample Prediction
INFO:root:--------------------------------------------------------------------------------
INFO:root:Prediction:
INFO:root:<tr><td><td><td></tbody></table><td><td></tr><td><td><table><tbody><td><td></td><td><td><td><td><td><tr><td></tr>
INFO:root:--------------------------------------------------------------------------------
INFO:root:Ground Truth:
INFO:root:<table><tbody><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr><tr><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td><td></tr></tbody></table>
INFO:root:Epoch 1: Curriculum stage 1, max_cells=4, samples=37
INFO:root:Encoder unfrozen at epoch 2
INFO:root:Epoch 2.00 Sample Prediction
INFO:root:--------------------------------------------------------------------------------
INFO:root:Prediction:
INFO:root:<tr><td><td><td><table><tbody><td><td></tr><td><td></tbody></table><td><tr><td><tr></tr><td></tr></td><td><tr><table><tbody><td><tr><tr></tr></tr></tr><td><tab

In [3]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk

import torch
from PIL import Image
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    GPT2TokenizerFast
)

MODEL_DIR = "tsr_vit_tablebank_v1" 
MAX_LENGTH = 128
NUM_BEAMS = 4

print("Loading model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained(MODEL_DIR).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_DIR)
image_processor = ViTImageProcessor.from_pretrained(MODEL_DIR)

model.eval()
print("Model loaded.")

def run_inference(image_path):
    image = Image.open(image_path).convert("RGB")
    
    pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(device)

    output_ids = model.generate(
        pixel_values,
        max_length=MAX_LENGTH,
        num_beams=NUM_BEAMS,
        early_stopping=True
    )

    prediction = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return prediction


def choose_image():
    filepath = filedialog.askopenfilename(
        title="Select Table Image",
        filetypes=[("Image Files", "*.png *.jpg *.jpeg *.bmp *.tiff")]
    )

    if not filepath:
        return
    
    try:
        result = run_inference(filepath)

        result_window = tk.Toplevel(root)
        result_window.title("Extracted Table Structure")

        print(f"\n=== Selected Image ===\n{filepath}")

        file_label = tk.Label(result_window, text=f"Selected Image:\n{filepath}")
        file_label.pack(pady=5)

        img = Image.open(filepath)
        img.thumbnail((300, 300))  
        img_tk = ImageTk.PhotoImage(img)

        img_label = tk.Label(result_window, image=img_tk)
        img_label.image = img_tk  
        img_label.pack(pady=10)

        text_widget = tk.Text(result_window, wrap="word", height=20, width=80)
        text_widget.pack(padx=10, pady=10)
        text_widget.insert(tk.END, result)

        print("\nExtracted HTML Tokens\n")
        print(result)

    except Exception as e:
        messagebox.showerror("Error", str(e))


root = tk.Tk()
root.title("Table Structure Recognition - TSR Inference")

select_btn = tk.Button(root, text="Select Table Image", command=choose_image, width=30)
select_btn.pack(pady=20)

info_label = tk.Label(root, text="Choose an image to extract its table HTML structure.")
info_label.pack()

root.mainloop()


Loading model...
Model loaded.

=== Selected Image ===
C:/Users/ahmed/Dropbox/PC/Desktop/Ahmed Sajid/Office - NCV/NCV - HTR/TableBank/Recognition/images/1410.7223.table_0.png

Extracted HTML Tokens

                                                                                                                               

=== Selected Image ===
C:/Users/ahmed/Dropbox/PC/Desktop/Ahmed Sajid/Office - NCV/NCV - HTR/TableBank/Recognition/images/1404.2843.table_0.png

Extracted HTML Tokens

                                                                                                                               


In [15]:
import os
import re
import math
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple

import torch
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageOps
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
import sacrebleu

# --------------------------
# Configuration
# --------------------------
@dataclass
class Config:
    ROOT_DIR: str = "./TableBank/Recognition"
    IMAGES_DIR_NAME: str = "Images"
    ANNOTATIONS_DIR_NAME: str = "Annotations"
    MODEL_OUTPUT_DIR: str = "./tsr_vit_byt5_out"
    CHECKPOINT_DIR: str = "./checkpoints"

    ENCODER_MODEL: str = "google/vit-base-patch16-224-in21k"
    DECODER_MODEL: str = "google/byt5-small"   # byt5-small used as tokenizer+decoder

    NUM_EPOCHS: int = 5
    BATCH_SIZE: int = 8
    LEARNING_RATE: float = 5e-5
    GRAD_ACCUMULATION_STEPS: int = 1
    EVAL_BATCH_SIZE: int = 8

    SEED: int = 42
    MAX_TARGET_LENGTH: int = 512
    MAX_DECODING_LENGTH: int = 512
    NUM_BEAMS: int = 4

    # Increased subset sizes to give more training signal (adjust to your hardware/time).
    TRAIN_SIZE: Optional[int] = 30000
    VAL_SIZE: Optional[int] = 2000
    TEST_SIZE: Optional[int] = 1000

    # image handling
    IMG_RESIZE_SIDE: int = 224  # ViT default patch size -> 224
    ASPECT_PRESERVE: bool = True

    # training stability
    FREEZE_ENCODER_EPOCHS: int = 1  # freeze encoder for 1 epoch (helps when dataset relatively small)
    WARMUP_STEPS: int = 2000
    MAX_GRAD_NORM: float = 1.0

    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()

# --------------------------
# Logging
# --------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)

# --------------------------
# Utils: aspect preserving resize + pad
# --------------------------
def resize_and_pad_to_square(image: Image.Image, size: int) -> Image.Image:
    """
    Resize an image keeping aspect ratio, then pad to a square of side `size`.
    Pads with white background.
    """
    # Convert to RGB if not already
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Preserve aspect ratio
    image.thumbnail((size, size), Image.Resampling.LANCZOS)

    # Create white background and paste centered
    new_im = Image.new("RGB", (size, size), (255, 255, 255))
    paste_x = (size - image.width) // 2
    paste_y = (size - image.height) // 2
    new_im.paste(image, (paste_x, paste_y))
    return new_im

# --------------------------
# Dataset
# --------------------------
class TableBankRecognitionDataset(Dataset):

    def __init__(
        self,
        root_dir: str,
        split: str,
        image_processor: ViTImageProcessor,
        tokenizer: Any,
        max_target_length: int = 256,
        img_dir_name: str = "Images",
        annotations_dir_name: str = "Annotations",
        img_exts: List[str] = [".png", ".jpg", ".jpeg"],
        subset_size: Optional[int] = None,
    ):
        self.root_dir = root_dir
        self.split = split
        self.img_processor = image_processor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length
        self.img_dir = os.path.join(root_dir, img_dir_name)
        self.ann_dir = os.path.join(root_dir, annotations_dir_name)

        possible_src_files = [
            f"src-all_{split}.txt",
            f"rc-all_{split}.txt",
            f"all_{split}.txt",
            f"src_all_{split}.txt",
            f"rc_all_{split}.txt",
        ]
        possible_tgt_files = [
            f"tgt-all_{split}.txt",
            f"tgt_all_{split}.txt",
            f"tgt-all.{split}.txt",
        ]

        src_path = None
        for fn in possible_src_files:
            p = os.path.join(self.ann_dir, fn)
            if os.path.exists(p):
                src_path = p
                break
        if src_path is None:
            raise FileNotFoundError(f"No source file found for split {split} in {self.ann_dir}")

        tgt_path = None
        for fn in possible_tgt_files:
            p = os.path.join(self.ann_dir, fn)
            if os.path.exists(p):
                tgt_path = p
                break
        if tgt_path is None:
            alt = os.path.join(self.ann_dir, f"tgt_{split}.txt")
            if os.path.exists(alt):
                tgt_path = alt
            else:
                raise FileNotFoundError(f"No target file found for split {split} in {self.ann_dir}")

        logger.info(f"Using src file: {src_path}")
        logger.info(f"Using tgt file: {tgt_path}")

        with open(src_path, "r", encoding="utf-8") as f:
            src_lines = [l.strip() for l in f if l.strip()]
        with open(tgt_path, "r", encoding="utf-8") as f:
            tgt_lines = [l.rstrip("\n") for l in f]

        if len(src_lines) != len(tgt_lines):
            logger.warning(f"src ({len(src_lines)}) and tgt ({len(tgt_lines)}) line counts differ.")

        n = min(len(src_lines), len(tgt_lines))
        if subset_size is not None:
            n = min(n, subset_size)

        self.image_filenames = src_lines[:n]
        self.targets = tgt_lines[:n]
        logger.info(f"Loaded {len(self.image_filenames)} samples for split '{split}'")

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        filename = self.image_filenames[idx]
        img_path = os.path.join(self.img_dir, filename)
        if not os.path.exists(img_path):
            img_path_alt = os.path.join(self.img_dir, os.path.basename(filename))
            if os.path.exists(img_path_alt):
                img_path = img_path_alt

        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path} (original: {filename})")

        image = Image.open(img_path).convert("RGB")

        # Aspect-preserving resize + pad to square (ViT expects square)
        if config.ASPECT_PRESERVE:
            image = resize_and_pad_to_square(image, config.IMG_RESIZE_SIDE)

        px = self.img_processor(image, return_tensors="pt")
        pixel_values = px["pixel_values"].squeeze(0)

        target = self.targets[idx].strip()
        tok = self.tokenizer(
            target,
            truncation=True,
            max_length=self.max_target_length,
            padding="max_length",
            return_tensors="pt",
        )
        labels = tok["input_ids"].squeeze(0)
        # set pad token ids to -100 for ignoring in loss
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {"pixel_values": pixel_values, "labels": labels, "text": target, "image_id": filename}


def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    pixel_values = torch.stack([item["pixel_values"] for item in batch], dim=0)
    labels = torch.stack([item["labels"] for item in batch], dim=0)
    return {"pixel_values": pixel_values, "labels": labels}


# --------------------------
# HTML structure metric
# --------------------------
TAG_RE = re.compile(r"</?([a-zA-Z0-9_\-]+)")

def extract_tag_sequence(html: str) -> List[str]:
    """
    Extract ordered tag names from html, ignoring attributes and text.
    e.g. "<table><tr><td>1</td></tr></table>" -> ["table", "tr", "td", "/td", "/tr", "/table"]
    We'll capture opening and closing with slash prefix for clarity.
    """
    tags = []
    for m in TAG_RE.finditer(html):
        full = m.group(0)
        name = m.group(1).lower()
        if full.startswith("</"):
            tags.append(f"/{name}")
        else:
            tags.append(name)
    return tags

def html_structure_accuracy_batch(preds: List[str], refs: List[str]) -> float:
    """
    Compute proportion of exact-matching tag sequences in the batch.
    This is a conservative structural accuracy: tag sequence must match exactly.
    """
    assert len(preds) == len(refs)
    matches = 0
    for p, r in zip(preds, refs):
        pt = extract_tag_sequence(p)
        rt = extract_tag_sequence(r)
        if pt == rt:
            matches += 1
    return matches / len(preds) if len(preds) > 0 else 0.0

# --------------------------
# Metrics used by Trainer
# --------------------------
def compute_metrics(eval_pred) -> Dict[str, float]:
    """
    eval_pred: tuple(predictions, label_ids)
    predictions may be generated sequences (token ids), or logits depending on call
    """
    preds, labels = eval_pred

    # If Trainer called with predict_with_generate=True, preds are decoded token ids
    if isinstance(preds, tuple):
        preds = preds[0]

    # decode predictions
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # prepare labels
    labels = labels.copy()
    labels[labels == -100] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # strip
    decoded_preds = [p.strip() for p in decoded_preds]
    decoded_labels = [r.strip() for r in decoded_labels]

    # BLEU
    bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
    bleu_score = bleu.score

    # HTML structure accuracy
    struct_acc = html_structure_accuracy_batch(decoded_preds, decoded_labels)

    return {"bleu": bleu_score, "html_structure_acc": struct_acc}


# --------------------------
# Main training pipeline
# --------------------------
def main():
    torch.manual_seed(config.SEED)

    global image_processor, tokenizer

    logger.info("Loading image processor and tokenizer / decoder...")
    image_processor = ViTImageProcessor.from_pretrained(config.ENCODER_MODEL)
    # ByT5 is an encoder-decoder model; using AutoTokenizer ensures correct tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.DECODER_MODEL, use_fast=True)

    # If tokenizer has no pad or bos/eos, add them (ByT5 should have pad, eos)
    add_special_tokens = {}
    if tokenizer.pad_token is None:
        add_special_tokens["pad_token"] = "<pad>"
    if tokenizer.eos_token is None:
        add_special_tokens["eos_token"] = "</s>"
    if tokenizer.bos_token is None:
        add_special_tokens["bos_token"] = "<s>"

    special_added = False
    if add_special_tokens:
        tokenizer.add_special_tokens(add_special_tokens)
        special_added = True
        logger.info(f"Added special tokens to tokenizer: {add_special_tokens}")

    logger.info("Creating VisionEncoderDecoderModel from pretrained encoder+decoder...")
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        config.ENCODER_MODEL, config.DECODER_MODEL
    )

    # If tokenizer expansion changed vocab size, resize embeddings
    if special_added:
        model.decoder.resize_token_embeddings(len(tokenizer))

    # Configure generation and decoding
    model.config.decoder_start_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.eos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    # ensure vocab size is set
    model.config.vocab_size = model.decoder.config.vocab_size

    model.config.max_length = config.MAX_DECODING_LENGTH
    model.config.no_repeat_ngram_size = 3
    model.config.early_stopping = True
    model.config.num_beams = config.NUM_BEAMS

    device = torch.device(config.DEVICE)
    model.to(device)

    # Prepare datasets
    logger.info("Preparing datasets...")
    train_ds = TableBankRecognitionDataset(
        root_dir=config.ROOT_DIR,
        split="train",
        image_processor=image_processor,
        tokenizer=tokenizer,
        max_target_length=config.MAX_TARGET_LENGTH,
        subset_size=config.TRAIN_SIZE,
    )
    val_ds = TableBankRecognitionDataset(
        root_dir=config.ROOT_DIR,
        split="val",
        image_processor=image_processor,
        tokenizer=tokenizer,
        max_target_length=config.MAX_TARGET_LENGTH,
        subset_size=config.VAL_SIZE,
    )

    # Optionally freeze encoder for initial epochs
    if config.FREEZE_ENCODER_EPOCHS > 0:
        logger.info(f"Freezing encoder parameters for first {config.FREEZE_ENCODER_EPOCHS} epoch(s).")
        for param in model.encoder.parameters():
            param.requires_grad = False

    # Training args
    training_args = Seq2SeqTrainingArguments(
        output_dir=config.MODEL_OUTPUT_DIR,
        num_train_epochs=config.NUM_EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        per_device_eval_batch_size=config.EVAL_BATCH_SIZE,
        predict_with_generate=True,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=100,
        save_total_limit=3,
        learning_rate=config.LEARNING_RATE,
        gradient_accumulation_steps=config.GRAD_ACCUMULATION_STEPS,
        fp16=torch.cuda.is_available(),
        remove_unused_columns=False,
        load_best_model_at_end=True,
        metric_for_best_model="html_structure_acc",
        greater_is_better=True,
        dataloader_pin_memory=True,
        warmup_steps=config.WARMUP_STEPS,
        fp16_opt_level="O1" if torch.cuda.is_available() else None,
        max_grad_norm=config.MAX_GRAD_NORM,
        predict_with_generate_kwargs={"max_length": config.MAX_DECODING_LENGTH, "num_beams": config.NUM_BEAMS},
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=lambda data: collate_fn(data),
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    # Unfreeze encoder after initial epoch by a callback-style check (Trainer does not have easy epoch callbacks here).
    # We'll implement a simple wrapper around train to unfreeze after one epoch if requested.
    logger.info("Starting training...")

    if config.FREEZE_ENCODER_EPOCHS <= 0:
        trainer.train()
    else:
        # train for the freeze epochs, then unfreeze and continue
        # NOTE: this splits the training into two calls - still uses the same Trainer object
        initial_epochs = min(config.FREEZE_ENCODER_EPOCHS, config.NUM_EPOCHS)
        remaining_epochs = config.NUM_EPOCHS - initial_epochs

        # Train initial epochs (we accomplish this via setting num_train_epochs temporarily)
        saved_num_epochs = trainer.args.num_train_epochs
        trainer.args.num_train_epochs = initial_epochs
        trainer.train()
        # Unfreeze encoder
        logger.info("Unfreezing encoder parameters for remaining epochs.")
        for param in model.encoder.parameters():
            param.requires_grad = True
        # Reset the Trainer's args to the remaining epochs
        trainer.args.num_train_epochs = remaining_epochs
        if remaining_epochs > 0:
            trainer.train()

        # restore original
        trainer.args.num_train_epochs = saved_num_epochs

    logger.info("Saving model and tokenizer...")
    os.makedirs(config.MODEL_OUTPUT_DIR, exist_ok=True)
    trainer.save_model(config.MODEL_OUTPUT_DIR)
    tokenizer.save_pretrained(config.MODEL_OUTPUT_DIR)
    image_processor.save_pretrained(config.MODEL_OUTPUT_DIR)

    logger.info("Training complete. Model saved to %s", config.MODEL_OUTPUT_DIR)


if __name__ == "__main__":
    main()


2025-11-30 17:17:42,751 - INFO - __main__ - Loading image processor and tokenizer / decoder...


tokenizer_config.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/698 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

2025-11-30 17:17:46,187 - INFO - __main__ - Added special tokens to tokenizer: {'bos_token': '<s>'}
2025-11-30 17:17:46,188 - INFO - __main__ - Creating VisionEncoderDecoderModel from pretrained encoder+decoder...


ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DeepseekV3Config, DiffLlamaConfig, ElectraConfig, Emu3Config, ErnieConfig, FalconConfig, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, Gemma3Config, Gemma3TextConfig, GitConfig, GlmConfig, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GraniteConfig, GraniteMoeConfig, GraniteMoeSharedConfig, HeliumConfig, JambaConfig, JetMoeConfig, LlamaConfig, Llama4Config, Llama4TextConfig, MambaConfig, Mamba2Config, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MistralConfig, MixtralConfig, MllamaConfig, MoshiConfig, MptConfig, MusicgenConfig, MusicgenMelodyConfig, MvpConfig, NemotronConfig, OlmoConfig, Olmo2Config, OlmoeConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, Phi3Config, Phi4MultimodalConfig, PhimoeConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, Qwen2MoeConfig, Qwen3Config, Qwen3MoeConfig, RecurrentGemmaConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, Speech2Text2Config, StableLmConfig, Starcoder2Config, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, ZambaConfig, Zamba2Config.

In [19]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast

MODEL_DIR = "./tsr_vit_gpt2_out/checkpoint-1250"

print("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VisionEncoderDecoderModel.from_pretrained(MODEL_DIR).to(device)
image_processor = ViTImageProcessor.from_pretrained(MODEL_DIR)
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_DIR)

model.eval()
print("Model loaded.")

def predict_structure(image_path):
    """Run TSR model on a single image."""
    try:
        image = Image.open(image_path).convert("RGB")
    except:
        return "Error: Could not open image."

    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)

    output_ids = model.generate(
        pixel_values,
        max_length=256,
        num_beams=4,
        early_stopping=True
    )

    pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return pred.strip()

class TSRGUI:
    def __init__(self, master):
        self.master = master
        master.title("Table Structure Recognition - ViT-GPT2")

        self.img_label = tk.Label(master, text="Select an image to begin.")
        self.img_label.pack(pady=10)

        self.select_button = tk.Button(master, text="Choose Image", command=self.select_image)
        self.select_button.pack()

        self.output_text = tk.Text(master, height=10, width=80)
        self.output_text.pack(pady=10)

    def select_image(self):
        file_path = filedialog.askopenfilename(
            title="Select an image",
            filetypes=[("Images", "*.png *.jpg *.jpeg *.bmp *.tiff")]
        )

        if not file_path:
            return

        img = Image.open(file_path)
        img.thumbnail((400, 400))
        img_tk = ImageTk.PhotoImage(img)

        self.img_label.configure(image=img_tk)
        self.img_label.image = img_tk  

        prediction = predict_structure(file_path)

        self.output_text.delete("1.0", tk.END)
        self.output_text.insert(tk.END, prediction)


if __name__ == "__main__":
    root = tk.Tk()
    gui = TSRGUI(root)
    root.mainloop()


Loading model...
Model loaded.
