In [5]:
%%writefile check_setup.py
"""
Pre-flight Check Script (Final Fix - Solves Double Counting & sample_id)
========================================================================
"""
import os
import sys
import glob
import json
import argparse
from pathlib import Path

# ============================================================================
# GLOBAL VARIABLES
# ============================================================================
IMAGES_DIR = None
JSONL_ROOT_DIR = None

def print_section(title):
    print(f"\n{'='*80}")
    print(f"{title}")
    print(f"{'='*80}\n")

def check_dependencies():
    print_section("CHECKING DEPENDENCIES")
    dependencies = {
        'torch': 'PyTorch', 'transformers': 'Hugging Face Transformers',
        'PIL': 'Pillow', 'Levenshtein': 'python-Levenshtein', 'sentencepiece': 'SentencePiece'
    }
    missing = []
    for module, name in dependencies.items():
        try:
            __import__(module)
            print(f"[OK] {name} installed")
        except ImportError:
            print(f"[X]  {name} NOT installed")
            missing.append(name)
    
    if missing:
        print(f"\n[!] Missing: {', '.join(missing)}")
        return False
    print("\n[OK] All dependencies installed!")
    return True

def check_gpu():
    print_section("CHECKING GPU")
    try:
        import torch
        if torch.cuda.is_available():
            print(f"[OK] GPU: {torch.cuda.get_device_name(0)}")
            return True
        print("[X]  No GPU detected (Expected for local check)")
        return False
    except Exception as e:
        print(f"[X]  Error: {e}")
        return False

def check_paths():
    print_section("CHECKING PATHS")
    global IMAGES_DIR, JSONL_ROOT_DIR
    valid = True
    
    # --- תיקון לבעיית הכפילות בווינדוס ---
    if os.path.exists(IMAGES_DIR):
        image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tiff', '*.bmp']
        unique_images = set()
        
        for ext in image_extensions:
            # חיפוש גם לקטן וגם לגדול
            files = glob.glob(os.path.join(IMAGES_DIR, ext)) + glob.glob(os.path.join(IMAGES_DIR, ext.upper()))
            for f in files:
                # שימוש ב-abspath מבטיח שכל קובץ נספר פעם אחת בלבד
                unique_images.add(os.path.abspath(f))
        
        count = len(unique_images)
        print(f"[OK] Images dir exists")
        print(f"     Found {count} unique images")
        
        if count == 0: 
            print("  [!] WARNING: No images found!")
            valid = False
    else:
        print(f"[X]  Images dir NOT found: {IMAGES_DIR}")
        valid = False

    if os.path.exists(JSONL_ROOT_DIR):
        print(f"[OK] JSONL dir exists")
    else:
        print(f"[X]  JSONL dir NOT found: {JSONL_ROOT_DIR}")
        valid = False
    return valid

