In [81]:
import torch
import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import numpy as np
import imageio
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torchvision import datasets
from torchvision import transforms

In [82]:
batch_size = 16
image_size = 32
conv_dim = 64
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
train_iters = 400 #40000
use_reconst_loss = True
log_step = 10
sample_step = 5 #500
sample_path = './samples'

In [83]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [84]:
def get_loader():
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform = transforms.Compose([
                    transforms.Resize(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))])

    svhn = datasets.SVHN(root='./svhn', download=True, transform=transform, split='train')
    mnist = datasets.MNIST(root='./mnist', download=True, transform=transform, train=True)

    svhn_test = datasets.SVHN(root='./svhn', download=True, transform=transform, split='test')
    mnist_test = datasets.MNIST(root='./mnist', download=True, transform=transform, train=False)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=2)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)


    svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=2)

    mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=2)

    return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader

In [85]:
svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader()

Using downloaded and verified file: ./svhn/train_32x32.mat
Using downloaded and verified file: ./svhn/test_32x32.mat


In [86]:
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom deconvolutional layer for simplicity."""
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom convolutional layer for simplicity."""
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

In [87]:
class G12(nn.Module):
    """Generator for transfering from mnist to svhn"""
    def __init__(self, conv_dim=64, svhn_input=None):
        super(G12, self).__init__()

        # encoding blocks
        self.conv1 = conv(1, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)

        # residual blocks
        self.conv3 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        self.conv4 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)

        # decoding blocks
        self.deconv1 = deconv(conv_dim*2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 3, 4, bn=False)

    def forward(self, x):
        out_1 = F.leaky_relu(self.conv1(x), 0.05)      # (?, 64, 16, 16)
        out_2 = F.leaky_relu(self.conv2(out_1), 0.05)    # (?, 128, 8, 8)

        out_3 = F.leaky_relu(self.conv3(out_2), 0.05)    # ( " )
        out_4 = F.leaky_relu(self.conv4(out_3), 0.05)    # ( " )

        out_5 = F.leaky_relu(self.deconv1(out_4), 0.05)  # (?, 64, 16, 16)
        out = F.tanh(self.deconv2(out_5))              # (?, 3, 32, 32)

        return out

In [88]:
class G21(nn.Module):
    """Generator for transfering from svhn to mnist"""
    def __init__(self,  conv_dim=64, svhn_input=None):
        super(G21, self).__init__()

        # encoding blocks
        self.conv1 = conv(3, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)

        # residual blocks
        self.conv3 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)
        self.conv4 = conv(conv_dim*2, conv_dim*2, 3, 1, 1)

        # decoding blocks
        self.deconv1 = deconv(conv_dim*2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 1, 4, bn=False)

    def forward(self, x):
        out_1 = F.leaky_relu(self.conv1(x), 0.05)      # (?, 64, 16, 16)
        out_2 = F.leaky_relu(self.conv2(out_1), 0.05)    # (?, 128, 8, 8)

        out_3 = F.leaky_relu(self.conv3(out_2), 0.05)    # ( " )
        out_4 = F.leaky_relu(self.conv4(out_3), 0.05)    # ( " )

        out_5 = F.leaky_relu(self.deconv1(out_4), 0.05)  # (?, 64, 16, 16)
        out = F.tanh(self.deconv2(out_5))              # (?, 1, 32, 32)

        return out

In [89]:
class D1(nn.Module):
    """Discriminator for mnist."""
    def __init__(self, conv_dim=64):
        super(D1, self).__init__()
        self.conv1 = conv(1, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.fc = conv(conv_dim*4, 1, 4, 1, 0, False)

    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)    # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out = self.fc(out).squeeze()
        return out

