# RepText Arabic Training - Quick Start Guide

This notebook demonstrates how to pretrain RepText for Arabic text generation.

## Prerequisites
- NVIDIA GPU with at least 24GB VRAM
- Python 3.8+
- CUDA 11.7+

## Step 1: Install Dependencies

In [None]:
!pip install -r requirements.txt

## Step 2: Download Arabic Fonts

In [None]:
# Download recommended Arabic fonts from Google Fonts
!python download_arabic_fonts.py

In [None]:
# Verify fonts were downloaded
import os
fonts = [f for f in os.listdir('arabic_fonts') if f.endswith(('.ttf', '.otf'))]
print(f"Found {len(fonts)} fonts:")
for font in fonts:
    print(f"  - {font}")

## Step 3: Test Font Rendering

In [None]:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

# Test Arabic text rendering
test_text = "ŸÖÿ±ÿ≠ÿ®ÿß ÿ®ŸÉŸÖ ŸÅŸä RepText"
font_path = "./arabic_fonts/Amiri-Regular.ttf"

img = Image.new('RGB', (600, 150), color='white')
draw = ImageDraw.Draw(img)
font = ImageFont.truetype(font_path, 60)
draw.text((50, 40), test_text, font=font, fill='black')

plt.figure(figsize=(10, 3))
plt.imshow(img)
plt.axis('off')
plt.title('Arabic Text Rendering Test')
plt.show()

print("‚úì Arabic text rendering works correctly!")

## Step 4: Prepare Training Dataset

This will generate synthetic training samples with:
- Glyph images (rendered Arabic text)
- Position maps (location heatmaps)
- Binary masks (text regions)
- Canny edges

In [None]:
# For quick testing, use a small number of samples
# For actual training, use 10000-50000
NUM_SAMPLES = 100  # Change to 10000 for real training

!python prepare_arabic_dataset.py \
    --output_dir ./arabic_training_data \
    --font_dir ./arabic_fonts \
    --text_file ./arabic_texts.txt \
    --num_samples {NUM_SAMPLES} \
    --width 1024 \
    --height 1024 \
    --min_font_size 60 \
    --max_font_size 120

## Step 5: Visualize Training Samples

In [None]:
import json
from PIL import Image
import matplotlib.pyplot as plt

# Load a sample
sample_dir = "./arabic_training_data/sample_000000"

# Load images
glyph = Image.open(f"{sample_dir}/glyph.png")
position = Image.open(f"{sample_dir}/position.png")
mask = Image.open(f"{sample_dir}/mask.png")
canny = Image.open(f"{sample_dir}/canny.png")

# Load metadata
with open(f"{sample_dir}/metadata.json", 'r', encoding='utf-8') as f:
    metadata = json.load(f)

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

axes[0, 0].imshow(glyph)
axes[0, 0].set_title(f"Glyph: {metadata['text']}")
axes[0, 0].axis('off')

axes[0, 1].imshow(position)
axes[0, 1].set_title("Position Map")
axes[0, 1].axis('off')

axes[1, 0].imshow(mask, cmap='gray')
axes[1, 0].set_title("Text Mask")
axes[1, 0].axis('off')

axes[1, 1].imshow(canny)
axes[1, 1].set_title("Canny Edges")
axes[1, 1].axis('off')

plt.suptitle(f"Training Sample - Font Size: {metadata['font_size']}px")
plt.tight_layout()
plt.show()

## Step 6: Test Dataset Loading

In [None]:
from arabic_dataset import create_dataloaders

# Create dataloaders
train_loader, val_loader = create_dataloaders(
    data_dir='./arabic_training_data',
    batch_size=2,
    num_workers=0,
    image_size=(1024, 1024),
    train_ratio=0.9
)

# Test loading a batch
batch = next(iter(train_loader))

print("Batch contents:")
print(f"  Glyph shape: {batch['glyph'].shape}")
print(f"  Position shape: {batch['position'].shape}")
print(f"  Mask shape: {batch['mask'].shape}")
print(f"  Canny shape: {batch['canny'].shape}")
print(f"  Text samples: {batch['text']}")
print(f"  Font sizes: {batch['font_size']}")
print("\n‚úì Dataset loading works correctly!")

## Step 7: Configure Accelerate for Training

In [None]:
!accelerate config

## Step 8: Review Training Configuration

In [None]:
import yaml

