# RepText Arabic Training - Quick Start Guide

This notebook demonstrates how to pretrain RepText for Arabic text generation.

## Prerequisites
- NVIDIA GPU with at least 24GB VRAM
- Python 3.8+
- CUDA 11.7+

## Step 1: Install Dependencies

In [None]:
!pip install -r requirements.txt

## Step 2: Download Arabic Fonts

In [None]:
# Download recommended Arabic fonts from Google Fonts
!python download_arabic_fonts.py

In [None]:
# Verify fonts were downloaded
import os
fonts = [f for f in os.listdir('arabic_fonts') if f.endswith(('.ttf', '.otf'))]
print(f"Found {len(fonts)} fonts:")
for font in fonts:
    print(f"  - {font}")

## Step 3: Test Font Rendering

In [None]:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

# Test Arabic text rendering
test_text = "ŸÖÿ±ÿ≠ÿ®ÿß ÿ®ŸÉŸÖ ŸÅŸä RepText"
font_path = "./arabic_fonts/Amiri-Regular.ttf"

img = Image.new('RGB', (600, 150), color='white')
draw = ImageDraw.Draw(img)
font = ImageFont.truetype(font_path, 60)
draw.text((50, 40), test_text, font=font, fill='black')

plt.figure(figsize=(10, 3))
plt.imshow(img)
plt.axis('off')
plt.title('Arabic Text Rendering Test')
plt.show()

print("‚úì Arabic text rendering works correctly!")

## Step 4: Prepare Training Dataset

This will generate synthetic training samples with:
- Glyph images (rendered Arabic text)
- Position maps (location heatmaps)
- Binary masks (text regions)
- Canny edges

In [None]:
# For quick testing, use a small number of samples
# For actual training, use 10000-50000
NUM_SAMPLES = 100  # Change to 10000 for real training

!python prepare_arabic_dataset.py \
    --output_dir ./arabic_training_data \
    --font_dir ./arabic_fonts \
    --text_file ./arabic_texts.txt \
    --num_samples {NUM_SAMPLES} \
    --width 1024 \
    --height 1024 \
    --min_font_size 60 \
    --max_font_size 120

## Step 5: Visualize Training Samples

In [None]:
import json
from PIL import Image
import matplotlib.pyplot as plt

# Load a sample
sample_dir = "./arabic_training_data/sample_000000"

# Load images
glyph = Image.open(f"{sample_dir}/glyph.png")
position = Image.open(f"{sample_dir}/position.png")
mask = Image.open(f"{sample_dir}/mask.png")
canny = Image.open(f"{sample_dir}/canny.png")

# Load metadata
with open(f"{sample_dir}/metadata.json", 'r', encoding='utf-8') as f:
    metadata = json.load(f)

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

axes[0, 0].imshow(glyph)
axes[0, 0].set_title(f"Glyph: {metadata['text']}")
axes[0, 0].axis('off')

axes[0, 1].imshow(position)
axes[0, 1].set_title("Position Map")
axes[0, 1].axis('off')

axes[1, 0].imshow(mask, cmap='gray')
axes[1, 0].set_title("Text Mask")
axes[1, 0].axis('off')

axes[1, 1].imshow(canny)
axes[1, 1].set_title("Canny Edges")
axes[1, 1].axis('off')

plt.suptitle(f"Training Sample - Font Size: {metadata['font_size']}px")
plt.tight_layout()
plt.show()

## Step 6: Test Dataset Loading

In [None]:
from arabic_dataset import create_dataloaders

# Create dataloaders
train_loader, val_loader = create_dataloaders(
    data_dir='./arabic_training_data',
    batch_size=2,
    num_workers=0,
    image_size=(1024, 1024),
    train_ratio=0.9
)

# Test loading a batch
batch = next(iter(train_loader))

print("Batch contents:")
print(f"  Glyph shape: {batch['glyph'].shape}")
print(f"  Position shape: {batch['position'].shape}")
print(f"  Mask shape: {batch['mask'].shape}")
print(f"  Canny shape: {batch['canny'].shape}")
print(f"  Text samples: {batch['text']}")
print(f"  Font sizes: {batch['font_size']}")
print("\n‚úì Dataset loading works correctly!")

