# Vision Transformer for Document Understanding - Rapid Prototype Development
// Discussion placeholder //

// Outline //

* Document Understanding Model Implementation
* Data Preparation Pipeline
* Multi-Lingual Text Processing

In [None]:
# DOCUMENT UNDERSTANDING MODEL IMPLEMENTATION

import tensorflow as tf
from transformers import DonutProcessor, VisionEncoderDecoderModel
import numpy as np

class DocumentExtractor(tf.keras.Model):
    def __init__(self, num_classes):
        super(DocumentExtractor, self).__init__()
        
        # Load pre-trained Donut model
        self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        self.base_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
        
        # Freeze base model layers
        for layer in self.base_model.encoder.layers:
            layer.trainable = False
            
        # Custom layers for specific field extraction
        self.estimate_number_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(32)  # Output dimension for estimate number
        ])
        
        self.address_head = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dense(64)  # Output dimension for address
        ])
        
        self.total_amount_head = tf.keras.Sequential([
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(1)  # Single output for total amount
        ])

    def preprocess_image(self, image_path):
        # Load and preprocess image using Donut processor
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [224, 224])
        pixel_values = self.processor(image, return_tensors="tf").pixel_values
        return pixel_values

    def call(self, inputs):
        # Get base model features
        features = self.base_model.encoder(inputs)
        
        # Extract specific fields using custom heads
        estimate_number = self.estimate_number_head(features)
        address = self.address_head(features)
        total_amount = self.total_amount_head(features)
        
        return {
            'estimate_number': estimate_number,
            'address': address,
            'total_amount': total_amount
        }

    def train_step(self, data):
        images, labels = data
        
        with tf.GradientTape() as tape:
            # Forward pass
            predictions = self(images, training=True)
            
            # Calculate losses for each head
            estimate_loss = tf.keras.losses.SparseCategoricalCrossentropy()(
                labels['estimate_number'], predictions['estimate_number'])
            address_loss = tf.keras.losses.SparseCategoricalCrossentropy()(
                labels['address'], predictions['address'])
            total_loss = tf.keras.losses.MeanSquaredError()(
                labels['total_amount'], predictions['total_amount'])
            
            # Combine losses
            total_loss = estimate_loss + address_loss + total_loss
        
        # Compute gradients and update weights
        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        return {
            'loss': total_loss,
            'estimate_loss': estimate_loss,
            'address_loss': address_loss,
            'total_loss': total_loss
        }

In [None]:
# DATA PREPARATION PIPELINE
import pandas as pd
import tensorflow as tf
from pathlib import Path
import json

class DocumentDataPreprocessor:
    def __init__(self, image_dir, labels_csv):
        self.image_dir = Path(image_dir)
        self.labels_df = pd.read_csv(labels_csv)
        
        # Create vocabularies for text fields
        self.estimate_tokenizer = self._create_estimate_tokenizer()
        self.address_tokenizer = self._create_address_tokenizer()
    
    def _create_estimate_tokenizer(self):
        # Create simple numeric tokenizer for estimate numbers
        all_estimates = self.labels_df['estimate_number'].astype(str).tolist()
        tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=10000,  # Adjust based on your dataset
            filters='',  # Keep all characters
            lower=False,  # Preserve case
            oov_token='<UNK>'
        )
        tokenizer.fit_on_texts(all_estimates)
        return tokenizer
    
    def _create_address_tokenizer(self):
        # Create tokenizer for addresses
        all_addresses = self.labels_df['address'].astype(str).tolist()
        tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=10000,  # Adjust based on your dataset
            filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
            lower=True,
            oov_token='<UNK>'
        )
        tokenizer.fit_on_texts(all_addresses)
        return tokenizer
    
    def preprocess_single_example(self, image_path, labels):
        # Load and preprocess image
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [224, 224])
        image = tf.keras.applications.imagenet_utils.preprocess_input(image)
        
        # Process labels
        estimate_tokens = self.estimate_tokenizer.texts_to_sequences([str(labels['estimate_number'])])[0]
        address_tokens = self.address_tokenizer.texts_to_sequences([str(labels['address'])])[0]
        total_amount = float(labels['total_amount'])
        
        return {
            'image': image,
            'labels': {
                'estimate_number': estimate_tokens,
                'address': address_tokens,
                'total_amount': total_amount
            }
        }
    
    def create_dataset(self, batch_size=32):
        # Create list of image paths and corresponding labels
        image_paths = []
        labels = []
        
        for _, row in self.labels_df.iterrows():
            image_path = self.image_dir / f"{row['image_filename']}"
            if image_path.exists():
                image_paths.append(str(image_path))
                labels.append({
                    'estimate_number': row['estimate_number'],
                    'address': row['address'],
                    'total_amount': row['total_amount']
                })
        
        # Create tensorflow dataset
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
        dataset = dataset.map(lambda x, y: self.preprocess_single_example(x, y))
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset
    
    def save_vocabularies(self, output_dir):
        # Save tokenizer vocabularies for later use
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Save estimate number vocabulary
        with open(output_dir / 'estimate_vocab.json', 'w') as f:
            json.dump(self.estimate_tokenizer.word_index, f)
        
        # Save address vocabulary
        with open(output_dir / 'address_vocab.json', 'w') as f:
            json.dump(self.address_tokenizer.word_index, f)