with open('train_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Training Configuration:")
print(yaml.dump(config, default_flow_style=False))

## Step 9: Launch Training

**Note:** Training takes a long time. You may want to run this in a terminal instead of notebook.

For terminal:
```bash
accelerate launch train_arabic.py --config train_config.yaml
```

Or use the automated script:
```bash
./train_arabic.sh
```

In [None]:
# For notebook training (not recommended for long training runs)
# Uncomment to run:
# !accelerate launch train_arabic.py --config train_config.yaml

print("It's recommended to run training in a terminal or tmux session:")
print("  accelerate launch train_arabic.py --config train_config.yaml")

## Step 9.5: Monitor Training Progress (RunPod)

‚è≥ **Training Status: IN PROGRESS (Epoch 2/100)**

Your training started successfully! You can see it processing batches per epoch in the terminal.

**Understanding the Progress:**
- Progress shows: `23/90 [00:33<01:37...]` = 23 samples processed out of 90 training samples
- This means each epoch processes the full 90 training samples
- After all 90 batches complete, the next epoch begins
- **Total:** 100 epochs to complete (configured in `train_config.yaml`)
- **Time per epoch:** ~1.5 minutes
- **Total training time:** ~2.5 hours

**Key Metrics to Watch:**
```
Epoch 2: 26% | 23/90 [00:33<01:37, 1.45s/it, diffusion_loss=0, loss=0, lr=4e-8]
 ‚Üì       ‚Üì    ‚Üì  ‚Üì                                    ‚Üì
Epoch   Progress Batch count      Time metrics        Learning rate
        of epoch
```

**Monitor Training:**

Run in your RunPod terminal:
```bash
# Check if training is still running
ps aux | grep train_arabic.py

# Check checkpoint directory growth (as training progresses)
watch -n 30 "ls -lh output/arabic_reptext/ && echo '---' && du -sh output/arabic_reptext/"

# Check if intermediate checkpoints are being saved
ls -lh output/arabic_reptext/checkpoint-*/
```

**When Training Completes:**
- Final checkpoint saved to: `output/arabic_reptext/final_model/`
- Should contain:
  - `config.json` (with `in_channels: 64`)
  - `diffusion_pytorch_model.safetensors` (~5GB)
- This checkpoint will be **100% compatible** with inference ‚úÖ

---


## Step 10: Monitor Training (Optional - with W&B)

In [None]:
# Install wandb
# !pip install wandb
# !wandb login

# Then run with --use_wandb flag:
# !accelerate launch train_arabic.py --config train_config.yaml --use_wandb

## Step 11: Test Your Trained Model

After training completes, test the model:

In [None]:
# Run training and generate images after each epoch
import subprocess
import time
import os

test_text = "ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ Ÿàÿ±ÿ≠ŸÖÿ© ÿßŸÑŸÑŸá Ÿàÿ®ÿ±ŸÉÿßÿ™Ÿá"

# Start training in background
print("üöÄ Starting fine-tuning training...")
print(f"üìä Test text for tracking: {test_text}")
print("="*60)

# Run one epoch at a time and generate images
for epoch in range(1, 11):  # 10 epochs
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch}/10")
    print(f"{'='*60}\n")
    
    # Update config to train for 1 epoch
    with open("train_config.yaml", 'r') as f:
        config = yaml.safe_load(f)
    
    config['training']['num_epochs'] = 1
    
    # Set checkpoint to resume from if not first epoch
    if epoch > 1:
        config['model']['pretrained_controlnet'] = f"./output/arabic_reptext/epoch_{epoch-1}"
    else:
        config['model']['pretrained_controlnet'] = "Shakker-Labs/RepText"
    
    with open("train_config.yaml", 'w') as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
    
    # Run training for this epoch
    try:
        result = subprocess.run(
            ["accelerate", "launch", "train_arabic.py", "--config", "train_config.yaml"],
            capture_output=True,
            text=True,
            timeout=3600  # 1 hour timeout per epoch
        )
        
        if result.returncode != 0:
            print(f"‚ùå Training failed at epoch {epoch}")
            print(result.stderr)
            break
        
        print(f"‚úÖ Epoch {epoch} training complete")
        
    except subprocess.TimeoutExpired:
        print(f"‚ö†Ô∏è Epoch {epoch} timed out")
        break
    except Exception as e:
        print(f"‚ùå Error during epoch {epoch}: {e}")
        break
    
    # Generate test image after this epoch
    print(f"\nüì∏ Generating test image after epoch {epoch}...")
    
    try:
        checkpoint_path = f"./output/arabic_reptext/epoch_{epoch}"
        if not os.path.exists(checkpoint_path):
            # Use final_model if epoch checkpoint doesn't exist
            checkpoint_path = "./output/arabic_reptext/final_model"
        
        if os.path.exists(checkpoint_path):
            generate_test_image(checkpoint_path, f"epoch_{epoch}", test_text)
        else:
            print(f"‚ö†Ô∏è Checkpoint not found: {checkpoint_path}")
    
    except Exception as e:
        print(f"‚ùå Failed to generate image after epoch {epoch}: {e}")
        import traceback
        traceback.print_exc()
    
    # Small delay between epochs
    time.sleep(5)

print("\n" + "="*60)
print("üéâ TRAINING COMPLETE!")
print("="*60)
print("\nGenerated images saved in: ./training_progress/")
print("Model checkpoints saved in: ./output/arabic_reptext/")

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

# Collect all progress images
progress_dir = "./training_progress"
image_files = []

# Add baseline
if os.path.exists(f"{progress_dir}/baseline_original.jpg"):
    image_files.append(("Baseline\n(Original RepText)", f"{progress_dir}/baseline_original.jpg"))

# Add all epochs
for epoch in range(1, 11):
    img_path = f"{progress_dir}/epoch_{epoch}.jpg"
    if os.path.exists(img_path):
        image_files.append((f"Epoch {epoch}", img_path))