def check_jsonl_content():
    """Combined check for JSON validity and Image Mapping"""
    print_section("CHECKING JSONL CONTENT & MAPPING")
    global JSONL_ROOT_DIR, IMAGES_DIR
    
    files = glob.glob(os.path.join(JSONL_ROOT_DIR, "*.jsonl"))
    if not files:
        files = glob.glob(os.path.join(JSONL_ROOT_DIR, "**", "*.jsonl"), recursive=True)
        
    if not files:
        print(f"[X] No .jsonl files found")
        return False
    
    print(f"[OK] Found {len(files)} JSONL files")
    
    total_entries = 0
    total_mapped = 0
    total_missing = 0
    total_json_errors = 0
    
    for f in files:
        print(f"\nScanning: {os.path.basename(f)}")
        file_entries = 0
        
        with open(f, 'r', encoding='utf-8') as fp:
            for i, line in enumerate(fp, 1):
                if not line.strip(): continue
                
                try:
                    entry = json.loads(line)
                    file_entries += 1
                    
                    # 1. Check for Image Field (כולל התיקון ל-sample_id)
                    image_name = None
                    
                    if 'sample_id' in entry:
                        sid = entry['sample_id']
                        # בדיקה חכמה: קודם jpg ואז png
                        if os.path.exists(os.path.join(IMAGES_DIR, f"datasheet_{sid}.jpg")):
                            image_name = f"datasheet_{sid}.jpg"
                        elif os.path.exists(os.path.join(IMAGES_DIR, f"datasheet_{sid}.png")):
                            image_name = f"datasheet_{sid}.png"
                        else:
                            # אם לא נמצא, נרשום את מה שהיה אמור להיות (כדי לספור אותו כחסר)
                            image_name = f"datasheet_{sid}.jpg" 
                            
                    elif 'image_path' in entry: image_name = os.path.basename(entry['image_path'])
                    elif 'image' in entry: image_name = os.path.basename(entry['image'])
                    elif 'filename' in entry: image_name = os.path.basename(entry['filename'])
                    
                    if not image_name:
                        # שגיאה אמיתית: אין שום דרך לדעת מה התמונה
                        print(f"  [X] Line {i}: No 'sample_id' or image path found")
                        total_json_errors += 1
                        continue

                    # 2. Verify Image Exists
                    full_path = os.path.join(IMAGES_DIR, image_name)
                    if os.path.exists(full_path):
                        total_mapped += 1
                    else:
                        if total_missing < 5: # מדפיס רק את ה-5 הראשונים החסרים
                            print(f"  [!] Missing image: {image_name}")
                        total_missing += 1
                        
                except json.JSONDecodeError:
                    print(f"  [X] Line {i}: Invalid JSON")
                    total_json_errors += 1
        
        total_entries += file_entries
        print(f"  Entries: {file_entries}")

    print(f"\n{'='*80}")
    print(f"TOTAL SUMMARY:")
    print(f"  Total JSON Entries: {total_entries}")
    print(f"  Successfully Mapped: {total_mapped}")
    print(f"  Missing Images:      {total_missing}")
    print(f"  JSON/Format Errors:  {total_json_errors}")
    
    # לוגיקה לעבור/לא לעבור
    if total_entries == 0:
        print("\n[X] No data found.")
        return False
        
    if total_mapped == 0:
        print("\n[X] Critical: 0 images mapped. Check folder paths.")
        return False
        
    if total_missing > 0:
        print(f"\n[!] Warning: {total_missing} images are missing.")
        print("    The training script will simply skip these. This is acceptable.")
        
    print(f"\n[OK] Data validation passed ({total_mapped} valid samples ready)")
    return True

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--images_dir", required=True)
    parser.add_argument("--jsonl_dir", required=True)
    args = parser.parse_args()
    
    global IMAGES_DIR, JSONL_ROOT_DIR
    IMAGES_DIR = args.images_dir
    JSONL_ROOT_DIR = args.jsonl_dir
    
    checks = [
        ("Dependencies", check_dependencies),
        ("GPU", check_gpu),
        ("Paths", check_paths),
        ("Content & Mapping", check_jsonl_content)
    ]
    
    all_pass = True
    for name, func in checks:
        try:
            if not func(): all_pass = False
        except Exception as e:
            print(f"[X] Crash in {name}: {e}")
            all_pass = False
            
    print_section("FINAL STATUS")
    if all_pass:
        print("[OK] READY TO TRAIN")
        print("     (Remember to run on a machine with GPU!)")
    else:
        print("[X] FIX ERRORS ABOVE")

if __name__ == "__main__":
    main()

Overwriting check_setup.py


הרצת הבדיקה

In [1]:
!python check_setup.py --images_dir "C:\Users\nivsa\Generation of Synthetic Training Data\embedded\final_dataset\images" --jsonl_dir "C:\Users\nivsa\Generation of Synthetic Training Data\embedded\final_dataset"


CHECKING DEPENDENCIES

