<a href="https://colab.research.google.com/github/Sangyups/VanillaGAN/blob/main/Vanilla_GAN(with_MNIST).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from __future__ import print_function
from torch import nn, optim, cuda
from torch.utils import data
from torchvision import datasets, transforms
import torch.nn.functional as F
import time

# Training settings
batch_size = 64
device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Training MNIST Model on {device}\n{"=" * 44}')

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
                              train=False,
                              transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)


Training MNIST Model on cuda
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
            nn.Linear(self.n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(1024, self.n_out),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.fc0 = nn.Sequential(
            nn.Linear(self.n_in, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(256, self.n_out),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


In [5]:
import torch

G = Generator().to(device)
D = Discriminator().to(device)

loss = torch.nn.BCELoss()

optimizer_G = optim.Adam(G.parameters(), lr=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=1e-4)


In [48]:
for epoch in range(10):

    # train
    G.train()
    D.train()
    for real_data, target in train_loader:
        real_data = real_data.to(device)
        real_label = torch.ones(real_data.shape[0], 1).to(device)
        fake_label = torch.zeros(real_data.shape[0], 1).to(device)
        
        noise = torch.randn(real_data.shape[0], 128).to(device)
        fake_data = G(noise)
        output = D(fake_data)
        loss_g = loss(output, real_label)
        optimizer_G.zero_grad()
        loss_g.backward()
        optimizer_G.step()

        fake_data = fake_data.detach()

        output_real = D(real_data)
        loss_d_real = loss(output_real, real_label)
        output_fake = D(fake_data)
        loss_d_fake = loss(output_fake, fake_label)
        loss_d_final = loss_d_real + loss_d_fake
        optimizer_D.zero_grad()
        loss_d_final.backward()
        optimizer_D.step()
    
    # test
    G.eval()
    D.eval()
    test_G_loss = 0
    test_D_loss = 0
    correct_real = 0
    correct_fake = 0
    for real_data, target in test_loader:
        real_data = real_data.to(device)
        real_label = torch.ones(real_data.shape[0], 1).to(device)
        fake_label = torch.zeros(real_data.shape[0], 1).to(device)

        noise = torch.randn(real_data.shape[0], 128).to(device)
        fake_data = G(noise)
        output = D(fake_data)
        test_G_loss += loss(output, real_label).item()

        fake_data = fake_data.detach()

        test_output_real = D(real_data)
        test_loss_d_real = loss(test_output_real, real_label)
        test_output_fake = D(fake_data)
        test_loss_d_fake = loss(test_output_fake, fake_label)
        test_loss_d_final = test_loss_d_real + test_loss_d_fake
        test_D_loss += test_loss_d_final
        correct_real += (test_output_real > 0.5).sum().item()
        correct_fake += (test_output_fake <= 0.5).sum().item()

    test_G_loss /= len(test_loader.dataset)
    test_D_loss /= len(test_loader.dataset)
    print("============epoch: ",epoch,"==========")
    print("Generator Loss:", loss_g.item())
    print("Discriminator Loss:", loss_d_final.item())
    print(f"Test set: Average Generator loss: {test_G_loss}, Average Discrminator loss: {test_D_loss}")
    print(f"Accuracy for real image: {correct_real}/{len(test_loader.dataset)} ({100. * correct_real / len(test_loader.dataset):.0f}%)")
    print(f"Accuracy for fake image: {correct_fake}/{len(test_loader.dataset)} ({100. * correct_fake / len(test_loader.dataset):.0f}%)")

Generator Loss: 4.388454437255859
Discriminator Loss: 0.11580228060483932
Test set: Average Generator loss: 0.06924490418434143, Average Discrminator loss: 0.0015436699613928795
Accuracy for real image: 9798/10000 (98%)
Accuracy for fake image: 9899/10000 (99%)
Generator Loss: 5.253950595855713
Discriminator Loss: 0.14912806451320648
Test set: Average Generator loss: 0.08664440369606018, Average Discrminator loss: 0.0013231942430138588
Accuracy for real image: 9731/10000 (97%)
Accuracy for fake image: 9982/10000 (100%)
Generator Loss: 5.708548069000244
Discriminator Loss: 0.4468453526496887
Test set: Average Generator loss: 0.07956724653244018, Average Discrminator loss: 0.0014663340989500284
Accuracy for real image: 9684/10000 (97%)
Accuracy for fake image: 9996/10000 (100%)
Generator Loss: 3.2601046562194824
Discriminator Loss: 0.3553532361984253
Test set: Average Generator loss: 0.06482096400260925, Average Discrminator loss: 0.0013215368380770087
Accuracy for real image: 9866/10000