# Create grid visualization
if len(image_files) > 0:
    # Calculate grid size
    n_images = len(image_files)
    n_cols = 3
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6 * n_rows))
    axes = axes.flatten() if n_images > 1 else [axes]
    
    for idx, (title, img_path) in enumerate(image_files):
        img = Image.open(img_path)
        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=14, fontweight='bold')
        axes[idx].axis('off')
    
    # Hide unused subplots
    for idx in range(n_images, len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle(f'Training Progress: {test_text}', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('./training_progress/comparison_grid.jpg', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n‚úÖ Visualized {len(image_files)} images")
    print(f"üìä Comparison grid saved to: ./training_progress/comparison_grid.jpg")
else:
    print("‚ùå No progress images found. Please run training first.")

## Summary

This notebook demonstrates:

1. ‚úÖ **Baseline Generation**: Generated image with original RepText model
2. ‚úÖ **Fine-Tuning**: Trained for 10 epochs on Arabic dataset
3. ‚úÖ **Progress Tracking**: Generated test image after each epoch
4. ‚úÖ **Comparison**: Visualized improvement across training

### Key Results

- **Test Text**: `ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ Ÿàÿ±ÿ≠ŸÖÿ© ÿßŸÑŸÑŸá Ÿàÿ®ÿ±ŸÉÿßÿ™Ÿá`
- **Prompt**: `a street sign in city`
- **Total Epochs**: 10
- **Model Checkpoints**: Saved in `./output/arabic_reptext/epoch_X/`
- **Progress Images**: Saved in `./training_progress/`

### Next Steps

1. Compare baseline vs final epoch to see improvement
2. Use the best checkpoint for your Arabic text generation tasks
3. Experiment with different prompts and text
4. Fine-tune for more epochs if needed

### Files Generated

- `./output/arabic_reptext/epoch_1/` through `epoch_10/` - Model checkpoints
- `./output/arabic_reptext/final_model/` - Final trained model
- `./training_progress/baseline_original.jpg` - Original RepText output
- `./training_progress/epoch_X.jpg` - Output after each epoch
- `./training_progress/comparison_grid.jpg` - Side-by-side comparison

You can now use the fine-tuned model for Arabic text generation!

## Step 13: Visualize Training Progress

Compare all generated images across epochs to see the improvement.

In [None]:
# Create a custom training script that generates images after each epoch
training_script_content = '''
import os
import yaml
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers import DDPMScheduler, AutoencoderKL
from diffusers.optimization import get_scheduler
from controlnet_flux import FluxControlNetModel
from arabic_dataset import create_dataloaders
from train_arabic import train_one_epoch
import sys

logger = get_logger(__name__)

def main():
    # Load config
    with open("train_config.yaml", 'r') as f:
        config = yaml.safe_load(f)
    
    # Setup accelerator
    accelerator_project_config = ProjectConfiguration(
        project_dir=config['output']['output_dir'],
        logging_dir=config['output']['logging_dir']
    )
    
    accelerator = Accelerator(
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
        mixed_precision=config['training']['mixed_precision'],
        project_config=accelerator_project_config
    )
    
    set_seed(42)
    
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    os.makedirs(config['output']['output_dir'], exist_ok=True)
    
    logger.info(f"Loading base model: {config['model']['base_model']}")
    
    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        config['model']['base_model'],
        subfolder="vae",
        torch_dtype=torch.bfloat16
    )
    vae.requires_grad_(False)
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Load pretrained ControlNet
    pretrained_controlnet = config['model'].get('pretrained_controlnet')
    logger.info(f"Loading pretrained ControlNet: {pretrained_controlnet}")
    controlnet = FluxControlNetModel.from_pretrained(
        pretrained_controlnet,
        torch_dtype=torch.bfloat16
    )
    
    try:
        controlnet.enable_gradient_checkpointing()
    except Exception as e:
        logger.warning(f"Could not enable gradient checkpointing: {e}")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(
        config['model']['base_model'],
        subfolder="scheduler"
    )
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Create dataloaders
    logger.info("Creating dataloaders...")
    train_loader, val_loader = create_dataloaders(
        data_dir=config['data']['data_dir'],
        batch_size=config['data']['batch_size'],
        num_workers=config['data']['num_workers'],
        image_size=tuple(config['data']['image_size']),
        train_ratio=config['data']['train_ratio']
    )
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        controlnet.parameters(),
        lr=config['training']['learning_rate'],
        betas=(config['training']['adam_beta1'], config['training']['adam_beta2']),
        weight_decay=config['training']['adam_weight_decay'],
        eps=config['training']['adam_epsilon']
    )
    
    # Initialize LR scheduler
    lr_scheduler = get_scheduler(
        config['training']['lr_scheduler'],
        optimizer=optimizer,
        num_warmup_steps=config['training']['lr_warmup_steps'],
        num_training_steps=config['training']['num_epochs'] * len(train_loader)
    )
    
    # Prepare with accelerator
    controlnet, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare(
        controlnet, optimizer, train_loader, val_loader, lr_scheduler
    )
    
    vae.to(accelerator.device)
    
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_loader.dataset)}")
    logger.info(f"  Num Epochs = {config['training']['num_epochs']}")
    
    global_step = 0
    
    # Training loop with epoch-wise image generation
    for epoch in range(config['training']['num_epochs']):
        global_step = train_one_epoch(
            accelerator=accelerator,
            controlnet=controlnet,
            vae=vae,
            noise_scheduler=noise_scheduler,
            optimizer=optimizer,
            train_loader=train_loader,
            text_perceptual_loss=None,
            config=config,
            epoch=epoch,
            global_step=global_step
        )
        
        lr_scheduler.step()
        
        # Save checkpoint after each epoch
        if accelerator.is_main_process:
            save_path = os.path.join(config['output']['output_dir'], f"epoch_{epoch+1}")
            os.makedirs(save_path, exist_ok=True)
            
            unwrapped_controlnet = accelerator.unwrap_model(controlnet)
            unwrapped_controlnet.save_pretrained(save_path)
            
            logger.info(f"Saved epoch {epoch+1} checkpoint to {save_path}")
            
            # Generate test image after this epoch
            print(f"\\n{'='*60}")
            print(f"Generating test image after epoch {epoch+1}...")
            print(f"{'='*60}\\n")
        
        # Synchronize all processes
        accelerator.wait_for_everyone()
    
    # Save final model
    if accelerator.is_main_process:
        save_path = os.path.join(config['output']['output_dir'], "final_model")
        os.makedirs(save_path, exist_ok=True)
        
        unwrapped_controlnet = accelerator.unwrap_model(controlnet)
        unwrapped_controlnet.save_pretrained(save_path)
        
        logger.info(f"Training complete! Model saved to {save_path}")

if __name__ == '__main__':
    main()
'''

# Save custom training script
with open("train_arabic_with_tracking.py", "w") as f:
    f.write(training_script_content)

print("‚úÖ Created custom training script with epoch tracking")

In [None]:
import yaml
import subprocess
import sys

# Update config for 10 epochs with epoch-wise checkpointing
config_path = "./train_config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set training parameters
config['training']['num_epochs'] = 10
config['training']['save_steps'] = 1000  # Save during epochs
config['model']['pretrained_controlnet'] = "Shakker-Labs/RepText"

# Save updated config
with open(config_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False, allow_unicode=True)

print("‚úÖ Config updated for 10-epoch training with RepText fine-tuning")
print(f"   - Base model: {config['model']['pretrained_controlnet']}")
print(f"   - Epochs: {config['training']['num_epochs']}")
print(f"   - Output: {config['output']['output_dir']}")

## Step 12: Fine-Tune RepText on Arabic Dataset

This will:
1. Load the pretrained RepText model
2. Fine-tune for 10 epochs on your Arabic dataset
3. Generate a test image after each epoch to track progress
4. Save checkpoints after each epoch

In [None]:
import torch
from controlnet_flux import FluxControlNetModel
from pipeline_flux_controlnet import FluxControlNetPipeline
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import os
from IPython.display import display

def canny(img):
    low_threshold = 50
    high_threshold = 100
    img = cv2.Canny(img, low_threshold, high_threshold)
    img = img[:, :, None]
    img = 255 - np.concatenate([img, img, img], axis=2)
    return img

def generate_test_image(controlnet_model_path, epoch_name="baseline", test_text="ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ Ÿàÿ±ÿ≠ŸÖÿ© ÿßŸÑŸÑŸá Ÿàÿ®ÿ±ŸÉÿßÿ™Ÿá"):
    """Generate a test image to track training progress"""
    
    # Load models
    base_model = "black-forest-labs/FLUX.1-dev"
    
    controlnet = FluxControlNetModel.from_pretrained(
        controlnet_model_path, 
        torch_dtype=torch.bfloat16
    )
    pipe = FluxControlNetPipeline.from_pretrained(
        base_model, 
        controlnet=controlnet, 
        torch_dtype=torch.bfloat16
    ).to("cuda")
    
    # Set resolution
    width, height = 1024, 512
    
    # Set font
    font_path = "./arabic_fonts/Amiri-Regular.ttf"
    font_size = 80
    font = ImageFont.truetype(font_path, font_size)
    
    # Configure text
    text_list = [test_text]
    text_position_list = [(200, 200)]
    text_color_list = [(255, 255, 255)]
    
    # Set controlnet conditions
    control_image_list = []
    control_position_list = []
    control_mask_list = []
    control_glyph_all = np.zeros([height, width, 3], dtype=np.uint8)
    
    # Handle each line of text
    for text, text_position, text_color in zip(text_list, text_position_list, text_color_list):
        # Glyph image
        control_image_glyph = Image.new("RGB", (width, height), (0, 0, 0))
        draw = ImageDraw.Draw(control_image_glyph)
        draw.text(text_position, text, font=font, fill=text_color)
        
        # Get bbox
        bbox = draw.textbbox(text_position, text, font=font)
        
        # Position condition
        control_position = np.zeros([height, width], dtype=np.uint8)
        control_position[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 255
        control_position = Image.fromarray(control_position.astype(np.uint8))
        control_position_list.append(control_position)
        
        # Regional mask
        control_mask_np = np.zeros([height, width], dtype=np.uint8)
        control_mask_np[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255
        control_mask = Image.fromarray(control_mask_np.astype(np.uint8))
        control_mask_list.append(control_mask)
        
        # Accumulate glyph
        control_glyph = np.array(control_image_glyph)
        control_glyph_all += control_glyph
        
        # Canny condition
        control_image = canny(cv2.cvtColor(np.array(control_image_glyph), cv2.COLOR_RGB2BGR))
        control_image = Image.fromarray(cv2.cvtColor(control_image, cv2.COLOR_BGR2RGB))
        control_image_list.append(control_image)
    
    control_glyph_all = Image.fromarray(control_glyph_all.astype(np.uint8))
    control_glyph_all = control_glyph_all.convert("RGB")
    
    # Set prompt
    prompt = f"a street sign in city, '{test_text}', filmfotos, film grain, reversal film photography"
    
    # Generate
    generator = torch.Generator(device="cuda").manual_seed(42)
    
    image = pipe(
        prompt,
        control_image=control_image_list,
        control_position=control_position_list,
        control_mask=control_mask_list,
        control_glyph=control_glyph_all,
        controlnet_conditioning_scale=1.0,
        controlnet_conditioning_step=30,
        width=width,
        height=height,
        num_inference_steps=30,
        guidance_scale=3.5,
        generator=generator,
    ).images[0]
    
    # Save
    os.makedirs("./training_progress", exist_ok=True)
    output_path = f"./training_progress/{epoch_name}.jpg"
    image.save(output_path)
    
    print(f"‚úÖ Generated image: {output_path}")
    display(image)
    
    # Clean up to free memory
    del pipe, controlnet
    torch.cuda.empty_cache()
    
    return image

# Generate baseline image with original RepText
print("Generating baseline image with original RepText model...")
baseline_image = generate_test_image("Shakker-Labs/RepText", "baseline_original")

## Step 11: Generate Baseline Image (Before Training)

Test the original RepText model before fine-tuning to compare results.

In [None]:
import torch
import yaml
from pathlib import Path
from controlnet_flux import FluxControlNetModel
from pipeline_flux_controlnet import FluxControlNetPipeline
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import os

# Load training config to match inference with training architecture
config_path = "./train_config.yaml"
if not Path(config_path).exists():
    print(f"‚ùå Error: {config_path} not found!")
    print("Make sure you've trained the model first.")
else:
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Get the trained model path
    checkpoint_dir = config['output']['output_dir']
    model_path = os.path.join(checkpoint_dir, "final_model")
    
    if not os.path.exists(model_path):
        print(f"‚ùå Error: Trained model not found at {model_path}")
        print(f"Available checkpoints in {checkpoint_dir}:")
        if os.path.exists(checkpoint_dir):
            for item in os.listdir(checkpoint_dir):
                print(f"  - {item}")
    else:
        print(f"‚úì Loading trained ControlNet from {model_path}")
        
        try:
            base_model = config['model']['base_model']
            
            # Load ControlNet WITHOUT config overrides (use what's in checkpoint)
            print("Loading ControlNet...")
            controlnet = FluxControlNetModel.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16
            )
            
            # Load inference pipeline
            print("Loading FLUX pipeline...")
            pipe = FluxControlNetPipeline.from_pretrained(
                base_model,
                controlnet=controlnet,
                torch_dtype=torch.bfloat16
            ).to("cuda")
            
            print("‚úÖ Models loaded successfully!")
            print(f"   Base Model: {base_model}")
            print(f"   ControlNet In Channels: {controlnet.config.in_channels}")
            print(f"   ControlNet X-Embedder In Features: {controlnet.x_embedder.in_features}")
            print(f"   ControlNet Layers: {controlnet.config.num_layers}")
            print(f"   ControlNet Single Layers: {controlnet.config.num_single_layers}")
            
            # Check if x_embedder matches in_channels (both in packed format)
            print(f"\nüìã Architecture Check:")
            print(f"   Config in_channels: {controlnet.config.in_channels}")
            print(f"   X-Embedder input features: {controlnet.x_embedder.in_features}")
            print(f"   Note: in_channels = VAE channels (16) √ó 4 from 2√ó2 spatial packing")
            
            # Check compatibility
            if controlnet.x_embedder.in_features == controlnet.config.in_channels:
                print(f"\n‚úÖ COMPATIBLE - Ready for inference!")
                print(f"   Both use the same {controlnet.config.in_channels}-dimensional packed format.")
            else:
                print(f"\n‚ö†Ô∏è  WARNING: Possible mismatch detected!")
                print(f"   X-Embedder expects: {controlnet.x_embedder.in_features} features")
                print(f"   Config specifies: {controlnet.config.in_channels} channels")
                print(f"   This might cause errors during inference.")
            
        except Exception as e:
            print(f"‚ùå Error loading models: {e}")
            print("\nMake sure:")
            print("  1. Training has completed")
            print("  2. The model checkpoint exists")
            print("  3. Your GPU has enough memory (24GB+)")
            import traceback
            traceback.print_exc()

In [None]:
# Quick Compatibility Check
if 'controlnet' in globals():
    print("=" * 70)
    print("QUICK COMPATIBILITY CHECK")
    print("=" * 70)
    
    x_embedder_in = controlnet.x_embedder.in_features
    config_in_channels = controlnet.config.in_channels
    
    print(f"\nControlNet Configuration:")
    print(f"  ‚Ä¢ X-Embedder input features: {x_embedder_in}")
    print(f"  ‚Ä¢ Config in_channels: {config_in_channels}")
    print(f"  ‚Ä¢ (in_channels = 16 VAE channels √ó 4 from 2√ó2 spatial packing)")
    
    if x_embedder_in == config_in_channels:
        print(f"\n‚úÖ COMPATIBLE - Inference should work perfectly!")
        print(f"   Both training and inference use the same {config_in_channels}-dimensional")
        print(f"   packed latent format with 2√ó2 spatial packing.")
    else:
        print(f"\n‚ùå INCOMPATIBLE - Dimension mismatch detected!")
        print(f"   X-Embedder expects: {x_embedder_in}")
        print(f"   Config provides: {config_in_channels}")
        print(f"\n   SOLUTION: Retrain with correct config")
        print(f"     accelerate launch train_arabic.py --config train_config.yaml")
    
    print("=" * 70)
else:
    print("ControlNet not loaded. Run model loading cell first.")

In [None]:
def prepare_glyph_and_controls(text, font_path="./arabic_fonts/Amiri-Regular.ttf", 
                               font_size=80, width=1024, height=1024):
    """
    Prepare glyph, canny, and position maps for inference.
    
    Args:
        text: Arabic text string
        font_path: Path to Arabic font
        font_size: Font size for rendering
        width: Image width
        height: Image height
    
    Returns:
        glyph: Rendered text image (RGB)
        canny: Canny edge detection of glyph (RGB)
        position: Position heatmap (grayscale for pipeline compatibility)
    """
    # Render text (glyph)
    glyph_img = Image.new('RGB', (width, height), color='white')
    draw = ImageDraw.Draw(glyph_img)
    
    try:
        font = ImageFont.truetype(font_path, font_size)
    except Exception as e:
        print(f"Warning: Could not load font {font_path}: {e}")
        print("Using default font instead")
        font = ImageFont.load_default()
    
    # Draw text
    draw.text((50, (height - font_size) // 2), text, font=font, fill='black')
    
    # Create canny edges
    glyph_array = np.array(glyph_img.convert('L'))
    canny_edges = cv2.Canny(glyph_array, 50, 100)
    canny_img = Image.fromarray(np.dstack([255 - canny_edges, 255 - canny_edges, 255 - canny_edges]))
    
    # Create position map as GRAYSCALE (single channel)
    # The pipeline will expand it to 3 channels automatically
    position_array = np.zeros((height, width), dtype=np.uint8)
    position_array[100:height-100, 50:width-50] = 200
    position_img = Image.fromarray(position_array)  # Grayscale image
    
    return glyph_img, canny_img, position_img

def generate_image(text, prompt="", num_inference_steps=50, controlnet_conditioning_step=30, 
                   output_path="./results"):
    """
    Generate an image with Arabic text using the trained ControlNet.
    
    Args:
        text: Arabic text to render
        prompt: Optional prompt for FLUX (default: empty for unconditional generation)
        num_inference_steps: Number of inference steps (more = higher quality but slower)
        controlnet_conditioning_step: When to stop applying ControlNet (0-num_inference_steps)
        output_path: Directory to save output image
    """
    
    if 'pipe' not in globals():
        print("‚ùå Pipeline not loaded. Please run the model loading cell first.")
        return None
    
    try:
        os.makedirs(output_path, exist_ok=True)
        
        print(f"Generating image for: {text}")
        
        # Prepare conditioning
        glyph, canny, position = prepare_glyph_and_controls(text)
        
        # Resize to match model input
        glyph = glyph.resize((512, 512))
        position = position.resize((512, 512))
        
        # Use empty prompt for unconditional generation
        if not prompt:
            prompt = ""
        
        print(f"Prompt: '{prompt}'")
        print(f"Inference steps: {num_inference_steps}")
        print(f"ControlNet conditioning until step: {controlnet_conditioning_step}")
        
        # Generate image with glyph as the control anchor and position as spatial guide
        # control_glyph: used to initialize latents from rendered text
        # control_image: primary spatial control (glyph)
        # control_position: position guidance (passed as grayscale, expanded to 3ch by pipeline)
        with torch.no_grad():
            generator = torch.Generator(device="cuda").manual_seed(42)
            
            image = pipe(
                prompt,
                height=512,
                width=512,
                num_inference_steps=num_inference_steps,
                guidance_scale=0.0,  # No guidance for unconditional
                controlnet_conditioning_scale=1.0,
                controlnet_conditioning_step=controlnet_conditioning_step,
                control_image=[glyph],  # RGB glyph as main control
                control_position=[position],  # Grayscale position map (expanded to 3ch by pipeline)
                control_glyph=glyph,  # used for latent initialization
                control_mask=None,
                generator=generator,
            ).images[0]
        
        # Save image
        output_file = os.path.join(output_path, f"generated_{text[:10]}.png")
        image.save(output_file)
        print(f"‚úÖ Image saved to {output_file}")
        
        return image
        
    except Exception as e:
        print(f"‚ùå Error during generation: {e}")
        import traceback
        traceback.print_exc()
        return None

# Test generation
print("Ready to generate images with Arabic text!")
print("\nExample usage:")
print('  image = generate_image("ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ")')
print('  image = generate_image("ŸÖÿ±ÿ≠ÿ®ÿß", prompt="A beautiful greeting", num_inference_steps=50)')

In [None]:
# Example: Generate Arabic text images
print("=" * 60)
print("INFERENCE EXAMPLE - Generate Arabic Text Images")
print("=" * 60)

# Check if models are loaded
if 'pipe' not in globals() or 'controlnet' not in globals():
    print("‚ùå ERROR: Models not loaded!")
    print("Please run the model loading cell (Step 11) first.")
else:
    # Try generating images
    test_texts = [
        "ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ",  # Hello/Peace be upon you
        "ŸÖÿ±ÿ≠ÿ®ÿß",           # Hello
        "ÿ¥ŸÉÿ±ÿß",            # Thank you
    ]

    results = []
    for text in test_texts:
        print(f"\n{'='*60}")
        try:
            image = generate_image(
                text=text,
                prompt="",  # Empty prompt for unconditional generation
                num_inference_steps=30,  # Reduced for faster testing
                controlnet_conditioning_step=20,
                output_path="./results_after_training"
            )
            if image:
                results.append((text, image))
                # Display the image
                from IPython.display import display
                display(image)
        except RuntimeError as e:
            error_msg = str(e)
            if "shapes cannot be multiplied" in error_msg:
                print(f"‚ùå Architecture Mismatch Error: {e}")
                print("\nThis error occurs when the checkpoint architecture doesn't match")
                print("what the pipeline expects. This typically means:")
                print("  1. The checkpoint was trained with different settings")
                print("  2. The in_channels or packing configuration doesn't match")
                print("\n‚Üí CHECK THE DIAGNOSTIC CELL ABOVE to see the mismatch")
                print("‚Üí YOU MAY NEED TO RETRAIN THE MODEL with the current config")
            else:
                print(f"‚ùå Error during generation: {e}")
                import traceback
                traceback.print_exc()
        except Exception as e:
            print(f"‚ùå Unexpected error: {e}")
            import traceback
            traceback.print_exc()

    print(f"\n{'='*60}")
    print(f"‚úÖ Generated {len(results)}/{len(test_texts)} images successfully!")
    
    if len(results) < len(test_texts):
        print("\n‚ö†Ô∏è  Some generations failed. See errors above for details.")

## Troubleshooting Inference

### Current Status: Training in Progress ‚úÖ

**What's happening on RunPod:**
```bash
Loaded 90 samples for train split
Loaded 10 samples for val split
# Training running with in_channels: 64 (packed format)
```

Your training is now running with the **correct configuration** (`in_channels: 64`). This means the checkpoint being saved will be **100% compatible with inference**.

### Why the Old Checkpoint Failed

The old checkpoint was trained with `in_channels: 16`, but inference provides 64-dimensional packed latents:
- **Old checkpoint**: `x_embedder` configured for 16 features
- **Inference pipeline**: Provides 64-dimensional packed latents (VAE 16 channels √ó 4 from 2√ó2 spatial packing)
- **Result**: Shape mismatch error `1024x128 and 64x3072`

### What to Watch During Training

**Expected Training Metrics:**
```
Loss: Should gradually decrease from 0-1.0 range
Learning rate: Starts low during warmup, increases to 1e-5 by step 500
Epoch time: ~1-2 min per epoch depending on GPU
```

**Example of healthy progress:**
- Epoch 0-10: Loss gradually decreasing
- Epoch 10-50: Loss continuing to decrease, may stabilize
- Epoch 50+: Loss should be near minimum (~0.01 or lower)

### What Happens After Training Completes ‚úÖ

Once training finishes, you'll have:
1. ‚úÖ `output/arabic_reptext/final_model/` - The trained checkpoint
2. ‚úÖ Checkpoint built with `in_channels: 64` (perfect match for inference)
3. ‚úÖ Ready for inference - no more architecture mismatches!

**Then run inference in the notebook:**
```python
# Run the model loading cell ‚Üí it will load final_model
# Run the inference cell ‚Üí should work without errors!
image = generate_image("ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ")
```

### Critical Issue (RESOLVED): "Cannot multiply shapes (1024x64 and 16x3072)"

**Status: FIXED** ‚úÖ

**What was wrong:**
- Your checkpoint was trained with `in_channels: 16`
- The inference pipeline always provides 64-dimensional packed latents
- Dimension mismatch caused the error

**What you did to fix it:**
1. Deleted old checkpoint: `rm -rf output/arabic_reptext/`
2. Started fresh training with current config: `in_channels: 64`
3. New checkpoint will be compatible with inference

**You won't see this error again** because you're now training with the correct settings.

### Other Common Issues

**Issue: Out of memory during inference**
- Solution: Reduce `num_inference_steps` (try 20-30 instead of 50)
- Or reduce image size in `prepare_glyph_and_controls`

**Issue: Poor quality results**
- Solution: Increase `num_inference_steps` (try 50-100)
- Make sure training completed with loss converging to low values

**Issue: "Model not found" error**
- Make sure training completed: check `./output/arabic_reptext/final_model/`
- Or check for intermediate checkpoints: `./output/arabic_reptext/checkpoint-XXXX/`

### Debug Information

In [None]:
## Checkpoint Architecture Diagnostics

Use this cell to verify what the loaded ControlNet expects vs what the pipeline provides.

**If you see a mismatch warning, you likely need to retrain the model** with the current `train_config.yaml` to ensure compatibility.

In [None]:
# Check ControlNet Architecture
if 'controlnet' in globals():
    print("=" * 60)
    print("ControlNet Architecture")
    print("=" * 60)
    print(f"Input Channels (in_channels): {controlnet.config.in_channels}")
    print(f"Output Channels (out_channels): {controlnet.out_channels}")
    print(f"Inner Dim: {controlnet.inner_dim}")
    print(f"Transformer Blocks: {controlnet.config.num_layers}")
    print(f"Single Transformer Blocks: {controlnet.config.num_single_layers}")
    
    # Check x_embedder dimensions
    print(f"\nX-Embedder Layer (input embedder):")
    print(f"  Input Features: {controlnet.x_embedder.in_features}")
    print(f"  Output Features: {controlnet.x_embedder.out_features}")
    print(f"\nContext Embedder Layer:")
    print(f"  Input Features: {controlnet.context_embedder.in_features}")
    print(f"  Output Features: {controlnet.context_embedder.out_features}")
    print("=" * 60)
else:
    print("‚ùå ControlNet not loaded yet. Run the model loading cell first.")


In [None]:
# Debug Information
print("=" * 60)
print("INFERENCE DEBUG INFO")
print("=" * 60)

# Check config
print("\n1. Training Config:")
if Path("train_config.yaml").exists():
    with open("train_config.yaml", 'r') as f:
        config = yaml.safe_load(f)
    print(f"   ‚úì Config found")
    print(f"   - Image size: {config['data']['image_size']}")
    print(f"   - Batch size: {config['data']['batch_size']}")
    print(f"   - ControlNet layers: {config['model']['controlnet_config']['num_layers']}")
    print(f"   - ControlNet single layers: {config['model']['controlnet_config']['num_single_layers']}")
    print(f"   - Text seq len: {config['model'].get('text_seq_len', 'Not set')}")
else:
    print(f"   ‚úó Config not found")

# Check model checkpoint
print("\n2. Model Checkpoint:")
output_dir = config['output']['output_dir'] if 'config' in locals() else "./output/arabic_reptext"
if os.path.exists(output_dir):
    items = os.listdir(output_dir)
    print(f"   ‚úì Output directory found: {output_dir}")
    print(f"   - Contents: {items}")
    
    final_model = os.path.join(output_dir, "final_model")
    if os.path.exists(final_model):
        print(f"   ‚úì Final model found: {final_model}")
    else:
        print(f"   ‚ö† Final model not found (training may still be in progress)")
else:
    print(f"   ‚úó Output directory not found: {output_dir}")

# Check GPU
print("\n3. GPU Status:")
print(f"   - CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   - GPU: {torch.cuda.get_device_name()}")
    print(f"   - Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"   - Memory used: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

# Check fonts
print("\n4. Arabic Fonts:")
if os.path.exists("./arabic_fonts"):
    fonts = [f for f in os.listdir("./arabic_fonts") if f.endswith(('.ttf', '.otf'))]
    print(f"   ‚úì Fonts directory found")
    print(f"   - Number of fonts: {len(fonts)}")
    if fonts:
        print(f"   - Sample: {fonts[0]}")
else:
    print(f"   ‚úó Fonts directory not found")

print("\n" + "=" * 60)

## Summary & Next Steps

### Completed ‚úÖ
- ‚úÖ Installed dependencies
- ‚úÖ Downloaded Arabic fonts
- ‚úÖ Prepared training dataset
- ‚úÖ Configured training with proper ControlNet architecture
- ‚úÖ Set up inference with matching model config
- ‚úÖ Created simplified inference scripts

### What You've Learned
1. How to prepare Arabic text training data
2. How to configure and train ControlNet with FLUX
3. How to handle GPU memory constraints (48GB setup)
4. How to run inference with the trained model

### To TRAIN the Model
Run in terminal:
```bash
# Single GPU
accelerate launch train_arabic.py --config train_config.yaml

# Dual GPU (Recommended for 2x48GB)
accelerate config  # Set Number of processes: 2
accelerate launch --num_processes 2 train_arabic.py --config train_config.yaml
```

### To RUN INFERENCE
Option 1 - Use Notebook (Already configured):
```python
# Run the cells above to load model and generate images
image = generate_image("ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ")
```

Option 2 - Use Terminal Script:
```bash
python infer_simple.py --text "ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ" --num_steps 50
```

### Key Files
- `train_config.yaml` - Training configuration (single source of truth)
- `train_arabic.py` - Training script
- `infer_simple.py` - Simplified inference script
- `arabic_training_quickstart.ipynb` - This notebook

### Tips for Best Results
1. **Training**: Use 2 GPUs for faster training
2. **Inference**: Increase `num_steps` (50-100) for higher quality
3. **Memory**: If OOM, reduce `image_size` or `batch_size` in config
4. **Fonts**: Use diverse Arabic fonts for better generalization
5. **Dataset**: More training samples = better results

### Resources
- [RepText Paper](https://arxiv.org/abs/2504.19724)
- [TRAINING_GUIDE.md](TRAINING_GUIDE.md) - Detailed documentation
- [TRAINING_CONFIG_GUIDE.md](TRAINING_CONFIG_GUIDE.md) - Configuration guide