In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
latent_dim = 100
num_classes = 10
batch_size = 64
num_epochs = 1
lr = 0.0002
beta1 = 0.5

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        z = torch.cat((self.label_emb(labels), z), -1)
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(784 + num_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        x = torch.cat((img_flat, self.label_emb(labels)), -1)
        validity = self.model(x)
        return validity

# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Function to generate and save images
def save_generated_images(epoch, generator, n_samples=25):
    z = torch.randn(n_samples, latent_dim).to(device)
    labels = torch.LongTensor(np.array([i for i in range(10) for _ in range(n_samples // 10)])).to(device)
    generated_imgs = generator(z, labels).cpu().detach()
    
    fig, axs = plt.subplots(5, 5, figsize=(10, 10))
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(generated_imgs[i*5 + j].squeeze(), cmap='gray')
            axs[i, j].set_title(f"Label: {labels[i*5 + j].item()}")
            axs[i, j].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'generated_images_epoch_{epoch}.png')
    plt.close()

# Lists to store loss values
G_losses = []
D_losses = []

# Training loop
for epoch in range(num_epochs):
    G_loss_epoch = 0
    D_loss_epoch = 0
    
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (imgs, labels) in progress_bar:
        batch_size = imgs.shape[0]

        # Adversarial ground truths
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # Configure input
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        gen_imgs = generator(z, gen_labels)
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), real)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs, labels), real)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Update progress bar
        G_loss_epoch += g_loss.item()
        D_loss_epoch += d_loss.item()
        progress_bar.set_postfix({'D loss': f"{d_loss.item():.4f}", 'G loss': f"{g_loss.item():.4f}"})

    # Calculate average losses for the epoch
    G_loss_epoch /= len(dataloader)
    D_loss_epoch /= len(dataloader)
    G_losses.append(G_loss_epoch)
    D_losses.append(D_loss_epoch)

    # Generate and save images
    if (epoch + 1) % 10 == 0 or epoch == 0:
        save_generated_images(epoch + 1, generator)

print("Training finished!")

# Plot loss curves
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig('loss_curve.png')
plt.close()

# Generate final set of images
save_generated_images(num_epochs, generator, n_samples=100)

print("Results have been saved!")

Epoch 1/1: 100%|██████████| 938/938 [01:36<00:00,  9.71it/s, D loss=0.1950, G loss=1.5907]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 20 but got size 25 for tensor number 1 in the list.