# HiGAN+ Model Evaluation and Analysis

This notebook provides comprehensive evaluation of the pre-trained HiGAN+ handwriting generation model using the `epoch_20.pth` checkpoint.

## Evaluation Pipeline:
1. **Load Pre-trained Model** - Restore all network components from checkpoint
2. **Quantitative Metrics** - FID, IS, CER, WER, SSIM, PSNR
3. **Qualitative Analysis** - Visualizations, style transfer, interpolation
4. **Performance Report** - Comprehensive summary and comparison

**Model Checkpoint:** `B:\College\DL\handwriting_autocomplete_system\Higan+ from Scratch\server_files\epoch_20.pth`

## 1. Load Pre-trained Model and Configuration

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import h5py
import cv2
from datetime import datetime
import json
import pandas as pd
from pathlib import Path

# Fix for Pillow>=10
if not hasattr(Image, "ANTIALIAS"):
    Image.ANTIALIAS = Image.Resampling.LANCZOS

# Add project paths
project_path = r'B:\College\DL\handwriting_autocomplete_system\Higan+ from Scratch'
sys.path.insert(0, project_path)

# Import HiGAN+ modules
from lib.datasets import get_dataset, get_collect_fn, Hdf5Dataset
from lib.alphabet import strLabelConverter, Alphabets, get_lexicon, get_true_alphabet
from lib.utils import yaml2config, draw_image, AverageMeter
from networks import get_model
from networks.BigGAN_networks import Generator, Discriminator, PatchDiscriminator
from networks.module import Recognizer, WriterIdentifier, StyleEncoder, StyleBackbone
from torch.utils.data import DataLoader

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

# Create output directory for results
output_dir = Path(project_path) / "evaluation_results"
output_dir.mkdir(exist_ok=True)
print(f"Results will be saved to: {output_dir}")

In [None]:
# Load configuration
config_path = os.path.join(project_path, 'configs', 'gan_iam.yml')
cfg = yaml2config(config_path)

# Override config settings
cfg.device = str(device)
cfg.training.batch_size = 16
cfg.training.eval_batch_size = 32

print("Configuration loaded:")
print(f"  Dataset: {cfg.dataset}")
print(f"  Model: {cfg.model}")
print(f"  Image height: {cfg.img_height}")
print(f"  Character width: {cfg.char_width}")

In [None]:
# Load checkpoint
checkpoint_path = r"B:\College\DL\handwriting_autocomplete_system\Higan+ from Scratch\server_files\epoch_20.pth"

