In [1]:
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.utils as vutils
import random
import os
import shutil
import pdb

In [2]:
# Initialization
num_channels = 1
num_classes = 10
latent_size = 100
labeled_rate = 0.1
num_epochs = 100

log_path = './SSL_GAN_log.csv'
model_path ='./SSL_GAN_model.ckpt'

In [3]:
DATA_FOLDER = './torch_data/MNIST'

In [4]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

In [5]:
# Load data 
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=True)

In [12]:
class DiscriminatorNet(torch.nn.Module):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        
        dropout_rate = 0.25
        d = 16
        
        # Conv operations
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(dropout_rate)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=d, out_channels=d*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=d*2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=d*2, out_channels=d*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=d*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(dropout_rate)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=d*4, out_channels=d*8, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Flatten the features
        self.flatten = nn.Sequential(
            nn.Linear(in_features=d*8, out_features=(num_classes + 1)),
            nn.Softmax()
        )
        
    def forward(self, x):
        # Convolutional Operations
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Linear
        x = x.view(x.size(0), -1)
        x = self.flatten(x)
        return x

In [14]:
class GeneratorNet(torch.nn.Module):
    def __init__(self):
        super(GeneratorNet, self).__init__()
        
        dropout_rate = 0.25
        d = 16
        
        # Conv operations
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_size, out_channels=d*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(d*8),
            nn.ReLU(inplace=True)
        )
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=d*8, out_channels=d*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=d*4),
            nn.ReLU(inplace=True)
        )
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=d*4, out_channels=d*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=d*2),
            nn.ReLU(inplace=True)
        )
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=d*2, out_channels=num_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Deconvolutional Operations
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        
        return x

In [7]:
# Initialize parameters
lr = 0.0001 
b1 = 0.5 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient

In [8]:
discriminator = DiscriminatorNet()
if torch.cuda.is_available():
    discriminator.cuda()

In [9]:
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [11]:
'''
Start Training
'''
for epoch in range(num_epochs):
    for i, (image, label) in enumerate(data_loader):
        
        ###############################################
        #              Train Discriminator            #
        ###############################################
        discriminator(image)

> <ipython-input-6-cb27f01792d4>(39)forward()
-> x = self.conv1(x)
(Pdb) n
> <ipython-input-6-cb27f01792d4>(40)forward()
-> x = self.conv2(x)
(Pdb) n
> <ipython-input-6-cb27f01792d4>(41)forward()
-> x = self.conv3(x)
(Pdb) n
> <ipython-input-6-cb27f01792d4>(42)forward()
-> x = self.conv4(x)
(Pdb) n
> <ipython-input-6-cb27f01792d4>(45)forward()
-> x = x.view(x.size(0), -1)
(Pdb) n
--Return--
> <ipython-input-6-cb27f01792d4>(45)forward()->None
-> x = x.view(x.size(0), -1)
(Pdb) n
> /Users/poorvarane/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py(478)__call__()
-> for hook in self._forward_hooks.values():
(Pdb) n
> /Users/poorvarane/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py(484)__call__()
-> if len(self._backward_hooks) > 0:
(Pdb) n
> /Users/poorvarane/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py(497)__call__()
-> return result
(Pdb) n
--Return--
> /Users/poorvarane/anaconda/lib/python2.7/site-packages/torch/nn/modules/module.py(49

BdbQuit: 