[OK] PyTorch installed
[OK] Hugging Face Transformers installed
[OK] Pillow installed
[OK] python-Levenshtein installed
[OK] SentencePiece installed

[OK] All dependencies installed!

CHECKING GPU

[X]  No GPU detected (Expected for local check)

CHECKING PATHS

[OK] Images dir exists
     Found 6737 unique images
[OK] JSONL dir exists

CHECKING JSONL CONTENT & MAPPING

[OK] Found 4 JSONL files

Scanning: production_20260208_001537.jsonl
  Entries: 2000

Scanning: production_20260208_035545.jsonl
  Entries: 2000

Scanning: production_20260208_070204.jsonl
  Entries: 2000

Scanning: production_20260208_090359.jsonl
  [!] Missing image: datasheet_d0cf48ff.jpg
  [!] Missing image: datasheet_d6495cf3.jpg
  [!] Missing image: datasheet_36d8ac6f.jpg
  [!] Missing image: datasheet_6af0fa73.jpg
  [!] Missing image: datasheet_ec9e09e7.jpg
  Entries: 755

TOTAL SUMMARY:
  Total JSON Entries: 6755
  Successfully Mapped: 6742
  Missing Images:      13
  JSON/Format Errors: 

In [5]:
%%writefile train.py
"""
Production-Ready Donut Fine-tuning Script for Electronic Datasheets
====================================================================
Author: Senior Computer Vision & NLP Engineer
Purpose: Fine-tune naver-clova-ix/donut-base on synthetic electronic datasheets
         with dense tables, small text, technical units, and charts.

Key Features:
- Multi-JSONL file aggregation with robust path handling
- Automatic special token discovery and injection
- High-resolution image processing (1600x1200) for small text
- Mixed precision training (FP16)
- Levenshtein distance metric for evaluation
- Stratified Splitting (Critical for balanced training)
"""

import os
import glob
import json
import random
import argparse
from typing import Dict, List, Any, Tuple
from pathlib import Path
from collections import Counter

