In [3]:
# some setup code
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import layers
import params

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return 

def preprocess_img(x):
    return 2 * x - 1.0

def deprocess_img(x):
    return (x + 1.0) / 2.0

def rel_error(x,y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def count_params(model):
    """Count the number of parameters in the current TensorFlow graph """
    param_count = np.sum([np.prod(p.size()) for p in model.parameters()])
    return param_count

USE_GPU = True
print(torch.__version__)
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

0.4.0
using device: cuda


In [None]:
img_size = params.img_size
img_channels = params.img_channels

In [4]:
class SEnc(nn.Module):
    def __init__(self, in_channel, channel_1, channel_2, dense_channel, s_dim):
        super().__init__()
def Encoder(img_size, in_channel, conv_channel, filter_size, latent_dim, bn):
    inner_conv_channel = conv_channel/2
    if img_size%4 != 0:
        print("WARNING: image size mod 4 != 0, may produce bug.")
    flatten_img_size = inner_conv_channel * img_size/4 * img_size/4
    model = nn.Sequential(
        layers.ConvLayer(in_channel,        conv_channel,       filter_size, stride=2, bn=bn),
        layers.ConvLayer(conv_channel,      inner_conv_channel, filter_size, stride=2, bn=bn),
        layers.ConvLayer(inner_conv_channel,inner_conv_channel, filter_size, stride=1, bn=bn),
        layers.Flatten(),
        layers.Dense(flatten_img_size, dense_size),
        layers.Dense(dense_size,       latent_dim)
    )
    return model
def Classifier(input_dim, dense_size, s_classes, bn):
    
    model = nn.Sequential(
        layers.Dense(input_dim,  dense_size, bn=bn, ),
        layers.Dense(dense_size, dense_size, bn=bn),
        layers.Dense(dense_size, s_classes,  bn=bn),
        nn.Softmax()
    )
    return model
def Decoder(img_size, in_channel, conv_channel, filter_size, bn):
    # essentially the mirror version of Encoder
    inner_conv_channel = conv_channel/2
    flatten_img_size = inner_conv_channel * img_size*img_size
    
    input_dim = s_dim + latent_dim
    
    model = nn.Sequential(
        layers.Dense(input_dim, dense_size),
        layers.ConvLayer(input_dim,         conv_channel,       filter_size, bn=bn),
        layers.ConvLayer(conv_channel,      inner_conv_channel, filter_size, bn=bn),
        layers.ConvLayer(inner_conv_channel,inner_conv_channel, filter_size, bn=bn),
        layers.Flatten(),
        layers.Dense(flatten_img_size, dense_size),
        layers.Dense(dense_size,       latent_dim)
    )
    return model


In [None]:
def test_Encoder(bn, latent_dim):
    # bn: whether use batch normalization in dense layer
    conv_channel= params.enc_conv_channel
    filter_size = params.enc_conv_filter_size
    in_channel  = img_channels
    latent_size = params.s_dim
    
    x = torch.zeros((64, img_channels, img_size, img_size), dtype=dtype)
    model = Encoder(img_size, in_channel, conv_channel, latent_dim, filter_size, bn)
    scores = model(x)
    print(scores.size())  # you should see [64, 10]
def test_Decoder():
    # TODO
    pass

#test S encoder
test_Encoder(params.s_enc_bn, params.s_enc_dim)
#test z encoder
test_Encoder(params.z_enc_bn, params.z_enc_dim)
#test decoder
test_Encoder(params.z_enc_bn, params.z_enc_dim)