# Étape Suivante : Création d'un Dataset PyTorch et Entraînement Basique d'un Modèle Text-to-Handwritten

Ce notebook utilise les paires préparées dans `01_Prepare_IAM_Dataset.ipynb` pour entraîner un modèle cGAN simple. Objectifs :
1. Charger les paires (texte tokenisé + images) à partir de `iam_lines_pairs.npz`.
2. Créer un Dataset et DataLoader PyTorch.
3. Implémenter un cGAN basique (Générateur + Discriminateur conditionnés par texte).
4. Entraîner le modèle sur GPU si disponible.
5. Générer et visualiser des samples.

**Prérequis** :
- Installez : `pip install torch torchvision torchtext scikit-learn tqdm`.
- Utilisez un GPU (Colab recommandé : Runtime > Change runtime type > GPU).
- Fichier `iam_lines_pairs.npz` doit exister (généré dans le notebook précédent).

Exécutez cellule par cellule. Temps : ~quelques minutes pour demo, heures pour full dataset.

In [None]:
!pip install torchtext

In [None]:
!pip show torch torchtext

In [None]:
!pip install --upgrade torch torchvision torchaudio
!pip install --upgrade torchtext

In [None]:
!pip uninstall torchtext
!pip install torchtext

In [None]:
!pip install torchtext==0.17.0  

In [None]:
!pip install scikit-learn


In [None]:
# Imports nécessaires
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import random

# Device (GPU si disponible)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Chemins
PAIRS_NPZ = 'C:\\Users\\Hp\\Desktop\\GEN AI\\iam_lines_pairs.npz'  # Ajustez si nécessaire
TRAIN_CSV = 'iam_train.csv'
VAL_CSV = 'iam_val.csv'
TEST_CSV = 'iam_test.csv'

## Étape 1 : Chargement des paires préparées et création de splits

In [None]:
import os
import xml.etree.ElementTree as ET
import numpy as np
from PIL import Image
import torch
from torchvision import transforms

# Transformation pour les images
transform = transforms.Compose([
    transforms.Resize((128, 1024)),  # Taille cible unifiée
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normaliser entre -1 et 1
])

# Dossiers
xml_dir = r"C:\Users\Hp\Desktop\GEN AI\IAM_dataset\xml"
img_dir = r"C:\Users\Hp\Desktop\GEN AI\IAM_dataset\lines"

pairs = []

for file in os.listdir(xml_dir):
    if file.endswith(".xml"):
        xml_path = os.path.join(xml_dir, file)
        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()

            form_id = file.replace(".xml", "")

            for line in root.findall(".//line"):
                line_id = line.get("id")
                text = line.get("text")
                if text:
                    # Construire le chemin de l'image
                    parts = line_id.split('-')
                    img_path = os.path.join(img_dir, parts[0], f"{parts[0]}-{parts[1]}", f"{line_id}.png")

                    if os.path.exists(img_path):
                        try:
                            image = Image.open(img_path).convert("L")
                            image_tensor = transform(image)
                            pairs.append({
                                "id": line_id,
                                "text": text,
                                "form_id": form_id,
                                "image": image_tensor
                            })
                        except Exception as e:
                            print(f"Erreur de traitement pour l'image {img_path}: {e}")
        except ET.ParseError:
            print(f"Erreur de parsing XML pour {xml_path}")

print(f"Nombre total de paires (texte, image) trouvées : {len(pairs)}")

# Sauvegarder le fichier .npz
np.savez("iam_lines_pairs.npz", pairs=pairs)
print("✅ Fichier iam_lines_pairs.npz recréé avec succès !")

In [None]:
# Tokenisation char-level manuelle
def create_vocab(texts):
    all_chars = set(''.join(texts))
    vocab = ['<pad>', '<unk>'] + sorted(all_chars)
    vocab_size = len(vocab)
    char_to_idx = {char: idx for idx, char in enumerate(vocab)}
    idx_to_char = {idx: char for idx, char in enumerate(vocab)}
    max_text_len = max(len(t) for t in texts)  # Pour padding
    return vocab, char_to_idx, idx_to_char, max_text_len

