In [None]:
# LayoutLMv3 Document Type Classification - 17 Document Types (FIXED VERSION)

# This notebook implements a document type classifier using LayoutLMv3 that can handle 17 different document types.
# Key fixes applied:
# - Fixed OCR integration and error handling
# - Improved tensor handling and device management
# - Fixed collate function and data loading
# - Better memory management and batch processing
# - Added proper validation and error recovery

import os
import cv2
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    LayoutLMv3Processor, 
    LayoutLMv3ForSequenceClassification,
    LayoutLMv3Config,
    get_linear_schedule_with_warmup
)
from PIL import Image, ImageDraw, ImageFont
import easyocr
import json
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Enable optimizations
torch.backends.cudnn.benchmark = True
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Configuration for 17 document types
CONFIG = {
    'ROOT_DIR': "/home/hasan/datasets/classify/test",  # Update this path to your dataset
    'DOC_TYPES': [
        "budget", "ID", "invoice", "form", "memo", "letter",
        "advertisement", "receipt", "scientific_report", "email", "scientific_publication",
        "handwritten", "news_article", "presentation", "resume", "questionnaire", "specification"
    ],  # 17 document types
    'BATCH_SIZE': 2,  # Reduced for better memory management
    'MAX_LENGTH': 512,
    'EPOCHS': 20,
    'LEARNING_RATE': 2e-5,
    'USE_VISUAL': True,
    'WARMUP_STEPS': 500,
    'PATIENCE': 5,
    'MODEL_NAME': "microsoft/layoutlmv3-base"
}

