### Prep data

In [10]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

# paths to the dataset
root_dir = '../data/IAM'
forms_file = os.path.join(root_dir, 'ascii/forms.txt')
images_dir = os.path.join(root_dir, 'forms/formsA-D')

# parse 'forms.txt' to extract metadata
def parse_forms_file(forms_file):
    with open(forms_file, 'r') as f:
        lines = f.readlines()
    data = []
    for line in lines:
        if line.startswith('#'):
            continue
        fields = line.strip().split()
        img_file, writer_id, text = fields[0], fields[1], ' '.join(fields[2:])
        data.append({'image': img_file + '.png', 'writer': writer_id, 'text': text})
    return pd.DataFrame(data)

# create a filtered dataset for a single writer
def filter_dataset_by_writer(dataframe, writer_id):
    return dataframe[dataframe['writer'] == writer_id]

# parse and filter dataset
data = parse_forms_file(forms_file)
writer_id = data['writer'].iloc[0]
filtered_data = filter_dataset_by_writer(data, writer_id=writer_id)

### Dataset Class

In [11]:
import torch

class IAMDataset(Dataset):
    def __init__(self, dataframe, images_dir, transform=None, max_text_len=50):
        self.dataframe = dataframe
        self.images_dir = images_dir
        self.transform = transform
        self.max_text_len = max_text_len
        
        # Create character vocabulary from the dataset
        self.chars = set()
        for text in dataframe['text']:
            self.chars.update(set(text))
        self.char_to_idx = {char: idx for idx, char in enumerate(sorted(self.chars))}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)

    def text_to_tensor(self, text):
        # Convert text to one-hot encoded tensor
        indices = [self.char_to_idx.get(c, 0) for c in text[:self.max_text_len]]
        # Pad if necessary
        indices = indices + [0] * (self.max_text_len - len(indices))
        return torch.tensor(indices, dtype=torch.long)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_path = os.path.join(self.images_dir, row['image'])
        img = Image.open(img_path).convert('L')  # Convert to grayscale
        
        # Center crop to maintain aspect ratio
        w, h = img.size
        min_dim = min(w, h)
        img = transforms.CenterCrop(min_dim)(img)
        
        if self.transform:
            img = self.transform(img)
            
        text_tensor = self.text_to_tensor(row['text'])
        return img, text_tensor, row['writer']
    
    def __len__(self):
        return len(self.dataframe)

In [12]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # normalize to [-1, 1]
])

dataset = IAMDataset(filtered_data, images_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


### Define Models

In [13]:
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, style_dim, channels=64):
        super(Generator, self).__init__()
        
        self.text_embedding = nn.Embedding(text_dim, 256)  # Text embedding layer
        self.style_embedding = nn.Linear(style_dim, 256)   # Style embedding layer
        
        # Initial dense layer
        self.fc = nn.Linear(noise_dim + 256 + 256, 4 * 4 * channels * 8)
        
        # Convolutional layers
        self.conv_blocks = nn.Sequential(
            # 4x4 -> 8x8
            nn.ConvTranspose2d(channels * 8, channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 4),
            nn.ReLU(True),
            
            # 8x8 -> 16x16
            nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 2),
            nn.ReLU(True),
            
            # 16x16 -> 32x32
            nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(True),
            
            # 32x32 -> 64x64
            nn.ConvTranspose2d(channels, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, text, style):
        # Embed text and style
        text_embedding = self.text_embedding(text).mean(dim=1)  # Average over sequence length
        style_embedding = self.style_embedding(style)
        
        # Concatenate inputs
        x = torch.cat([noise, text_embedding, style_embedding], dim=1)
        
        # Generate image
        x = self.fc(x)
        x = x.view(x.size(0), -1, 4, 4)
        x = self.conv_blocks(x)
        return x

