In [None]:
import tensorflow as tf
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage
from mltu.tensorflow.dataProvider import DataProvider
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
from mltu.tensorflow.metrics import CWERMetric

import mltu
import os
import tarfile
from tqdm import tqdm
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile
from datetime import datetime

from mltu.configs import BaseModelConfigs

from keras import layers
from keras.models import Model
from mltu.tensorflow.model_utils import residual_block

In [None]:
# Configuration améliorée pour une gestion des paramètres plus claire
class ModelConfigs(BaseModelConfigs):
    def __init__(self):
        super().__init__()
        self.model_path = os.path.join("kaggle/working/Models/03_handwriting_recognition", "improved_model")
        self.vocab = ""
        self.height = 32
        self.width = 128
        self.max_text_length = 0
        self.batch_size = 16
        self.learning_rate = 0.0005
        self.train_epochs = 1  # Augmenté pour un meilleur apprentissage
        self.train_workers = 4

In [None]:
# Couche de normalisation personnalisée
class NormalizeLayer(layers.Layer):
    def call(self, inputs):
        return tf.image.per_image_standardization(inputs)

# Modèle amélioré
def train_model(input_dim, output_dim, activation="relu", dropout=0.3):
    inputs = layers.Input(shape=input_dim, name="input")
    normalized_inputs = NormalizeLayer()(inputs)

    x = residual_block(normalized_inputs, 16, activation=activation, skip_conv=True, strides=1, dropout=dropout)
    for filters, strides in [(32, 2), (64, 2)]:
        x = residual_block(x, filters, activation=activation, skip_conv=True, strides=strides, dropout=dropout)
        x = residual_block(x, filters, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    squeezed = layers.Reshape((x.shape[-3] * x.shape[-2], x.shape[-1]))(x)
    blstm = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(squeezed)  # Augmenté à 256
    blstm = layers.Dropout(dropout)(blstm)

    output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)
    model = Model(inputs=inputs, outputs=output)
    return model

# Téléchargement et extraction avec gestion des erreurs
def download_and_unzip(url, extract_to="Datasets"):
    try:
        http_response = urlopen(url)
        zipfile = ZipFile(BytesIO(http_response.read()))
        zipfile.extractall(path=extract_to)
    except Exception as e:
        print(f"Erreur lors du téléchargement : {e}")
        raise

# Préparation des données
dataset_path = os.path.join("Datasets", "IAM_Words")


if not os.path.exists(dataset_path):
    download_and_unzip("https://git.io/J0fjL", extract_to="Datasets")
    file = tarfile.open(os.path.join(dataset_path, "words.tgz"))
    file.extractall(os.path.join(dataset_path, "words"))

In [None]:
dataset, vocab, max_len = [], set(), 0
words = open(os.path.join(dataset_path, "words.txt"), "r").readlines()
for line in tqdm(words):
    if line.startswith("#") or "err" in line:
        continue
    line_split = line.split(" ")
    folder1, folder2, file_name = line_split[0][:3], "-".join(line_split[0].split("-")[:2]), f"{line_split[0]}.png"
    label = line_split[-1].rstrip("\n")

    rel_path = os.path.join(dataset_path, "words", folder1, folder2, file_name)
    if not os.path.exists(rel_path):
        print(f"Fichier non trouvé : {rel_path}")
        continue

    dataset.append([rel_path, label])
    vocab.update(list(label))
    max_len = max(max_len, len(label))

# Mise à jour de la configuration
configs = ModelConfigs()
configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()

# Chargement des données avec transformations
data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[ImageReader(CVImage)],
    transformers=[
        ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
        LabelIndexer(configs.vocab),
        LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
    ],
)
train_data_provider, val_data_provider = data_provider.split(split=0.9)


In [None]:
# Ajout d'augmentations de données
train_data_provider.augmentors = [
    RandomBrightness(),
    RandomErodeDilate(),
    RandomSharpen(),
    RandomRotate(angle=15),  # Rotation augmentée
]

# Construction et compilation du modèle
model = train_model(
    input_dim=(configs.height, configs.width, 3),
    output_dim=len(configs.vocab),
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
    loss=CTCloss(),
    metrics=[CWERMetric(padding_token=len(configs.vocab))],
)
model.summary(line_length=120)

# Callbacks optimisés
earlystopper = EarlyStopping(monitor="val_CER", mode="min", patience=10, verbose=1)
checkpoint = ModelCheckpoint(filepath=f"{configs.model_path}/model.keras", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
trainLogger = TrainLogger(configs.model_path)
tb_callback = TensorBoard(log_dir=f"{configs.model_path}/logs", update_freq=1)
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.5, patience=5, verbose=1, mode="min")
model2onnx = Model2onnx(f"{configs.model_path}/model.keras")

# Entraînement du modèle
model.fit(
    train_data_provider,
    validation_data=val_data_provider,
    epochs=configs.train_epochs,
    callbacks=[earlystopper, checkpoint, trainLogger, reduceLROnPlat, tb_callback, model2onnx],
)

# Sauvegarde des ensembles d'entraînement et de validation
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))
