# CRNN OCR 训练（Scheme B，自包含 Notebook 版，Windows/Notebook 适配）

本 Notebook 实现了无需直接运行 code/train.py 的完整训练流程：
- 使用 Windows/Notebook 友好设置（DataLoader num_workers=0；无 multiprocessing_context='fork'）。
- 日志目录 runs/，检查点目录 checkpoints/。
- 支持断点续训、早停、最佳模型保存。
- 可选的数据集字符覆盖检查。
- 可选推理示例。

请确保：
- 你已经准备好了 data/train.lmdb 和可选的 data/val.lmdb。
- models/model.py 与 dataset.py 可被导入（本仓库已有）。
- requirements.txt 已安装依赖。


In [11]:
# 可选：安装依赖（如已安装可跳过）
# 在某些环境下，建议手工在终端执行：pip install -r requirements.txt
# !pip -q install -r requirements.txt
import sys, os
print('Python:', sys.version)
print('CWD:', os.getcwd())


Python: 3.12.11 | packaged by Anaconda, Inc. | (main, Jun  5 2025, 12:58:53) [MSC v.1929 64 bit (AMD64)]
CWD: C:\workspace\mys-ocr\notebooks


In [12]:
# Imports 与字符集、工具函数
import os
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

# 字符集设置（与 code/train.py 保持一致，必要时按你的数据修改）
BLANK = '-'
CHARS = BLANK + "OP12/403ADC@E"

# 重复字符检查
if len(set(CHARS)) != len(CHARS):
    duplicates = set([c for c in CHARS if CHARS.count(c) > 1])
    raise ValueError(f'字符集 CHARS 存在重复字符: {sorted(duplicates)}，请检查并去重！')

nclass = len(CHARS)
char2idx = {c: i for i, c in enumerate(CHARS)}
idx2char = {i: c for i, c in enumerate(CHARS)}

def text_to_indices(text):
    return [char2idx[c] for c in text if c in char2idx]

def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.stack(images)
    label_indices = [torch.tensor(text_to_indices(label), dtype=torch.long) for label in labels]
    label_lengths = torch.tensor([len(l) for l in label_indices], dtype=torch.long)
    labels_concat = torch.cat(label_indices)
    return images, labels_concat, label_lengths

@torch.no_grad()
def decode(preds):
    # preds: (T, N, C) logit 或 log_softmax 输出
    preds = preds.argmax(2)
    preds = preds.permute(1, 0)  # (batch, seq)
    texts = []
    for pred in preds:
        char_list = []
        prev_idx = 0
        for idx in pred:
            idx = idx.item()
            if idx != 0 and idx != prev_idx:
                char_list.append(idx2char[idx])
            prev_idx = idx
        texts.append(''.join(char_list))
    return texts


In [13]:
# 导入模型与数据集（复用仓库实现）
import sys, os
repo_root = os.path.dirname(os.getcwd()) if os.path.basename(os.getcwd()).lower() == 'notebooks' else os.getcwd()
code_dir = os.path.join(repo_root, 'code')
if code_dir not in sys.path:
    sys.path.append(code_dir)

# 兼容多种项目结构的导入尝试
try:
    from model import CRNN          # 优先 code/model.py
    from dataset import OCRDataset  # 优先 code/dataset.py
except ModuleNotFoundError:
    try:
        from models.model import CRNN   # 备选：models/model.py
        from dataset import OCRDataset
    except ModuleNotFoundError:
        # 最后尝试：从仓库根路径直接导入（若用户将文件放到了根路径）
        if repo_root not in sys.path:
            sys.path.append(repo_root)
        from model import CRNN
        from dataset import OCRDataset
print('已导入 CRNN 与 OCRDataset')


已导入 CRNN 与 OCRDataset


In [14]:
# 训练配置（Windows/Notebook 友好）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# 模型超参数
imgH = 32
nh = 256
nc = 1

# 训练超参数（可按需调整）
finetune = False
pretrained_model = None  # 示例: os.path.join(repo_root, 'checkpoints', 'crnn_best.pth')
batch_size = 64
n_epoch = 50
lr = 5e-4
weight_decay = 1e-4
scheduler_step = 30
scheduler_gamma = 0.7
patience = 10

# 数据路径（LMDB）
train_lmdb = os.path.join(repo_root, 'data', 'train.lmdb')
val_lmdb = os.path.join(repo_root, 'data', 'val.lmdb')
use_val = os.path.isdir(val_lmdb) and any(os.scandir(val_lmdb)) if os.path.exists(val_lmdb) else False

