<a href="https://colab.research.google.com/github/Sangyups/VanillaGAN/blob/main/Vanilla_GAN(with_MNIST).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function
import torch
from torch import nn, optim, cuda
from torch.utils import data
from torchvision import datasets, transforms, utils
import torch.nn.functional as F
import time

# Training settings
batch_size = 100
device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Training MNIST Model on {device}\n{"=" * 44}')

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
                              train=False,
                              transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

print(train_dataset)

Training MNIST Model on cuda
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./mnist_data/
    Split: Train
    StandardTransform
Transform: ToTensor()


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
            nn.Linear(self.n_features, 256),
            nn.ReLU()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU()
        )
        self.fc3 = nn.Sequential(
            nn.Linear(1024, self.n_out),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.fc0 = nn.Sequential(
            nn.Linear(self.n_in, 1024),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(256, self.n_out),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


In [None]:
G = Generator().to(device)
D = Discriminator().to(device)

loss = torch.nn.BCELoss()

optimizer_G = optim.Adam(G.parameters(), lr=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=1e-4)


In [None]:
for epoch in range(200):

    # train
    G.train()
    D.train()
    i = 1
    for real_data, target in train_loader:
        # print(i)
        real_data = real_data.to(device)
        real_label = torch.ones(real_data.shape[0], 1).to(device)
        fake_label = torch.zeros(real_data.shape[0], 1).to(device)
        noise = torch.randn(real_data.shape[0], G.n_features).to(device)
        fake_data = G(noise)
        output = D(fake_data)
        loss_g = loss(output, real_label)
        optimizer_G.zero_grad()
        loss_g.backward()
        optimizer_G.step()

        fake_data = fake_data.detach()

        output_real = D(real_data)
        loss_d_real = loss(output_real, real_label)
        output_fake = D(fake_data)
        loss_d_fake = loss(output_fake, fake_label)
        loss_d_final = loss_d_real + loss_d_fake
        optimizer_D.zero_grad()
        loss_d_final.backward()
        optimizer_D.step()

    print("===========epoch:",epoch,"===========")
    if (epoch+1) % 10 == 0:
        fake_img = fake_data.reshape([batch_size, 1, 28, 28])
        img_grid = utils.make_grid(fake_img, nrow=10, normalize=True)
        utils.save_image(img_grid, "/content/gdrive/My Drive/Colab Notebooks/VanillaGAN_result/%d.png"%(epoch+1))
        print("image saved at epoch: ", epoch)
        
    
    # test
    # G.eval()
    # D.eval()
    # test_G_loss = 0
    # test_D_loss = 0
    # correct_real = 0
    # correct_fake = 0
    # for real_data, target in test_loader:
    #     real_data = real_data.to(device)
    #     real_label = torch.ones(real_data.shape[0], 1).to(device)
    #     fake_label = torch.zeros(real_data.shape[0], 1).to(device)

    #     noise = torch.randn(real_data.shape[0], 128).to(device)
    #     fake_data = G(noise)
    #     output = D(fake_data)
    #     test_G_loss += loss(output, real_label).item()

    #     fake_data = fake_data.detach()

    #     test_output_real = D(real_data)
    #     test_loss_d_real = loss(test_output_real, real_label)
    #     test_output_fake = D(fake_data)
    #     test_loss_d_fake = loss(test_output_fake, fake_label)
    #     test_loss_d_final = test_loss_d_real + test_loss_d_fake
    #     test_D_loss += test_loss_d_final
    #     correct_real += (test_output_real > 0.5).sum().item()
    #     correct_fake += (test_output_fake <= 0.5).sum().item()

    # test_G_loss /= len(test_loader.dataset)
    # test_D_loss /= len(test_loader.dataset)
    # print("============epoch: ",epoch,"==========")
    # print("Generator Loss:", loss_g.item())
    # print("Discriminator Loss:", loss_d_final.item())
    # print(f"Test set: Average Generator loss: {test_G_loss}, Average Discrminator loss: {test_D_loss}")
    # print(f"Accuracy for real image: {correct_real}/{len(test_loader.dataset)} ({100. * correct_real / len(test_loader.dataset):.0f}%)")
    # print(f"Accuracy for fake image: {correct_fake}/{len(test_loader.dataset)} ({100. * correct_fake / len(test_loader.dataset):.0f}%)")

image saved at epoch:  9
image saved at epoch:  19
image saved at epoch:  29
image saved at epoch:  39
image saved at epoch:  49


KeyboardInterrupt: ignored

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive
