In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# Define hyperparameters
BATCH_SIZE = 64
EMBEDDING_DIM = 512
NUM_HEADS = 8
NUM_ENCODER_LAYERS = 6
DROPOUT = 0.1
FFN_HID_DIM = 2048
MAX_SEQ_LEN = 3  # Length of the text description (e.g., "Draw one")
IMG_DIM = 28 * 28  # MNIST images are 28x28
IMG_SIZE = 28
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
Z_DIM = 100  # Dimension of the noise vector

# Define text-to-digit mapping
text_to_digit = {
    "Draw one": 1,
    "Draw two": 2,
    "Draw three": 3,
    "Draw four": 4,
    "Draw five": 5,
    "Draw six": 6,
    "Draw seven": 7,
    "Draw eight": 8,
    "Draw nine": 9,
    "Draw zero": 0
}

# Create a text dataset with descriptions like "Draw one", "Draw two", etc.
text_data = []
labels = []
for text, digit in text_to_digit.items():
    for _ in range(100):  # 6000 samples per digit for balance
        text_data.append(text)
        labels.append(digit)

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_data = datasets.MNIST(root='mnist_data', train=True, download=True, transform=transform)

# Filter MNIST data to match text descriptions
mnist_images = []
mnist_labels = []
for image, label in mnist_data:
    if label in text_to_digit.values():
        mnist_images.append(image)
        mnist_labels.append(label)

# Create custom dataset class
class TextToImageDataset(Dataset):
    def __init__(self, text_data, labels, mnist_images, mnist_labels):
        self.text_data = text_data
        self.labels = labels
        self.mnist_images = mnist_images
        self.mnist_labels = mnist_labels
        self.vocab = {word: i for i, word in enumerate(set(" ".join(text_data).split()))}

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        text = self.text_data[idx]
        text_tensor = torch.tensor([self.vocab[word] for word in text.split()], dtype=torch.long)
        digit = self.labels[idx]
        image_idx = np.random.choice(np.where(np.array(self.mnist_labels) == digit)[0])
        image = self.mnist_images[image_idx]
        return text_tensor, image

# Create custom dataset
dataset = TextToImageDataset(text_data, labels, mnist_images, mnist_labels)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float, maxlen: int):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * (np.log(10000.0) / emb_size))
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(1)

        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        seq_len, batch_size, _ = token_embedding.size()
        pos_embedding = self.pos_embedding[:seq_len, :]
        return self.dropout(token_embedding + pos_embedding)

# Text Encoder (Transformer Encoder)
class TransformerEncoder(nn.Module):
    def __init__(self, emb_size: int, nhead: int, num_encoder_layers: int, dim_feedforward: int, dropout: float, maxlen: int):
        super(TransformerEncoder, self).__init__()
        self.emb_size = emb_size

        self.embedding = nn.Embedding(100, emb_size)  # Assuming vocab size of 100 for dummy data
        self.positional_encoding = PositionalEncoding(emb_size, dropout, maxlen)

        encoder_layers = nn.TransformerEncoderLayer(emb_size, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)

    def forward(self, src: torch.Tensor):
        src = self.embedding(src) * np.sqrt(self.emb_size)
        src = self.positional_encoding(src)

        output = self.transformer_encoder(src)
        output = output.mean(dim=0)  # Global average pooling
        return output

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim, text_feature_dim, img_size):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.text_feature_dim = text_feature_dim
        self.img_size = img_size

        self.fc = nn.Sequential(
            nn.Linear(z_dim + text_feature_dim, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(True)
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, text_features):
        x = torch.cat([z, text_features], dim=1)
        x = self.fc(x)
        x = x.view(x.size(0), 128, 7, 7)
        img = self.deconv(x)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_size, text_feature_dim):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.text_feature_dim = text_feature_dim

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 3 * 3 + text_feature_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img, text_features):
        x = self.conv(img)
        x = x.view(x.size(0), -1)
        x = torch.cat([x, text_features], dim=1)
        validity = self.fc(x)
        return validity

# Initialize the models
text_encoder = TransformerEncoder(EMBEDDING_DIM, NUM_HEADS, NUM_ENCODER_LAYERS, FFN_HID_DIM, DROPOUT, MAX_SEQ_LEN)
generator = Generator(Z_DIM, EMBEDDING_DIM, IMG_SIZE)
discriminator = Discriminator(IMG_SIZE, EMBEDDING_DIM)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))



In [None]:
text_data, _ = next(iter(dataloader))
text_data.size()

torch.Size([64, 2])

In [None]:
# Training loop
from torchvision.utils import save_image
for epoch in range(10):
    for batch_idx, (text_data, real_imgs) in enumerate(dataloader):

        # Adversarial ground truths
        valid = torch.ones(real_imgs.size(0), 1, requires_grad=False)
        fake = torch.zeros(real_imgs.size(0), 1, requires_grad=False)

        # Configure input
        text_features = text_encoder(text_data.transpose(0, 1))  # Transpose to match (seq_len, batch_size, emb_size)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(real_imgs.size(0), Z_DIM)

        # Generate a batch of images
        gen_imgs = generator(z, text_features)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs, text_features), valid)

        g_loss.backward(retain_graph=True)
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs, text_features), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), text_features), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1}/{NUM_EPOCHS}, Batch: {batch_idx}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')
            save_image(gen_imgs.data[:25], f'generated_images/{epoch+1}_{batch_idx}.png', nrow=5, normalize=True)

