In [None]:
import os
import torch
import numpy as np
from PIL import Image
from torch import nn
from torch.utils.data import Dataset as TorchDataset, DataLoader
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
import easyocr
import gc
import pickle
from tqdm import tqdm
import argparse
from datetime import datetime
import json
import tempfile

# Check transformers version and import appropriate LayoutLM model
try:
    import transformers
    print(f"Transformers version: {transformers.__version__}")
    
    # Try to import LayoutLMv3 first (requires transformers >= 4.21.0)
    try:
        from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
        MODEL_NAME = "microsoft/layoutlmv3-base"
        PROCESSOR_CLASS = LayoutLMv3Processor
        MODEL_CLASS = LayoutLMv3ForSequenceClassification
        print("✅ Using LayoutLMv3")
    except ImportError:
        print("❌ LayoutLMv3 not available, trying LayoutLMv2...")
        try:
            from transformers import LayoutLMv2Processor, LayoutLMv2ForSequenceClassification
            MODEL_NAME = "microsoft/layoutlmv2-base-uncased"
            PROCESSOR_CLASS = LayoutLMv2Processor
            MODEL_CLASS = LayoutLMv2ForSequenceClassification
            print("✅ Using LayoutLMv2")
        except ImportError:
            print("❌ Neither LayoutLMv3 nor LayoutLMv2 available")
            print("📝 Please upgrade transformers: pip install transformers>=4.21.0")
            raise ImportError("LayoutLM models not available. Please upgrade transformers library.")
            
except ImportError as e:
    print(f"❌ Transformers library not found: {e}")
    print("📝 Please install transformers: pip install transformers>=4.21.0")
    raise

# ---------------------- CONFIG ----------------------
DATA_DIR = "/home/hasan/datasets/classify/test"
MAX_LENGTH = 128
BATCH_SIZE = 2  # Reduced batch size
EPOCHS = 3
LEARNING_RATE = 2e-5
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_GPU else "cpu")
CACHE_DIR = "./cache"  # Directory to cache processed samples
PROCESS_BATCH_SIZE = 50  # Process samples in smaller batches
SAVE_DIR = "./models"  # Directory to save trained models

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"🔧 Configuration:")
print(f"   Model: {MODEL_NAME}")
print(f"   Device: {DEVICE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Max length: {MAX_LENGTH}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Epochs: {EPOCHS}")

# ---------------------- INIT ----------------------
try:
    processor = PROCESSOR_CLASS.from_pretrained(MODEL_NAME, apply_ocr=False)
    print("✅ Processor loaded successfully")
except Exception as e:
    print(f"❌ Failed to load processor: {e}")
    raise

# Initialize EasyOCR with limited GPU memory
try:
    reader = easyocr.Reader(['en'], gpu=USE_GPU, model_storage_directory='./easyocr_models')
    print("✅ EasyOCR initialized successfully")
except Exception as e:
    print(f"❌ Failed to initialize EasyOCR: {e}")
    print("📝 Please install EasyOCR: pip install easyocr")
    raise