import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import (
    VisionEncoderDecoderModel,
    DonutProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
# from datasets import Dataset as HFDataset # Not strictly needed here
import Levenshtein


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Centralized configuration for training"""
    
    # Data paths - will be set via command line arguments
    IMAGES_DIR = None
    JSONL_ROOT_DIR = None
    
    # Model configuration
    MODEL_NAME = "naver-clova-ix/donut-base"
    OUTPUT_DIR = "./donut-datasheets-finetuned"
    
    # High-resolution settings for small text in datasheets
    # (1600, 1200) is a sweet spot for 16GB-24GB GPUs
    IMAGE_SIZE = (1600, 1200)  # (width, height)
    
    # Generation settings - long context for verbose datasheets
    MAX_LENGTH = 1024
    
    # Training hyperparameters
    TRAIN_BATCH_SIZE = 2
    GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch size = 16
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 10
    WARMUP_STEPS = 500
    
    # Data split
    TRAIN_SPLIT_RATIO = 0.9
    
    # Evaluation
    EVAL_STEPS = 500
    SAVE_STEPS = 500
    SAVE_TOTAL_LIMIT = 2
    
    # Seed for reproducibility
    SEED = 42


# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================

def load_all_jsonl_files(jsonl_root_dir: str) -> List[Dict[str, Any]]:
    """Load and aggregate all JSONL files from the root directory."""
    print(f"\n{'='*80}")
    print(f"Loading JSONL files from: {jsonl_root_dir}")
    print(f"{'='*80}")
    
    # Find all JSONL files using glob (handles Windows paths correctly)
    jsonl_pattern = os.path.join(jsonl_root_dir, "*.jsonl")
    jsonl_files = glob.glob(jsonl_pattern)
    
    if not jsonl_files:
        # Try recursive search
        jsonl_pattern = os.path.join(jsonl_root_dir, "**", "*.jsonl")
        jsonl_files = glob.glob(jsonl_pattern, recursive=True)
    
    print(f"Found {len(jsonl_files)} JSONL file(s):")
    for jsonl_file in jsonl_files:
        print(f"  - {os.path.basename(jsonl_file)}")
    
    if not jsonl_files:
        raise FileNotFoundError(
            f"No JSONL files found in {jsonl_root_dir}. "
            f"Please check the path and ensure *.jsonl files exist."
        )
    
    # Aggregate all entries
    all_entries = []
    
    for jsonl_file in jsonl_files:
        print(f"\nProcessing: {jsonl_file}")
        
        with open(jsonl_file, 'r', encoding='utf-8') as f:
            file_entries = 0
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                
                try:
                    entry = json.loads(line)
                    all_entries.append(entry)
                    file_entries += 1
                except json.JSONDecodeError as e:
                    print(f"  WARNING: Skipping invalid JSON at line {line_num}: {e}")
                    continue
            
            print(f"  Loaded {file_entries} entries")
    
    print(f"\n{'='*80}")
    print(f"Total entries loaded: {len(all_entries)}")
    print(f"{'='*80}\n")
    
    return all_entries


def construct_image_paths(entries: List[Dict], images_dir: str) -> List[Dict]:
    """Construct full image paths from filenames in JSONL entries."""
    print("Constructing full image paths...")
    
    processed_entries = []
    missing_images = []
    
    for entry in entries:
        image_name = None
        
        # --- Logic to find image name from sample_id ---
        if 'sample_id' in entry:
            sid = entry['sample_id']
            # Try jpg first
            if os.path.exists(os.path.join(images_dir, f"datasheet_{sid}.jpg")):
                image_name = f"datasheet_{sid}.jpg"
            elif os.path.exists(os.path.join(images_dir, f"datasheet_{sid}.png")):
                image_name = f"datasheet_{sid}.png"
            else:
                image_name = f"datasheet_{sid}.jpg" # Fallback
        
        # Support legacy keys
        elif 'image_path' in entry: image_name = os.path.basename(entry['image_path'])
        elif 'image' in entry: image_name = os.path.basename(entry['image'])
        elif 'filename' in entry: image_name = os.path.basename(entry['filename'])
        
        if not image_name:
            # print(f"  WARNING: Entry missing image filename") # Too verbose
            continue
        
        # Construct full path
        full_image_path = os.path.join(images_dir, image_name)
        
        # Verify image exists
        if not os.path.exists(full_image_path):
            missing_images.append(full_image_path)
            continue
        
        # Create processed entry
        processed_entry = {
            'image_path': full_image_path,
            'ground_truth': entry.get('ground_truth', {}),
            # Keep original fields for stratification
            'component_type': entry.get('component_type', 'unknown') 
        }
        
        processed_entries.append(processed_entry)
    
    print(f"  Successfully processed: {len(processed_entries)} entries")
    
    if missing_images:
        print(f"  WARNING: {len(missing_images)} images not found (first 5):")
        for img_path in missing_images[:5]:
            print(f"    - {img_path}")
    
    return processed_entries


def extract_unique_keys(entries: List[Dict]) -> List[str]:
    """Extract all unique JSON keys from ground truth data for special tokens."""
    print("\nExtracting unique keys from ground truth data...")
    
    unique_keys = set()
    
    def extract_keys_recursive(obj: Any):
        if isinstance(obj, dict):
            for key, value in obj.items():
                unique_keys.add(key)
                extract_keys_recursive(value)
        elif isinstance(obj, list):
            for item in obj:
                extract_keys_recursive(item)
    
    for entry in entries:
        ground_truth = entry.get('ground_truth', {})
        extract_keys_recursive(ground_truth)
    
    sorted_keys = sorted(list(unique_keys))
    
    print(f"  Found {len(sorted_keys)} unique keys:")
    for key in sorted_keys[:20]:
        print(f"    - {key}")
    if len(sorted_keys) > 20:
        print(f"    ... and {len(sorted_keys) - 20} more")
    
    return sorted_keys


def split_train_val_stratified(
    entries: List[Dict], 
    train_ratio: float = 0.9, 
    seed: int = 42,
    stratify_key: str = "component_type"
) -> Tuple[List[Dict], List[Dict]]:
    """
    Split data into training and validation sets with STRATIFICATION.
    Ensures rare component types are represented in both sets.
    """
    print(f"\nSplitting data with STRATIFICATION on '{stratify_key}'")
    print(f"  Train ratio: {train_ratio*100:.0f}%")
    print(f"  Validation ratio: {(1-train_ratio)*100:.0f}%")
    
    random.seed(seed)
    
    # Group entries by component type
    groups = {}
    entries_without_key = []
    
    for entry in entries:
        # Check both top-level and inside ground_truth
        strat_value = entry.get(stratify_key) or entry.get('ground_truth', {}).get(stratify_key)
        
        if strat_value:
            if strat_value not in groups:
                groups[strat_value] = []
            groups[strat_value].append(entry)
        else:
            entries_without_key.append(entry)
    
    print(f"\nFound {len(groups)} categories:")
    for category, items in sorted(groups.items(), key=lambda x: len(x[1]), reverse=True):
        print(f"  - {category}: {len(items)} samples")
    
    if entries_without_key:
        print(f"  - [No {stratify_key}]: {len(entries_without_key)} samples")
    
    train_entries = []
    val_entries = []
    
    # Split each group
    for category, items in groups.items():
        random.shuffle(items)
        split_idx = max(1, int(len(items) * train_ratio)) # Ensure at least 1 in train
        
        # If we have only 1 sample, put it in train (otherwise validation crashes)
        if len(items) == 1:
            train_entries.extend(items)
        else:
            train_entries.extend(items[:split_idx])
            val_entries.extend(items[split_idx:])
        
        print(f"  {category}: {len(items[:split_idx]) if len(items)>1 else 1} train, {len(items[split_idx:]) if len(items)>1 else 0} val")
    
    # Handle unlabeled
    if entries_without_key:
        random.shuffle(entries_without_key)
        split_idx = int(len(entries_without_key) * train_ratio)
        train_entries.extend(entries_without_key[:split_idx])
        val_entries.extend(entries_without_key[split_idx:])
    
    random.shuffle(train_entries)
    random.shuffle(val_entries)
    
    print(f"\n{'='*80}")
    print(f"FINAL SPLIT:")
    print(f"  Training samples: {len(train_entries)}")
    print(f"  Validation samples: {len(val_entries)}")
    print(f"{'='*80}")
    
    return train_entries, val_entries


# ============================================================================
# CUSTOM DATASET CLASS
# ============================================================================

class DonutDatasheetDataset(Dataset):
    def __init__(
        self,
        entries: List[Dict],
        processor: DonutProcessor,
        max_length: int = 1024,
        task_start_token: str = "<s_datasheet>",
    ):
        self.entries = entries
        self.processor = processor
        self.max_length = max_length
        self.task_start_token = task_start_token
    
    def __len__(self) -> int:
        return len(self.entries)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        entry = self.entries[idx]
        
        # Load image
        image_path = entry['image_path']
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            image = Image.new("RGB", (100, 100), color=(255, 255, 255))
        
        # Process image
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
        
        # Prepare ground truth
        ground_truth = entry['ground_truth']
        ground_truth_json = json.dumps(ground_truth, ensure_ascii=False)
        
        # Create target sequence
        target_sequence = f"{self.task_start_token}{ground_truth_json}</s>"
        
        # Tokenize
        labels = self.processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).input_ids.squeeze()
        
        # Ignore padding in loss
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
        }


# ============================================================================
# METRICS
# ============================================================================

def compute_metrics(pred, processor: DonutProcessor) -> Dict[str, float]:
    predictions = pred.predictions
    labels = pred.label_ids
    
    labels = [[token if token != -100 else processor.tokenizer.pad_token_id for token in label] for label in labels]
    
    decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    total_distance = 0
    total_length = 0
    
    for pred_str, label_str in zip(decoded_preds, decoded_labels):
        distance = Levenshtein.distance(pred_str, label_str)
        total_distance += distance
        total_length += len(label_str)
    
    normalized_distance = total_distance / max(total_length, 1)
    
    return {"normalized_edit_distance": normalized_distance}


# ============================================================================
# MAIN
# ============================================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--images_dir", type=str, required=True)
    parser.add_argument("--jsonl_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="./donut-datasheets-finetuned")
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--max_samples", type=int, default=None)
    
    args = parser.parse_args()
    
    Config.IMAGES_DIR = args.images_dir
    Config.JSONL_ROOT_DIR = args.jsonl_dir
    Config.OUTPUT_DIR = args.output_dir
    Config.TRAIN_BATCH_SIZE = args.batch_size
    Config.NUM_EPOCHS = args.epochs
    Config.LEARNING_RATE = args.learning_rate
    
    print("\n" + "="*80)
    print("DONUT FINE-TUNING (Production Version)")
    print("="*80 + "\n")
    
    # 1. Load Data
    all_entries = load_all_jsonl_files(Config.JSONL_ROOT_DIR)
    all_entries = construct_image_paths(all_entries, Config.IMAGES_DIR)
    
    if args.max_samples:
        print(f"\n[!] Limiting dataset to {args.max_samples} samples for testing")
        random.shuffle(all_entries)
        all_entries = all_entries[:args.max_samples]
    
    # 2. Prepare Tokens
    unique_keys = extract_unique_keys(all_entries)
    special_tokens = [f"<s_{key}>" for key in unique_keys]
    new_tokens = ["<s_datasheet>"] + special_tokens
    
    # 3. Stratified Split
    train_entries, val_entries = split_train_val_stratified(
        all_entries, 
        train_ratio=Config.TRAIN_SPLIT_RATIO,
        seed=Config.SEED
    )
    
    # Validation check to prevent crash
    if len(val_entries) == 0:
        print("[!] Warning: Validation set empty. Moving 1 sample from train to val.")
        val_entries = [train_entries.pop()]

    # 4. Model & Processor
    print("\nLoading model...")
    processor = DonutProcessor.from_pretrained(Config.MODEL_NAME)
    processor.image_processor.size = {"height": Config.IMAGE_SIZE[1], "width": Config.IMAGE_SIZE[0]}
    processor.image_processor.do_align_long_axis = False
    
    processor.tokenizer.add_tokens(new_tokens)
    
    model = VisionEncoderDecoderModel.from_pretrained(Config.MODEL_NAME)
    model.decoder.resize_token_embeddings(len(processor.tokenizer))
    model.config.max_length = Config.MAX_LENGTH
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<s_datasheet>"])[0]
    
    # 5. Datasets
    train_ds = DonutDatasheetDataset(train_entries, processor, Config.MAX_LENGTH)
    val_ds = DonutDatasheetDataset(val_entries, processor, Config.MAX_LENGTH)
    
    # 6. Training Arguments (FIXED)
    training_args = Seq2SeqTrainingArguments(
        output_dir=Config.OUTPUT_DIR,
        num_train_epochs=Config.NUM_EPOCHS,
        per_device_train_batch_size=Config.TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=Config.TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
        learning_rate=Config.LEARNING_RATE,
        warmup_steps=Config.WARMUP_STEPS,
        fp16=torch.cuda.is_available(),
        
        # --- THE FIX ---
        eval_strategy="steps",  # Was evaluation_strategy
        eval_steps=Config.EVAL_STEPS,
        save_strategy="steps",
        save_steps=Config.SAVE_STEPS,
        save_total_limit=Config.SAVE_TOTAL_LIMIT,
        # ---------------
        
        predict_with_generate=True,
        logging_steps=100,
        remove_unused_columns=False,
        load_best_model_at_end=True,
        metric_for_best_model="normalized_edit_distance",
        greater_is_better=False
    )
    
    # 7. Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=lambda p: compute_metrics(p, processor)
    )
    
    print("\nStarting training...")
    trainer.train()
    
    print(f"\nSaving model to {Config.OUTPUT_DIR}")
    trainer.save_model(Config.OUTPUT_DIR)
    processor.save_pretrained(Config.OUTPUT_DIR)
    print("DONE!")

if __name__ == "__main__":
    main()

Overwriting train.py


הרצת קוד האימון

In [None]:
!python train.py --images_dir "C:\Users\nivsa\Generation of Synthetic Training Data\embedded\final_dataset\images" --jsonl_dir "C:\Users\nivsa\Generation of Synthetic Training Data\embedded\final_dataset" --batch_size 2 --epochs 1 --max_samples 10

בדיקת התוצאות

In [None]:
# ----------------------------------------------------------------------------
# Donut Inference for Jupyter Notebook
# ----------------------------------------------------------------------------

# --- Imports ---
import json
import os
import re
import torch
from pathlib import Path
from PIL import Image
from transformers import VisionEncoderDecoderModel, DonutProcessor
import matplotlib.pyplot as plt

# ============================================================================
# CONFIGURATION (Change these paths!)
# ============================================================================

# Path to your fine-tuned model folder (where train.py saved it)
MODEL_PATH = "./donut-datasheets-finetuned"

# Path to the image you want to test
IMAGE_PATH = r"C:\Users\nivsa\Generation of Synthetic Training Data\embedded\final_dataset\images\datasheet_00eb2cd6.jpg"

# Device (cuda or cpu)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ============================================================================
# FUNCTIONS
# ============================================================================

def load_model(model_path, device="cuda"):
    print(f"Loading model from: {model_path}")
    try:
        processor = DonutProcessor.from_pretrained(model_path)
        model = VisionEncoderDecoderModel.from_pretrained(model_path)
    except OSError as e:
        print(f"[X] Error: Could not load model from {model_path}")
        print(f"  Details: {e}")
        return None, None

    model.to(device)
    model.eval()
    
    print(f"[OK] Model loaded on {device}")
    print(f"[OK] Input Resolution: {processor.image_processor.size}")
    return model, processor

def extract_json(image_path, model, processor, device="cuda"):
    print(f"\nProcessing: {os.path.basename(image_path)}")
    
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        return {"error": str(e)}

    # Display image in notebook
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    # Prepare input
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    task_prompt = "<s_datasheet>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

    # Generate
    print("Generating output...")
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=1024,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            return_dict_in_generate=True,
        )

    # Decode
    seq = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
    clean_seq = seq.replace(task_prompt, "").strip()
    
    # Try parsing JSON
    try:
        return json.loads(clean_seq)
    except json.JSONDecodeError:
        # Fallback: Try to find JSON object in text
        match = re.search(r'\{.*\}', clean_seq, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(0))
            except:
                pass
        
        # Fallback: Fix missing braces
        if clean_seq.startswith("{") and not clean_seq.endswith("}"):
            try:
                return json.loads(clean_seq + "}")
            except:
                pass
                
        return {"raw_output": clean_seq, "error": "Failed to parse JSON"}

# ============================================================================
# MAIN EXECUTION
# ============================================================================

# 1. Load Model (Only runs once if you keep the cell output)
if 'model' not in globals():
    model, processor = load_model(MODEL_PATH, DEVICE)

# 2. Run Inference
if model:
    result = extract_json(IMAGE_PATH, model, processor, DEVICE)
    
    print("\n" + "="*60)
    print("EXTRACTION RESULT")
    print("="*60)
    print(json.dumps(result, indent=2, ensure_ascii=False))