In [2]:
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import transforms,datasets
from utils.Logger import Logger



In [18]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((.5),(.5))
        ])
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir,train=True,transform=compose,download=True)

In [19]:
# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)
# Num batches
num_batches = len(data_loader)

In [172]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 784
        n_out = 1

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=2, stride=2),
            nn.LeakyReLU(0.2, inplace=True),


            nn.Conv2d(64, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, kernel_size=7, stride=1),
            nn.Sigmoid()


        )

    def forward(self, x):
        print(f"Discriminator input shape: {x.shape}")
        x = self.conv(x)


        return x
discriminator = DiscriminatorNet()

In [173]:
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [174]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(n_features, 256, kernel_size=4, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 1, kernel_size=7, stride=1),
            nn.Tanh()
        )


    def forward(self, x):
        x = x.view(x.size(0), 100, 1, 1)
        print(f"Generator input shape: {x.shape}")
        x = self.deconv(x)
        return x
generator = GeneratorNet()

In [175]:
def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size,1, 100))
    return n
from matplotlib import pyplot as plt

In [176]:
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

In [177]:
loss = nn.BCELoss()

In [178]:
def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = (torch.ones(size, 1))
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = (torch.zeros(size, 1))
    return data

In [179]:
def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()

    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, ones_target(N) )
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()

    # 1.3 Update weights with gradients
    optimizer.step()

    # Return error and predictions for real and fake inputs
    return error_real + error_fake, prediction_real, prediction_fake

In [180]:
def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, ones_target(N))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

In [181]:
num_test_samples = 16
test_noise = noise(num_test_samples)

In [183]:
# Create logger instance
from IPython.display import clear_output

logger = Logger(model_name='VGAN', data_name='MNIST')
# Total number of epochs to train
num_epochs = 200
for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0)
        # 1. Train Discriminator
        real_data = torch.tensor(real_batch, dtype=torch.float32)
        # Generate fake data and detach
        # (so gradients are not calculated for generator)
        fake_data = generator(noise(N)).detach()
        # Train D
        d_error, d_pred_real, d_pred_fake = \
              train_discriminator(d_optimizer, real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(N))
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 100 == 0:
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples,
                epoch, n_batch, num_batches
            );
            # Display status Logs
            # clear_output(True);
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
            clear_output(wait=True)


Generator input shape: torch.Size([100, 100, 1, 1])
Discriminator input shape: torch.Size([100, 1, 28, 28])


  real_data = torch.tensor(real_batch, dtype=torch.float32)


RuntimeError: Calculated padded input size per channel: (3 x 3). Kernel size: (7 x 7). Kernel size can't be greater than actual input size