# Charger les textes
data = np.load(PAIRS_NPZ, allow_pickle=True)['pairs']
pairs = list(data)
texts = [pair['text'] for pair in pairs]

# Créer vocabulaire
vocab, char_to_idx, idx_to_char, max_text_len = create_vocab(texts)
vocab_size = len(vocab)
print(f'Vocab size: {vocab_size}, Max text len: {max_text_len}')

# Fonction pour encoder le texte
def encode_text(text):
    encoded = [char_to_idx.get(c, char_to_idx['<unk>']) for c in text]
    padded = encoded + [char_to_idx['<pad>']] * (max_text_len - len(encoded))
    return torch.tensor(padded, dtype=torch.long)

## Étape 2 : Tokenisation du texte et création d'un vocabulaire

In [None]:
# Tous les textes
texts = [pair['text'] for pair in pairs]

# Vocabulaire char-level
all_chars = set(''.join(texts))
vocab = ['<pad>', '<unk>'] + sorted(all_chars)
vocab_size = len(vocab)
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for idx, char in enumerate(vocab)}
max_text_len = max(len(t) for t in texts)  # Pour padding
print(f'Vocab size: {vocab_size}, Max text len: {max_text_len}')

# Fonction pour encoder le texte
def encode_text(text):
    encoded = [char_to_idx.get(c, char_to_idx['<unk>']) for c in text]
    padded = encoded + [char_to_idx['<pad>']] * (max_text_len - len(encoded))
    return torch.tensor(padded, dtype=torch.long)

## Étape 3 : Custom Dataset PyTorch

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split

# Charger les données
data = np.load("iam_lines_pairs.npz", allow_pickle=True)['pairs']

# Division en ensembles train / val / test
pairs_train, pairs_temp = train_test_split(data, test_size=0.3, random_state=42)
pairs_val, pairs_test = train_test_split(pairs_temp, test_size=0.5, random_state=42)

print(f"Train: {len(pairs_train)}  |  Val: {len(pairs_val)}  |  Test: {len(pairs_test)}")


In [None]:
class IAMHandwritingDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        text = pair['text']
        image = pair['image']  # Déjà un tenseur (1, H, W)
        text_encoded = encode_text(text)
        return {'text': text_encoded, 'image': image, 'raw_text': text}

# Création des datasets et dataloaders
dataset_train = IAMHandwritingDataset(pairs_train)
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True)
if len(pairs_val) > 0:  # Remplacez if pairs_val: par une vérification de longueur
    dataset_val = IAMHandwritingDataset(pairs_val)
    dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False)
if len(pairs_test) > 0:  # Remplacez if pairs_test:
    dataset_test = IAMHandwritingDataset(pairs_test)
    dataloader_test = DataLoader(dataset_test, batch_size=16, shuffle=False)

## Étape 4 : Définition du Modèle cGAN Simple

In [None]:
!pip install torch==2.5.1 --upgrade


In [None]:
import torch
x = torch.tensor([1.0], device='cpu')
print(x)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


In [None]:
vocab_size = 1000
max_text_len = 20
device = 'cpu'  # ou 'cuda' si GPU dispo


In [None]:
!pip uninstall torch torchvision torchaudio -y
!pip cache purge
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


In [None]:
!pip install torch==2.4.0 torchvision==0.15.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cpu


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

linear = nn.Linear(10, 1)
opt = optim.Adam(linear.parameters())
x = torch.randn(1, 10)
loss = linear(x).sum()
loss.backward()
opt.step()
print("Adam fonctionne !")


In [None]:
# Embedding pour texte
class TextEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
    
    def forward(self, text):
        return self.embedding(text)  # (batch, seq_len, embed_dim)