## Step 7: Configure Accelerate for Training

In [None]:
!accelerate config

## Step 8: Review Training Configuration

In [None]:
import yaml

with open('train_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Training Configuration:")
print(yaml.dump(config, default_flow_style=False))

## Step 9: Launch Training

**Note:** Training takes a long time. You may want to run this in a terminal instead of notebook.

For terminal:
```bash
accelerate launch train_arabic.py --config train_config.yaml
```

Or use the automated script:
```bash
./train_arabic.sh
```

In [None]:
# For notebook training (not recommended for long training runs)
# Uncomment to run:
# !accelerate launch train_arabic.py --config train_config.yaml

print("It's recommended to run training in a terminal or tmux session:")
print("  accelerate launch train_arabic.py --config train_config.yaml")

## Step 9.5: Monitor Training Progress (RunPod)

‚è≥ **Training Status: IN PROGRESS (Epoch 2/100)**

Your training started successfully! You can see it processing batches per epoch in the terminal.

**Understanding the Progress:**
- Progress shows: `23/90 [00:33<01:37...]` = 23 samples processed out of 90 training samples
- This means each epoch processes the full 90 training samples
- After all 90 batches complete, the next epoch begins
- **Total:** 100 epochs to complete (configured in `train_config.yaml`)
- **Time per epoch:** ~1.5 minutes
- **Total training time:** ~2.5 hours

**Key Metrics to Watch:**
```
Epoch 2: 26% | 23/90 [00:33<01:37, 1.45s/it, diffusion_loss=0, loss=0, lr=4e-8]
 ‚Üì       ‚Üì    ‚Üì  ‚Üì                                    ‚Üì
Epoch   Progress Batch count      Time metrics        Learning rate
        of epoch
```

**Monitor Training:**

Run in your RunPod terminal:
```bash
# Check if training is still running
ps aux | grep train_arabic.py

# Check checkpoint directory growth (as training progresses)
watch -n 30 "ls -lh output/arabic_reptext/ && echo '---' && du -sh output/arabic_reptext/"

# Check if intermediate checkpoints are being saved
ls -lh output/arabic_reptext/checkpoint-*/
```

**When Training Completes:**
- Final checkpoint saved to: `output/arabic_reptext/final_model/`
- Should contain:
  - `config.json` (with `in_channels: 64`)
  - `diffusion_pytorch_model.safetensors` (~5GB)
- This checkpoint will be **100% compatible** with inference ‚úÖ

---


## Step 10: Monitor Training (Optional - with W&B)

In [None]:
# Install wandb
# !pip install wandb
# !wandb login

# Then run with --use_wandb flag:
# !accelerate launch train_arabic.py --config train_config.yaml --use_wandb

## Step 11: Test Your Trained Model

After training completes, test the model:

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

# Collect all progress images
progress_dir = "./training_progress"
image_files = []

# Add baseline
if os.path.exists(f"{progress_dir}/baseline_original.jpg"):
    image_files.append(("Baseline\n(Original RepText)", f"{progress_dir}/baseline_original.jpg"))

# Add epoch images (every 2 epochs: 2, 4, 6, 8, 10)
for epoch in range(2, 11, 2):
    img_path = f"{progress_dir}/epoch_{epoch}.jpg"
    if os.path.exists(img_path):
        image_files.append((f"Epoch {epoch}", img_path))

# Create comparison grid
if len(image_files) > 0:
    n_images = len(image_files)
    n_cols = 3
    n_rows = (n_images + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6 * n_rows))
    axes = axes.flatten() if n_images > 1 else [axes]
    
    for idx, (title, img_path) in enumerate(image_files):
        img = Image.open(img_path)
        axes[idx].imshow(img)
        axes[idx].set_title(title, fontsize=14, fontweight='bold')
        axes[idx].axis('off')
    
    # Hide unused subplots
    for idx in range(n_images, len(axes)):
        axes[idx].axis('off')
    
    test_text = "ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ Ÿàÿ±ÿ≠ŸÖÿ© ÿßŸÑŸÑŸá Ÿàÿ®ÿ±ŸÉÿßÿ™Ÿá"
    plt.suptitle(f'Training Progress: {test_text}', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('./training_progress/comparison_grid.jpg', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n‚úÖ Compared {len(image_files)} images")
    print(f"üìä Grid saved: ./training_progress/comparison_grid.jpg")
else:
    print("‚ùå No images found. Run training first.")

## Step 12: Train with Progress Monitoring

This cell will:

1. Generate baseline image with original RepText3. Save final model
2. Train for 10 epochs (generating test image every 2 epochs)

In [None]:
import os
import yaml
import torch
import numpy as np
import cv2
import subprocess
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, clear_output
from controlnet_flux import FluxControlNetModel
from pipeline_flux_controlnet import FluxControlNetPipeline

# Configuration
test_text = "ÿßŸÑÿ≥ŸÑÿßŸÖ ÿπŸÑŸäŸÉŸÖ Ÿàÿ±ÿ≠ŸÖÿ© ÿßŸÑŸÑŸá Ÿàÿ®ÿ±ŸÉÿßÿ™Ÿá"
test_prompt = "a street sign in city"
total_epochs = 10
epochs_per_step = 2  # Generate image every 2 epochs

print("="*80)
print("REPTEXT ARABIC FINE-TUNING WITH PROGRESS MONITORING")
print("="*80)
print(f"Test text: {test_text}")
print(f"Total epochs: {total_epochs}")
print(f"Image generation interval: every {epochs_per_step} epochs")
print("="*80)

# Helper functions
def canny(img):
    low_threshold = 50
    high_threshold = 100
    img = cv2.Canny(img, low_threshold, high_threshold)
    img = img[:, :, None]
    img = 255 - np.concatenate([img, img, img], axis=2)
    return img

def generate_test_image(controlnet_path, epoch_label, text=test_text, prompt=test_prompt):
    """Generate test image with current model"""
    print(f"\nüì∏ Generating image for {epoch_label}...")
    
    try:
        # Load models
        base_model = "black-forest-labs/FLUX.1-dev"
        controlnet = FluxControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.bfloat16)
        pipe = FluxControlNetPipeline.from_pretrained(
            base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
        ).to("cuda")
        
        # Setup
        width, height = 1024, 512
        font_path = "./arabic_fonts/Amiri-Regular.ttf"
        font_size = 80
        font = ImageFont.truetype(font_path, font_size)
        
        # Prepare control images
        text_position = (200, 200)
        text_color = (255, 255, 255)
        
        control_image_glyph = Image.new("RGB", (width, height), (0, 0, 0))
        draw = ImageDraw.Draw(control_image_glyph)
        draw.text(text_position, text, font=font, fill=text_color)
        bbox = draw.textbbox(text_position, text, font=font)
        
        # Position map
        control_position = np.zeros([height, width], dtype=np.uint8)
        control_position[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 255
        control_position = Image.fromarray(control_position)
        
        # Regional mask
        control_mask_np = np.zeros([height, width], dtype=np.uint8)
        control_mask_np[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255
        control_mask = Image.fromarray(control_mask_np)
        
        # Glyph
        control_glyph = np.array(control_image_glyph)
        control_glyph = Image.fromarray(control_glyph).convert("RGB")
        
        # Canny
        control_image = canny(cv2.cvtColor(np.array(control_image_glyph), cv2.COLOR_RGB2BGR))
        control_image = Image.fromarray(cv2.cvtColor(control_image, cv2.COLOR_BGR2RGB))
        
        # Generate
        full_prompt = f"{prompt}, '{text}', filmfotos, film grain, reversal film photography"
        generator = torch.Generator(device="cuda").manual_seed(42)
        
        image = pipe(
            full_prompt,
            control_image=[control_image],
            control_position=[control_position],
            control_mask=[control_mask],
            control_glyph=control_glyph,
            controlnet_conditioning_scale=1.0,
            controlnet_conditioning_step=30,
            width=width,
            height=height,
            num_inference_steps=30,
            guidance_scale=3.5,
            generator=generator,
        ).images[0]
        
        # Save
        os.makedirs("./training_progress", exist_ok=True)
        output_path = f"./training_progress/{epoch_label}.jpg"
        image.save(output_path)
        
        print(f"   Saved: {output_path}")
        display(image)
        
        # Cleanup
        del pipe, controlnet
        torch.cuda.empty_cache()
        
        return image
    except Exception as e:
        print(f"   ‚ö†Ô∏è Error generating image: {e}")
        return None

# STEP 1: Generate baseline with original RepText
print("\n" + "="*80)
print("STEP 1: BASELINE - Original RepText Model")
print("="*80)
generate_test_image("Shakker-Labs/RepText", "baseline_original")

# STEP 2: Update config for training
print("\n" + "="*80)
print("STEP 2: PREPARING TRAINING CONFIG")
print("="*80)

with open("train_config.yaml", 'r') as f:
    config = yaml.safe_load(f)

config['model']['pretrained_controlnet'] = "Shakker-Labs/RepText"
original_epochs = config['training']['num_epochs']
config['training']['num_epochs'] = epochs_per_step

with open("train_config.yaml", 'w') as f:
    yaml.dump(config, f, default_flow_style=False, allow_unicode=True)

print(f"‚úÖ Config updated: {epochs_per_step} epochs per training step")

# STEP 3: Training loop with progress monitoring
print("\n" + "="*80)
print(f"STEP 3: TRAINING FOR {total_epochs} EPOCHS")
print("="*80)

epoch_count = 0
while epoch_count < total_epochs:
    current_step = epoch_count // epochs_per_step + 1
    epochs_this_step = min(epochs_per_step, total_epochs - epoch_count)
    
    print(f"\n{'='*80}")
    print(f"TRAINING STEP {current_step}: Epochs {epoch_count+1}-{epoch_count+epochs_this_step}")
    print(f"{'='*80}")
    
    # Update config for this step
    with open("train_config.yaml", 'r') as f:
        config = yaml.safe_load(f)
    
    config['training']['num_epochs'] = epochs_this_step
    
    # After first step, load from checkpoint instead of pretrained
    if epoch_count > 0:
        checkpoint_path = os.path.join(config['output']['output_dir'], "checkpoint")
        if os.path.exists(checkpoint_path):
            config['model']['pretrained_controlnet'] = checkpoint_path
    
    with open("train_config.yaml", 'w') as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
    
    # Run training
    print(f"\nüöÄ Launching training...")
    result = subprocess.run(
        ["accelerate", "launch", "train_arabic.py", "--config", "train_config.yaml"],
        capture_output=False
    )
    
    if result.returncode != 0:
        print(f"\n‚ùå Training failed at step {current_step}")
        break
    
    epoch_count += epochs_this_step
    
    # Save checkpoint for next iteration
    checkpoint_src = os.path.join(config['output']['output_dir'], "final_model")
    checkpoint_dst = os.path.join(config['output']['output_dir'], "checkpoint")
    
    if os.path.exists(checkpoint_src):
        import shutil
        if os.path.exists(checkpoint_dst):
            shutil.rmtree(checkpoint_dst)
        shutil.copytree(checkpoint_src, checkpoint_dst)
    
    # Generate test image
    if os.path.exists(checkpoint_dst):
        generate_test_image(checkpoint_dst, f"epoch_{epoch_count}")
    
    print(f"\n‚úÖ Completed {epoch_count}/{total_epochs} epochs")

# STEP 4: Training complete
print("\n" + "="*80)
print("üéâ TRAINING COMPLETE!")
print("="*80)
print(f"\nüìÅ Results:")
print(f"   Baseline: ./training_progress/baseline_original.jpg")
for i in range(epochs_per_step, total_epochs + 1, epochs_per_step):
    print(f"   Epoch {i}: ./training_progress/epoch_{i}.jpg")
print(f"\nüì¶ Final model: {config['output']['output_dir']}/final_model/")

# Restore original config
with open("train_config.yaml", 'r') as f:
    config = yaml.safe_load(f)
config['training']['num_epochs'] = original_epochs
config['model']['pretrained_controlnet'] = "Shakker-Labs/RepText"
with open("train_config.yaml", 'w') as f:
    yaml.dump(config, f, default_flow_style=False, allow_unicode=True)

## Step 13: Visualize All Progress Images

Compare baseline vs all training checkpoints.