In [None]:
import os
import zipfile
import random
import json
import paddle
import sys
import numpy as np
from PIL import Image
from paddle.vision import transforms
import matplotlib.pyplot as plt
from paddle.io import Dataset

In [None]:
# 训练参数

train_parameters = {
    "input_size": [3, 64, 256],  # 裁剪后的单词图像尺寸（高度64，宽度256）
    "class_dim": 0,              # 类别数量（根据gt.txt统计得到）
    "train_img_dirs": [
        "/home/aistudio/data/data336725/ch8_training_word_images_gt_part_1.zip",
        "/home/aistudio/data/data336725/ch8_training_word_images_gt_part_2.zip",
        "/home/aistudio/data/data336725/ch8_training_word_images_gt_part_3.zip"
    ],                                                                                   # 训练集路径
    "eval_img_dirs": "/home/aistudio/data/data336725/ch8_validation_word_images_gt.zip", # 验证集路径                                                                            # 验证集图片路径
    "target_path":"/home/aistudio/data",                # 解压目标路径
    "train_list_path": "/home/aistudio/data/train.txt",       # train.txt路径
    "eval_list_path": "/home/aistudio/data/eval.txt",         # eval.txt路径
    "label_dict": {},            # 语言标签映射字典，会在读取数据时填充
    "num_epochs": 10,            # 训练轮数
    "train_batch_size": 32,      # 训练时每个批次的大小
    "eval_batch_size": 32,       # 验证时每个批次的大小
    "skip_steps": 10,            # 每N个batch输出日志
    "save_steps": 300,           # 每N个batch保存模型
    "learning_strategy": {
        "lr": 0.0001             # 学习率
    },
    "checkpoints": "/home/aistudio/work/checkpoints",  # 模型保存目录
}


In [None]:
# 解压数据

def unzip_data(zip_path,target_path,kind):
    if kind == "train":
        if(not os.path.isdir(target_path + kind)):
            target_path = os.path.join(target_path, kind)
            os.makedirs(target_path, exist_ok=True)
            for dir in zip_path:
                z = zipfile.ZipFile(dir, 'r')
                z.extractall(path=target_path)
            z.close()
    elif kind == "eval":
        if(not os.path.isdir(target_path + kind)):
            target_path = os.path.join(target_path, kind)
            os.makedirs(target_path, exist_ok=True)
            z = zipfile.ZipFile(zip_path, 'r')
            z.extractall(path=target_path)
            z.close()

train_img_dirs = train_parameters["train_img_dirs"]
eval_img_dirs = train_parameters["eval_img_dirs"]
target_path = train_parameters["target_path"]

unzip_data(train_img_dirs, target_path, "train")
unzip_data(eval_img_dirs, target_path, "eval")

In [None]:
# 生成随机数据列表

def get_data_list(target_path, kind, train_parameters):
    """
    生成 train.tx t或 eval.txt, 并统计类别信息
    """
    # 路径准备
    gt_path = os.path.join(target_path, kind, 'gt.txt')
    img_dir = os.path.join(target_path, kind)
    output_txt = os.path.join(target_path, f"{kind}.txt")

    # 临时存储信息
    label_dict = {}
    label_id = 0
    lines = []

    with open(gt_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line == "":
                continue
            img_name, lang, _ = line.split(',', 2)

            if lang not in label_dict:
                label_dict[lang] = label_id
                label_id += 1

            img_path = os.path.join(img_dir, img_name)
            label = label_dict[lang]
            lines.append(f"{img_path}\t{label}")

    # 打乱数据
    random.shuffle(lines)

    # 写入 train.txt 或 eval.txt
    with open(output_txt, 'w', encoding='utf-8') as f:
        for l in lines:
            f.write(l + "\n")

    # 更新 train_parameters
    train_parameters["label_dict"] = label_dict
    train_parameters["class_dim"] = len(label_dict)

    print(f"[完成] 已生成 {kind}.txt, 共 {len(lines)} 条数据, {len(label_dict)} 种类别")


target_path = train_parameters["target_path"]

get_data_list(target_path, "train", train_parameters)
get_data_list(target_path, "eval", train_parameters)

# 打印语言标签映射字典
print("语言标签映射字典 (label_dict):")
for lang, idx in train_parameters["label_dict"].items():
    print(f"{lang}: {idx}")
