# Fine-Tuning Stable Diffusion for Text Generation using TextOCR Dataset

## Overview
This notebook demonstrates fine-tuning Stable Diffusion v1.5 with LoRA adapters to improve text rendering capabilities using the TextOCR dataset.

**Flow:**
1. Image + text as input → Model
2. Model trained with LoRA
3. Give a prompt → Generate image
4. Evaluate with OCR (Exact Match + Character Accuracy)

**Dataset:** TextOCR (facebook/textocr)
**Model:** Stable Diffusion v1.5 with LoRA fine-tuning
**Evaluation:** OCR-based readability metrics

## Cell 1: Installation and Setup

In [None]:
print("="*80)
print("SETUP: Installing Required Packages")
print("="*80)

!pip install -q diffusers transformers accelerate peft datasets pillow pytesseract python-Levenshtein opencv-python-headless

# Install Tesseract OCR for Colab
!apt-get install -y tesseract-ocr

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
import cv2
import os
from tqdm import tqdm
import json
import warnings
warnings.filterwarnings('ignore')

# Configuration
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "./text-render-lora-textocr"
MAX_LENGTH = 77
IMAGE_SIZE = 512
LEARNING_RATE = 1e-4
TRAIN_STEPS = 800  # Increased for better convergence
BATCH_SIZE = 1

os.makedirs(OUTPUT_DIR, exist_ok=True)
accelerator = Accelerator()

print("✓ Setup complete")
print(f"  Device: {accelerator.device}")
print(f"  PyTorch version: {torch.__version__}")

## Cell 2: Exploratory Data Analysis - TextOCR Dataset

In [None]:
print("\n" + "="*80)
print("STEP 1: EXPLORATORY DATA ANALYSIS - TextOCR Dataset")
print("="*80)

# Load TextOCR dataset
print("\nLoading TextOCR dataset...")
try:
    dataset = load_dataset("facebook/textocr", split="train", streaming=True)
    print("✓ TextOCR dataset loaded (streaming mode)")
except Exception as e:
    print(f"Note: {e}")
    print("Attempting alternative loading method...")
    dataset = load_dataset("facebook/textocr", split="validation", streaming=True)

# Sample initial examples for EDA
print("\nSampling examples for analysis...")
samples = []
for i, example in enumerate(dataset):
    if i >= 30:
        break
    samples.append(example)
    if (i + 1) % 10 == 0:
        print(f"  Loaded {i + 1} samples...")

print(f"\n✓ Sampled {len(samples)} examples")

# Analyze dataset structure
print("\n" + "-"*80)
print("Dataset Structure:")
print("-"*80)
print(f"Available keys: {samples[0].keys()}")

# Display sample data
print("\nSample Examples:")
print("-"*80)
for i in range(min(5, len(samples))):
    sample = samples[i]
    print(f"\nSample {i+1}:")
    print(f"  Keys: {sample.keys()}")
    
    if 'image' in sample:
        img = sample['image']
        print(f"  Image size: {img.size if hasattr(img, 'size') else 'N/A'}")
    
    if 'text' in sample:
        texts = sample['text']
        if isinstance(texts, list):
            print(f"  Texts found: {len(texts)} annotations")
            print(f"  Sample texts: {texts[:3]}")
        else:
            print(f"  Text: {str(texts)[:100]}")
    
    if 'caption' in sample:
        print(f"  Caption: {str(sample['caption'])[:100]}")

# Visualize samples
print("\nVisualizing sample images...")
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i in range(min(6, len(samples))):
    sample = samples[i]
    if 'image' in sample:
        img = sample['image']
        axes[i].imshow(img)
        
        title = "TextOCR Sample"
        if 'text' in sample:
            texts = sample['text']
            if isinstance(texts, list) and len(texts) > 0:
                title = f"Texts: {', '.join(str(t) for t in texts[:2])}"
            else:
                title = f"Text: {str(texts)[:30]}"
        
        axes[i].set_title(title, fontsize=8)
        axes[i].axis('off')

