<a href="https://colab.research.google.com/github/Anspire/Notebooks/blob/master/GAN_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable

import torchvision
import torchvision.transforms as transforms

In [0]:
# Preprocessing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Training data
train_set = torchvision.datasets.MNIST(root='.',
                                      train=True,
                                      download=True,
                                      transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,
                                          batch_size=32,
                                          shuffle=True)

# Labels
classes = [str(i) for i in range(0,10)]
print(classes)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 9895936/9912422 [00:34<00:00, 220629.76it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz



0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s][A
 57%|█████▋    | 16384/28881 [00:00<00:00, 86224.52it/s][A
32768it [00:00, 56551.03it/s]                           [A
0it [00:00, ?it/s][A

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s][A
  1%|          | 16384/1648877 [00:00<00:21, 76704.32it/s][A
  2%|▏         | 40960/1648877 [00:00<00:18, 88085.69it/s][A
  4%|▍         | 73728/1648877 [00:00<00:15, 103981.97it/s][A
  7%|▋         | 114688/1648877 [00:01<00:12, 123942.73it/s][A
  9%|▉         | 147456/1648877 [00:01<00:10, 136665.55it/s][A
 11%|█▏        | 188416/1648877 [00:01<00:09, 154801.32it/s][A
 14%|█▍        | 229376/1648877 [00:01<00:08, 170700.20it/s][A
 16%|█▋        | 270336/1648877 [00:01<00:07, 183817.93it/s][A
 19%|█▉        | 319488/1648877 [00:02<00:06, 203140.33it/s][A
 22%|██▏       | 360448/1648877 [00:02<00:06, 208989.97it/s][A
 25%|██▍       | 409600/1648877 [00:02<00:05, 223885.43it/s][A
 28%|██▊       | 466944/1648877 [00:02<00:04, 245055.52it/s][A
 31%|███▏      | 516096/1648877 [00:02<00:04, 251744.70it/s][A
 35%|███▍      | 573440/1648877 [00:02<00:04, 267687.53it/s][A
 39%|███▉      | 638976/1648877 [00:03<00:03, 290167.69it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz




  0%|          | 0/4542 [00:00<?, ?it/s][A[A

8192it [00:00, 21184.36it/s]            [A[A

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


In [0]:
# Our Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.model(x.view(x.size(0), 784))
        out = out.view(out.size(0), -1)
        return out.cuda()
        
discriminator = Discriminator()

In [0]:
# Our Generator class
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), 100)
        out = self.model(x).cuda()
        return out

generator = Generator()

In [0]:
# If we have a GPU with CUDA, use it
if torch.cuda.is_available():
    print("Using CUDA")
    discriminator.cuda()
    generator.cuda()

# Setup loss function and optimizers
lr = 0.0001
num_epochs = 40
num_batches = len(train_loader)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

Using CUDA


In [0]:
# Convenience function for training our Discriminator
def train_discriminator(discriminator, real_images, real_labels, fake_images, fake_labels):
    discriminator.zero_grad()

    # Get the predictions, loss, and score of the real images
    predictions = discriminator(real_images)
    real_loss = criterion(predictions, real_labels)
    real_score = predictions

    # Get the predictions, loss, and score of the fake images
    predictions = discriminator(fake_images)
    fake_loss = criterion(predictions, fake_labels)
    fake_score = predictions

    # Calculate the total loss, update the weights, and update the optimizer
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss, real_score, fake_score

# Convenience function for training our Generator
def train_generator(generator, discriminator_outputs, real_labels):
    generator.zero_grad()

    # Calculate the total loss, update the weights, and update the optimizer
    g_loss = criterion(discriminator_outputs, real_labels)
    g_loss.backward()
    g_optimizer.step()
    return g_loss
    
for epoch in range(num_epochs):
    for n, (images, _) in enumerate(train_loader):

        # (1) Prepare the real data for the Discriminator
        real_images = Variable(images).cuda()
        real_labels = Variable(torch.ones(images.size(0))).cuda()

        # (2) Prepare the random noise data for the Generator
        noise = Variable(torch.randn(images.size(0), 100)).cuda()

        # (3) Prepare the fake data for the Discriminator
        fake_images = generator(noise)
        fake_labels = Variable(torch.zeros(images.size(0))).cuda()

        # (4) Train the discriminator on real and fake data
        d_loss, real_score, fake_score = train_discriminator(discriminator,
                                                             real_images, real_labels,
                                                             fake_images, fake_labels)

        # (5a) Generate some new fake images from the Generator.
        # (5b) Get the label predictions of the Discriminator on that fake data.
        noise = Variable(torch.randn(images.size(0), 100)).cuda()
        fake_images = generator(noise)

        outputs = discriminator(fake_images)

        # (6) Train the generator
        g_loss = train_generator(generator, outputs, real_labels)
    print(epoch)