In [21]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from nltk.tokenize import word_tokenize
import nltk
import torch.optim as optim
import torch.nn as nn

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\hardi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [22]:
def load_glove_embeddings(glove_file):
    embeddings_index = {}
    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

# Load GloVe embeddings
glove_file = 'goolove\glove.6B.300d.txt'
glove_model = load_glove_embeddings(glove_file)

  glove_file = 'goolove\glove.6B.300d.txt'


In [23]:
class StackGANDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=(64, 64), glove_model=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.img_size = img_size
        self.glove_model = glove_model

        for subdir in os.listdir(root_dir):
            subdir_path = os.path.join(root_dir, subdir)
            if os.path.isdir(subdir_path):
                try:
                    img_file = [f for f in os.listdir(subdir_path) if f.endswith('.jpg')][0]
                    txt_file = [f for f in os.listdir(subdir_path) if f.endswith('.txt')][0]
                    self.data.append((os.path.join(subdir_path, img_file), os.path.join(subdir_path, txt_file)))
                except IndexError as e:
                    print(f"Error: {e}, in directory {subdir_path}")
                except Exception as e:
                    print(f"Unexpected error: {e}, in directory {subdir_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            img_path, txt_path = self.data[idx]
            image = Image.open(img_path).convert('RGB')

            image = image.resize(self.img_size)  # Resize the image
            if self.transform:
                image = self.transform(image)

            with open(txt_path, 'r') as f:
                text = f.read().strip()

            text_embedding = self.text_to_embedding(text)
            return image, text_embedding
        except Exception as e:
            print(f"Error loading item at index {idx}: {e}")
            raise

    def text_to_embedding(self, text):
        words = word_tokenize(text.lower())
        embeddings = [self.glove_model[word] for word in words if word in self.glove_model]
        if embeddings:
            text_embedding = np.mean(embeddings, axis=0)
        else:
            text_embedding = np.zeros(len(next(iter(self.glove_model.values()))))
        return text_embedding


In [24]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # for Stage-I
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Create dataset and dataloader
dataset = StackGANDataset('stackgan_input', transform=transform, img_size=(128, 128), glove_model=glove_model)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)  # Changed num_workers to 0


In [25]:

class ConditionalAugmentation(nn.Module):
    def __init__(self, text_dim, projected_dim):
        super(ConditionalAugmentation, self).__init__()
        self.proj = nn.Linear(text_dim, projected_dim * 2)

    def forward(self, text_embedding):
        mu_logvar = self.proj(text_embedding)
        mu, logvar = mu_logvar.chunk(2, dim=1)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

class Generator_Stage1(nn.Module):
    def __init__(self, noise_dim, text_dim, projected_dim):
        super(Generator_Stage1, self).__init__()
        self.ca = ConditionalAugmentation(text_dim, projected_dim)
        self.fc = nn.Linear(noise_dim + projected_dim, 256 * 4 * 4)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # output: 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # output: 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # output: 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1), # output: 64x64
            nn.Tanh()
        )

    def forward(self, noise, text_embedding):
        cond_code = self.ca(text_embedding)
        z = torch.cat([noise, cond_code], dim=1)
        out = self.fc(z)
        out = out.view(-1, 256, 4, 4)
        out = self.main(out)
        return out

class Discriminator_Stage1(nn.Module):
    def __init__(self, text_dim):
        super(Discriminator_Stage1, self).__init__()
        self.img_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.text_proj = nn.Linear(text_dim, 256)
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),  # Changed kernel size to 3
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, img, text_embedding):
        noise = torch.randn_like(img) * 0.1  # Add small noise to images
        img_features = self.img_encoder(img + noise)
        # img_features = self.img_encoder(img)
        text_features = self.text_proj(text_embedding)

        text_features = text_features.view(-1, 256, 1, 1)
        text_features = text_features.repeat(1, 1, img_features.size(2), img_features.size(3))

        features = torch.cat([img_features, text_features], dim=1)
        out = self.classifier(features)
        
        # Average over the spatial dimensions
        out = out.view(out.size(0), -1).mean(dim=1, keepdim=True)
        return out


In [26]:

# Model initialization
noise_dim = 100
projected_dim = 128
text_dim = 300

netG1 = Generator_Stage1(noise_dim, text_dim, projected_dim)
netD1 = Discriminator_Stage1(text_dim)