In [None]:
# MULTI-LINGUAL TEXT PROCESSING
import spacy
from transformers import MarianMTTokenizer, MarianMTModel
import tensorflow as tf

class MultilingualTextProcessor:
    def __init__(self):
        # Load Spanish and English language models
        self.nlp_es = spacy.load('es_core_news_sm')
        self.nlp_en = spacy.load('en_core_web_sm')
        
        # Load translation model
        self.translator_es_en = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-es-en')
        self.translator_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-es-en')
    
    def detect_language(self, text):
        # Simple language detection based on spaCy models
        es_doc = self.nlp_es(text)
        en_doc = self.nlp_en(text)
        
        # Compare confidence scores
        es_score = sum(token.prob for token in es_doc)
        en_score = sum(token.prob for token in en_doc)
        
        return 'es' if es_score > en_score else 'en'
    
    def translate_to_english(self, text):
        # Translate Spanish text to English if needed
        if self.detect_language(text) == 'es':
            inputs = self.translator_tokenizer(text, return_tensors="pt", padding=True)
            translated = self.translator_es_en.generate(**inputs)
            return self.translator_tokenizer.decode(translated[0], skip_special_tokens=True)
        return text
    
    def process_description(self, text):
        # Process mixed language description field
        # First, detect main language
        main_lang = self.detect_language(text)
        
        # Translate if Spanish
        english_text = self.translate_to_english(text)
        
        # Extract key information (customize based on your needs)
        doc = self.nlp_en(english_text)
        
        # Extract relevant entities
        entities = {
            'locations': [ent.text for ent in doc.ents if ent.label_ in ['LOC', 'GPE']],
            'quantities': [ent.text for ent in doc.ents if ent.label_ == 'QUANTITY'],
            'dates': [ent.text for ent in doc.ents if ent.label_ == 'DATE']
        }
        
        return {
            'original_text': text,
            'translated_text': english_text if main_lang == 'es' else None,
            'main_language': main_lang,
            'entities': entities,
            'normalized_text': ' '.join([token.text for token in doc])
        }
    
    def preprocess_multilingual_field(self, text, field_type):
        """
        Preprocess specific fields based on their type
        """
        if field_type == 'description':
            return self.process_description(text)
        elif field_type == 'address':
            # Addresses usually don't need translation
            return text.strip()
        elif field_type == 'amount':
            # Handle numeric values consistently
            return text.replace('$', '').strip()
        else:
            return text.strip()

In [None]:
# IMPLEMENTATION

# Prepare data
preprocessor = DocumentDataPreprocessor('path/to/images', 'labels.csv')
train_dataset = preprocessor.create_dataset()

# Initialize and train model
model = DocumentExtractor(num_classes=YOUR_NUM_CLASSES)
model.compile(optimizer='adam')
model.fit(train_dataset, epochs=10)

# Process mixed language content
text_processor = MultilingualTextProcessor()
processed_text = text_processor.process_description(description_text)