print("Training complete")


print("Training complete")

Epoch: 1/20, Batch: 0, D Loss: 0.7128311991691589, G Loss: 0.6560519933700562
Epoch: 2/20, Batch: 0, D Loss: 0.11783580482006073, G Loss: 2.1648099422454834
Epoch: 3/20, Batch: 0, D Loss: 0.040208667516708374, G Loss: 3.9405386447906494
Epoch: 4/20, Batch: 0, D Loss: 0.011169323697686195, G Loss: 4.486865520477295
Epoch: 5/20, Batch: 0, D Loss: 0.005038658156991005, G Loss: 5.286547660827637
Epoch: 6/20, Batch: 0, D Loss: 0.0030209682881832123, G Loss: 5.792292594909668
Epoch: 7/20, Batch: 0, D Loss: 0.0023541913833469152, G Loss: 6.275784015655518
Epoch: 8/20, Batch: 0, D Loss: 0.001976217608898878, G Loss: 6.345122337341309
Epoch: 9/20, Batch: 0, D Loss: 0.0017750205006450415, G Loss: 6.431353569030762


In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Ensure text_encoder and generator definitions are correct and instantiated

def generate_image_from_text(text, text_encoder, generator, vocab, z_dim=100):
    # Tokenize the input text
    text_tokens = torch.tensor([vocab[word] for word in text.split()], dtype=torch.long).unsqueeze(1)

    # Encode the text to get text features
    text_features = text_encoder(text_tokens.transpose(0, 1))  # Transpose to match (seq_len, batch_size, emb_size)

    # Flatten the text features to 2D tensor
    text_features = text_features.mean(dim=0).view(1, -1)  # Global average pooling and reshape to (1, emb_size)

    # Sample noise as generator input
    z = torch.randn(1, z_dim)

    # Generate image from text features and noise
    gen_img = generator(z, text_features).detach().squeeze(0).cpu()

    return gen_img

def show_generated_image(image_tensor):
    # Rescale image tensor from [-1, 1] to [0, 1]
    image_tensor = (image_tensor + 1) / 2

    # Plot the image
    plt.imshow(image_tensor.squeeze(), cmap='gray')
    plt.axis('off')
    plt.show()

# Assume the vocab is the same as used during training
# Here is a dummy text data list
text_data = [
    "Draw one",
    "Draw two",
    "Draw three",
    "Draw four",
    "Draw five",
    "Draw six",
    "Draw seven",
    "Draw eight",
    "Draw nine",
    "Draw zero"
]

# Create the vocabulary
vocab = {word: i for i, word in enumerate(set(" ".join(text_data).split()))}

# Text descriptions for evaluation
text_descriptions = [
    "Draw one",
    "Draw two",
    "Draw three",
    "Draw four",
    "Draw five",
    "Draw six",
    "Draw seven",
    "Draw eight",
    "Draw nine",
    "Draw zero"
]

# Generate and display images for each text description
for text in text_descriptions:
    gen_img = generate_image_from_text(text, text_encoder, generator, vocab, z_dim=Z_DIM)
    print(f"Generated image for text: '{text}'")
    show_generated_image(gen_img)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 6272])

In [None]:
# text_encoder = TransformerEncoder(EMBEDDING_DIM, NUM_HEADS, NUM_ENCODER_LAYERS, FFN_HID_DIM, DROPOUT, MAX_SEQ_LEN)
# generator = Generator()
Z_DIM = 100

def generate_image_from_text(text, text_encoder, generator, vocab, z_dim=100):
    # Tokenize the input text
    text_tokens = torch.tensor([vocab[word] for word in text.split()], dtype=torch.long).unsqueeze(1)

    # Encode the text to get text features
    text_features = text_encoder(text_tokens)

    # Sample noise as generator input
    z = torch.randn(1, z_dim)

    # Generate image from text features and noise
    gen_img = generator(z, text_features).detach().squeeze(0).cpu()

    return gen_img

def show_generated_image(image_tensor):
    # Rescale image tensor from [-1, 1] to [0, 1]
    image_tensor = (image_tensor + 1) / 2

    # Plot the image
    plt.imshow(image_tensor.squeeze(), cmap='gray')
    plt.axis('off')
    plt.show()

# Assume the vocab is the same as used during training
# Here is a dummy text data list
text_data = [
    "Draw one",
    "Draw two",
    "Draw three",
    "Draw four",
    "Draw five",
    "Draw six",
    "Draw seven",
    "Draw eight",
    "Draw nine",
    "Draw zero"
]

# Create the vocabulary
vocab = {word: i for i, word in enumerate(set(" ".join(text_data).split()))}

# Text descriptions for evaluation
text_descriptions = [
    "Draw one",
    "Draw two",
    "Draw three",
    "Draw four",
    "Draw five",
    "Draw six",
    "Draw seven",
    "Draw eight",
    "Draw nine",
    "Draw zero"
]

# Generate and display images for each text description
for text in text_descriptions:
    gen_img = generate_image_from_text(text, text_encoder, generator, vocab, z_dim=Z_DIM)
    print(f"Generated image for text: '{text}'")
    show_generated_image(gen_img)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 6272])