# Générateur simple
class Generator(nn.Module):
    def __init__(self, embed_dim=256, noise_dim=100):
        super().__init__()
        self.text_embed = TextEmbedding(vocab_size, embed_dim)
        self.model = nn.Sequential(
            nn.Linear(max_text_len * embed_dim + noise_dim, 128 * 16 * 4),
            nn.ReLU(),
            nn.Unflatten(1, (128, 16, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, text, noise):
        text_emb = self.text_embed(text).view(text.size(0), -1)
        input = torch.cat([text_emb, noise], dim=1)
        return self.model(input)

# Discriminateur
class Discriminator(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.text_embed = TextEmbedding(vocab_size, embed_dim)
        self.img_model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.LeakyReLU(0.2),
        )
        self.joint_model = nn.Sequential(
            nn.Conv2d(64 + embed_dim, 128, kernel_size=3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=3),
            nn.Sigmoid()
        )
    
    def forward(self, image, text):
        img_feat = self.img_model(image)
        text_emb = self.text_embed(text).max(dim=1)[0].unsqueeze(2).unsqueeze(3)
        text_emb = text_emb.expand(-1, -1, img_feat.size(2), img_feat.size(3))
        input = torch.cat([img_feat, text_emb], dim=1)
        return self.joint_model(input)

# Instanciation
gen = Generator().to(device)
disc = Discriminator().to(device)

# Optimizers
gen_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
disc_opt = optim.Adam(disc.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss
criterion = nn.BCELoss()

## Étape 5 : Boucle d'Entraînement

In [None]:
num_epochs = 5  # Augmentez à 50+ pour un entraînement réel
noise_dim = 100

for epoch in range(num_epochs):
    gen.train()
    disc.train()
    total_disc_loss = 0
    total_gen_loss = 0
    num_batches = 0
    
    for batch in tqdm(dataloader_train):
        images = batch['image'].to(device)  # (batch, 1, H, W)
        texts = batch['text'].to(device)
        batch_size = images.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)
        
        # Train Discriminateur
        disc_opt.zero_grad()
        real_pred = disc(images, texts)
        disc_real_loss = criterion(real_pred, real_labels)
        
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_images = gen(texts, noise)
        fake_pred = disc(fake_images.detach(), texts)
        disc_fake_loss = criterion(fake_pred, fake_labels)
        
        disc_loss = disc_real_loss + disc_fake_loss
        disc_loss.backward()
        disc_opt.step()
        
        # Train Générateur
        gen_opt.zero_grad()
        fake_pred = disc(fake_images, texts)
        gen_loss = criterion(fake_pred, real_labels)
        gen_loss.backward()
        gen_opt.step()
        
        total_disc_loss += disc_loss.item()
        total_gen_loss += gen_loss.item()
        num_batches += 1
    
    avg_disc_loss = total_disc_loss / num_batches
    avg_gen_loss = total_gen_loss / num_batches
    print(f'Epoch {epoch+1}/{num_epochs} - Disc Loss: {avg_disc_loss:.4f}, Gen Loss: {avg_gen_loss:.4f}')
    
    # Sauvegarde du modèle
    torch.save(gen.state_dict(), f'gen_epoch_{epoch}.pth')
    torch.save(disc.state_dict(), f'disc_epoch_{epoch}.pth')

## Étape 6 : Génération et Visualisation

In [None]:
# Inference
gen.eval()
test_text = "Hello world"  # Texte personnalisé
test_encoded = encode_text(test_text).unsqueeze(0).to(device)
noise = torch.randn(1, noise_dim).to(device)
with torch.no_grad():
    fake_img = gen(test_encoded, noise).cpu().squeeze(0).numpy()

# Ajuster la normalisation (inverser Tanh et Normalize)
fake_img = (fake_img + 1) / 2  # De [-1, 1] à [0, 1]
plt.imshow(fake_img[0], cmap='gray')
plt.title(f'Generated: {test_text}')
plt.axis('off')
plt.show()

In [None]:
# Boucle d'Entraînement Améliorée
num_epochs = 20  # Augmenté pour de meilleurs résultats
noise_dim = 100

for epoch in range(num_epochs):
    gen.train()
    disc.train()
    total_disc_loss = 0
    total_gen_loss = 0
    num_batches = 0
    
    # Entraînement
    for batch in tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs} - Train'):
        images = batch['image'].to(device)
        texts = batch['text'].to(device)
        batch_size = images.size(0)
        
        real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)
        
        # Discriminateur
        disc_opt.zero_grad()
        real_pred = disc(images, texts)
        disc_real_loss = criterion(real_pred, real_labels)
        
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_images = gen(texts, noise)
        fake_pred = disc(fake_images.detach(), texts)
        disc_fake_loss = criterion(fake_pred, fake_labels)
        
        disc_loss = disc_real_loss + disc_fake_loss
        disc_loss.backward()
        disc_opt.step()
        
        # Générateur
        gen_opt.zero_grad()
        fake_pred = disc(fake_images, texts)
        gen_loss = criterion(fake_pred, real_labels)
        gen_loss.backward()
        gen_opt.step()
        
        total_disc_loss += disc_loss.item()
        total_gen_loss += gen_loss.item()
        num_batches += 1
    
    avg_disc_loss = total_disc_loss / num_batches
    avg_gen_loss = total_gen_loss / num_batches
    
    # Validation (si des données de validation existent)
    if dataloader_val:
        gen.eval()
        total_val_disc_loss = 0
        total_val_gen_loss = 0
        val_num_batches = 0
        
        with torch.no_grad():
            for batch in dataloader_val:
                images = batch['image'].to(device)
                texts = batch['text'].to(device)
                batch_size = images.size(0)
                
                real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
                fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)
                
                real_pred = disc(images, texts)
                val_disc_real_loss = criterion(real_pred, real_labels)
                
                noise = torch.randn(batch_size, noise_dim).to(device)
                fake_images = gen(texts, noise)
                fake_pred = disc(fake_images, texts)
                val_disc_fake_loss = criterion(fake_pred, fake_labels)
                
                val_disc_loss = val_disc_real_loss + val_disc_fake_loss
                val_gen_loss = criterion(fake_pred, real_labels)
                
                total_val_disc_loss += val_disc_loss.item()
                total_val_gen_loss += val_gen_loss.item()
                val_num_batches += 1
        
        avg_val_disc_loss = total_val_disc_loss / val_num_batches
        avg_val_gen_loss = total_val_gen_loss / val_num_batches
        print(f'Epoch {epoch+1}/{num_epochs} - Train Disc Loss: {avg_disc_loss:.4f}, Train Gen Loss: {avg_gen_loss:.4f}, '
              f'Val Disc Loss: {avg_val_disc_loss:.4f}, Val Gen Loss: {avg_val_gen_loss:.4f}')
    else:
        print(f'Epoch {epoch+1}/{num_epochs} - Train Disc Loss: {avg_disc_loss:.4f}, Train Gen Loss: {avg_gen_loss:.4f}')
    
    # Sauvegarde du modèle
    torch.save(gen.state_dict(), f'gen_epoch_{epoch}.pth')
    torch.save(disc.state_dict(), f'disc_epoch_{epoch}.pth')