# 日志与检查点目录
log_dir = os.path.join(repo_root, 'runs')
ckpt_dir = os.path.join(repo_root, 'checkpoints')
os.makedirs(log_dir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)

# Data transforms
transform = transforms.Compose([
    transforms.Resize((imgH, 100)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 数据集与 DataLoader（Windows: num_workers=0）
train_dataset = OCRDataset(lmdb_path=train_lmdb, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          collate_fn=collate_fn, num_workers=0,
                          pin_memory=torch.cuda.is_available())

val_loader = None
if use_val:
    val_dataset = OCRDataset(lmdb_path=val_lmdb, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=collate_fn, num_workers=0,
                            pin_memory=torch.cuda.is_available())

# 模型
model = CRNN(imgH, nc, nclass, nh).to(device)

# 微调：加载预训练
if finetune and pretrained_model and os.path.exists(pretrained_model):
    ckpt = torch.load(pretrained_model, map_location=device)
    if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
        model.load_state_dict(ckpt['model_state_dict'])
    else:
        model.load_state_dict(ckpt)
    print('Loaded pretrained weights from', pretrained_model)

criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()) if finetune else model.parameters(),
                       lr=lr, weight_decay=weight_decay, betas=(0.9, 0.95) if finetune else (0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
writer = SummaryWriter(log_dir=log_dir)
print('Dataloaders ready. Model initialized.')


Using device: cuda


Error: C:\workspace\mys-ocr\data\train.lmdb: ϵͳ�Ҳ���ָ����·����


In [None]:
# 可选：数据集字符覆盖检查
@torch.no_grad()
def check_dataset_chars(dataset, char2idx):
    dataset_chars = set()
    print('正在检查数据集字符...')
    for i in range(len(dataset)):
        if i % 1000 == 0:
            print(f'已检查 {i}/{len(dataset)} 个样本')
        _, label = dataset[i]
        for char in label:
            dataset_chars.add(char)
    missing_chars = dataset_chars - set(char2idx.keys())
    extra_chars = set(char2idx.keys()) - dataset_chars - {BLANK}
    if missing_chars:
        print('数据集中存在字符集未包含的字符：', sorted(missing_chars))
        raise ValueError('请修改 CHARS 以覆盖全部数据字符')
    if extra_chars:
        print('提示：字符集中存在数据集中未出现的字符：', sorted(extra_chars))
    print(f'数据集字符检查完成，共发现 {len(dataset_chars)} 个字符')
    return True

# 按需启用：
# check_dataset_chars(train_dataset, char2idx)
# if val_loader is not None:
#     check_dataset_chars(val_dataset, char2idx)


In [None]:
# 训练循环（无交互 input，含早停与checkpoint）
best_val_loss = float('inf')
best_epoch = 0
early_stop_counter = 0
start_epoch = 0
resume_from = None  # 可填写 checkpoint 路径以断点续训

# 断点恢复（可选）
if resume_from and os.path.exists(resume_from):
    print('Resuming from', resume_from)
    checkpoint = torch.load(resume_from, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        best_epoch = checkpoint.get('best_epoch', 0)
        early_stop_counter = checkpoint.get('early_stop_counter', 0)
    else:
        model.load_state_dict(checkpoint)
        m = re.search(r'epoch(\d+)', os.path.basename(resume_from))
        start_epoch = int(m.group(1)) if m else 0

for epoch in range(start_epoch, n_epoch):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels, label_lengths in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)

        preds = model(images)                                # (T, N, C)
        preds_log_softmax = nn.functional.log_softmax(preds, 2)
        input_lengths = torch.full((images.size(0),), preds.size(0), dtype=torch.long, device=device)
        ctc_loss = criterion(preds_log_softmax, labels, input_lengths, label_lengths)

        # 轻度长度惩罚（可选）
        pred_texts_for_penalty = decode(preds_log_softmax)
        pred_lengths = torch.tensor([len(p) for p in pred_texts_for_penalty], dtype=torch.float32, device=device)
        target_length = 21.0
        length_penalty = nn.functional.mse_loss(pred_lengths, torch.full_like(pred_lengths, target_length))
        loss = ctc_loss + 0.1 * length_penalty

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += ctc_loss.item()

        # 训练准确率（严格字符串完全匹配）
        pred_texts = decode(preds_log_softmax)
        batch_size_eff = images.size(0)
        labels_cpu = labels.detach().cpu().numpy()
        label_lengths_cpu = label_lengths.detach().cpu().numpy()
        gt_texts = []
        start = 0
        for l in label_lengths_cpu:
            gt_texts.append(''.join([idx2char[i] for i in labels_cpu[start:start+l]]))
            start += l
        for p, g in zip(pred_texts, gt_texts):
            if p == g:
                total_correct += 1
        total_samples += batch_size_eff

    avg_train_loss = total_loss / max(1, len(train_loader))
    train_acc = total_correct / max(1, total_samples)

    writer.add_scalar('Loss/train', avg_train_loss, epoch + 1)
    writer.add_scalar('Acc/train', train_acc, epoch + 1)
    print(f'Epoch {epoch+1}/{n_epoch}, Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f}')

    # 验证
    if val_loader is not None:
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_correct = 0
            val_samples = 0
            for images, labels, label_lengths in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                label_lengths = label_lengths.to(device)
                preds = model(images)
                preds_log_softmax = nn.functional.log_softmax(preds, 2)
                input_lengths = torch.full((images.size(0),), preds.size(0), dtype=torch.long, device=device)
                loss = criterion(preds_log_softmax, labels, input_lengths, label_lengths)
                val_loss += loss.item()

                pred_texts = decode(preds_log_softmax)
                batch_size_eff = images.size(0)
                labels_cpu = labels.detach().cpu().numpy()
                label_lengths_cpu = label_lengths.detach().cpu().numpy()
                gt_texts = []
                start = 0
                for l in label_lengths_cpu:
                    gt_texts.append(''.join([idx2char[i] for i in labels_cpu[start:start+l]]))
                    start += l
                for p, g in zip(pred_texts, gt_texts):
                    if p == g:
                        val_correct += 1
                val_samples += batch_size_eff

            avg_val_loss = val_loss / max(1, len(val_loader))
            val_acc = val_correct / max(1, val_samples)
            writer.add_scalar('Loss/val', avg_val_loss, epoch + 1)
            writer.add_scalar('Acc/val', val_acc, epoch + 1)
            print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}')
    else:
        avg_val_loss = avg_train_loss
        val_acc = train_acc

    # 调度与日志
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    writer.add_scalar('Learning_Rate', current_lr, epoch + 1)
    print(f'Learning Rate: {current_lr:.6f}')

    # 早停与保存
    improved = (avg_val_loss < best_val_loss)
    if improved:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        early_stop_counter = 0
        torch.save(model.state_dict(), os.path.join(ckpt_dir, 'crnn_best.pth'))
        print(f'保存最佳模型 (Epoch {best_epoch}, Loss: {best_val_loss:.4f})')
    else:
        early_stop_counter += 1
        print(f'验证损失未改善，早停计数器: {early_stop_counter}/{patience}')

    # 保存完整 checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
        'best_epoch': best_epoch,
        'early_stop_counter': early_stop_counter,
        'loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'acc': train_acc,
        'val_acc': val_acc,
        'finetune': finetune
    }
    torch.save(checkpoint, os.path.join(ckpt_dir, f'crnn_epoch{epoch+1}.pth'))

    # 早停
    if early_stop_counter >= patience:
        print(f'早停触发！最佳模型在 Epoch {best_epoch}, Loss: {best_val_loss:.4f}')
        break

writer.close()
print(f'训练完成！最佳模型在 Epoch {best_epoch}, Loss: {best_val_loss:.4f}')


In [None]:
# 推理/可视化示例（可选）
import os
from PIL import Image

best_pth = os.path.join(ckpt_dir, 'crnn_best.pth')
if os.path.exists(best_pth):
    model.load_state_dict(torch.load(best_pth, map_location=device))
    model.eval()
    print('最佳权重已加载:', best_pth)
else:
    print('未找到最佳权重文件:', best_pth)

# 示例：读取单张测试图并解码（请替换路径）
# img_path = r'C:\\path\\to\\your\\test.jpg'
# if os.path.exists(img_path):
#     img = Image.open(img_path).convert('L')
#     img = img.resize((100, imgH))
#     img = transforms.ToTensor()(img)
#     img = transforms.Normalize((0.5,), (0.5,))(img)
#     img = img.unsqueeze(0).to(device)
#     with torch.no_grad():
#         preds = model(img)
#         preds_log_softmax = nn.functional.log_softmax(preds, 2)
#     texts = decode(preds_log_softmax)
#     print('识别结果:', texts)
# else:
#     print('请设置有效的测试图像路径以运行推理示例')