plt.tight_layout()
plt.savefig('textocr_samples_eda.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ EDA complete - samples saved to 'textocr_samples_eda.png'")

## Cell 3: Data Filtering and Preprocessing

In [None]:
print("\n" + "="*80)
print("STEP 2: DATA FILTERING AND PREPROCESSING")
print("="*80)

tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")

def extract_text_from_sample(sample):
    """Extract text annotations from TextOCR sample"""
    texts = []
    
    if 'text' in sample:
        text_data = sample['text']
        if isinstance(text_data, list):
            texts.extend([str(t) for t in text_data if t])
        elif text_data:
            texts.append(str(text_data))
    
    if 'caption' in sample:
        caption = sample['caption']
        if caption:
            texts.append(str(caption))
    
    return texts

def is_good_textocr_sample(sample):
    """Filter for samples with clear, readable text"""
    texts = extract_text_from_sample(sample)
    
    if not texts:
        return False
    
    num_texts = len(texts)
    
    if num_texts == 0 or num_texts > 5:
        return False
    
    valid_texts = [t for t in texts if 2 <= len(t) <= 30 and t.strip()]
    
    if len(valid_texts) == 0:
        return False
    
    alphanumeric_texts = [t for t in valid_texts if any(c.isalpha() for c in t)]
    
    return len(alphanumeric_texts) > 0

def create_caption_from_textocr(sample):
    """Create a descriptive caption for training"""
    texts = extract_text_from_sample(sample)
    
    if not texts:
        return "An image with text"
    
    main_text = max(texts, key=len)
    
    captions = [
        f"An image with the text '{main_text}'",
        f"A photo showing '{main_text}' written on it",
        f"Text that says '{main_text}'",
        f"An image containing the word '{main_text}'",
    ]
    
    return np.random.choice(captions)

def preprocess_textocr_sample(sample):
    """Convert TextOCR sample to tensors"""
    image = sample['image'].convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
    
    image_array = np.array(image).astype(np.float32) / 127.5 - 1.0
    pixel_values = torch.from_numpy(image_array).permute(2, 0, 1)
    
    caption = create_caption_from_textocr(sample)
    
    inputs = tokenizer(
        caption,
        padding="max_length",
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    input_ids = inputs.input_ids[0]
    
    texts = extract_text_from_sample(sample)
    ground_truth = texts[0] if texts else ""
    
    return pixel_values, input_ids, caption, ground_truth

# Filter and collect good samples
print("\nFiltering TextOCR dataset for text-focused samples...")
print("This may take a few minutes...")

dataset_stream = load_dataset("facebook/textocr", split="train", streaming=True)

filtered_samples = []
processed_count = 0

for example in dataset_stream:
    processed_count += 1
    
    if len(filtered_samples) >= 300:
        break
    
    if processed_count >= 3000:
        print(f"  Searched {processed_count} samples, found {len(filtered_samples)}")
        break
    
    if is_good_textocr_sample(example):
        filtered_samples.append(example)
        
        if len(filtered_samples) % 50 == 0:
            print(f"  Found {len(filtered_samples)} suitable samples (searched {processed_count})...")

print(f"\n✓ Filtered dataset: {len(filtered_samples)} text-focused samples from {processed_count} total")

# Show filtered examples
print("\nFiltered Sample Captions:")
print("-"*80)
for i in range(min(5, len(filtered_samples))):
    caption = create_caption_from_textocr(filtered_samples[i])
    texts = extract_text_from_sample(filtered_samples[i])
    print(f"{i+1}. Caption: {caption}")
    print(f"   Ground truth texts: {texts}")

# Visualize filtered samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(min(8, len(filtered_samples))):
    sample = filtered_samples[i]
    img = sample['image']
    caption = create_caption_from_textocr(sample)
    texts = extract_text_from_sample(sample)
    
    axes[i].imshow(img)
    axes[i].set_title(f"{caption}\nTexts: {', '.join(texts[:2])}", fontsize=7)
    axes[i].axis('off')

plt.tight_layout()
plt.savefig('textocr_filtered_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Filtered samples visualization saved")

# Preprocess all samples
print("\nPreprocessing training data...")
preprocessed_data = []

for sample in tqdm(filtered_samples, desc="Preprocessing"):
    try:
        pixel_values, input_ids, caption, ground_truth = preprocess_textocr_sample(sample)
        preprocessed_data.append({
            'pixel_values': pixel_values,
            'input_ids': input_ids,
            'caption': caption,
            'ground_truth': ground_truth
        })
    except Exception as e:
        continue

print(f"\n✓ {len(preprocessed_data)} samples preprocessed and ready for training")

## Cell 4: Model Setup

In [None]:
print("\n" + "="*80)
print("STEP 3: MODEL INITIALIZATION")
print("="*80)

# Load Core Components
print("Loading Stable Diffusion v1.5 components...")
vae = AutoencoderKL.from_pretrained(MODEL_NAME, subfolder="vae")
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder")
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet")

print("✓ Base models loaded")

# Freeze original weights (only train LoRA adapters)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print("✓ Base model weights frozen")

# Configure LoRA
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    lora_dropout=0.05,
    bias="none",
)

# Add LoRA adapters to UNet
unet = get_peft_model(unet, lora_config)

print("✓ LoRA adapters added to UNet")

# Move to device
device = accelerator.device
unet.to(device)
vae.to(device, dtype=torch.float16)
text_encoder.to(device, dtype=torch.float16)

# Count parameters
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in unet.parameters())

print("\n" + "-"*80)
print("Model Statistics:")
print("-"*80)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters (LoRA): {trainable_params:,}")
print(f"  Trainable percentage: {100 * trainable_params / total_params:.4f}%")
print(f"  Device: {device}")
print(f"  VAE scaling factor: {vae.config.scaling_factor}")
print("-"*80)

## Cell 5: Training Loop

In [None]:
print("\n" + "="*80)
print(f"STEP 4: TRAINING - {TRAIN_STEPS} steps")
print("="*80)

optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
unet.train()

print(f"\nTraining Configuration:")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Training steps: {TRAIN_STEPS}")
print(f"  Training samples: {len(preprocessed_data)}")
print(f"  Epochs (approx): {TRAIN_STEPS / len(preprocessed_data):.1f}")

# Training loop
losses = []
unet_dtype = next(unet.parameters()).dtype

print("\nStarting training...")
progress_bar = tqdm(range(TRAIN_STEPS), desc="Training")

for step in progress_bar:
    # Get data (cycle through dataset)
    idx = step % len(preprocessed_data)
    data_item = preprocessed_data[idx]
    
    pixel_values = data_item['pixel_values']
    input_ids = data_item['input_ids']
    
    # Move to GPU and add batch dimension
    pixel_values = pixel_values.unsqueeze(0).to(device, dtype=torch.float16)
    input_ids = input_ids.unsqueeze(0).to(device)
    
    # VAE ENCODING - Convert image to latent space
    with torch.no_grad():
        latents = vae.encode(pixel_values).latent_dist.sample()
        # CRITICAL: Scale latents by VAE scaling factor
        latents = latents * vae.config.scaling_factor
    
    # Create random noise for diffusion training
    noise = torch.randn_like(latents)
    
    # Sample random timestep
    timesteps = torch.randint(
        0, 
        noise_scheduler.config.num_train_timesteps, 
        (1,), 
        device=device
    ).long()
    
    # Add noise to latents (forward diffusion process)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
    # Get text embeddings
    with torch.no_grad():
        encoder_hidden_states = text_encoder(input_ids)[0]
    
    # Ensure dtype consistency for LoRA
    noisy_latents = noisy_latents.to(dtype=unet_dtype)
    encoder_hidden_states = encoder_hidden_states.to(dtype=unet_dtype)
    
    # Predict noise using UNet
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
    
    # Calculate loss
    loss = torch.nn.functional.mse_loss(
        model_pred.float(), 
        noise.float(), 
        reduction="mean"
    )
    
    # Backward pass
    loss.backward()
    
    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
    
    optimizer.step()
    optimizer.zero_grad()
    
    # Record loss
    loss_value = loss.item()
    losses.append(loss_value)
    
    # Update progress bar
    if step % 10 == 0:
        avg_loss = np.mean(losses[-50:]) if len(losses) >= 50 else np.mean(losses)
        progress_bar.set_postfix({'loss': f'{avg_loss:.4f}'})
    
    # Detailed logging every 100 steps
    if (step + 1) % 100 == 0:
        avg_loss = np.mean(losses[-100:])
        print(f"\n  Step {step+1}/{TRAIN_STEPS}: Avg Loss (last 100) = {avg_loss:.4f}")

# Save the LoRA adapter
print("\nSaving model...")
unet.save_pretrained(OUTPUT_DIR)
print(f"✓ Model saved to {OUTPUT_DIR}")

# Save training metadata
metadata = {
    'model': MODEL_NAME,
    'training_steps': TRAIN_STEPS,
    'learning_rate': LEARNING_RATE,
    'num_samples': len(preprocessed_data),
    'final_loss': losses[-1] if losses else None,
    'avg_loss_last_100': np.mean(losses[-100:]) if len(losses) >= 100 else np.mean(losses)
}

with open(f"{OUTPUT_DIR}/training_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("✓ Training metadata saved")

# Plot training curve
plt.figure(figsize=(12, 5))

# Raw loss
plt.subplot(1, 2, 1)
plt.plot(losses, alpha=0.3, label='Raw Loss')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Training Loss (Raw)')
plt.legend()
plt.grid(True, alpha=0.3)

# Smoothed loss
plt.subplot(1, 2, 2)
window = 50
if len(losses) >= window:
    smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(losses)), smoothed, linewidth=2, label='Smoothed Loss', color='red')
