In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image

In [29]:
class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, style_dim, img_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_dim + style_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_size * img_size),
            nn.Tanh()  # Output normalized to [-1, 1]
        )
        self.img_size = img_size

    def forward(self, noise, text_emb, style_emb):
        x = torch.cat([noise, text_emb, style_emb], dim=1)
        x = self.fc(x)
        return x.view(-1, 1, self.img_size, self.img_size)

In [30]:
class Discriminator(nn.Module):
    def __init__(self, text_dim, style_dim, img_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(1 * img_size * img_size + text_dim + style_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability
        )
        self.img_size = img_size

    def forward(self, img, text_emb, style_emb):
        img_flat = img.view(img.size(0), -1)
        x = torch.cat([img_flat, text_emb, style_emb], dim=1)
        return self.fc(x)


In [31]:
class HandwritingDataset(Dataset):
    def __init__(self, img_paths, labels, styles, transform=None):
        self.img_paths = img_paths
        self.labels = labels
        self.styles = styles
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx])
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        style = self.styles[idx]
        return img, label, style

In [35]:
img_size = 64
noise_dim = 100
text_dim = 50
style_dim = 50
batch_size = 64
num_epochs = 100
lr = 0.0002

# Initialize models
generator = Generator(noise_dim, text_dim, style_dim, img_size)
discriminator = Discriminator(text_dim, style_dim, img_size)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

adversarial_loss = nn.BCELoss()

In [34]:
# Example usage
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

img_paths = ['../data/a.png']
text_embeddings = torch.randn(len(img_paths), text_dim)
style_embeddings = torch.randn(len(img_paths), style_dim)

dataset = HandwritingDataset(img_paths, text_embeddings, style_embeddings, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)