# Projet Text to Handwriting – Préparation du Dataset IAM et Implémentation d'un cGAN
Ce notebook combine la préparation des données du dataset IAM et l'implémentation d'un Conditional GAN (cGAN) pour générer des images manuscrites à partir de texte.

## 1. Préparation des données

In [4]:
# Importer les bibliothèques nécessaires
import os
import pandas as pd
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import string

# Définir le chemin du dossier IAM
data_dir = "IAM_dataset"
lines_file = os.path.join(data_dir, "ascii", "lines.txt")
lines_img_dir = os.path.join(data_dir, "lines")  # Dossier des images

# Vérifier si les dossiers existent
if not os.path.exists(data_dir):
    print("❌ Erreur : Le dossier IAM_dataset n'existe pas.")
elif not os.path.exists(lines_img_dir):
    print("❌ Erreur : Le dossier lines n'existe pas dans IAM_dataset.")
else:
    print(f"✅ Dossier trouvé : {data_dir}")
    print(f"📁 Dossier des images : {lines_img_dir}")

# Lecture et parsing du fichier lines.txt
lines_data = []
with open(lines_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.startswith('#') or not line.strip():
            continue
        parts = line.strip().split()
        if len(parts) >= 9:
            img_id = parts[0]
            status = parts[1]
            if status == 'ok':
                text = ' '.join(parts[8:]).replace('|', ' ')
                img_base = img_id.split('-')[0]
                sub_dir = '-'.join(img_id.split('-')[:2])
                img_name = f"{img_id}.png"
                img_path = os.path.join(lines_img_dir, img_base, sub_dir, img_name)
                if os.path.exists(img_path):
                    lines_data.append((img_path, text))
                else:
                    print(f"⚠️ Image non trouvée : {img_path}")

# Créer un DataFrame
df_lines = pd.DataFrame(lines_data, columns=['image_path', 'text'])

# Diviser en ensembles d'entraînement, validation et test
train_df, temp_df = train_test_split(df_lines, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Encodage du texte
vocab = string.ascii_letters + string.digits + string.punctuation + " "
vocab_size = len(vocab)
char_to_idx = {char: idx for idx, char in enumerate(vocab)}

def text_to_one_hot(text, max_len=50):
    one_hot = torch.zeros(max_len, vocab_size)
    for i, char in enumerate(text[:max_len]):
        if char in char_to_idx:
            one_hot[i, char_to_idx[char]] = 1
    return one_hot

# Prétraitement des images
def preprocess_image(image):
    img = np.array(image)
    _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return Image.fromarray(img)

# Dataset personnalisé
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((64, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class HandwrittenDataset(Dataset):
    def __init__(self, dataframe, transform=None, max_text_len=50):
        self.dataframe = dataframe
        self.transform = transform
        self.max_text_len = max_text_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        text = self.dataframe.iloc[idx]['text']
        image = Image.open(img_path).convert('L')
        image = preprocess_image(image)
        if self.transform:
            image = self.transform(image)
        text_encoded = text_to_one_hot(text, self.max_text_len)
        return image, text_encoded, text

# Créer les DataLoaders
train_dataset = HandwrittenDataset(train_df, transform=transform)
val_dataset = HandwrittenDataset(val_df, transform=transform)
test_dataset = HandwrittenDataset(test_df, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Nombre de batches dans train_loader : {len(train_loader)}")
print(f"Nombre de batches dans val_loader : {len(val_loader)}")
print(f"Nombre de batches dans test_loader : {len(test_loader)}")
for batch_imgs, batch_texts_encoded, batch_texts in train_loader:
    print(f"Batch images shape : {batch_imgs.shape}")
    print(f"Batch texts encoded shape : {batch_texts_encoded.shape}")
    print(f"Exemple de texte : {batch_texts[0]}")
    break

✅ Dossier trouvé : IAM_dataset
📁 Dossier des images : IAM_dataset\lines
Nombre de batches dans train_loader : 249
Nombre de batches dans val_loader : 54
Nombre de batches dans test_loader : 54
Batch images shape : torch.Size([32, 1, 64, 256])
Batch texts encoded shape : torch.Size([32, 50, 95])
Exemple de texte : she was a lady who , like her Uncle Charles ,


## 2. Implémentation du cGAN

In [9]:
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Paramètres
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
noise_dim = 100
text_dim = 50 * 95  # 50 caractères max, vocabulaire de 95
img_channels = 1
img_size = 64
epochs = 200
batch_size = 32
lr = 0.0002
beta1 = 0.5
lambda_gp = 10  # Gradient penalty coefficient for WGAN-GP

# Générateur
class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, img_channels=1, hidden_dim=256):
        super(Generator, self).__init__()
        self.text_embedding = nn.LSTM(input_size=text_dim, hidden_size=hidden_dim, batch_first=True)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise_dim + hidden_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, text):
        text = text.view(-1, 50, 95)
        text_embed, _ = self.text_embedding(text)
        text_embed = text_embed[:, -1, :]
        text_embed = text_embed.unsqueeze(2).unsqueeze(3)
        x = torch.cat([noise, text_embed.expand(-1, -1, 4, 4)], dim=1)
        return self.main(x)

# Discriminateur
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, text_dim=95, feature_maps=64):
        super(Discriminator, self).__init__()
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)
        self.text_fc = nn.Linear(text_dim, feature_maps * 4)

        self.conv = nn.Sequential(
            nn.Conv2d(img_channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.output = nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False)

    def forward(self, img, text):
        # ✅ Correction du texte
        if text.dim() == 3:
            text_embed = text.mean(dim=1)
        elif text.dim() == 2:
            text_embed = text
        else:
            raise ValueError(f"Forme inattendue pour le texte : {text.shape}")

        # Projeter le texte
        text_feat = self.leakyrelu(self.text_fc(text_embed))

        # Extraire les features image
        img_feat = self.conv(img)

        # Étendre le texte pour concaténer
        text_feat = text_feat.unsqueeze(2).unsqueeze(3)
        text_feat = text_feat.expand(-1, -1, img_feat.size(2), img_feat.size(3))

        combined = torch.cat((img_feat, text_feat), dim=1)
        validity = self.output(combined)
        return validity.view(-1)


# Fonction de perte WGAN-GP
def gradient_penalty(discriminator, real_img, fake_img, text, device):
    batch_size = real_img.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    interpolates = alpha * real_img + (1 - alpha) * fake_img
    interpolates = interpolates.requires_grad_(True)
    d_interpolates = discriminator(interpolates, text)
    gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones_like(d_interpolates),
                                  create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Fonction de calcul des métriques
def calculate_metrics(real_img, fake_img, device):
    real_img = real_img.cpu().numpy().transpose(0, 2, 3, 1)
    fake_img = fake_img.cpu().numpy().transpose(0, 2, 3, 1)
    ssim_score = ssim(real_img[0, ..., 0], fake_img[0, ..., 0], data_range=1.0)
    psnr_score = psnr(real_img[0, ..., 0], fake_img[0, ..., 0], data_range=1.0)
    return 0, ssim_score, psnr_score  # FID placeholder

## 3. Boucle d'entraînement

In [10]:
import os
import torch
from torchvision.utils import save_image

os.makedirs("generated", exist_ok=True)

for epoch in range(epochs):
    for i, (real_imgs, text_encoded, _) in enumerate(train_loader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        text_encoded = text_encoded.to(device)

        # 🔧 Ajuster la forme du texte (si besoin)
        if text_encoded.dim() > 2:
            # Parfois, le DataLoader retourne (batch, seq_len, emb_dim)
            # On prend la moyenne sur la séquence pour avoir (batch, emb_dim)
            text_encoded = text_encoded.mean(dim=1)

        # 🔧 S'assurer que noise est de la bonne forme
        noise = torch.randn(batch_size, noise_dim).to(device)

        # ============================================================
        # 🔹 1. Entraîner le Discriminateur
        # ============================================================
        optim_d.zero_grad()
        fake_imgs = generator(noise, text_encoded)

        real_validity = discriminator(real_imgs, text_encoded)
        fake_validity = discriminator(fake_imgs.detach(), text_encoded)

        gp = gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), text_encoded, device)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp
        d_loss.backward()
        optim_d.step()

        # ============================================================
        # 🔹 2. Entraîner le Générateur
        # ============================================================
        optim_g.zero_grad()
        fake_validity = discriminator(fake_imgs, text_encoded)
        g_loss = -torch.mean(fake_validity)
        g_loss.backward()
        optim_g.step()

        # ============================================================
        # 🔹 3. Affichage et Sauvegarde intermédiaire
        # ============================================================
        if i % 100 == 0:
            with torch.no_grad():
                fid, ssim_score, psnr_score = calculate_metrics(real_imgs, fake_imgs, device)
                print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(train_loader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}] "
                      f"[SSIM: {ssim_score:.4f}] [PSNR: {psnr_score:.4f}]")

                # Sauvegarder un échantillon d’images
                save_image(fake_imgs.data[:25],
                           f"generated/epoch_{epoch+1}_batch_{i}.png",
                           nrow=5, normalize=True)

    # ============================================================
    # 🔹 4. Validation à la fin de chaque époque
    # ============================================================
    generator.eval()
    val_d_loss = 0
    val_g_loss = 0

    with torch.no_grad():
        for real_imgs, text_encoded, _ in val_loader:
            real_imgs = real_imgs.to(device)
            text_encoded = text_encoded.to(device)

            if text_encoded.dim() > 2:
                text_encoded = text_encoded.mean(dim=1)

            batch_size = real_imgs.size(0)
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_imgs = generator(noise, text_encoded)

            real_validity = discriminator(real_imgs, text_encoded)
            fake_validity = discriminator(fake_imgs, text_encoded)
            gp = gradient_penalty(discriminator, real_imgs, fake_imgs, text_encoded, device)

            val_d_loss += (-torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp).item()
            val_g_loss += (-torch.mean(fake_validity)).item()

        val_d_loss /= len(val_loader)
        val_g_loss /= len(val_loader)

        print(f"✅ Validation [Epoch {epoch+1}/{epochs}] "
              f"[D loss: {val_d_loss:.4f}] [G loss: {val_g_loss:.4f}]")

    generator.train()


RuntimeError: shape '[-1, 50, 95]' is invalid for input of size 3040

## 4. Évaluation finale sur l'ensemble de test

In [None]:
# Évaluation finale
generator.eval()
test_fid, test_ssim, test_psnr = 0, 0, 0
test_batches = 0
with torch.no_grad():
    for real_imgs, text_encoded, _ in test_loader:
        real_imgs = real_imgs.to(device)
        text_encoded = text_encoded.to(device)
        batch_size = real_imgs.size(0)
        noise = torch.randn(batch_size, noise_dim, 1, 1).to(device)
        fake_imgs = generator(noise, text_encoded)
        fid_batch, ssim_batch, psnr_batch = calculate_metrics(real_imgs, fake_imgs, device)
        test_fid += fid_batch
        test_ssim += ssim_batch
        test_psnr += psnr_batch
        test_batches += 1
test_fid /= test_batches
test_ssim /= test_batches
test_psnr /= test_batches
print(f"Test Metrics - FID: {test_fid:.4f}, SSIM: {test_ssim:.4f}, PSNR: {test_psnr:.4f}")