### Prep data

In [1]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

# paths to the dataset
root_dir = 'data/IAM'
forms_file = os.path.join(root_dir, 'ascii/forms.txt')
images_dir = os.path.join(root_dir, 'forms')

# parse 'forms.txt' to extract metadata
def parse_forms_file(forms_file):
    with open(forms_file, 'r') as f:
        lines = f.readlines()
    data = []
    for line in lines:
        if line.startswith('#'):
            continue
        fields = line.strip().split()
        img_file, writer_id, text = fields[0], fields[1], ' '.join(fields[2:])
        data.append({'image': img_file + '.png', 'writer': writer_id, 'text': text})
    return pd.DataFrame(data)

# create a filtered dataset for a single writer
def filter_dataset_by_writer(dataframe, writer_id):
    return dataframe[dataframe['writer'] == writer_id]

# parse and filter dataset
data = parse_forms_file(forms_file)
filtered_data = filter_dataset_by_writer(data, writer_id='0001')

### Dataset Class

In [None]:
class IAMDataset(Dataset):
    def __init__(self, dataframe, images_dir, transform=None):
        self.dataframe = dataframe
        self.images_dir = images_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_path = os.path.join(self.images_dir, row['image'])
        img = Image.open(img_path).convert('L')
        if self.transform:
            img = self.transform(img)
        label = row['text']
        style_id = row['writer']  # style embedding can be one-hot or latent vector
        return img, label, style_id

In [None]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # normalize to [-1, 1]
])

dataset = IAMDataset(filtered_data, images_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


### Define Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, style_dim, img_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_dim + style_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_size * img_size),
            nn.Tanh()  # Output normalized to [-1, 1]
        )
        self.img_size = img_size

    def forward(self, noise, text_emb, style_emb):
        x = torch.cat([noise, text_emb, style_emb], dim=1)
        x = self.fc(x)
        return x.view(-1, 1, self.img_size, self.img_size)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, text_dim, style_dim, img_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(1 * img_size * img_size + text_dim + style_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability
        )
        self.img_size = img_size

    def forward(self, img, text_emb, style_emb):
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, text_emb, style_emb], dim=1)
        return self.fc(x)

### Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# init models
generator = Generator(noise_dim=100, text_dim=50, style_dim=10, img_size=64).to(device)
discriminator = Discriminator(text_dim=50, style_dim=10, img_size=64).to(device)