In [90]:
class D2(nn.Module):
    """Discriminator for svhn."""
    def __init__(self, conv_dim=64):
        super(D2, self).__init__()
        self.conv1 = conv(3, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.fc = conv(conv_dim*4, 1, 4, 1, 0, False)

    def forward(self, x):
        out = F.leaky_relu(self.conv1(x), 0.05)    # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out = self.fc(out).squeeze()
        return out

In [91]:
class Solver(object):
    def __init__(self, svhn_loader, mnist_loader):
        self.svhn_loader = svhn_loader
        self.mnist_loader = mnist_loader
        self.g12 = None
        self.g21 = None
        self.d1 = None
        self.d2 = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.build_model()

    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g12 = G12(conv_dim=conv_dim)
        self.g21 = G21(conv_dim=conv_dim)
        self.d1 = D1(conv_dim=conv_dim)
        self.d2 = D2(conv_dim=conv_dim)

        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, lr, [beta1, beta2])
        self.d_optimizer = optim.Adam(d_params, lr, [beta1, beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()

    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(batch_size))
        merged = np.zeros([3, row*h, row*w*2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
            merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
        return merged.transpose(1, 2, 0)

    def to_var(self, x):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Converts variable to numpy."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data.numpy()

    def reset_grad(self):
        """Zeros the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def as_np(self, data):
        return data.cpu().data.numpy()

    def train(self, svhn_test_loader, mnist_test_loader):
        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) -1

        # fixed mnist and svhn for sampling
        svhn_test_iter = iter(svhn_test_loader)
        mnist_test_iter = iter(mnist_test_loader)
        fixed_svhn = next(svhn_test_iter)
        fixed_mnist = next(mnist_test_iter)

        fixed_svhn = fixed_svhn[0].to(device,dtype=torch.float)
        fixed_mnist = fixed_mnist[0].to(device,dtype=torch.float)

        for step in range(train_iters+1):
            # reset data_iter for each epoch
            if (step+1) % iter_per_epoch == 0:
                mnist_iter = iter(self.mnist_loader)
                svhn_iter = iter(self.svhn_loader)

            # load svhn and mnist dataset
            svhn, s_labels = next(svhn_iter)
            svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
            mnist, m_labels = next(mnist_iter)
            mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

            #============ train D ============#

            # train with real images
            self.reset_grad()
            out = self.d1(mnist)
            d1_loss = torch.mean((out-1)**2)

            out = self.d2(svhn)
            d2_loss = torch.mean((out-1)**2)

            d_mnist_loss = d1_loss
            d_svhn_loss = d2_loss
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()
            fake_svhn = self.g12(mnist)
            out = self.d2(fake_svhn)
            d2_loss = torch.mean(out**2)

            fake_mnist = self.g21(svhn)
            out = self.d1(fake_mnist)
            d1_loss = torch.mean(out**2)

            d_fake_loss = d1_loss + d2_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            #============ train G ============#

            # train mnist-svhn-mnist cycle
            self.reset_grad()
            fake_svhn = self.g12(mnist)
            out_svhn = self.d2(fake_svhn)
            reconst_mnist = self.g21(fake_svhn)

            gen_loss_A = torch.mean((out_svhn-1)**2)
            g_loss = gen_loss_A

            if use_reconst_loss:
                reconst_loss_A = torch.mean((mnist - reconst_mnist) ** 2)
                g_loss += reconst_loss_A

            g_loss.backward()
            self.g_optimizer.step()

            # train svhn-mnist-svhn cycle
            self.reset_grad()
            fake_mnist  = self.g21(svhn)
            out_mnist = self.d1(fake_mnist)
            reconst_svhn = self.g12(fake_mnist)

            gen_loss_B = torch.mean((out_mnist - 1) ** 2)
            g_loss = gen_loss_B

            if use_reconst_loss:
                reconst_loss_B = torch.mean((svhn - reconst_svhn) ** 2)
                g_loss += reconst_loss_B

            g_loss.backward()
            self.g_optimizer.step()

            # print the log info
            if (step+1) % log_step == 0:

                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
                      'd_fake_loss: %.4f, gen_loss_A: %.4f, gen_loss_B: %.4f,'
                      %(step+1, train_iters, d_real_loss.item(), d_mnist_loss.item(),
                        d_svhn_loss.item(), d_fake_loss.item(), gen_loss_A.item(),gen_loss_B.item()))

                if use_reconst_loss:
                    print ('reconst_loss_A: %.4f, recons_loss_B: %.4f, ' %
                           (reconst_loss_A.item(), reconst_loss_B.item()))


            # save the sampled images
            if (step+1) % sample_step == 0:
                fake_svhn = self.g12(fixed_mnist)
                fake_mnist = self.g21(fixed_svhn)

                mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
                svhn , fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)

                merged = self.merge_images(mnist, fake_svhn)
                merged = (merged * 255).astype('uint8')
                path = os.path.join(sample_path, 'sample-%d-m-s.png' %(step+1))
                imageio.imwrite(path, merged)
                print ('saved %s' %path)

                merged = self.merge_images(svhn, fake_mnist)
                merged = (merged * 255).astype('uint8')
                path = os.path.join(sample_path, 'sample-%d-s-m.png' %(step+1))
                imageio.imwrite(path, merged)
                print ('saved %s' %path)

In [92]:
# create directories if not exist
if not os.path.exists(sample_path):
      os.makedirs(sample_path)


In [93]:
solver = Solver(svhn_loader, mnist_loader)
solver.train(svhn_test_loader, mnist_test_loader)

Step [10/4000], d_real_loss: 1.2894, d_mnist_loss: 0.2837, d_svhn_loss: 1.0057, d_fake_loss: 3.0810, gen_loss_A: 9.6788, gen_loss_B: 1.3250,
reconst_loss_A: 0.9470, recons_loss_B: 0.4832, 
Step [20/4000], d_real_loss: 0.4207, d_mnist_loss: 0.2436, d_svhn_loss: 0.1770, d_fake_loss: 2.6705, gen_loss_A: 6.4560, gen_loss_B: 2.0398,
reconst_loss_A: 0.7151, recons_loss_B: 0.4402, 
Step [30/4000], d_real_loss: 0.3593, d_mnist_loss: 0.2838, d_svhn_loss: 0.0755, d_fake_loss: 2.8598, gen_loss_A: 4.8810, gen_loss_B: 2.2877,
reconst_loss_A: 0.5394, recons_loss_B: 0.4561, 
Step [40/4000], d_real_loss: 0.1860, d_mnist_loss: 0.1036, d_svhn_loss: 0.0823, d_fake_loss: 1.4992, gen_loss_A: 2.8970, gen_loss_B: 2.8059,
reconst_loss_A: 0.3781, recons_loss_B: 0.4634, 
Step [50/4000], d_real_loss: 0.1857, d_mnist_loss: 0.0691, d_svhn_loss: 0.1166, d_fake_loss: 1.5544, gen_loss_A: 6.0815, gen_loss_B: 1.1481,
reconst_loss_A: 0.3254, recons_loss_B: 0.4682, 
saved ./samples/sample-50-m-s.png
saved ./samples/sampl