<a href="https://colab.research.google.com/github/Yyzhang2000/learning-generative-models/blob/main/gan/01_mnist_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [3]:
batch_size = 128
data_path = "./data"

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root=data_path, train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=data_path, train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

100%|██████████| 9.91M/9.91M [00:11<00:00, 900kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 65.6kB/s]
100%|██████████| 1.65M/1.65M [00:06<00:00, 241kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.29MB/s]


In [4]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256 * 2)
        self.fc3 = nn.Linear(256 * 2, 256 * 4)
        self.fc4 = nn.Linear(256 * 4, output_dim)

        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.act(self.fc3(x))
        x = self.act(self.fc4(x))
        return x

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 1024 // 2)
        self.fc3 = nn.Linear(1024 // 2, 1024 // 4)
        self.fc4 = nn.Linear(1024 // 4, 1)

        self.act = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(0.3)


    def forward(self, x):
        x = self.dropout(self.act(self.fc1(x)))
        x = self.dropout(self.act(self.fc2(x)))
        x = self.dropout(self.act(self.fc3(x)))

        return F.sigmoid(self.fc4(x))


In [5]:
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
mnist_dim



784

In [6]:
z_dim = 100

generator = Generator(
    100,
    mnist_dim
).to(device)

discriminator = Discriminator(mnist_dim).to(device)


In [7]:
criterion = nn.BCELoss()

lr = 1e-4
g_optimizer = optim.Adam(generator.parameters(), lr = lr )
d_optimizer = optim.Adam(discriminator.parameters(), lr = lr )

In [8]:
def train_discrimator(model, data):
    B = data.shape[0]
    x_real, y_real = data.view(-1, mnist_dim), torch.ones(B, 1)
    x_real = x_real.to(device)
    y_real = y_real.to(device)

    model_out = model(x_real)
    model_real_loss = criterion(model_out, y_real)
    model_real_score = model_out

    # Training Discrimnator on fake image
    z = torch.randn(B, z_dim).to(device)
    x_fake, y_fake = generator(z), torch.zeros(B, 1).to(device)

    model_out = model(x_fake)
    model_fake_loss = criterion(model_out, y_fake)
    model_fake_score = model_out

    model_loss = model_real_loss + model_fake_loss
    d_optimizer.zero_grad()
    model_loss.backward()
    d_optimizer.step()

    return model_loss.item()


In [9]:
def train_generator(model, data):
    B = data.shape[0]
    z = torch.randn(B, z_dim).to(device)
    y = torch.ones(B, 1).to(device)

    model_output = generator(z)
    distrimator_output = discriminator(model_output)
    model_loss = criterion(distrimator_output, y)

    g_optimizer.zero_grad()
    model_loss.backward()
    g_optimizer.step()

    return model_loss.item()

In [10]:
n_epochs = 200

def train():
    for epoch in range(1, n_epochs + 1):
        d_losses, g_losses = [], []

        for batch_idx, (x, _) in enumerate(train_loader):
            d_loss = train_discrimator(discriminator, x)
            g_loss = train_generator(generator, x)

            d_losses.append(d_loss)
            g_losses.append(g_loss)

        print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epochs, torch.mean(torch.FloatTensor(d_losses)), torch.mean(torch.FloatTensor(g_losses))))


In [11]:
train()

[1/200]: loss_d: 0.687, loss_g: 1.964
[2/200]: loss_d: 0.697, loss_g: 2.440
[3/200]: loss_d: 0.501, loss_g: 3.622
[4/200]: loss_d: 0.318, loss_g: 4.245
[5/200]: loss_d: 0.216, loss_g: 4.437
[6/200]: loss_d: 0.178, loss_g: 5.341
[7/200]: loss_d: 0.202, loss_g: 5.333
[8/200]: loss_d: 0.179, loss_g: 5.205


KeyboardInterrupt: 

In [None]:
from torchvision.utils import save_image
with torch.no_grad():
    test_z = torch.randn(10, z_dim).to(device)
    generated = generator(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')