else:
    plt.plot(losses, linewidth=2, label='Loss')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title(f'Training Loss (Smoothed, window={window})')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('textocr_training_loss.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Final average loss: {np.mean(losses[-100:]):.4f}")

## Cell 6: Image Generation

In [None]:
print("\n" + "="*80)
print("STEP 5: GENERATING TEST IMAGES")
print("="*80)

# Load the fine-tuned model
print("Loading fine-tuned model...")
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.load_lora_weights(OUTPUT_DIR)
pipe.to("cuda")

print("✓ Fine-tuned model loaded with LoRA weights")

# Test prompts based on common text patterns
test_cases = [
    {
        'prompt': "An image with the text 'STOP' written on it",
        'expected': "STOP",
        'filename': "test_1_stop.png"
    },
    {
        'prompt': "A photo showing 'OPEN' written on it",
        'expected': "OPEN",
        'filename': "test_2_open.png"
    },
    {
        'prompt': "Text that says 'EXIT'",
        'expected': "EXIT",
        'filename': "test_3_exit.png"
    },
    {
        'prompt': "An image containing the word 'SALE'",
        'expected': "SALE",
        'filename': "test_4_sale.png"
    },
    {
        'prompt': "An image with the text 'CAFE' written on it",
        'expected': "CAFE",
        'filename': "test_5_cafe.png"
    },
]

negative_prompt = "blurry, distorted, messy text, unclear letters, multiple objects, people, complex scene, watermark"

print("\nGenerating test images...")
generated_images = []

for i, test_case in enumerate(test_cases):
    print(f"  {i+1}/{len(test_cases)}: {test_case['prompt']}")
    
    try:
        image = pipe(
            test_case['prompt'],
            negative_prompt=negative_prompt,
            num_inference_steps=50,
            guidance_scale=7.5,
            height=512,
            width=512
        ).images[0]
        
        image.save(test_case['filename'])
        generated_images.append(test_case)
        
    except Exception as e:
        print(f"    Error generating image: {e}")
        continue

print(f"\n✓ Generated {len(generated_images)} test images")

# Display generated images
num_images = len(generated_images)
cols = 3
rows = (num_images + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
if rows == 1:
    axes = axes.reshape(1, -1)
axes = axes.flatten()

for i, test_case in enumerate(generated_images):
    img = Image.open(test_case['filename'])
    axes[i].imshow(img)
    axes[i].set_title(f"Expected: '{test_case['expected']}'\n{test_case['prompt'][:40]}...", fontsize=9)
    axes[i].axis('off')

# Hide extra subplots
for i in range(len(generated_images), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.savefig('textocr_generated_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Results visualization saved")

## Cell 7: OCR Evaluation

In [None]:
print("\n" + "="*80)
print("STEP 6: QUANTITATIVE EVALUATION - OCR-based")
print("="*80)

import pytesseract
import Levenshtein

def preprocess_for_ocr(image_path):
    """Enhanced preprocessing for better OCR"""
    img = cv2.imread(image_path)
    if img is None:
        return None, None
    
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.resize(gray, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
    denoised = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21)
    
    thresh = cv2.adaptiveThreshold(
        denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
        cv2.THRESH_BINARY, 11, 2
    )
    
    blur = cv2.GaussianBlur(denoised, (5, 5), 0)
    _, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    return thresh, otsu

def extract_text_multiple_strategies(image_path):
    """Try multiple OCR strategies and return best result"""
    results = []
    
    img = Image.open(image_path)
    
    psm_modes = [
        ('--psm 6', 'Block of text'),
        ('--psm 8', 'Single word'),
        ('--psm 11', 'Sparse text'),
    ]
    
    for config, desc in psm_modes:
        try:
            text = pytesseract.image_to_string(img, config=config).strip().upper()
            results.append((text, desc))
        except:
            pass
    
    thresh, otsu = preprocess_for_ocr(image_path)
    
    if thresh is not None:
        try:
            text_thresh = pytesseract.image_to_string(thresh, config='--psm 8').strip().upper()
            results.append((text_thresh, 'Adaptive threshold'))
        except:
            pass
    
    if otsu is not None:
        try:
            text_otsu = pytesseract.image_to_string(otsu, config='--psm 8').strip().upper()
            results.append((text_otsu, 'Otsu threshold'))
        except:
            pass
    
    if not results:
        return "", "No strategy worked"
    
    best_result = max(results, key=lambda x: sum(c.isalnum() for c in x[0]))
    return best_result[0], best_result[1]

def evaluate_text_generation(image_path, expected_text):
    """Evaluate OCR accuracy against expected text"""
    recovered_text, strategy = extract_text_multiple_strategies(image_path)
    
    clean_recovered = "".join(filter(str.isalnum, recovered_text))
    clean_expected = "".join(filter(str.isalnum, expected_text.upper()))
    
    exact_match = 1.0 if clean_recovered == clean_expected else 0.0
    
    if len(clean_expected) == 0:
        return 0.0, 0.0, recovered_text, strategy
    
    distance = Levenshtein.distance(clean_recovered, clean_expected)
    char_accuracy = max(0, 1 - (distance / len(clean_expected)))
    
    return exact_match, char_accuracy, recovered_text, strategy

# Evaluate all generated images
print("\nRunning OCR evaluation on generated images...")
print("="*100)
print(f"{'Filename':<25} | {'Expected':<10} | {'OCR Result':<20} | {'Strategy':<20} | {'Exact':<6} | {'Char Acc'}")
print("-"*100)

results = []

for test_case in generated_images:
    image_path = test_case['filename']
    expected = test_case['expected']
    
    if os.path.exists(image_path):
        exact, acc, recovered, strategy = evaluate_text_generation(image_path, expected)
        
        results.append({
            'filename': test_case['filename'],
            'prompt': test_case['prompt'],
            'expected': expected,
            'recovered': recovered,
            'exact': exact,
            'accuracy': acc,
            'strategy': strategy
        })
        
        clean_recovered = "".join(filter(str.isalnum, recovered))[:20]
        
        print(f"{image_path:<25} | {expected:<10} | {clean_recovered:<20} | {strategy:<20} | {exact:<6.0f} | {acc:>7.1%}")

# Summary statistics
print("-"*100)
if results:
    avg_exact = np.mean([r['exact'] for r in results])
    avg_acc = np.mean([r['accuracy'] for r in results])
    print(f"{'AVERAGE':<25} | {'':<10} | {'':<20} | {'':<20} | {avg_exact:<6.2f} | {avg_acc:>7.1%}")
else:
    print("No results to evaluate")

print("="*100)

# Save detailed results
results_file = f"{OUTPUT_DIR}/evaluation_results.json"
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n✓ Detailed results saved to {results_file}")

# Visualize OCR preprocessing steps
if results:
    fig, axes = plt.subplots(len(results), 3, figsize=(15, 5*len(results)))
    if len(results) == 1:
        axes = axes.reshape(1, -1)
    
    for i, result in enumerate(results):
        image_path = result['filename']
        
        img = Image.open(image_path)
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"Generated Image\nExpected: '{result['expected']}'", fontsize=10)
        axes[i, 0].axis('off')
        
        thresh, otsu = preprocess_for_ocr(image_path)
        
        if thresh is not None:
            axes[i, 1].imshow(thresh, cmap='gray')
            axes[i, 1].set_title("Adaptive Threshold", fontsize=10)
            axes[i, 1].axis('off')
        
        if otsu is not None:
            axes[i, 2].imshow(otsu, cmap='gray')
            axes[i, 2].set_title(f"Otsu Threshold\nOCR: '{result['recovered'][:20]}'", fontsize=10)
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('textocr_ocr_evaluation.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("✓ OCR preprocessing visualization saved")

## Cell 8: Final Summary and Comparison

In [None]:
print("\n" + "="*80)
print("STEP 7: FINAL ANALYSIS AND COMPARISON")
print("="*80)

# Generate comparison with base model (no fine-tuning)
print("\nGenerating comparison with BASE model (no fine-tuning)...")

base_pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", 
    torch_dtype=torch.float16
)
base_pipe.to("cuda")

comparison_prompt = "An image with the text 'STOP' written on it"
expected_text = "STOP"

print(f"  Prompt: {comparison_prompt}")
print(f"  Expected text: {expected_text}")

# Generate with base model
print("\n  Generating with BASE model...")
base_image = base_pipe(
    comparison_prompt,
    num_inference_steps=50,
    guidance_scale=7.5
).images[0]
base_image.save("comparison_base.png")

# Evaluate base model
base_exact, base_acc, base_recovered, _ = evaluate_text_generation("comparison_base.png", expected_text)

# Generate with fine-tuned model
print("  Generating with FINE-TUNED model...")
finetuned_image = pipe(
    comparison_prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=50,
    guidance_scale=7.5
).images[0]
finetuned_image.save("comparison_finetuned.png")

# Evaluate fine-tuned model
ft_exact, ft_acc, ft_recovered, _ = evaluate_text_generation("comparison_finetuned.png", expected_text)

# Display comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 7))

axes[0].imshow(base_image)
axes[0].set_title(f"Base SD 1.5 (No Fine-tuning)\nOCR: '{base_recovered[:20]}'\nAccuracy: {base_acc:.1%}", 
                  fontsize=11, fontweight='bold')
axes[0].axis('off')

axes[1].imshow(finetuned_image)
axes[1].set_title(f"After LoRA Fine-tuning (TextOCR)\nOCR: '{ft_recovered[:20]}'\nAccuracy: {ft_acc:.1%}", 
                  fontsize=11, fontweight='bold')
axes[1].axis('off')

plt.suptitle(f"Prompt: '{comparison_prompt}' | Expected: '{expected_text}'", fontsize=13, y=0.98)
plt.tight_layout()
plt.savefig('textocr_base_vs_finetuned.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*80)
print("COMPARISON RESULTS")
print("="*80)
print(f"Base Model:")
print(f"  OCR Result: '{base_recovered}'")
print(f"  Exact Match: {base_exact}")
print(f"  Character Accuracy: {base_acc:.1%}")
print(f"\nFine-tuned Model:")
print(f"  OCR Result: '{ft_recovered}'")
print(f"  Exact Match: {ft_exact}")
print(f"  Character Accuracy: {ft_acc:.1%}")
print(f"\nImprovement:")
print(f"  Accuracy Delta: {(ft_acc - base_acc)*100:+.1f} percentage points")
print("="*80)

# Final Summary
print("\n" + "="*80)
print("FINAL SUMMARY - TextOCR Fine-tuning")
print("="*80)

summary = f"""
Dataset: TextOCR (facebook/textocr)
Training Samples: {len(preprocessed_data)}
Training Steps: {TRAIN_STEPS}
Final Loss: {losses[-1]:.4f}
Average Loss (last 100 steps): {np.mean(losses[-100:]):.4f}

Evaluation Results:
- Test Cases: {len(results)}
- Average Exact Match: {avg_exact:.2f}
- Average Character Accuracy: {avg_acc:.1%}

Key Observations:
1. TextOCR contains real-world images with text annotations
2. The model learned from crowd-sourced text-image pairs
3. Fine-tuning adapted the model to generate images with text
4. OCR evaluation provides quantitative metrics

Challenges:
- TextOCR images are complex real-world scenes
- Text is often embedded in context (signs, products, etc.)
- Model learns both scene generation AND text rendering
- Limited training steps due to computational constraints

Next Steps:
- Increase training steps (2000+) for better convergence
- Use more filtered samples (focus on simple text)
- Experiment with different LoRA configurations
- Try higher resolution (768x768) if GPU allows
"""

print(summary)

# Save summary
with open(f"{OUTPUT_DIR}/final_summary.txt", "w") as f:
    f.write(summary)

print(f"\n✓ Summary saved to {OUTPUT_DIR}/final_summary.txt")
print("\n" + "="*80)
print("TEXTOCR TRAINING COMPLETE!")
print("="*80)