In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, text_embed_dim, img_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim + text_embed_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, img_size * img_size * 3), nn.Tanh()
        )
        self.img_size = img_size

    def forward(self, noise, text_embedding):
        x = torch.cat((noise, text_embedding), dim=1)
        x = self.model(x)
        return x.view(-1, 3, self.img_size, self.img_size)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_size, text_embed_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size * 3 + text_embed_dim, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 1), nn.Sigmoid()
        )

    def forward(self, img, text_embedding):
        img = img.view(img.size(0), -1)
        x = torch.cat((img, text_embedding), dim=1)
        return self.model(x)

# Hyperparameters
noise_dim, text_embed_dim, img_size = 100, 128, 64
batch_size, lr, num_epochs = 32, 0.0002, 100

# Models, Loss, Optimizers
G, D = Generator(noise_dim, text_embed_dim, img_size).to(device), Discriminator(img_size, text_embed_dim).to(device)
criterion = nn.BCELoss()
optimizer_G, optimizer_D = optim.Adam(G.parameters(), lr=lr), optim.Adam(D.parameters(), lr=lr)

# Data generation
def generate_noise(batch_size, noise_dim):
    return torch.randn(batch_size, noise_dim).to(device)

def generate_text_embeddings(batch_size, text_embed_dim):
    return torch.randn(batch_size, text_embed_dim).to(device)

# Training
for epoch in range(num_epochs):
    for _ in range(batch_size):
        real_imgs = torch.randn(batch_size, 3, img_size, img_size).to(device)
        real_text = generate_text_embeddings(batch_size, text_embed_dim)
        noise = generate_noise(batch_size, noise_dim)
        fake_text = generate_text_embeddings(batch_size, text_embed_dim)

        # Discriminator
        fake_imgs = G(noise, fake_text)
        d_loss = criterion(D(real_imgs, real_text), torch.ones(batch_size, 1).to(device)) + \
                 criterion(D(fake_imgs.detach(), fake_text), torch.zeros(batch_size, 1).to(device))
        optimizer_D.zero_grad(); d_loss.backward(); optimizer_D.step()

        # Generator
        g_loss = criterion(D(fake_imgs, fake_text), torch.ones(batch_size, 1).to(device))
        optimizer_G.zero_grad(); g_loss.backward(); optimizer_G.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
    if (epoch + 1) % 10 == 0:
        save_image(fake_imgs[:25], f'generated_images_{epoch + 1}.png', nrow=5, normalize=True)

# Save models
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')


Epoch [1/100], d_loss: 0.6261, g_loss: 0.9230
Epoch [2/100], d_loss: 0.6429, g_loss: 0.9324
Epoch [3/100], d_loss: 0.2378, g_loss: 1.9684
Epoch [4/100], d_loss: 1.4385, g_loss: 4.2193
Epoch [5/100], d_loss: 3.4767, g_loss: 4.6963
Epoch [6/100], d_loss: 1.1495, g_loss: 2.3880
Epoch [7/100], d_loss: 0.4785, g_loss: 2.9321
Epoch [8/100], d_loss: 0.7469, g_loss: 0.8995
Epoch [9/100], d_loss: 1.4935, g_loss: 0.4295
Epoch [10/100], d_loss: 1.1057, g_loss: 3.9364
Epoch [11/100], d_loss: 1.3725, g_loss: 5.9522
Epoch [12/100], d_loss: 1.1523, g_loss: 4.9878
Epoch [13/100], d_loss: 0.6431, g_loss: 5.2546
Epoch [14/100], d_loss: 1.0472, g_loss: 6.0711
Epoch [15/100], d_loss: 0.2665, g_loss: 6.8469
Epoch [16/100], d_loss: 0.5779, g_loss: 7.6480
Epoch [17/100], d_loss: 0.3265, g_loss: 8.3808
Epoch [18/100], d_loss: 0.4605, g_loss: 10.4592
Epoch [19/100], d_loss: 0.5311, g_loss: 18.6118
Epoch [20/100], d_loss: 0.3360, g_loss: 47.6461
Epoch [21/100], d_loss: 0.1350, g_loss: 9.5121
Epoch [22/100], d_l

In [15]:
pip install torchvision


Collecting torchvision
  Downloading torchvision-0.20.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.1 kB)
Downloading torchvision-0.20.1-cp312-cp312-macosx_11_0_arm64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m471.2 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchvision
Successfully installed torchvision-0.20.1
Note: you may need to restart the kernel to use updated packages.