In [27]:
# Optimizers
optimizerD1 = optim.Adam(netD1.parameters(), lr=0.00004, betas=(0.5, 0.999))
optimizerG1 = optim.Adam(netG1.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
# criterion = nn.BCELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG1.to(device)
netD1.to(device)
# criterion.to(device)

Discriminator_Stage1(
  (img_encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (text_proj): Linear(in_features=300, out_features=256, bias=True)
  (classifier): Sequential(
    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    

In [29]:
num_epochs=10
for epoch in range(num_epochs):
    for i, (real_imgs, text_embeddings) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32).to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizerD1.zero_grad()

        noise = torch.randn(len(real_imgs), noise_dim).to(device)
        fake_imgs = netG1(noise, text_embeddings)

        real_validity = netD1(real_imgs, text_embeddings)
        fake_validity = netD1(fake_imgs.detach(), text_embeddings)

        # gradient_penalty = compute_gradient_penalty(netD1, real_imgs.data, fake_imgs.data, text_embeddings)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        # d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

        d_loss.backward()
        optimizerD1.step()

        # ------------------
        #  Train Generator
        # ------------------

        optimizerG1.zero_grad()

        gen_validity = netD1(fake_imgs, text_embeddings)
        g_loss = -torch.mean(gen_validity)

        g_loss.backward()
        optimizerG1.step()
        print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(dataloader)}] Loss D: {d_loss.item()}, loss G: {g_loss.item()}")

  text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32).to(device)


Epoch [1/10] Batch [1/469] Loss D: -0.40161705017089844, loss G: -0.24863529205322266
Epoch [1/10] Batch [2/469] Loss D: -0.37660855054855347, loss G: -0.11486691981554031
Epoch [1/10] Batch [3/469] Loss D: -0.38887912034988403, loss G: -0.187670037150383
Epoch [1/10] Batch [4/469] Loss D: -0.35497623682022095, loss G: -0.25086838006973267
Epoch [1/10] Batch [5/469] Loss D: -0.29792386293411255, loss G: -0.13597102463245392
Epoch [1/10] Batch [6/469] Loss D: -0.2904597520828247, loss G: -0.22203028202056885
Epoch [1/10] Batch [7/469] Loss D: -0.36134710907936096, loss G: -0.1591770350933075
Epoch [1/10] Batch [8/469] Loss D: -0.34976255893707275, loss G: -0.10442692041397095
Epoch [1/10] Batch [9/469] Loss D: -0.4185994863510132, loss G: -0.22135594487190247
Epoch [1/10] Batch [10/469] Loss D: -0.4643121361732483, loss G: -0.12373197078704834
Epoch [1/10] Batch [11/469] Loss D: -0.4379749596118927, loss G: -0.2481486201286316
Epoch [1/10] Batch [12/469] Loss D: -0.430651992559433, loss

In [30]:
# Save the trained Stage-I models
torch.save(netG1.state_dict(), 'netG1.pth')
torch.save(netD1.state_dict(), 'netD1.pth')

In [31]:
def load_glove_embeddings(glove_file):
    embeddings_index = {}
    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

glove_file = 'goolove\glove.6B.300d.txt'
glove_model = load_glove_embeddings(glove_file)
text_dim = 300  # GloVe embedding dimension
projected_dim = 128
noise_dim = 100

  glove_file = 'goolove\glove.6B.300d.txt'


In [34]:

# Load the trained models
netG1 = Generator_Stage1(noise_dim, text_dim, projected_dim)
netG1.load_state_dict(torch.load('netG1.pth'))
netG1.eval()

# netG2 = Generator_Stage2(text_dim, projected_dim)
# netG2.load_state_dict(torch.load('netG2.pth'))
# netG2.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG1.to(device)
# netG2.to(device)

nltk.download('punkt')

# Function to get text embeddings using GloVe
def get_text_embedding(text, glove_model):
    words = word_tokenize(text.lower())
    embeddings = [glove_model[word] for word in words if word in glove_model]
    if embeddings:
        text_embedding = np.mean(embeddings, axis=0)
    else:
        text_embedding = np.zeros(len(next(iter(glove_model.values()))))
    text_embedding = torch.tensor(text_embedding, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    return text_embedding.to(device)

# Function to generate and save image from text
def generate_image_from_text(text, noise_dim, glove_model):
    text_embedding = get_text_embedding(text, glove_model)
    
    noise = torch.randn(1, noise_dim).to(device)
    with torch.no_grad():
        fake_img_stage1 = netG1(noise, text_embedding)
        # fake_img_stage2 = netG2(fake_img_stage1, text_embedding)

    # Convert the generated image to a PIL image and save
    img = fake_img_stage1.squeeze().cpu().numpy()
    img = np.transpose(img, (1, 2, 0))
    img = (img + 1) / 2.0 * 255  # Rescale to [0, 255]
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    
    # img.save('generated_image.png')
    return img

# Example usage
text_input = "digit illustr kabuto rock'n'roll watertyp dodo pok√©mon character brown shell larg black eye cerise pupil promin yellow claw"
generated_image = generate_image_from_text(text_input, noise_dim, glove_model)
generated_image.show()

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\hardi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
