In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m67.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.15.1 tokenizers-0.13.3 transformers-4.29.2


In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import os
import pandas as pd
import cv2
from transformers import AutoTokenizer
from google.colab.patches import cv2_imshow
from torchvision.utils import save_image
from torch.autograd import Variable
from PIL import Image

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
class TextToImageDataset(Dataset):
    def __init__(self, text_path, image_path, transform=None, tokenizer= None):
        self.text_path = text_path
        self.image_path = image_path
        self.transform = transform
        self.tokenizer = tokenizer
        data_list = []
        for filename in sorted(os.listdir(self.text_path)):
          with open(f'{self.text_path}/{filename}', 'r') as f:
              text = f.read()
              sentences = text.split('.')
              sentences = [sentence.strip() for sentence in sentences]

          for text in list(filter(lambda x: x != ' ', sentences)): # Sometimes blank text
            filename = filename.split('.')[0]
            data_dict = {'img_path': (f'{self.image_path}/{filename}.jpg'), 'text': text}
            data_list.append(data_dict)

        self.df = pd.DataFrame(data_list)

    def __len__(self):
        return len(self.df['text'])

    def __getitem__(self, index):
        # load the image and apply transformations if provided
        image = Image.open(self.df['img_path'][index])
        if self.transform:
            image = self.transform(image)

        # load the text description and encode it as a tensor
        text = self.df['text'][index]
        encoded_text = torch.tensor(self.tokenizer.encode(text)).float() 
        return encoded_text, image

    
def collate_fn(batch):
    # get the maximum text length in the batch
    max_text_length = 50
    
    # pad the text descriptions in the batch to have the same length
    padded_texts = []
    images = []
    for text, image in batch:
        padded_text = torch.zeros(max_text_length, dtype=torch.float)
        padded_text[:len(text)] = text[:max_text_length]
        padded_texts.append(padded_text)
        images.append(image)

    # stack the padded text descriptions and images into tensors
    padded_texts = torch.stack(padded_texts)
    images = torch.stack(images)
    return padded_texts, images

In [None]:
transform = transforms.Compose([
  transforms.Resize(64),
  transforms.CenterCrop(64),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [6]:
text_path = './texts'
image_path = './images'
dataset = TextToImageDataset(text_path, image_path, tokenizer = tokenizer,  transform=transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=4, shuffle=True,collate_fn=collate_fn)



In [7]:
class Generator(nn.Module):
    def __init__(self, input_size=50, num_channels=3):
        super(Generator, self).__init__()
        
        self.fc = nn.Linear(input_size, 256 * 8 * 8)
        self.conv1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.ConvTranspose2d(64, num_channels, 4, stride=2, padding=1)

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 256, 8, 8)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.tanh(self.conv3(x))
        return x

# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self, num_channels=3):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Conv2d(num_channels, 64, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.fc = nn.Linear(256 * 8 * 8, 1)

        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.bn2(self.conv2(x)))
        x = self.leakyrelu(self.bn3(self.conv3(x)))
        x = x.view(-1, 256 * 8 * 8)
        x = self.fc(x)
        return x

In [10]:
def train(generator,discriminator,dataloader, num_epochs=2, batch_size=64, learning_rate=0.0002):

    criterion = nn.BCEWithLogitsLoss()
    optimizerG = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for i, ( texts, images) in enumerate(dataloader):
            batch_size = images.size(0)

            # Train Discriminator
            discriminator.zero_grad()
            real_images = Variable(images)
            real_labels = Variable(torch.ones(batch_size)).unsqueeze(1)
            fake_labels = Variable(torch.zeros(batch_size)).unsqueeze(1)

            # Generate fake images from text embeddings
            noise = Variable(texts)
            fake_images = generator(noise).detach()

            # Train discriminator with real images
            real_logits = discriminator(real_images)

            d_loss_real = criterion(real_logits, real_labels)

            # Train discriminator with fake images
            fake_logits = discriminator(fake_images)
            d_loss_fake = criterion(fake_logits, fake_labels)

            # Compute total loss and update parameters
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizerD.step()

            # Train Generator
            generator.zero_grad()
            fake_labels = Variable(torch.ones(batch_size)).unsqueeze(1)

            # Generate fake images from text embeddings
            noise = Variable(texts)
            fake_images = generator(noise)
  
            # Compute generator loss and update parameters
            fake_logits = discriminator(fake_images)
            g_loss = criterion(fake_logits, fake_labels)
            g_loss.backward()
            optimizerG.step()

            print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

In [11]:
generator = Generator()
discriminator = Discriminator()

In [12]:
train(generator,discriminator,dataloader)
torch.save(generator, 'generator.pt')
torch.save(discriminator, 'discriminator.pt')



Epoch [1/2], Step [1/316], Discriminator Loss: 1.5381, Generator Loss: 1.7766
Epoch [1/2], Step [2/316], Discriminator Loss: 1.2705, Generator Loss: 1.9738
Epoch [1/2], Step [3/316], Discriminator Loss: 1.0981, Generator Loss: 2.4921
Epoch [1/2], Step [4/316], Discriminator Loss: 0.8341, Generator Loss: 2.6639
Epoch [1/2], Step [5/316], Discriminator Loss: 0.7644, Generator Loss: 2.7193
Epoch [1/2], Step [6/316], Discriminator Loss: 0.6763, Generator Loss: 3.1020
Epoch [1/2], Step [7/316], Discriminator Loss: 0.5863, Generator Loss: 3.1520
Epoch [1/2], Step [8/316], Discriminator Loss: 0.5949, Generator Loss: 2.9215
Epoch [1/2], Step [9/316], Discriminator Loss: 0.4637, Generator Loss: 3.1208
Epoch [1/2], Step [10/316], Discriminator Loss: 0.5483, Generator Loss: 3.2639
Epoch [1/2], Step [11/316], Discriminator Loss: 0.4181, Generator Loss: 3.6745
Epoch [1/2], Step [12/316], Discriminator Loss: 0.4134, Generator Loss: 3.6087
Epoch [1/2], Step [13/316], Discriminator Loss: 0.3287, Gener

In [13]:
from PIL import Image
from torchvision.transforms import ToPILImage, Compose, Normalize, ToTensor

generator = torch.load('generator.pt')


input_text = "she is smiling and has blue eyes"
encoded_text = torch.tensor(tokenizer.encode(input_text)).float()
padded_text = torch.zeros(50, dtype=torch.float)
padded_text[:len(encoded_text)] = encoded_text[:50]

fake_image = generator(padded_text)

# Convert the generated image tensor to a PIL image

fake_image = fake_image.detach().cpu()
fake_image = fake_image.squeeze(0)
fake_image = ToPILImage()(fake_image)

# Save the generated image to disk
fake_image.save('generated_image.png')