In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 50
batch_size = 100
learning_rate = 0.0002


In [20]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])


In [21]:

mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)

In [22]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, image_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.tanh(self.fc3(x))
        return x

In [23]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(image_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x



In [24]:

generator = Generator()
discriminator = Discriminator()


In [25]:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)

In [29]:
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        images = images.reshape(batch_size, -1)

        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, latent_size)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        z = torch.randn(batch_size, latent_size)
        fake_images = generator(z)
        outputs = discriminator(fake_images)

        g_loss = criterion(outputs, real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, ')

    if (epoch + 1) % 10 == 0:
        save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), f'fake_images_{epoch + 1}.png')



Epoch [0/50], d_loss: 0.1387, g_loss: 3.5160, 
Epoch [1/50], d_loss: 0.2875, g_loss: 4.0993, 
Epoch [2/50], d_loss: 0.4883, g_loss: 2.4538, 
Epoch [3/50], d_loss: 0.6047, g_loss: 2.4206, 
Epoch [4/50], d_loss: 0.8132, g_loss: 2.7440, 
Epoch [5/50], d_loss: 0.3209, g_loss: 3.7389, 
Epoch [6/50], d_loss: 0.5135, g_loss: 2.8085, 
Epoch [7/50], d_loss: 0.4841, g_loss: 2.6978, 
Epoch [8/50], d_loss: 0.1945, g_loss: 3.2825, 
Epoch [9/50], d_loss: 0.3565, g_loss: 2.0777, 
Epoch [10/50], d_loss: 0.3243, g_loss: 3.5435, 
Epoch [11/50], d_loss: 0.3373, g_loss: 4.6024, 
Epoch [12/50], d_loss: 0.3366, g_loss: 4.1664, 
Epoch [13/50], d_loss: 0.4799, g_loss: 4.3709, 
Epoch [14/50], d_loss: 0.4109, g_loss: 3.3456, 
Epoch [15/50], d_loss: 0.2784, g_loss: 3.8478, 
Epoch [16/50], d_loss: 0.2923, g_loss: 3.8609, 
Epoch [17/50], d_loss: 0.5372, g_loss: 2.7150, 
Epoch [18/50], d_loss: 0.6222, g_loss: 4.2521, 
Epoch [19/50], d_loss: 0.2586, g_loss: 3.5641, 
Epoch [20/50], d_loss: 0.1828, g_loss: 4.1899, 
Ep