class DocumentTypeDataset(Dataset):
    def __init__(self, samples, processor, max_length=512, use_visual=True):
        self.samples = samples
        self.processor = processor
        self.max_length = max_length
        self.use_visual = use_visual
        
        # Initialize OCR reader with error handling
        try:
            self.ocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
            print("OCR reader initialized successfully")
        except Exception as e:
            print(f"Warning: OCR initialization failed: {e}")
            self.ocr_reader = None
        
        # Document type mapping for 17 types
        self.doc_type_mapping = {
            doc_type: idx for idx, doc_type in enumerate(CONFIG['DOC_TYPES'])
        }
        
        print(f"Document type mapping: {self.doc_type_mapping}")
    
    def __len__(self):
        return len(self.samples)
    
    def extract_text_and_boxes(self, image):
        """Extract text and bounding boxes using OCR with improved error handling"""
        try:
            if self.ocr_reader is None:
                # Fallback: create dummy text and boxes
                return ["sample text"], [[50, 50, 150, 70]]
            
            # Convert PIL to numpy if needed
            if isinstance(image, Image.Image):
                image_np = np.array(image)
            else:
                image_np = image
            
            # Ensure image is in correct format
            if len(image_np.shape) == 3 and image_np.shape[2] == 4:  # RGBA
                image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
            elif len(image_np.shape) == 3 and image_np.shape[2] == 3:  # RGB
                pass  # Already in correct format
            else:
                image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
            
            # Get OCR results
            results = self.ocr_reader.readtext(image_np)
            
            words = []
            boxes = []
            
            image_height, image_width = image_np.shape[:2]
            
            for (bbox, text, confidence) in results:
                if confidence > 0.3 and len(text.strip()) > 0:  # Lowered confidence threshold
                    # Get bounding box coordinates
                    x_coords = [point[0] for point in bbox]
                    y_coords = [point[1] for point in bbox]
                    
                    x1, y1 = int(min(x_coords)), int(min(y_coords))
                    x2, y2 = int(max(x_coords)), int(max(y_coords))
                    
                    # Ensure coordinates are within image bounds
                    x1 = max(0, min(x1, image_width))
                    y1 = max(0, min(y1, image_height))
                    x2 = max(x1 + 1, min(x2, image_width))
                    y2 = max(y1 + 1, min(y2, image_height))
                    
                    # Only add if box is valid
                    if x2 > x1 and y2 > y1:
                        words.append(text.strip())
                        boxes.append([x1, y1, x2, y2])
            
            # Ensure we have at least one word/box
            if not words:
                words = [""]
                boxes = [[0, 0, min(100, image_width), min(20, image_height)]]
            
            return words, boxes
            
        except Exception as e:
            print(f"OCR extraction failed: {e}")
            return [""], [[0, 0, 100, 20]]  # Fallback
    
    def preprocess_image(self, image_path):
        """Load and preprocess image with better error handling"""
        try:
            # Check if file exists
            if not os.path.exists(image_path):
                print(f"File not found: {image_path}")
                return Image.new('RGB', (224, 224), color='white')
            
            # Load image
            image = Image.open(image_path)
            
            # Convert to RGB if needed
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize if too large (for memory efficiency)
            max_size = 1000
            if image.size[0] > max_size or image.size[1] > max_size:
                # Calculate new size maintaining aspect ratio
                ratio = min(max_size / image.size[0], max_size / image.size[1])
                new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
                image = image.resize(new_size, Image.Resampling.LANCZOS)
            
            return image
            
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Create a blank image as fallback
            return Image.new('RGB', (224, 224), color='white')
    
    def __getitem__(self, idx):
        try:
            image_path, doc_type = self.samples[idx]
            
            # Load and preprocess image
            image = self.preprocess_image(image_path)
            
            # Extract text and bounding boxes
            words, boxes = self.extract_text_and_boxes(image)
            
            # Limit number of words/boxes to prevent memory issues
            max_words = 200
            if len(words) > max_words:
                words = words[:max_words]
                boxes = boxes[:max_words]
            
            # Prepare inputs for LayoutLMv3
            if self.use_visual:
                encoding = self.processor(
                    images=image,
                    text=words,
                    boxes=boxes,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
            else:
                # Text-only mode
                encoding = self.processor(
                    text=words,
                    boxes=boxes,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
            
            # Get label
            label = self.doc_type_mapping.get(doc_type, 0)  # Default to 0 if not found
            
            # Prepare return dictionary
            result = {}
            
            # Handle each tensor properly
            for key, value in encoding.items():
                if value is not None and isinstance(value, torch.Tensor):
                    # Remove batch dimension if present
                    if value.dim() > 1 and value.size(0) == 1:
                        result[key] = value.squeeze(0)
                    else:
                        result[key] = value
                elif value is not None:
                    result[key] = value
            
            result['labels'] = torch.tensor(label, dtype=torch.long)
            
            return result
            
        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            # Return fallback encoding
            fallback = {
                'input_ids': torch.zeros(self.max_length, dtype=torch.long),
                'attention_mask': torch.ones(self.max_length, dtype=torch.long),
                'bbox': torch.zeros((self.max_length, 4), dtype=torch.long),
                'labels': torch.tensor(0, dtype=torch.long)
            }
            
            if self.use_visual:
                fallback['pixel_values'] = torch.zeros((3, 224, 224), dtype=torch.float)
            
            return fallback

def collect_document_samples(root_dir, doc_types):
    """Collect document samples from directory structure with better validation"""
    samples = []
    
    if not os.path.exists(root_dir):
        print(f"Root directory does not exist: {root_dir}")
        return samples
    
    for doc_type in doc_types:
        doc_dir = os.path.join(root_dir, doc_type)
        if not os.path.exists(doc_dir):
            print(f"Document type directory not found: {doc_dir}")
            continue
            
        # Check different possible structures
        for split in ["Real", "Forged", "Images", ""]:
            if split:
                img_dir = os.path.join(doc_dir, split, "Images") if split != "Images" else os.path.join(doc_dir, split)
            else:
                img_dir = doc_dir
                
            if os.path.exists(img_dir):
                for fname in os.listdir(img_dir):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                        path = os.path.join(img_dir, fname)
                        if os.path.isfile(path):  # Ensure it's a file
                            samples.append((path, doc_type))
    
    print(f"Collected {len(samples)} samples")
    for doc_type in doc_types:
        count = sum(1 for _, dt in samples if dt == doc_type)
        if count > 0:
            print(f"  {doc_type}: {count} samples")
    
    return samples

def improved_collate_fn(batch, use_visual=True):
    """Improved collate function with better error handling"""
    if not batch:
        return {}
    
    # Get all possible keys from the batch
    all_keys = set()
    for item in batch:
        all_keys.update(item.keys())
    
    batch_dict = {}
    
    for key in all_keys:
        values = []
        for item in batch:
            if key in item and item[key] is not None:
                values.append(item[key])
            else:
                # Create appropriate default based on first valid item
                if values:
                    if isinstance(values[0], torch.Tensor):
                        default_val = torch.zeros_like(values[0])
                    else:
                        default_val = values[0]  # Use first valid value as default
                    values.append(default_val)
        
        if values:
            try:
                if isinstance(values[0], torch.Tensor):
                    # Ensure all tensors have the same shape
                    shapes = [v.shape for v in values]
                    if len(set(shapes)) == 1:  # All shapes are the same
                        batch_dict[key] = torch.stack(values)
                    else:
                        print(f"Warning: Inconsistent shapes for {key}: {shapes}")
                        # Pad to maximum size
                        max_shape = [max(s[i] for s in shapes) if i < len(s) else 1 
                                   for i in range(max(len(s) for s in shapes))]
                        padded_values = []
                        for v in values:
                            pad_sizes = []
                            for i in range(len(max_shape) - 1, -1, -1):
                                if i < len(v.shape):
                                    pad_size = max_shape[i] - v.shape[i]
                                    pad_sizes.extend([0, pad_size])
                                else:
                                    pad_sizes.extend([0, max_shape[i]])
                            if pad_sizes:
                                v_padded = torch.nn.functional.pad(v, pad_sizes)
                            else:
                                v_padded = v
                            padded_values.append(v_padded)
                        batch_dict[key] = torch.stack(padded_values)
                else:
                    batch_dict[key] = values
            except Exception as e:
                print(f"Error stacking {key}: {e}")
                # Use first value and repeat
                if isinstance(values[0], torch.Tensor):
                    batch_dict[key] = values[0].unsqueeze(0).repeat(len(batch), *[1]*len(values[0].shape))
                else:
                    batch_dict[key] = values
    
    return batch_dict

def create_dataloaders(root_dir, doc_types, processor, batch_size=2, max_length=512, use_visual=True):
    """Create train/val/test dataloaders with improved error handling"""
    
    # Collect all samples
    all_samples = collect_document_samples(root_dir, doc_types)
    
    if len(all_samples) == 0:
        raise ValueError("No samples found! Check your data directory structure.")
    
    # Ensure minimum samples per class for stratification
    sample_counts = {}
    for _, doc_type in all_samples:
        sample_counts[doc_type] = sample_counts.get(doc_type, 0) + 1
    
    # Filter out classes with too few samples
    min_samples = 3  # Minimum samples needed for train/val/test split
    valid_samples = []
    for sample in all_samples:
        if sample_counts[sample[1]] >= min_samples:
            valid_samples.append(sample)
    
    if len(valid_samples) == 0:
        raise ValueError("No classes have enough samples for splitting!")
    
    print(f"Using {len(valid_samples)} samples from {len(set(s[1] for s in valid_samples))} classes")
    
    # Split data
    try:
        train_samples, temp_samples = train_test_split(
            valid_samples, test_size=0.3, 
            stratify=[sample[1] for sample in valid_samples],
            random_state=42
        )
        
        val_samples, test_samples = train_test_split(
            temp_samples, test_size=0.5,
            stratify=[sample[1] for sample in temp_samples],
            random_state=42
        )
    except ValueError as e:
        print(f"Stratification failed: {e}")
        # Fall back to random split
        np.random.seed(42)
        np.random.shuffle(valid_samples)
        n_train = int(0.7 * len(valid_samples))
        n_val = int(0.15 * len(valid_samples))
        
        train_samples = valid_samples[:n_train]
        val_samples = valid_samples[n_train:n_train+n_val]
        test_samples = valid_samples[n_train+n_val:]
    
    print(f"Train: {len(train_samples)}, Val: {len(val_samples)}, Test: {len(test_samples)}")
    
    # Create datasets
    train_dataset = DocumentTypeDataset(train_samples, processor, max_length, use_visual)
    val_dataset = DocumentTypeDataset(val_samples, processor, max_length, use_visual)
    test_dataset = DocumentTypeDataset(test_samples, processor, max_length, use_visual)
    
    # Create dataloaders with improved collate function
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0,  # Set to 0 to avoid multiprocessing issues
        pin_memory=torch.cuda.is_available(),
        collate_fn=lambda batch: improved_collate_fn(batch, use_visual)
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
        collate_fn=lambda batch: improved_collate_fn(batch, use_visual)
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0,
        collate_fn=lambda batch: improved_collate_fn(batch, use_visual)
    )
    
    return train_loader, val_loader, test_loader

class LayoutLMv3DocumentClassifier:
    def __init__(self, num_classes=18, model_name="microsoft/layoutlmv3-base", use_visual=True):
        self.num_classes = num_classes
        self.model_name = model_name
        self.use_visual = use_visual
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Initializing model on {self.device}")
        
        # Initialize processor and model
        try:
            self.processor = LayoutLMv3Processor.from_pretrained(model_name)
            
            # Configure model for classification
            config = LayoutLMv3Config.from_pretrained(model_name)
            config.num_labels = num_classes
            
            self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
                model_name, 
                config=config,
                ignore_mismatched_sizes=True  # Handle size mismatches
            ).to(self.device)
            
            print(f"Model loaded successfully")
            print(f"Number of parameters: {sum(p.numel() for p in self.model.parameters()):,}")
            
        except Exception as e:
            print(f"Error initializing model: {e}")
            raise
    
    def train(self, train_loader, val_loader, epochs=10, learning_rate=2e-5, warmup_steps=500, patience=3):
        """Train the model with improved error handling"""
        
        # Setup optimizer and scheduler
        optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=learning_rate, 
            weight_decay=0.01,
            eps=1e-8
        )
        
        total_steps = len(train_loader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training tracking
        train_losses = []
        val_losses = []
        train_accs = []
        val_accs = []
        
        best_val_acc = 0
        wait = 0
        
        print("Starting training...")
        
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            total_train_loss = 0
            train_predictions = []
            train_labels = []
            
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
            
            for batch_idx, batch in enumerate(progress_bar):
                try:
                    # Move batch to device
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                            for k, v in batch.items()}
                    
                    optimizer.zero_grad()
                    
                    # Forward pass
                    outputs = self.model(**batch)
                    loss = outputs.loss
                    
                    # Check for NaN loss
                    if torch.isnan(loss):
                        print(f"NaN loss detected at batch {batch_idx}, skipping...")
                        continue
                    
                    # Backward pass
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    
                    total_train_loss += loss.item()
                    
                    # Get predictions
                    preds = torch.argmax(outputs.logits, dim=-1)
                    train_predictions.extend(preds.cpu().numpy())
                    train_labels.extend(batch['labels'].cpu().numpy())
                    
                    progress_bar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                    })
                    
                except Exception as e:
                    print(f"Error in training batch {batch_idx}: {e}")
                    continue
            
            if len(train_loader) == 0:
                print("No valid batches in training loader!")
                break
                
            avg_train_loss = total_train_loss / len(train_loader)
            train_acc = accuracy_score(train_labels, train_predictions) if train_predictions else 0
            
            # Validation phase
            try:
                val_loss, val_acc = self.evaluate(val_loader)
            except Exception as e:
                print(f"Validation error: {e}")
                val_loss, val_acc = float('inf'), 0
            
            # Store metrics
            train_losses.append(avg_train_loss)
            val_losses.append(val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            print("-" * 50)
            
            # Early stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                try:
                    self.save_model("best_layoutlmv3_document_classifier_17types.pt")
                except Exception as e:
                    print(f"Error saving model: {e}")
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered")
                    break
        
        # Plot training curves
        try:
            self.plot_training_curves(train_losses, val_losses, train_accs, val_accs)
        except Exception as e:
            print(f"Error plotting training curves: {e}")
        
        return train_losses, val_losses, train_accs, val_accs
    
    def evaluate(self, dataloader):
        """Evaluate the model with improved error handling"""
        self.model.eval()
        total_loss = 0
        predictions = []
        labels = []
        valid_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating", leave=False):
                try:
                    batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                            for k, v in batch.items()}
                    
                    outputs = self.model(**batch)
                    loss = outputs.loss
                    
                    if not torch.isnan(loss):
                        total_loss += loss.item()
                        valid_batches += 1
                        
                        preds = torch.argmax(outputs.logits, dim=-1)
                        predictions.extend(preds.cpu().numpy())
                        labels.extend(batch['labels'].cpu().numpy())
                    
                except Exception as e:
                    print(f"Error in evaluation batch: {e}")
                    continue
        
        if valid_batches == 0:
            return float('inf'), 0
            
        avg_loss = total_loss / valid_batches
        accuracy = accuracy_score(labels, predictions) if predictions else 0
        
        return avg_loss, accuracy
    
    def save_model(self, path):
        """Save model with error handling"""
        try:
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'num_classes': self.num_classes,
                'model_name': self.model_name,
                'use_visual': self.use_visual
            }, path)
            print(f"Model saved to {path}")
        except Exception as e:
            print(f"Error saving model: {e}")
    
    def load_model(self, path):
        """Load model with error handling"""
        try:
            checkpoint = torch.load(path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Model loaded from {path}")
        except Exception as e:
            print(f"Error loading model: {e}")
    
    def plot_training_curves(self, train_losses, val_losses, train_accs, val_accs):
        """Plot training curves"""
        try:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Plot losses
            ax1.plot(train_losses, label='Train Loss')
            ax1.plot(val_losses, label='Val Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend()
            ax1.grid(True)
            
            # Plot accuracies
            ax2.plot(train_accs, label='Train Accuracy')
            ax2.plot(val_accs, label='Val Accuracy')
            ax2.set_title('Training and Validation Accuracy')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Accuracy')
            ax2.legend()
            ax2.grid(True)
            
            plt.tight_layout()
            plt.show()
        except Exception as e:
            print(f"Error plotting curves: {e}")

# Main execution
def main():
    """Main function to run the training"""
    try:
        # Initialize model
        classifier = LayoutLMv3DocumentClassifier(
            num_classes=len(CONFIG['DOC_TYPES']),
            model_name=CONFIG['MODEL_NAME'],
            use_visual=CONFIG['USE_VISUAL']
        )
        
        # Create dataloaders
        train_loader, val_loader, test_loader = create_dataloaders(
            CONFIG['ROOT_DIR'],
            CONFIG['DOC_TYPES'],
            classifier.processor,
            CONFIG['BATCH_SIZE'],
            CONFIG['MAX_LENGTH'],
            CONFIG['USE_VISUAL']
        )
        
        print("Testing data loading...")
        # Test one batch
        for batch in train_loader:
            print(f"Batch keys: {batch.keys()}")
            for key, value in batch.items():
                if isinstance(value, torch.Tensor):
                    print(f"{key}: {value.shape}")
                else:
                    print(f"{key}: {type(value)}")
            break
        
        # Start training
        train_losses, val_losses, train_accs, val_accs = classifier.train(
            train_loader, val_loader,
            epochs=CONFIG['EPOCHS'],
            learning_rate=CONFIG['LEARNING_RATE'],
            warmup_steps=CONFIG['WARMUP_STEPS'],
            patience=CONFIG['PATIENCE']
        )
        
        print("Training completed successfully!")
        
    except Exception as e:
        print(f"Error in main execution: {e}")
        import traceback
        traceback.print_exc()

# Uncomment to run
if __name__ == "__main__":
    main()