print(f"\nLoading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

print("\nCheckpoint contents:")
for key in checkpoint.keys():
    print(f"  • {key}")

trained_epoch = checkpoint.get('epoch', 'Unknown')
print(f"\nTrained Epoch: {trained_epoch}")

In [None]:
# Initialize model architectures
print("\nInitializing model architectures...")

# Generator
generator = Generator(**cfg.GenModel).to(device)
generator.load_state_dict(checkpoint['generator'])
generator.eval()
print(f"✓ Generator loaded - Style dim: {generator.style_dim}")

# Discriminators
discriminator = Discriminator(**cfg.DiscModel).to(device)
discriminator.load_state_dict(checkpoint['discriminator'])
discriminator.eval()

patch_discriminator = PatchDiscriminator(**cfg.PatchDiscModel).to(device)
patch_discriminator.load_state_dict(checkpoint['patch_discriminator'])
patch_discriminator.eval()
print(f"✓ Discriminators loaded")

# Style Encoder & Backbone
style_backbone = StyleBackbone(**cfg.StyBackbone).to(device)
style_encoder = StyleEncoder(**cfg.EncModel).to(device)
style_encoder.load_state_dict(checkpoint['style_encoder'])
style_encoder.eval()
print(f"✓ Style Encoder loaded")

# Recognizer (OCR)
recognizer = Recognizer(**cfg.OcrModel).to(device)
print(f"✓ Recognizer initialized")

# Writer Identifier
writer_identifier = WriterIdentifier(**cfg.WidModel).to(device)
writer_identifier.load_state_dict(checkpoint['writer_identifier'])
writer_identifier.eval()
print(f"✓ Writer Identifier loaded")

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n--- Model Parameters ---")
print(f"Generator: {count_parameters(generator):,}")
print(f"Discriminator: {count_parameters(discriminator):,}")
print(f"Style Encoder: {count_parameters(style_encoder):,}")
print(f"Writer ID: {count_parameters(writer_identifier):,}")

In [None]:
# Load pretrained OCR and Writer ID models
print("\nLoading pretrained auxiliary models...")

try:
    if os.path.exists(cfg.training.pretrained_r):
        r_dict = torch.load(cfg.training.pretrained_r, map_location=device)
        recognizer.load_state_dict(r_dict['Recognizer'])
        print(f"✓ Loaded pretrained OCR from {cfg.training.pretrained_r}")
    else:
        print(f"⚠ Pretrained OCR not found: {cfg.training.pretrained_r}")
    
    if os.path.exists(cfg.training.pretrained_w):
        w_dict = torch.load(cfg.training.pretrained_w, map_location=device)
        if 'StyleBackbone' in w_dict:
            style_backbone.load_state_dict(w_dict['StyleBackbone'])
        print(f"✓ Loaded pretrained Writer ID from {cfg.training.pretrained_w}")
    else:
        print(f"⚠ Pretrained Writer ID not found: {cfg.training.pretrained_w}")
        
except Exception as e:
    print(f"⚠ Error loading pretrained models: {e}")

recognizer.eval()
print("\n✓ All models loaded and set to evaluation mode")

## 2. Setup Evaluation Dataset

In [None]:
# Create dataset and dataloader
print("Loading evaluation dataset...")

collect_fn = get_collect_fn(cfg.training.sort_input, sort_style=True)

try:
    # Load test dataset
    test_dataset = get_dataset(
        cfg.valid.dset_name,
        cfg.valid.dset_split,
        recogn_aug=False,
        wid_aug=False,
        process_style=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.training.eval_batch_size,
        shuffle=False,
        collate_fn=collect_fn,
        num_workers=0,
        drop_last=False
    )
    
    print(f"✓ Test samples: {len(test_dataset)}")
    print(f"✓ Test batches: {len(test_loader)}")
    
    # Also load a subset of training data for comparison
    train_dataset = get_dataset(
        cfg.dataset,
        cfg.training.dset_split,
        recogn_aug=False,
        wid_aug=False,
        process_style=True
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.training.eval_batch_size,
        shuffle=True,
        collate_fn=collect_fn,
        num_workers=0,
        drop_last=False
    )
    
    print(f"✓ Train samples (for reference): {len(train_dataset)}")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    raise

In [None]:
# Load lexicon and label converter
label_converter = strLabelConverter('all')

lexicon = get_lexicon(
    cfg.training.lexicon,
    get_true_alphabet(cfg.dataset),
    max_length=cfg.training.max_word_len
)

if not lexicon:
    print("Building fallback lexicon from dataset...")
    alphabet = get_true_alphabet(cfg.dataset)
    dataset_words = set()
    for start_idx, length in zip(train_dataset.lb_seek_idxs, train_dataset.lb_lens):
        length = int(length)
        if length <= 1 or length >= cfg.training.max_word_len:
            continue
        raw_word = ''.join(chr(ch) for ch in train_dataset.lbs[start_idx:start_idx + length])
        filtered_word = ''.join(ch for ch in raw_word if ch in alphabet)
        if len(filtered_word) > 1:
            dataset_words.add(filtered_word.lower())
    lexicon = sorted(dataset_words)

print(f"✓ Lexicon loaded: {len(lexicon)} words")
print(f"Sample words: {lexicon[:10]}")

In [None]:
# Visualize sample test data
sample_batch = next(iter(test_loader))

imgs = sample_batch['style_imgs'][:8]
lbs = sample_batch['lbs'][:8]
lb_lens = sample_batch['lb_lens'][:8]
texts = label_converter.decode(lbs, lb_lens)

fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.ravel()

for i in range(min(8, imgs.size(0))):
    img = imgs[i].squeeze().numpy()
    img = (1 - img) / 2
    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(f'"{texts[i]}"', fontsize=10)
    axes[i].axis('off')

plt.suptitle('Sample Test Dataset Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'sample_test_images.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Sample visualization saved")

## 3. Generate Sample Images

Generate diverse samples for evaluation using different methods.

In [None]:
from networks.rand_dist import prepare_z_dist, prepare_y_dist
from networks.utils import idx_to_words

print("Setting up random distributions for sampling...")

# Initialize random distributions
z_dist = prepare_z_dist(cfg.training.eval_batch_size, cfg.EncModel.style_dim, device, seed=42)
y_dist = prepare_y_dist(cfg.training.eval_batch_size, len(lexicon), device, seed=42)

print(f"✓ Style distribution: {z_dist.mean().shape}")
print(f"✓ Text distribution: {y_dist.shape}")

In [None]:
# Generate samples with random styles
print("Generating random style samples...")

n_random_samples = 100
random_generated_imgs = []
random_generated_texts = []
max_label_len = 25

with torch.no_grad():
    for i in tqdm(range(n_random_samples // cfg.training.eval_batch_size), desc="Random Generation"):
        # Sample random texts
        y_dist.sample_()
        sampled_words = idx_to_words(y_dist, lexicon, max_label_len, 0.3, 0.0)
        fake_lbs, fake_lb_lens = label_converter.encode(sampled_words, max_label_len)
        fake_lbs = fake_lbs.to(device)
        fake_lb_lens = fake_lb_lens.to(device)
        
        # Sample random styles
        z_dist.sample_()
        
        # Generate images
        fake_imgs = generator(z_dist, fake_lbs, fake_lb_lens)
        
        random_generated_imgs.append(fake_imgs.cpu())
        random_generated_texts.extend(sampled_words)

random_generated_imgs = torch.cat(random_generated_imgs, dim=0)
print(f"✓ Generated {random_generated_imgs.size(0)} random style images")

In [None]:
# Generate style-guided samples from test set
print("Generating style-guided samples...")

style_guided_imgs = []
style_guided_texts = []
reference_imgs = []

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Style-Guided Generation")):
        if batch_idx >= 5:  # Limit to 5 batches
            break
            
        real_imgs = batch['style_imgs'].to(device)
        real_img_lens = batch['style_img_lens'].to(device)
        real_lbs = batch['lbs'].to(device)
        real_lb_lens = batch['lb_lens'].to(device)
        
        # Extract style
        enc_z = style_encoder(real_imgs, real_img_lens, style_backbone, vae_mode=False)
        
        # Generate with same content
        fake_imgs = generator(enc_z, real_lbs, real_lb_lens)
        
        style_guided_imgs.append(fake_imgs.cpu())
        reference_imgs.append(real_imgs.cpu())
        texts = label_converter.decode(real_lbs, real_lb_lens)
        style_guided_texts.extend(texts)

style_guided_imgs = torch.cat(style_guided_imgs, dim=0)
reference_imgs = torch.cat(reference_imgs, dim=0)
print(f"✓ Generated {style_guided_imgs.size(0)} style-guided images")

In [None]:
# Generate fixed text with varying styles
print("Generating fixed texts with varying styles...")

fixed_texts = ["Hello", "World", "Python", "Deep", "Learning", "Neural", "Network", "Artificial"]
fixed_text_imgs = []

n_style_variations = 10

with torch.no_grad():
    for text in tqdm(fixed_texts, desc="Fixed Text Generation"):
        text_imgs = []
        
        for _ in range(n_style_variations):
            # Random style
            z = torch.randn(1, generator.style_dim).to(device)
            
            # Encode text
            lbs, lb_lens = label_converter.encode([text], max_label_len)
            lbs = lbs.to(device)
            lb_lens = lb_lens.to(device)
            
            # Generate
            fake_img = generator(z, lbs, lb_lens)
            text_imgs.append(fake_img.cpu())
        
        fixed_text_imgs.append(torch.cat(text_imgs, dim=0))

print(f"✓ Generated {len(fixed_texts)} texts with {n_style_variations} style variations each")

## 4. Calculate FID (Fréchet Inception Distance)

FID measures the quality and diversity of generated images by comparing feature distributions.

In [None]:
# Install pytorch-fid if not available
try:
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3
    print("✓ pytorch-fid available")
except ImportError:
    print("Installing pytorch-fid...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pytorch-fid"])
    from pytorch_fid import fid_score
    from pytorch_fid.inception import InceptionV3
    print("✓ pytorch-fid installed")

In [None]:
from scipy import linalg
import torchvision.transforms as transforms

def calculate_activation_statistics(imgs, model, batch_size=50, dims=2048, device='cpu'):
    """Calculate mean and covariance statistics of inception features"""
    model.eval()
    
    # Prepare images (convert grayscale to RGB)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    
    # Normalize to [-1, 1]
    imgs = (imgs - 0.5) / 0.5
    
    n_samples = imgs.size(0)
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    pred_arr = np.empty((n_samples, dims))
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Extracting features"):
            start = i * batch_size
            end = min(start + batch_size, n_samples)
            batch = imgs[start:end].to(device)
            
            # Resize to 299x299 for InceptionV3
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            
            pred = model(batch)[0]
            
            if pred.size(2) != 1 or pred.size(3) != 1:
                pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
            
            pred = pred.squeeze(3).squeeze(2).cpu().numpy()
            pred_arr[start:end] = pred
    
    mu = np.mean(pred_arr, axis=0)
    sigma = np.cov(pred_arr, rowvar=False)
    
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Calculate Fréchet Distance between two Gaussian distributions"""
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    
    diff = mu1 - mu2
    
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    
    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError(f'Imaginary component {m}')
        covmean = covmean.real
    
    tr_covmean = np.trace(covmean)
    
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

print("Setting up Inception V3 model for FID calculation...")
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_model = InceptionV3([block_idx]).to(device)
inception_model.eval()
print("✓ Inception V3 ready")

In [None]:
# Calculate FID score
print("\nCalculating FID score...")

# Collect real images
real_imgs_for_fid = []
with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Collecting real images")):
        if batch_idx >= 10:  # Limit samples
            break
        real_imgs_for_fid.append(batch['style_imgs'])

real_imgs_for_fid = torch.cat(real_imgs_for_fid, dim=0)
print(f"Real images for FID: {real_imgs_for_fid.shape}")

# Use generated images
fake_imgs_for_fid = style_guided_imgs[:real_imgs_for_fid.size(0)]
print(f"Fake images for FID: {fake_imgs_for_fid.shape}")

# Calculate statistics
print("\nCalculating real image statistics...")
mu_real, sigma_real = calculate_activation_statistics(
    real_imgs_for_fid, inception_model, batch_size=32, device=device
)

print("Calculating generated image statistics...")
mu_fake, sigma_fake = calculate_activation_statistics(
    fake_imgs_for_fid, inception_model, batch_size=32, device=device
)

# Calculate FID
fid_value = calculate_frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake)

print(f"\n{'='*60}")
print(f"FID Score: {fid_value:.4f}")
print(f"{'='*60}")
print("(Lower is better. FID < 50 is good, FID < 20 is excellent)")

## 5. Calculate Inception Score (IS)

IS measures both quality and diversity of generated images.

In [None]:
from torchvision.models import inception_v3

def calculate_inception_score(imgs, model, batch_size=32, splits=10, device='cpu'):
    """Calculate Inception Score"""
    model.eval()
    
    # Prepare images
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    
    imgs = (imgs - 0.5) / 0.5
    
    n_samples = imgs.size(0)
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    preds = []
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc="Computing IS"):
            start = i * batch_size
            end = min(start + batch_size, n_samples)
            batch = imgs[start:end].to(device)
            
            # Resize to 299x299
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            
            pred = model(batch)
            pred = F.softmax(pred, dim=1).cpu().numpy()
            preds.append(pred)
    
    preds = np.concatenate(preds, axis=0)
    
    # Calculate score
    split_scores = []
    
    for k in range(splits):
        part = preds[k * (n_samples // splits): (k + 1) * (n_samples // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(np.sum(pyx * np.log(pyx / py + 1e-10)))
        split_scores.append(np.exp(np.mean(scores)))
    
    return np.mean(split_scores), np.std(split_scores)

print("Loading Inception V3 for IS calculation...")
inception_v3_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_v3_model.eval()
print("✓ Inception V3 loaded")

In [None]:
# Calculate Inception Score
print("\nCalculating Inception Score...")

is_imgs = random_generated_imgs[:500] if random_generated_imgs.size(0) > 500 else random_generated_imgs

is_mean, is_std = calculate_inception_score(
    is_imgs, inception_v3_model, batch_size=32, splits=10, device=device
)

print(f"\n{'='*60}")
print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
print(f"{'='*60}")
print("(Higher is better. IS > 5 is good, IS > 10 is excellent)")

## 6. Calculate OCR Metrics (CER & WER)

Evaluate text readability using Character Error Rate and Word Error Rate.

In [None]:
from distance import levenshtein
from networks.utils import ctc_greedy_decoder

def evaluate_ocr_metrics(model_rec, generated_imgs, generated_labels, generated_label_lens, 
                        label_conv, ctc_len_scale, char_width, device):
    """Calculate CER and WER for generated images"""
    model_rec.eval()
    
    char_errors = 0
    total_chars = 0
    word_errors = 0
    total_words = 0
    
    predictions = []
    ground_truths = []
    
    with torch.no_grad():
        # Process in batches
        batch_size = 32
        n_batches = (generated_imgs.size(0) + batch_size - 1) // batch_size
        
        for i in tqdm(range(n_batches), desc="Running OCR"):
            start = i * batch_size
            end = min(start + batch_size, generated_imgs.size(0))
            
            imgs = generated_imgs[start:end].to(device)
            lbs = generated_labels[start:end]
            lb_lens = generated_label_lens[start:end]
            
            img_lens = lb_lens * char_width
            
            # Run OCR
            logits = model_rec(imgs, img_lens.to(device))
            logits = F.softmax(logits, dim=2).cpu().numpy()
            
            # Decode predictions
            for logit, img_len in zip(logits, img_lens.cpu().numpy()):
                label = ctc_greedy_decoder(logit[:img_len // ctc_len_scale])
                predictions.append(label_conv.decode(label))
            
            # Ground truth
            ground_truths.extend(label_conv.decode(lbs, lb_lens))
    
    # Calculate errors
    for pred, gt in zip(predictions, ground_truths):
        char_error = levenshtein(pred, gt)
        char_errors += char_error
        total_chars += len(gt)
        total_words += 1
        if char_error > 0:
            word_errors += 1
    
    cer = char_errors / total_chars if total_chars > 0 else 0
    wer = word_errors / total_words if total_words > 0 else 0
    
    return cer, wer, predictions, ground_truths

print("Preparing for OCR evaluation...")
ctc_len_scale = recognizer.len_scale
print(f"CTC length scale: {ctc_len_scale}")

In [None]:
# Evaluate OCR on style-guided samples
print("\nEvaluating OCR on style-guided samples...")

# Prepare labels for style-guided images
style_guided_lbs = []
style_guided_lb_lens = []

with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        if batch_idx >= 5:
            break
        style_guided_lbs.append(batch['lbs'])
        style_guided_lb_lens.append(batch['lb_lens'])

style_guided_lbs = torch.cat(style_guided_lbs, dim=0)
style_guided_lb_lens = torch.cat(style_guided_lb_lens, dim=0)

cer_style, wer_style, preds_style, gts_style = evaluate_ocr_metrics(
    recognizer, 
    style_guided_imgs,
    style_guided_lbs,
    style_guided_lb_lens,
    label_converter,
    ctc_len_scale,
    cfg.char_width,
    device
)

print(f"\n{'='*60}")
print(f"OCR Metrics (Style-Guided Generation):")
print(f"  Character Error Rate (CER): {cer_style*100:.2f}%")
print(f"  Word Error Rate (WER): {wer_style*100:.2f}%")
print(f"  Character Accuracy: {(1-cer_style)*100:.2f}%")
print(f"  Word Accuracy: {(1-wer_style)*100:.2f}%")
print(f"{'='*60}")

In [None]:
# Show some OCR predictions vs ground truth
print("\nSample OCR Predictions vs Ground Truth:")
print("-" * 80)

n_samples = min(20, len(preds_style))
correct = 0

for i in range(n_samples):
    match = "✓" if preds_style[i] == gts_style[i] else "✗"
    if preds_style[i] == gts_style[i]:
        correct += 1
    print(f"{match} GT: '{gts_style[i]:15s}' | Pred: '{preds_style[i]:15s}'")

print("-" * 80)
print(f"Accuracy in sample: {correct}/{n_samples} ({100*correct/n_samples:.1f}%)")

## 7. Calculate Writer Identification Accuracy

Evaluate how well the model preserves writer-specific characteristics.

In [None]:
def evaluate_writer_identification(model_wid, model_backbone, generated_imgs, 
                                   generated_lb_lens, true_wids, char_width, device):
    """Calculate writer identification accuracy"""
    model_wid.eval()
    model_backbone.eval()
    
    correct = 0
    total = 0
    
    all_preds = []
    all_true = []
    
    with torch.no_grad():
        batch_size = 32
        n_batches = (generated_imgs.size(0) + batch_size - 1) // batch_size
        
        for i in tqdm(range(n_batches), desc="Writer ID Evaluation"):
            start = i * batch_size
            end = min(start + batch_size, generated_imgs.size(0))
            
            imgs = generated_imgs[start:end].to(device)
            lb_lens = generated_lb_lens[start:end]
            wids = true_wids[start:end]
            
            img_lens = lb_lens * char_width
            
            # Get writer predictions
            logits = model_wid(imgs, img_lens.to(device), model_backbone)
            preds = torch.argmax(logits, dim=1).cpu()
            
            correct += (preds == wids).sum().item()
            total += imgs.size(0)
            
            all_preds.extend(preds.tolist())
            all_true.extend(wids.tolist())
    
    accuracy = correct / total if total > 0 else 0
    
    return accuracy, all_preds, all_true

print("Evaluating writer identification accuracy...")

# Get writer IDs for style-guided samples
style_guided_wids = []
with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        if batch_idx >= 5:
            break
        style_guided_wids.append(batch['wids'])

style_guided_wids = torch.cat(style_guided_wids, dim=0)

wid_accuracy, wid_preds, wid_true = evaluate_writer_identification(
    writer_identifier,
    style_backbone,
    style_guided_imgs,
    style_guided_lb_lens,
    style_guided_wids,
    cfg.char_width,
    device
)

print(f"\n{'='*60}")
print(f"Writer Identification Accuracy: {wid_accuracy*100:.2f}%")
print(f"{'='*60}")
print(f"Total samples evaluated: {len(wid_preds)}")
print(f"Correct predictions: {sum(1 for p, t in zip(wid_preds, wid_true) if p == t)}")

## 8. Calculate Image Quality Metrics (SSIM, PSNR)

Compute structural similarity and peak signal-to-noise ratio for reconstructed images.

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

def calculate_image_quality_metrics(real_imgs, generated_imgs):
    """Calculate SSIM and PSNR between real and generated images"""
    ssim_scores = []
    psnr_scores = []
    
    n_samples = min(real_imgs.size(0), generated_imgs.size(0))
    
    for i in tqdm(range(n_samples), desc="Calculating SSIM/PSNR"):
        # Convert to numpy
        real_img = real_imgs[i].squeeze().cpu().numpy()
        gen_img = generated_imgs[i].squeeze().cpu().numpy()
        
        # Normalize to [0, 1]
        real_img = (real_img + 1) / 2
        gen_img = (gen_img + 1) / 2
        
        # Calculate SSIM
        ssim_val = ssim(real_img, gen_img, data_range=1.0)
        ssim_scores.append(ssim_val)
        
        # Calculate PSNR
        psnr_val = psnr(real_img, gen_img, data_range=1.0)
        psnr_scores.append(psnr_val)
    
    return np.mean(ssim_scores), np.std(ssim_scores), np.mean(psnr_scores), np.std(psnr_scores)

print("Calculating SSIM and PSNR...")

ssim_mean, ssim_std, psnr_mean, psnr_std = calculate_image_quality_metrics(
    reference_imgs, style_guided_imgs
)

print(f"\n{'='*60}")
print(f"Image Quality Metrics:")
print(f"  SSIM: {ssim_mean:.4f} ± {ssim_std:.4f}")
print(f"  PSNR: {psnr_mean:.2f} ± {psnr_std:.2f} dB")
print(f"{'='*60}")
print("SSIM: 1.0 is perfect, > 0.7 is good")
print("PSNR: Higher is better, > 30 dB is good")

## 9. Style Consistency Analysis

Measure how consistently the model maintains a given writing style across different texts.

In [None]:
def extract_style_features(model_encoder, model_backbone, imgs, img_lens, device):
    """Extract style features from images"""
    model_encoder.eval()
    model_backbone.eval()
    
    with torch.no_grad():
        imgs = imgs.to(device)
        img_lens = img_lens.to(device)
        features = model_encoder(imgs, img_lens, model_backbone, vae_mode=False)
    
    return features.cpu()

def calculate_style_consistency(features):
    """Calculate variance in style features (lower = more consistent)"""
    # Calculate pairwise distances
    features_np = features.numpy()
    
    # Mean feature vector
    mean_feat = np.mean(features_np, axis=0)
    
    # Calculate variance
    variance = np.mean(np.var(features_np, axis=0))
    
    # Calculate pairwise cosine similarity
    from sklearn.metrics.pairwise import cosine_similarity
    similarities = cosine_similarity(features_np)
    
    # Get upper triangle (excluding diagonal)
    n = similarities.shape[0]
    upper_tri_idx = np.triu_indices(n, k=1)
    pairwise_sims = similarities[upper_tri_idx]
    
    mean_similarity = np.mean(pairwise_sims)
    std_similarity = np.std(pairwise_sims)
    
    return variance, mean_similarity, std_similarity

print("Analyzing style consistency...")

# For each text, extract style features from all variations
consistency_results = []

for text_idx, text in enumerate(tqdm(fixed_texts, desc="Style Consistency")):
    imgs = fixed_text_imgs[text_idx]
    img_lens = torch.ones(imgs.size(0), dtype=torch.int) * imgs.size(3)
    
    # Extract features
    features = extract_style_features(style_encoder, style_backbone, imgs, img_lens, device)
    
    # Calculate consistency
    variance, mean_sim, std_sim = calculate_style_consistency(features)
    
    consistency_results.append({
        'text': text,
        'variance': variance,
        'mean_similarity': mean_sim,
        'std_similarity': std_sim
    })

# Summary
print(f"\n{'='*60}")
print("Style Consistency Results:")
print(f"{'='*60}")

for result in consistency_results:
    print(f"Text: '{result['text']:10s}' | Variance: {result['variance']:.4f} | "
          f"Similarity: {result['mean_similarity']:.4f}±{result['std_similarity']:.4f}")

avg_variance = np.mean([r['variance'] for r in consistency_results])
avg_similarity = np.mean([r['mean_similarity'] for r in consistency_results])

print(f"{'='*60}")
print(f"Average Variance: {avg_variance:.4f} (lower = more consistent)")
print(f"Average Similarity: {avg_similarity:.4f} (higher = more consistent)")
print(f"{'='*60}")

## 10. Diversity Metrics

Evaluate output diversity to ensure the model doesn't suffer from mode collapse.

In [None]:
def calculate_diversity_metrics(imgs):
    """Calculate diversity metrics for generated images"""
    imgs_np = imgs.numpy()
    
    # Reshape to (n_samples, -1)
    n_samples = imgs_np.shape[0]
    imgs_flat = imgs_np.reshape(n_samples, -1)
    
    # Calculate pairwise L2 distances
    from sklearn.metrics.pairwise import euclidean_distances
    distances = euclidean_distances(imgs_flat)
    
    # Get upper triangle
    upper_tri_idx = np.triu_indices(n_samples, k=1)
    pairwise_dists = distances[upper_tri_idx]
    
    mean_dist = np.mean(pairwise_dists)
    std_dist = np.std(pairwise_dists)
    
    # Calculate entropy of pixel values
    pixel_values = imgs_flat.flatten()
    hist, _ = np.histogram(pixel_values, bins=50, density=True)
    hist = hist[hist > 0]
    entropy = -np.sum(hist * np.log(hist + 1e-10))
    
    return mean_dist, std_dist, entropy

print("Calculating diversity metrics...")

# Intra-class diversity (same style, different texts)
intra_diversities = []
for text_idx in range(len(fixed_texts)):
    imgs = fixed_text_imgs[text_idx]
    mean_dist, std_dist, entropy = calculate_diversity_metrics(imgs)
    intra_diversities.append({
        'text': fixed_texts[text_idx],
        'mean_dist': mean_dist,
        'std_dist': std_dist,
        'entropy': entropy
    })

# Inter-class diversity (different styles)
inter_mean_dist, inter_std_dist, inter_entropy = calculate_diversity_metrics(random_generated_imgs[:100])

print(f"\n{'='*60}")
print("Diversity Metrics:")
print(f"{'='*60}")

print("\nIntra-class Diversity (same style, different cannot-vary texts):")
for result in intra_diversities:
    print(f"  Text '{result['text']:10s}': Distance={result['mean_dist']:.4f}±{result['std_dist']:.4f}, "
          f"Entropy={result['entropy']:.4f}")

avg_intra_dist = np.mean([r['mean_dist'] for r in intra_diversities])
avg_intra_entropy = np.mean([r['entropy'] for r in intra_diversities])

print(f"\nAverage Intra-class Distance: {avg_intra_dist:.4f}")
print(f"Average Intra-class Entropy: {avg_intra_entropy:.4f}")

print(f"\nInter-class Diversity (different random styles):")
print(f"  Mean Distance: {inter_mean_dist:.4f}±{inter_std_dist:.4f}")
print(f"  Entropy: {inter_entropy:.4f}")

print(f"\n{'='*60}")
print(f"Diversity Ratio (Inter/Intra): {inter_mean_dist / avg_intra_dist:.2f}")
print(f"{'='*60}")
print("(Higher ratio indicates better diversity across different styles)")

## 11. Qualitative Visualizations - Style Transfer

Visualize reference images, reconstructions, and style-transferred versions.

In [None]:
# Create comprehensive style transfer visualization
n_examples = 8

fig, axes = plt.subplots(n_examples, 3, figsize=(15, 3*n_examples))

for i in range(n_examples):
    # Reference image
    ref_img = (1 - reference_imgs[i].squeeze().numpy())
    axes[i, 0].imshow(ref_img, cmap='gray')
    axes[i, 0].set_title(f'Reference: "{style_guided_texts[i]}"', fontsize=10)
    axes[i, 0].axis('off')
    
    # Reconstruction
    recon_img = (1 - style_guided_imgs[i].squeeze().numpy())
    axes[i, 1].imshow(recon_img, cmap='gray')
    axes[i, 1].set_title(f'Reconstruction', fontsize=10)
    axes[i, 1].axis('off')
    
    # OCR prediction
    pred_text = preds_style[i] if i < len(preds_style) else ""
    gt_text = gts_style[i] if i < len(gts_style) else ""
    match = "✓" if pred_text == gt_text else "✗"
    
    axes[i, 2].text(0.5, 0.5, f'{match}\nGT: "{gt_text}"\nPred: "{pred_text}"',
                    ha='center', va='center', fontsize=9, 
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    axes[i, 2].axis('off')

plt.suptitle('Style Transfer: Reference → Reconstruction → OCR Verification', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_style_transfer.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Style transfer visualization saved")

## 12. Qualitative Visualizations - Random Generation

Display a grid of randomly generated samples.

In [None]:
# Random generation grid
n_rows = 4
n_cols = 6
n_total = n_rows * n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 10))
axes = axes.ravel()

for i in range(n_total):
    if i < random_generated_imgs.size(0):
        img = (1 - random_generated_imgs[i].squeeze().numpy())
        text = random_generated_texts[i] if i < len(random_generated_texts) else ""
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'"{text}"', fontsize=8)
        axes[i].axis('off')

plt.suptitle('Random Style Generation - Diverse Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_random_generation.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Random generation visualization saved")

## 13. Qualitative Visualizations - Style Interpolation

Show smooth transitions between different writing styles.

In [None]:
# Generate style interpolation sequences
interpolation_text = "Interpolate"
n_interpolations = 3
n_steps = 8

fig, axes = plt.subplots(n_interpolations, n_steps, figsize=(20, 3*n_interpolations))

with torch.no_grad():
    for interp_idx in range(n_interpolations):
        # Two random styles
        style_a = torch.randn(1, generator.style_dim).to(device)
        style_b = torch.randn(1, generator.style_dim).to(device)
        
        # Encode text
        text_lbs, text_lb_lens = label_converter.encode([interpolation_text], 25)
        text_lbs = text_lbs.to(device)
        text_lb_lens = text_lb_lens.to(device)
        
        for step_idx in range(n_steps):
            alpha = step_idx / (n_steps - 1)
            style_interp = (1 - alpha) * style_a + alpha * style_b
            
            gen_img = generator(style_interp, text_lbs, text_lb_lens)
            
            img_np = (1 - gen_img.squeeze().cpu().numpy())
            axes[interp_idx, step_idx].imshow(img_np, cmap='gray')
            axes[interp_idx, step_idx].set_title(f'α={alpha:.2f}', fontsize=8)
            axes[interp_idx, step_idx].axis('off')

plt.suptitle(f'Style Interpolation: "{interpolation_text}"', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_interpolation.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Style interpolation visualization saved")

## 14. Qualitative Visualizations - Text Variation

Show how the same style renders different text content.

In [None]:
# Generate same style, different texts
varied_texts = ["Hello", "World", "123", "CAPS", "lower", "Mixed", "Special!", "PyTorch"]

fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.ravel()

with torch.no_grad():
    # One consistent style
    consistent_style = torch.randn(1, generator.style_dim).to(device)
    
    for i, text in enumerate(varied_texts):
        # Encode text
        lbs, lb_lens = label_converter.encode([text], 25)
        lbs = lbs.to(device)
        lb_lens = lb_lens.to(device)
        
        # Generate
        gen_img = generator(consistent_style, lbs, lb_lens)
        
        img_np = (1 - gen_img.squeeze().cpu().numpy())
        axes[i].imshow(img_np, cmap='gray')
        axes[i].set_title(f'"{text}"', fontsize=11, fontweight='bold')
        axes[i].axis('off')

plt.suptitle('Same Writing Style - Different Text Content', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_text_variation.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Text variation visualization saved")

## 15. Cross-Writer Style Transfer Evaluation

Transfer multiple writer styles to the same text content.

In [None]:
# Cross-writer style transfer
target_text = "Transfer"
n_writers = 8

# Get different writer samples from test set
writer_samples = []
writer_ids = []

with torch.no_grad():
    for batch in test_loader:
        for i in range(batch['style_imgs'].size(0)):
            wid = batch['wids'][i].item()
            if wid not in writer_ids and len(writer_ids) < n_writers:
                writer_ids.append(wid)
                writer_samples.append({
                    'img': batch['style_imgs'][i:i+1],
                    'img_len': batch['style_img_lens'][i:i+1],
                    'wid': wid,
                    'text': label_converter.decode(batch['lbs'][i:i+1], batch['lb_lens'][i:i+1])[0]
                })
        
        if len(writer_ids) >= n_writers:
            break

# Generate target text in each writer's style
fig, axes = plt.subplots(n_writers, 3, figsize=(12, 3*n_writers))

with torch.no_grad():
    # Encode target text
    target_lbs, target_lb_lens = label_converter.encode([target_text], 25)
    target_lbs = target_lbs.to(device)
    target_lb_lens = target_lb_lens.to(device)
    
    for i, sample in enumerate(writer_samples):
        # Reference image
        ref_img = (1 - sample['img'].squeeze().numpy())
        axes[i, 0].imshow(ref_img, cmap='gray')
        axes[i, 0].set_title(f'Writer {sample["wid"]}\n"{sample["text"]}"', fontsize=9)
        axes[i, 0].axis('off')
        
        # Extract style
        style = style_encoder(sample['img'].to(device), sample['img_len'].to(device), 
                             style_backbone, vae_mode=False)
        
        # Generate target text
        gen_img = generator(style, target_lbs, target_lb_lens)
        gen_img_np = (1 - gen_img.squeeze().cpu().numpy())
        axes[i, 1].imshow(gen_img_np, cmap='gray')
        axes[i, 1].set_title(f'Generated\n"{target_text}"', fontsize=9)
        axes[i, 1].axis('off')
        
        # Info
        axes[i, 2].text(0.5, 0.5, f'Writer ID: {sample["wid"]}\nStyle → "{target_text}"',
                       ha='center', va='center', fontsize=9,
                       bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
        axes[i, 2].axis('off')

plt.suptitle(f'Cross-Writer Style Transfer: Multiple Writers → "{target_text}"', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_cross_writer.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Cross-writer style transfer visualization saved")

## 16. Attention Map Visualization

Visualize feature maps from the Style Encoder to understand what contributes to style extraction.

In [None]:
import torch.nn as nn

def extract_feature_maps(model_backbone, img, img_len):
    """Extract intermediate feature maps from style backbone"""
    model_backbone.eval()
    
    feature_maps = []
    
    # Hook to capture intermediate outputs
    def hook_fn(module, input, output):
        feature_maps.append(output.detach().cpu())
    
    # Register hooks on convolutional layers
    hooks = []
    for name, module in model_backbone.named_modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(hook_fn))
    
    # Forward pass
    with torch.no_grad():
        _ = model_backbone(img.to(device))
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return feature_maps

# Visualize feature maps for sample images
n_samples = 3
sample_batch = next(iter(test_loader))

fig = plt.figure(figsize=(18, 4*n_samples))

for sample_idx in range(n_samples):
    img = sample_batch['style_imgs'][sample_idx:sample_idx+1]
    img_len = sample_batch['style_img_lens'][sample_idx:sample_idx+1]
    text = label_converter.decode(sample_batch['lbs'][sample_idx:sample_idx+1], 
                                  sample_batch['lb_lens'][sample_idx:sample_idx+1])[0]
    
    # Extract feature maps
    feature_maps = extract_feature_maps(style_backbone, img, img_len)
    
    # Select a few layers to visualize
    n_layers = min(5, len(feature_maps))
    layer_indices = np.linspace(0, len(feature_maps)-1, n_layers, dtype=int)
    
    for i, layer_idx in enumerate(layer_indices):
        feat_map = feature_maps[layer_idx][0]  # First sample in batch
        
        # Average across channels
        feat_map_avg = torch.mean(feat_map, dim=0).numpy()
        
        ax = plt.subplot(n_samples, n_layers + 1, sample_idx * (n_layers + 1) + i + 1)
        
        if i == 0:
            # Show original image
            orig_img = (1 - img.squeeze().numpy())
            ax.imshow(orig_img, cmap='gray')
            ax.set_title(f'Input: "{text}"', fontsize=9)
        else:
            # Show feature map
            im = ax.imshow(feat_map_avg, cmap='viridis')
            ax.set_title(f'Layer {layer_idx}', fontsize=8)
            plt.colorbar(im, ax=ax, fraction=0.046)
        
        ax.axis('off')

plt.suptitle('Style Encoder Feature Maps Visualization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'visualization_feature_maps.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Feature maps visualization saved")

## 17. Failure Case Analysis

Identify and analyze failure modes of the model.

In [None]:
# Analyze failures based on OCR errors
print("Analyzing failure cases...")

# Calculate CER for each sample
individual_errors = []
for i, (pred, gt) in enumerate(zip(preds_style, gts_style)):
    char_error = levenshtein(pred, gt)
    cer_individual = char_error / len(gt) if len(gt) > 0 else 0
    
    individual_errors.append({
        'index': i,
        'gt': gt,
        'pred': pred,
        'cer': cer_individual,
        'char_error': char_error,
        'word_length': len(gt)
    })

# Sort by CER
individual_errors.sort(key=lambda x: x['cer'], reverse=True)

# Get worst cases
n_worst = 12
worst_cases = individual_errors[:n_worst]

print(f"\nTop {n_worst} Failure Cases (Highest CER):")
print("-" * 80)
for i, case in enumerate(worst_cases, 1):
    print(f"{i:2d}. GT: '{case['gt']:15s}' | Pred: '{case['pred']:15s}' | "
          f"CER: {case['cer']*100:.1f}% | Len: {case['word_length']}")

In [None]:
# Visualize failure cases
fig, axes = plt.subplots(4, 3, figsize=(15, 12))
axes = axes.ravel()

for i, case in enumerate(worst_cases):
    idx = case['index']
    
    if idx < style_guided_imgs.size(0):
        img = (1 - style_guided_imgs[idx].squeeze().numpy())
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"GT: '{case['gt']}'\nPred: '{case['pred']}'\nCER: {case['cer']*100:.1f}%",
                         fontsize=9, color='red')
        axes[i].axis('off')

plt.suptitle('Failure Case Analysis - Highest OCR Errors', fontsize=14, fontweight='bold', color='red')
plt.tight_layout()
plt.savefig(output_dir / 'analysis_failure_cases.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Failure case visualization saved")

In [None]:
# Analyze error patterns
print("\n" + "="*80)
print("Error Pattern Analysis:")
print("="*80)

# By word length
length_bins = [0, 3, 5, 7, 10, 100]
length_errors = {f'{length_bins[i]}-{length_bins[i+1]}': [] 
                 for i in range(len(length_bins)-1)}

for case in individual_errors:
    for i in range(len(length_bins)-1):
        if length_bins[i] <= case['word_length'] < length_bins[i+1]:
            length_errors[f'{length_bins[i]}-{length_bins[i+1]}'].append(case['cer'])
            break

print("\nCER by Word Length:")
for length_range, cers in length_errors.items():
    if cers:
        print(f"  {length_range:8s}: {np.mean(cers)*100:.2f}% ± {np.std(cers)*100:.2f}% "
              f"(n={len(cers)})")

# Character-level analysis
char_errors = {}
for case in individual_errors:
    for char in case['gt']:
        if char not in char_errors:
            char_errors[char] = []
        char_errors[char].append(case['cer'])

# Most problematic characters
sorted_chars = sorted(char_errors.items(), key=lambda x: np.mean(x[1]), reverse=True)
print("\nTop 10 Most Problematic Characters:")
for char, cers in sorted_chars[:10]:
    if len(cers) >= 5:  # Only if appears at least 5 times
        print(f"  '{char}': {np.mean(cers)*100:.2f}% (n={len(cers)})")

## 18. Comprehensive Performance Report

Generate final summary with all metrics, charts, and export results.