In [14]:
class Discriminator(nn.Module):
    def __init__(self, text_dim, style_dim, channels=64):
        super(Discriminator, self).__init__()
        
        self.text_embedding = nn.Embedding(text_dim, 256)
        self.style_embedding = nn.Linear(style_dim, 256)
        
        # Image processing
        self.conv_blocks = nn.Sequential(
            # 64x64 -> 32x32
            nn.Conv2d(1, channels, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 32x32 -> 16x16
            nn.Conv2d(channels, channels * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 16x16 -> 8x8
            nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 8x8 -> 4x4
            nn.Conv2d(channels * 4, channels * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Final classification layer
        self.classifier = nn.Sequential(
            nn.Linear(channels * 8 * 4 * 4 + 256 + 256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, text, style):
        # Process image
        img_features = self.conv_blocks(img)
        img_features = img_features.view(img_features.size(0), -1)
        
        # Process text and style
        text_embedding = self.text_embedding(text).mean(dim=1)
        style_embedding = self.style_embedding(style)
        
        # Concatenate all features
        x = torch.cat([img_features, text_embedding, style_embedding], dim=1)
        return self.classifier(x)

### Training

In [15]:
# Training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 100
batch_size = 32
lr = 0.0002
b1 = 0.5
b2 = 0.999
noise_dim = 100

# Initialize models
generator = Generator(noise_dim=noise_dim, 
                     text_dim=dataset.vocab_size,
                     style_dim=10).to(device)
discriminator = Discriminator(text_dim=dataset.vocab_size,
                            style_dim=10).to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Loss functions
adversarial_loss = nn.BCELoss()

for epoch in range(num_epochs):
    for i, (real_imgs, texts, style_ids) in enumerate(dataloader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        texts = texts.to(device)
        
        # Convert style_ids to one-hot vectors
        style_emb = F.one_hot(torch.tensor([int(s) for s in style_ids]), num_classes=10).float().to(device)
        
        # Ground truths
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        
        # Generate images
        z = torch.randn(batch_size, noise_dim).to(device)
        gen_imgs = generator(z, texts, style_emb)
        
        # Calculate loss
        g_loss = adversarial_loss(discriminator(gen_imgs, texts, style_emb), valid)
        
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Real images
        real_loss = adversarial_loss(discriminator(real_imgs, texts, style_emb), valid)
        
        # Fake images
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), texts, style_emb), fake)
        
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()

        if i % 100 == 0:
            print(
                f"[Epoch {epoch}/{num_epochs}] "
                f"[Batch {i}/{len(dataloader)}] "
                f"[D loss: {d_loss.item():.4f}] "
                f"[G loss: {g_loss.item():.4f}]"
            )

[Epoch 0/100] [Batch 0/2] [D loss: 0.7405] [G loss: 0.6848]
[Epoch 1/100] [Batch 0/2] [D loss: 0.0321] [G loss: 2.8305]
[Epoch 2/100] [Batch 0/2] [D loss: 0.0312] [G loss: 2.8433]
[Epoch 3/100] [Batch 0/2] [D loss: 0.0148] [G loss: 3.6401]
[Epoch 4/100] [Batch 0/2] [D loss: 0.0069] [G loss: 4.6671]
[Epoch 5/100] [Batch 0/2] [D loss: 0.1548] [G loss: 1.3468]
[Epoch 6/100] [Batch 0/2] [D loss: 0.0067] [G loss: 4.8182]
[Epoch 7/100] [Batch 0/2] [D loss: 0.0042] [G loss: 5.5506]
[Epoch 8/100] [Batch 0/2] [D loss: 0.0315] [G loss: 2.8618]
[Epoch 9/100] [Batch 0/2] [D loss: 0.0147] [G loss: 3.6471]
[Epoch 10/100] [Batch 0/2] [D loss: 0.0229] [G loss: 3.1662]
[Epoch 11/100] [Batch 0/2] [D loss: 0.0120] [G loss: 3.8659]
[Epoch 12/100] [Batch 0/2] [D loss: 0.0156] [G loss: 3.5671]
[Epoch 13/100] [Batch 0/2] [D loss: 0.0119] [G loss: 3.8232]
[Epoch 14/100] [Batch 0/2] [D loss: 0.0100] [G loss: 4.0199]
[Epoch 15/100] [Batch 0/2] [D loss: 0.0094] [G loss: 4.0849]
[Epoch 16/100] [Batch 0/2] [D loss

In [18]:
def generate_handwriting(generator, text, style_id, dataset, device):
    generator.eval()
    
    # Convert text to tensor using dataset's vocabulary
    text_tensor = dataset.text_to_tensor(text).unsqueeze(0).to(device)
    
    # Create style embedding
    style_emb = F.one_hot(torch.tensor([style_id]), num_classes=10).float().to(device)
    
    # Generate image
    z = torch.randn(1, noise_dim).to(device)
    with torch.no_grad():
        gen_img = generator(z, text_tensor, style_emb)
        
    # Convert to PIL image
    gen_img = gen_img.squeeze().cpu()
    gen_img = (gen_img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
    return transforms.ToPILImage()(gen_img)

# Example usage
sample_text = "Hello World"
style_id = 1
generated_image = generate_handwriting(generator, sample_text, style_id, dataset, device)
generated_image.show()