# ---------------------- MEMORY MANAGEMENT ----------------------
def clear_memory():
    """Clear GPU and system memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def get_memory_usage():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3  # GB
    return 0

# ---------------------- LOAD DATA ----------------------
def load_data(data_dir):
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Data directory not found: {data_dir}")
    
    samples = []
    label_map = {}
    label_id = 0

    for label_name in sorted(os.listdir(data_dir)):
        label_path = os.path.join(data_dir, label_name)
        if not os.path.isdir(label_path):
            continue

        if label_name not in label_map:
            label_map[label_name] = label_id
            label_id += 1

        for file in os.listdir(label_path):
            if file.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                samples.append({
                    "image_path": os.path.join(label_path, file),
                    "label": label_map[label_name],
                    "filename": file
                })

    print(f"📊 Dataset Summary:")
    print(f"   Unique labels: {sorted(label_map.items(), key=lambda x: x[1])}")
    print(f"   Total samples: {len(samples)}")
    
    if len(samples) == 0:
        raise ValueError("No image files found in the dataset directory")
    
    return samples, label_map

# ---------------------- ROBUST CACHE MANAGEMENT ----------------------
def safe_save_cache(data, filepath):
    """Safely save data to cache with atomic write"""
    temp_file = None
    try:
        # Use temporary file for atomic write
        temp_dir = os.path.dirname(filepath)
        with tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, suffix='.tmp') as f:
            temp_file = f.name
            pickle.dump(data, f)
        
        # Atomic move
        os.rename(temp_file, filepath)
        return True
    except Exception as e:
        print(f"Failed to save cache {filepath}: {e}")
        if temp_file and os.path.exists(temp_file):
            os.unlink(temp_file)
        return False

def safe_load_cache(filepath):
    """Safely load cache with validation"""
    if not os.path.exists(filepath):
        return None
    
    try:
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        return data
    except (EOFError, pickle.UnpicklingError, Exception) as e:
        print(f"Cache file corrupted: {filepath}, error: {e}")
        print("Deleting corrupted cache file...")
        try:
            os.unlink(filepath)
        except:
            pass
        return None

# ---------------------- PROCESS SAMPLE ----------------------
def process_single(sample, max_retries=2):
    """Process a single sample with error handling and retries"""
    for attempt in range(max_retries + 1):
        try:
            image_path = sample["image_path"]
            
            # Load and resize image if too large
            image = Image.open(image_path).convert("RGB")
            
            # Resize very large images to prevent memory issues
            max_size = 2048
            if max(image.size) > max_size:
                ratio = max_size / max(image.size)
                new_size = tuple(int(dim * ratio) for dim in image.size)
                image = image.resize(new_size, Image.Resampling.LANCZOS)
            
            img_np = np.array(image)
            
            # OCR with error handling
            try:
                results = reader.readtext(img_np)
            except Exception as ocr_error:
                print(f"OCR failed for {sample['filename']}: {ocr_error}")
                results = []
            
            words, boxes = [], []
            
            for box, text, conf in results:
                if conf > 0.5 and text.strip():
                    words.append(text.strip())
                    x0 = min(pt[0] for pt in box)
                    y0 = min(pt[1] for pt in box)
                    x1 = max(pt[0] for pt in box)
                    y1 = max(pt[1] for pt in box)
                    boxes.append([int(x0), int(y0), int(x1), int(y1)])

            # Fallback for images with no text
            if not words:
                words = ["[EMPTY]"]
                boxes = [[0, 0, 50, 20]]

            # Limit number of words to prevent memory issues
            if len(words) > 100:
                words = words[:100]
                boxes = boxes[:100]

            # Process with the appropriate processor
            encoded = processor(
                images=image,
                text=words,
                boxes=boxes,
                truncation=True,
                padding="max_length",
                max_length=MAX_LENGTH,
                return_tensors="pt"
            )

            # Process encoded data
            result = {}
            for k, v in encoded.items():
                if isinstance(v, torch.Tensor):
                    result[k] = v.squeeze(0)
                else:
                    result[k] = v

            result["labels"] = torch.tensor(sample["label"], dtype=torch.long)
            
            # Clean up
            del image, img_np, encoded
            clear_memory()
            
            return result

        except Exception as e:
            print(f"Attempt {attempt + 1} failed for {sample['filename']}: {e}")
            clear_memory()
            
            if attempt == max_retries:
                # Return dummy sample as last resort
                print(f"Creating dummy sample for {sample['filename']}")
                dummy_image = Image.new("RGB", (224, 224), color="white")
                dummy = processor(
                    images=dummy_image,
                    text=["[ERROR]"],
                    boxes=[[0, 0, 50, 20]],
                    truncation=True,
                    padding="max_length",
                    max_length=MAX_LENGTH,
                    return_tensors="pt"
                )
                dummy = {k: v.squeeze(0) for k, v in dummy.items()}
                dummy["labels"] = torch.tensor(sample["label"], dtype=torch.long)
                return dummy

# ---------------------- IMPROVED BATCH PROCESSING ----------------------
def process_samples_in_batches(samples, batch_size=PROCESS_BATCH_SIZE):
    """Process samples in batches to manage memory with robust caching"""
    processed_samples = []
    
    # Check for existing cache chunks
    cache_pattern = os.path.join(CACHE_DIR, "chunk_*.pkl")
    import glob
    existing_chunks = sorted(glob.glob(cache_pattern))
    
    if existing_chunks:
        print(f"📦 Found {len(existing_chunks)} existing cache chunks")
        print("🔄 Loading from cache chunks...")
        
        for chunk_file in tqdm(existing_chunks, desc="Loading cache chunks"):
            chunk_data = safe_load_cache(chunk_file)
            if chunk_data:
                processed_samples.extend(chunk_data)
            else:
                print(f"Skipping corrupted chunk: {os.path.basename(chunk_file)}")
        
        print(f"✅ Loaded {len(processed_samples)} samples from cache")
        
        # Verify we have all samples
        if len(processed_samples) >= len(samples):
            return processed_samples[:len(samples)]
        else:
            print(f"⚠️  Cache incomplete ({len(processed_samples)}/{len(samples)}), processing remaining...")
            remaining_samples = samples[len(processed_samples):]
            start_chunk = len(existing_chunks)
    else:
        print(f"🔄 Processing {len(samples)} samples in batches of {batch_size}...")
        remaining_samples = samples
        start_chunk = 0
    
    # Process remaining samples
    if 'remaining_samples' in locals() and remaining_samples:
        for i in tqdm(range(0, len(remaining_samples), batch_size), desc="Processing batches"):
            batch = remaining_samples[i:i + batch_size]
            batch_processed = []
            
            for j, sample in enumerate(batch):
                try:
                    processed = process_single(sample)
                    batch_processed.append(processed)
                    
                    # Progress update
                    current_idx = len(processed_samples) + j + 1
                    if current_idx % 100 == 0:
                        mem_usage = get_memory_usage()
                        print(f"Processed {current_idx}/{len(samples)} samples. GPU memory: {mem_usage:.2f}GB")
                    
                except Exception as e:
                    print(f"Failed to process sample {len(processed_samples) + j}: {e}")
                    continue
            
            # Save chunk immediately
            chunk_num = start_chunk + (i // batch_size)
            chunk_file = os.path.join(CACHE_DIR, f"chunk_{chunk_num:04d}.pkl")
            
            if safe_save_cache(batch_processed, chunk_file):
                print(f"💾 Saved chunk {chunk_num} with {len(batch_processed)} samples")
            else:
                print(f"❌ Failed to save chunk {chunk_num}")
            
            processed_samples.extend(batch_processed)
            
            # Clear memory after each batch
            clear_memory()
    
    return processed_samples

# ---------------------- CUSTOM DATASET ----------------------
class MemoryEfficientDataset(TorchDataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        # Ensure proper tensor format
        result = {}
        for k, v in sample.items():
            if isinstance(v, torch.Tensor):
                result[k] = v.clone().detach()
            else:
                result[k] = torch.tensor(v)
        return result

# ---------------------- TRAIN & EVAL FUNCTIONS ----------------------
def train_epoch(model, dataloader, optimizer, epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch}")
    
    for batch_idx, batch in enumerate(progress_bar):
        try:
            # Move to device
            for k in batch:
                batch[k] = batch[k].to(DEVICE, non_blocking=True)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Skipping batch {batch_idx} due to invalid loss: {loss}")
                continue
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Clear cache periodically
            if batch_idx % 20 == 0:
                clear_memory()
                
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"GPU OOM in batch {batch_idx}, clearing cache...")
                clear_memory()
                continue
            else:
                print(f"Runtime error in batch {batch_idx}: {e}")
                continue
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    avg_loss = total_loss / max(num_batches, 1)
    print(f"Average training loss: {avg_loss:.4f}")
    return avg_loss

def evaluate_model(model, dataloader, label_map, title="Validation"):
    model.eval()
    preds, trues = [], []
    
    progress_bar = tqdm(dataloader, desc=f"Evaluating {title}")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            try:
                labels = batch['labels'].cpu().numpy()
                
                for k in batch:
                    batch[k] = batch[k].to(DEVICE, non_blocking=True)
                
                outputs = model(**batch)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1).cpu().numpy()
                
                preds.extend(predictions)
                trues.extend(labels)
                
                if batch_idx % 10 == 0:
                    clear_memory()
                    
            except Exception as e:
                print(f"Error in evaluation batch {batch_idx}: {e}")
                continue

    if len(preds) == 0:
        print(f"No valid predictions for {title}")
        return

    # Calculate metrics
    accuracy = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='weighted', zero_division=0)
    
    print(f"\n📊 {title} Metrics:")
    print(f"   Accuracy: {accuracy:.4f}")
    print(f"   F1-Score (weighted): {f1:.4f}")
    
    print(f"\n📋 Classification Report ({title}):")
    print(classification_report(trues, preds, target_names=list(label_map.keys()), zero_division=0))

    # Create confusion matrix plot
    cm = confusion_matrix(trues, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", 
                xticklabels=list(label_map.keys()), 
                yticklabels=list(label_map.keys()), 
                cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title})")
    plt.tight_layout()
    
    # Save confusion matrix
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    plt.savefig(f"confusion_matrix_{title.lower()}_{timestamp}.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()  # Important: close the figure to free memory
    
    return accuracy, f1

def split_data(samples, train_ratio=0.8, val_ratio=0.1):
    """Split data into train, validation, and test sets"""
    np.random.shuffle(samples)
    n = len(samples)
    
    train_end = int(n * train_ratio)
    val_end = int(n * (train_ratio + val_ratio))
    
    train_samples = samples[:train_end]
    val_samples = samples[train_end:val_end]
    test_samples = samples[val_end:]
    
    print(f"📊 Data Split:")
    print(f"   Training: {len(train_samples)} samples")
    print(f"   Validation: {len(val_samples)} samples")
    print(f"   Test: {len(test_samples)} samples")
    
    return train_samples, val_samples, test_samples

def save_model_and_results(model, label_map, results, timestamp):
    """Save model and training results"""
    try:
        # Save model
        model_path = os.path.join(SAVE_DIR, f"layoutlm_model_{timestamp}")
        model.save_pretrained(model_path)
        
        # Save processor
        processor.save_pretrained(model_path)
        
        # Save label map and results
        metadata = {
            'label_map': label_map,
            'results': results,
            'model_name': MODEL_NAME,
            'config': {
                'max_length': MAX_LENGTH,
                'batch_size': BATCH_SIZE,
                'epochs': EPOCHS,
                'learning_rate': LEARNING_RATE
            }
        }
        
        metadata_path = os.path.join(model_path, 'metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2, default=str)
        
        print(f"💾 Model and metadata saved to: {model_path}")
        
    except Exception as e:
        print(f"❌ Failed to save model: {e}")

# ---------------------- MAIN FUNCTION ----------------------
def main(data_dir=None, epochs=None, batch_size=None, learning_rate=None, no_cache=False):
    """
    Main training function that can be called directly or via command line
    """
    # Use provided parameters or fall back to globals
    data_dir = data_dir or DATA_DIR
    epochs = epochs or EPOCHS
    batch_size = batch_size or BATCH_SIZE
    learning_rate = learning_rate or LEARNING_RATE
    
    print(f"🚀 Starting LayoutLM Document Classification Training")
    print(f"   Data directory: {data_dir}")
    print(f"   Epochs: {epochs}")
    print(f"   Batch size: {batch_size}")
    print(f"   Learning rate: {learning_rate}")
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Load data
    print("\n📂 Loading dataset...")
    samples, label_map = load_data(data_dir)
    
    # Split data
    train_samples, val_samples, test_samples = split_data(samples)
    
    # Process samples with improved caching
    if no_cache:
        print("🗑️  Clearing cache as requested...")
        import glob
        cache_files = glob.glob(os.path.join(CACHE_DIR, "chunk_*.pkl"))
        for cache_file in cache_files:
            try:
                os.unlink(cache_file)
            except:
                pass
    
    print("🔄 Processing samples with robust caching...")
    all_processed = process_samples_in_batches(samples)
    
    # Verify processing completed
    if len(all_processed) != len(samples):
        print(f"⚠️  Warning: Only {len(all_processed)}/{len(samples)} samples processed successfully")
    
    # Split processed samples according to original split
    train_processed = all_processed[:len(train_samples)]
    val_processed = all_processed[len(train_samples):len(train_samples)+len(val_samples)]
    test_processed = all_processed[len(train_samples)+len(val_samples):]
    
    # Create datasets and dataloaders
    print("\n🔧 Creating datasets and dataloaders...")
    train_dataset = MemoryEfficientDataset(train_processed)
    val_dataset = MemoryEfficientDataset(val_processed)
    test_dataset = MemoryEfficientDataset(test_processed)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=0, pin_memory=True if USE_GPU else False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                           num_workers=0, pin_memory=True if USE_GPU else False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=0, pin_memory=True if USE_GPU else False)
    
    # Initialize model
    print("\n🤖 Initializing model...")
    num_labels = len(label_map)
    model = MODEL_CLASS.from_pretrained(MODEL_NAME, num_labels=num_labels)
    model.to(DEVICE)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Training loop
    print(f"\n🎯 Starting training...")
    training_losses = []
    validation_accuracies = []
    validation_f1_scores = []
    
    best_f1 = 0.0
    best_model_state = None
    
    for epoch in range(1, epochs + 1):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch}/{epochs}")
        print(f"{'='*50}")
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, epoch)
        training_losses.append(train_loss)
        
        # Validate
        if len(val_processed) > 0:
            val_acc, val_f1 = evaluate_model(model, val_loader, label_map, f"Validation Epoch {epoch}")
            validation_accuracies.append(val_acc)
            validation_f1_scores.append(val_f1)
            
            # Save best model
            if val_f1 > best_f1:
                best_f1 = val_f1
                best_model_state = model.state_dict().copy()
                print(f"🏆 New best model! F1: {best_f1:.4f}")
        
        # Memory cleanup
        clear_memory()
    
    # Load best model for final evaluation
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\n🏆 Loaded best model (F1: {best_f1:.4f}) for final evaluation")
    
    # Final evaluation on test set
    if len(test_processed) > 0:
        print(f"\n{'='*50}")
        print("FINAL TEST EVALUATION")
        print(f"{'='*50}")
        test_acc, test_f1 = evaluate_model(model, test_loader, label_map, "Test Set")
    else:
        test_acc, test_f1 = 0.0, 0.0
    
    # Create training plots
    if len(training_losses) > 1:
        plt.figure(figsize=(12, 4))
        
        # Loss plot
        plt.subplot(1, 2, 1)
        plt.plot(range(1, len(training_losses) + 1), training_losses, 'b-', label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()
        plt.grid(True)
        
        # Metrics plot
        if validation_accuracies:
            plt.subplot(1, 2, 2)
            plt.plot(range(1, len(validation_accuracies) + 1), validation_accuracies, 'g-', label='Validation Accuracy')
            plt.plot(range(1, len(validation_f1_scores) + 1), validation_f1_scores, 'r-', label='Validation F1')
            plt.xlabel('Epoch')
            plt.ylabel('Score')
            plt.title('Validation Metrics')
            plt.legend()
            plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f"training_progress_{timestamp}.png", dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()
    
    # Save results
    results = {
        'training_losses': training_losses,
        'validation_accuracies': validation_accuracies,
        'validation_f1_scores': validation_f1_scores,
        'test_accuracy': test_acc,
        'test_f1': test_f1,
        'best_validation_f1': best_f1,
        'label_map': label_map,
        'timestamp': timestamp
    }
    
    save_model_and_results(model, label_map, results, timestamp)
    
    print(f"\n🎉 Training completed!")
    print(f"   Best validation F1: {best_f1:.4f}")
    print(f"   Test accuracy: {test_acc:.4f}")
    print(f"   Test F1: {test_f1:.4f}")
    print(f"   Model and results saved with timestamp: {timestamp}")

def main_with_args():
    """Command line version with argparse"""
    DEFAULT_DATA_DIR = DATA_DIR
    DEFAULT_EPOCHS = EPOCHS
    DEFAULT_BATCH_SIZE = BATCH_SIZE
    DEFAULT_LEARNING_RATE = LEARNING_RATE
    
    parser = argparse.ArgumentParser(description='Train LayoutLM for document classification')
    parser.add_argument('--data_dir', type=str, default=DEFAULT_DATA_DIR, help='Path to dataset directory')
    parser.add_argument('--epochs', type=int, default=DEFAULT_EPOCHS, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=DEFAULT_BATCH_SIZE, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=DEFAULT_LEARNING_RATE, help='Learning rate')
    parser.add_argument('--no_cache', action='store_true', help='Skip using cached processed samples')
    
    args = parser.parse_args()
    
    main(args.data_dir, args.epochs, args.batch_size, args.learning_rate, args.no_cache)

# For command line usage
if __name__ == "__main__":
    # Check if we're in a Jupyter notebook
    try:
        # This will be True if running in Jupyter
        get_ipython()
        print("🔍 Detected Jupyter environment - running with default parameters")
        print("💡 To use custom parameters, call: main(epochs=5, batch_size=8, etc.)")
        main()  # Run with defaults in Jupyter
    except NameError:
        # We're in a regular Python script
        main_with_args()