In [None]:
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize BERT tokenizer and model
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertModel.from_pretrained(bert_model_name)

class TextEmbeddingBERT(nn.Module):
    def __init__(self, bert_model):
        super(TextEmbeddingBERT, self).__init__()
        self.bert_model = bert_model

    def forward(self, text):
        # Tokenize and encode text with BERT
        input_ids = tokenizer(text, padding=True, truncation=True, return_tensors='pt')['input_ids']
        with torch.no_grad():
            bert_output = self.bert_model(input_ids)
        text_embedding = bert_output.last_hidden_state.mean(dim=1)
        return text_embedding

# Instantiate TextEmbeddingBERT
text_embedding_model = TextEmbeddingBERT(bert_model).to(device)


# Custom dataset class using the combined dataframe
class BirdsDataset(Dataset):
    def __init__(self, image_folder, full_image_data, transform=None):
        self.image_folder = image_folder
        self.full_image_data = full_image_data
        self.transform = transform

    def __len__(self):
        return len(self.full_image_data)

    def __getitem__(self, idx):
        # Load and transform image
        image_path = os.path.join(self.image_folder, self.full_image_data.iloc[idx]['image_name'])
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)

        # Get class name and tokenize for BERT
        class_name = self.full_image_data.iloc[idx]['class_name']
        text_embedding = text_embedding_model(class_name)

        return image, text_embedding

# Load dataset files and merge dataframes
image_folder = '/content/drive/MyDrive/CUB_200_2011/CUB_200_2011/images'
classes_df = pd.read_csv('/content/drive/MyDrive/CUB_200_2011/CUB_200_2011/classes.txt', sep=' ', names=['class_id', 'class_name'], index_col='class_id')
images_df = pd.read_csv('/content/drive/MyDrive/CUB_200_2011/CUB_200_2011/images.txt', sep=' ', names=['image_id', 'image_name'])
image_class_map_df = pd.read_csv('/content/drive/MyDrive/CUB_200_2011/CUB_200_2011/image_class_labels.txt', sep=' ', names=['image_id', 'class_id'])

# Merge to get full image data
full_image_data = image_class_map_df.merge(images_df, on='image_id').merge(classes_df, on='class_id')

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    # Add any other transformations you need
])

# Create dataset and dataloader
dataset = BirdsDataset(image_folder, full_image_data, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_noise(n_samples, noise_dim, device='cpu'):
    '''
    Generate noise vectors from the random normal distribution with dimensions (n_samples, noise_dim),
    where
        n_samples: the number of samples to generate based on  batch_size
        noise_dim: the dimension of the noise vector
        device: device type can be cuda or cpu
    '''

    return  torch.randn(n_samples,noise_dim, 1,1,device=device)

class Generator(nn.Module):
    def __init__(self, noise_dim, text_embedding_dim, gen_dim, no_of_channels=1):
      super(Generator, self).__init__()
      #self.text_embedding_dim = text_embedding_dim
      self.network = nn.Sequential(
          nn.ConvTranspose2d(expected_channels, gen_dim * 4, 4, 1, 0, bias=False),
          nn.BatchNorm2d(gen_dim*4),
          nn.ReLU(True),

          nn.ConvTranspose2d(gen_dim*4, gen_dim*2, 3, 2, 1, bias=False),
          nn.BatchNorm2d(gen_dim*2),
          nn.ReLU(True),

          nn.ConvTranspose2d(gen_dim*2, gen_dim, 4, 2, 1, bias=False),
          nn.BatchNorm2d(gen_dim),
          nn.ReLU(True),

          nn.ConvTranspose2d(gen_dim, no_of_channels, 4, 2, 1, bias=False),
          nn.Tanh()
      )

    def forward(self, noise, text_embedding):
        combined_input = torch.cat((noise, text_embedding), 1)  # Assuming text_embedding is properly reshaped
        print("Combined input shape:", combined_input.shape)
        return self.network(combined_input)

class Discriminator(nn.Module):
    def __init__(self, no_of_channels=1, disc_dim=32):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(

                nn.Conv2d(no_of_channels, disc_dim, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(disc_dim, disc_dim * 2, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(disc_dim * 2, affine=True),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(disc_dim * 2, disc_dim * 4, 3, 2, 1, bias=False),
                nn.InstanceNorm2d(disc_dim * 4, affine=True),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(disc_dim * 4, 1, 4, 1, 0, bias=False),

            )
    def forward(self, input):
        output = self.network(input)
        #return output.view(-1, 1).squeeze(1)
        return output

noise_dim = 100  # Size of the noise vector
text_embedding_dim = 256
expected_channels = noise_dim + text_embedding_dim
gen = Generator(noise_dim=noise_dim, text_embedding_dim=text_embedding_dim, gen_dim=32, no_of_channels=1).to(device)
critic =Discriminator().to(device)


In [None]:
# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(m.bias, val=0)
gen = gen.apply(weights_init)
critic = critic.apply(weights_init)

lr = 5e-5
#criterion = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.RMSprop(gen.parameters(), lr=lr)
critic_opt = torch.optim.RMSprop(critic.parameters(), lr=lr)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def display_images(images_tensor):
    # Move the images tensor to CPU and convert to NumPy array
    images = images_tensor.cpu().detach().numpy()

    # Adjust the shape of the images for plotting
    images = np.transpose(images, (0, 2, 3, 1))

    # Scale the images from [-1, 1] to [0, 1] range if necessary
    images = (images + 1) / 2

    # Plotting
    fig, axs = plt.subplots(1, len(images), figsize=(15, 15))
    for i, img in enumerate(images):
        if images.shape[1] == 1:  # for grayscale images
            axs[i].imshow(img.squeeze(), cmap='gray')
        else:
            axs[i].imshow(img)
        axs[i].axis('off')
    plt.show()


In [None]:
n_epochs = 5
cur_step = 0
display_step = 500
z_dim = 100
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

for epoch in range(n_epochs):
    for real_image, text_embeddings in tqdm(dataloader):
        cur_batch_size = real_image.shape[0]
        real_image = real_image.to(device)
        text_embeddings = text_embeddings.to(device).view(cur_batch_size, -1, 1, 1)

        for _ in range(CRITIC_ITERATIONS):
            # Generate fake images
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise, text_embeddings)

            # Calculate the critic's predictions
            critic_real_pred = critic(real_image).reshape(-1)
            critic_fake_pred = critic(fake).reshape(-1)

            # Calculate gradient penalty on real and fake images
            gp = gradient_penalty(critic, real_image, fake, device)
            critic_loss = -(torch.mean(critic_real_pred) - torch.mean(critic_fake_pred)) + LAMBDA_GP * gp

            # Update critic
            critic.zero_grad()
            critic_loss.backward()
            critic_opt.step()

        # Update generator
        gen_fake = critic(fake).reshape(-1)
        gen_loss = -torch.mean(gen_fake)
        gen.zero_grad()
        gen_loss.backward()
        gen_opt.step()

        # Visualization and Logging
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {gen_loss}, Critic loss: {critic_loss}")
            display_images(fake)
            display_images(real_image)
            gen_loss = 0
            critic_loss = 0
        cur_step += 1

  0%|          | 0/369 [00:00<?, ?it/s]

torch.Size([32, 768, 1, 1])
Combined input shape: torch.Size([32, 868, 1, 1])


RuntimeError: ignored