In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable

In [2]:
# Generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()  # Tanh activation for image generation
        )

    def forward(self, x):
        return self.fc(x)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()  # Sigmoid activation for binary classification
        )

    def forward(self, x):
        return self.fc(x)

In [3]:
# Hyperparameters
batch_size = 64
lr = 0.0002
z_size = 100  # Size of the random noise vector
hidden_size = 128

# Load dataset (assuming MNIST for simplicity)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_loader = DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transform),
                          batch_size=batch_size, shuffle=True)

# Initialize networks and optimizers
generator = Generator(z_size, hidden_size, 28*28)
discriminator = Discriminator(28*28, hidden_size, 1)
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

0.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw


100.0%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz



100.0%
100.0%


Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw



In [None]:
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    for batch, (real_images, _) in enumerate(train_loader):
        # Train Discriminator
        optimizer_D.zero_grad()

        real_images = real_images.view(-1, 28*28)
        batch_size = real_images.size(0)  # Get the current batch size
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Forward pass real batch through discriminator
        output_real = discriminator(real_images)
        loss_real = nn.BCELoss()(output_real, real_labels)

        # Generate fake images
        noise = Variable(torch.randn(batch_size, z_size))
        fake_images = generator(noise)

        # Forward pass fake batch through discriminator
        output_fake = discriminator(fake_images.detach())  # Detach to avoid backprop through generator
        loss_fake = nn.BCELoss()(output_fake, fake_labels)

        # Backpropagation
        loss_d = loss_real + loss_fake
        loss_d.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        # Forward pass fake batch through discriminator again
        output_fake = discriminator(fake_images)
        loss_g = nn.BCELoss()(output_fake, real_labels)  # Generator wants discriminator to predict real labels

        # Backpropagation
        loss_g.backward()
        optimizer_G.step()

        # Print loss
        if epoch % 10 and batch % 250 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Batch [{batch}/{len(train_loader)}], '
                  f'D Loss: {loss_d.item():.4f}, G Loss: {loss_g.item():.4f}')

In [None]:
# After training, you can generate new images using the trained generator
noise = Variable(torch.randn(16, z_size))
generated_images = generator(noise)