In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.models import resnet18
from transformers import BertTokenizer, BertModel

# Text Embedding using BERT
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
    
    def forward(self, text):
        inputs = self.bert(text)[0]  # Using the last hidden state
        return inputs.mean(dim=1)  # Averaging word embeddings

# Generator Model (cGAN)
class Generator(nn.Module):
    def __init__(self, z_dim, text_emb_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim + text_emb_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Tanh()
        )
    
    def forward(self, z, text_emb):
        x = torch.cat([z, text_emb], dim=1)  # Concatenate noise and text embedding
        return self.fc(x).view(-1, 3, 64, 64)  # Output image (64x64x3)

# Discriminator Model (cGAN)
class Discriminator(nn.Module):
    def __init__(self, text_emb_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_emb_dim + 3 * 64 * 64, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, text_emb):
        x = torch.cat([img.view(img.size(0), -1), text_emb], dim=1)  # Concatenate image and text embedding
        return self.fc(x)

# Training Setup
def train_gan(generator, discriminator, dataloader, text_encoder, device, epochs=10):
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for images, texts in dataloader:
            # Convert text to embeddings
            text_emb = text_encoder(texts).to(device)
            
            # Generate fake images
            z = torch.randn(images.size(0), 100).to(device)  # Random noise
            fake_images = generator(z, text_emb)

            # Real images (label as 1) and fake images (label as 0)
            real_labels = torch.ones(images.size(0), 1).to(device)
            fake_labels = torch.zeros(images.size(0), 1).to(device)
            
            # Train Discriminator
            optimizer_D.zero_grad()
            real_out = discriminator(images.to(device), text_emb)
            fake_out = discriminator(fake_images.detach(), text_emb)
            d_loss_real = criterion(real_out, real_labels)
            d_loss_fake = criterion(fake_out, fake_labels)
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_out = discriminator(fake_images, text_emb)
            g_loss = criterion(g_out, real_labels)  # Want generator to fool the discriminator
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

# Example of usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_encoder = TextEncoder().to(device)
generator = Generator(z_dim=100, text_emb_dim=768).to(device)  # BERT's embedding size
discriminator = Discriminator(text_emb_dim=768).to(device)

# Example DataLoader and training (Replace with your dataset)
dataloader = []  # Replace with your image-text dataset
train_gan(generator, discriminator, dataloader, text_encoder, device)