In [None]:
# Évaluation et Visualisation
gen.eval()
num_samples = 5
sample_texts = ["Hello world", "This is a test", "Machine learning", "Handwritten text", "Good day!"]

plt.figure(figsize=(15, 3 * num_samples))
for i, text in enumerate(sample_texts):
    test_encoded = encode_text(text).unsqueeze(0).to(device)
    noise = torch.randn(1, noise_dim).to(device)
    with torch.no_grad():
        fake_img = gen(test_encoded, noise).cpu().squeeze(0).numpy()
    
    fake_img = (fake_img + 1) / 2  # De [-1, 1] à [0, 1]
    plt.subplot(num_samples, 1, i+1)
    plt.imshow(fake_img[0], cmap='gray')
    plt.title(f'Generated: {text}')
    plt.axis('off')
plt.tight_layout()
plt.show()

# Sauvegarde des images (optionnel)
for i, text in enumerate(sample_texts):
    test_encoded = encode_text(text).unsqueeze(0).to(device)
    noise = torch.randn(1, noise_dim).to(device)
    with torch.no_grad():
        fake_img = gen(test_encoded, noise).cpu().squeeze(0).numpy()
    fake_img = (fake_img + 1) / 2
    img = Image.fromarray((fake_img[0] * 255).astype(np.uint8))
    img.save(f'generated_{text.replace(" ", "_")}.png')