In [0]:
from __future__ import print_function
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import Adam

import torchvision
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


import matplotlib.pyplot as plt
import pdb
import pickle

In [0]:
import os
os.getcwd()

'/content'

In [0]:
USE_CUDA = True if torch.cuda.is_available() else False
print(USE_CUDA)

class Mnist:
  
  def __init__(self, batch_size):
    dataset_transform = transform.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                        ])
    train_dataset = datasets.Mnist('/mnist_data/', train = True, download = True, transform = dataset_transform)
    test_dataset = datasets.Mnist('mnist_data/', train = True, download = True, transform = dataset_transform)
    self.trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
    self.testloader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = True)    


True


In [0]:
class ConvLayer(nn.Module):

    def __init__(self, in_channels=1, out_channels=256, kernel_size=9):

        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        return F.relu(self.conv(x))

In [0]:
class PrimaryCaps(nn.Module):

    def __init__(self, size_of_capsule=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCaps, self).__init__()
        self.size_of_capsule=size_of_capsule
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(size_of_capsule)])    

    def forward(self, x):
        num_of_capsules=32*6*6
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), num_of_capsules, self.size_of_capsule)

        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [0]:
class DigitCaps(nn.Module):

    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16, use_cuda=False):
        super(DigitCaps, self).__init__()
        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.use_cuda = USE_CUDA
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        
        if USE_CUDA:
            b_ij = b_ij.cuda()
            
        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            if USE_CUDA:
                c_ij = c_ij.cuda()
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)


    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [0]:
class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

        

    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10))

        if USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 28, 28)

        return reconstructions, masked

In [0]:
def init_weight(m):
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0., 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1., 0.02)
        m.bias.data.fill_(0.)

In [0]:
class DCGenerator(nn.Module):

    def __init__(self, convs):
        super(DCGenerator, self).__init__()
        self.convs = nn.ModuleList()
        in_channels = 1

        for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
            self.convs.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
            if i < len(convs)-1:
                self.convs.append(nn.BatchNorm2d(out_channels))
                self.convs.append(nn.ReLU())
            else:
                self.convs.append(nn.Tanh())
            in_channels = out_channels
            
        self.apply(init_weight)


    def forward(self, input):
        out = input
        for module in self.convs:
            out = module(out)
            
        return out

In [0]:
class CapsNet_Discriminator(nn.Module):

    def __init__(self):
        super(CapsNet_Discriminator, self).__init__()
        self.conv_layer = ConvLayer()
        self.primary_capsules = PrimaryCaps()
        self.prediction_capsule = DigitCaps(num_capsules=1)
        self.decoder = Decoder()        
        self.mse_loss = nn.MSELoss()
        

    def forward(self, data):
        
        output = self.prediction_capsule(self.primary_capsules(self.conv_layer(data)))
        #reconstructions, masked = self.decoder(output, data) 

        return output

    def loss(self, data, x, target, reconstructions=False):
        return self.margin_loss(x, target) #+ self.reconstruction_loss(data, reconstructions)
    

    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)
        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))
        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)
        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss

    
    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.0005

In [0]:
class Discriminator(nn.Module):

    def __init__(self, convs):
        super(Discriminator, self).__init__()
        self.convs = nn.ModuleList()
        in_channels = 1

        for i, (out_channels, kernel_size, stride, padding) in enumerate(convs):
            self.convs.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False))
            if i != 0 and i != len(convs)-1:
                self.convs.append(nn.BatchNorm2d(out_channels))

            if i != len(convs)-1:
                self.convs.append(nn.LeakyReLU(0.2))
                in_channels = out_channels
        #self.cls = nn.Linear(out_channels*in_width*in_height, nout)

        self.apply(init_weight)


    def forward(self, input):
        out = input
        for layer in self.convs:
            out = layer(out)
        out = out.view(out.size(0), -1)
        out = F.sigmoid(out)
        return out

In [0]:
def sample_noise(batch_size, channels):
  return torch.randn(batch_size, channels, 1, 1).float()

max_iter = 25
download = True

trans = transforms.Compose([transforms.ToTensor(),
                          transforms.Normalize([0.5, ], [0.5, ])])

mnist = datasets.MNIST('./', train = True, transform = trans, download = download)

batch_size = 64

use_cuda = USE_CUDA

In [0]:
Global_loss = 100000
min_loss_per_epoch = 10000
if __name__ == '__main__':

    #d_convs = [(32, 4, 2, 1), (64, 4, 2, 1), (1, 7, 1, 0)]
    discriminator = CapsNet_Discriminator()                  #discriminator = Discriminator(d_convs)
    g_convs = [(64, 7, 1, 0), (32, 4, 2, 1), (1, 4, 2, 1)]
    generator = DCGenerator(g_convs)
    print(discriminator)
    print(generator)

    if use_cuda:
        discriminator, generator = discriminator.cuda(), generator.cuda()

    dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    real_label, fake_label = 1, 0
    criterion = nn.BCELoss()

    if use_cuda:
        criterion = criterion.cuda()

    fixed_noise = sample_noise(batch_size, 1)

    if use_cuda:
        fixed_noise = fixed_noise.cuda()

    fixed_noise = Variable(fixed_noise, volatile=True)

    for epoch in range(1, max_iter+1):
        for i, (x, _) in enumerate(dataloader):
            batch_size = x.size(0)

            optimizer_d.zero_grad() 
            x = Variable(x)
            if use_cuda:
                x = x.cuda()
            
            output=discriminator(x)
            real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())

            if use_cuda:
                real_v = real_v.cuda()

            loss_d = discriminator.loss(x,output, real_v,False)
      
            loss_d.backward()

            Dx = output.data.mean(dim=0)[0]
            Dx= Dx.sum(dim=1)

            z = sample_noise(batch_size, 1)
            z = Variable(z)
            if use_cuda:
                z = z.cuda()
                
            fake = generator(z)
            output = discriminator(fake.detach())
            fake_v = Variable(torch.Tensor(batch_size).fill_(fake_label).float())
            
            if use_cuda:
                fake_v = fake_v.cuda()

            loss_g = discriminator.loss(x,output, fake_v, False)
            
            loss_g.backward()
            optimizer_d.step()

            err_D = loss_d.data + loss_g.data


            optimizer_g.zero_grad()
            output = discriminator(fake)
            real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float())

            if use_cuda:
                real_v = real_v.cuda()


            loss = discriminator.loss(x,output, real_v)
            loss.backward()
            optimizer_g.step()
            err_G = loss.data

            DGz = output.data.mean(dim=0)[0]
            DGz= DGz.sum(dim=1)
            print('[{:02d}/{:02d}],[{:03d}/{:03d}], errD: {:.4f}, errG: {:.4f}'.format(
                  epoch, max_iter, i, len(dataloader), err_D, err_G))
            
            if err_D <= min_loss_per_epoch:
              min_loss_per_epoch = err_D
        if min_loss_per_epoch <= Global_loss:
          torch.save(discriminator.state_dict(), 'discriminator2.pickle')
          torch.save(generator.state_dict(), 'generator2.pickle')
          print("saved model! at epoch:", epoch)
    
        fake = generator(fixed_noise)
        save_image(fake.data, './mnist-fake2-{:02d}.png'.format(epoch),
                   normalize=True)

CapsNet_Discriminator(
  (conv_layer): ConvLayer(
    (conv): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  )
  (primary_capsules): PrimaryCaps(
    (capsules): ModuleList(
      (0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (1): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (2): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (3): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (4): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (5): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (6): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (7): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
    )
  )
  (prediction_capsule): DigitCaps()
  (decoder): Decoder(
    (reconstraction_layers): Sequential(
      (0): Linear(in_features=160, out_features=512, bias=True)
      (1): ReLU(inplace)
      (2): Linear(in_features=512, out_features=1024, bias=True)
      (3): ReLU(inplace)
      (4): Linear(in_features=1024, o



[01/25],[000/938], errD: 57.5685, errG: 57.4275
[01/25],[001/938], errD: 51.8955, errG: 56.6496
[01/25],[002/938], errD: 43.2048, errG: 55.3249
[01/25],[003/938], errD: 34.4675, errG: 54.3647
[01/25],[004/938], errD: 28.6357, errG: 53.2144
[01/25],[005/938], errD: 24.8567, errG: 53.2675
[01/25],[006/938], errD: 21.2489, errG: 53.6676
[01/25],[007/938], errD: 21.8109, errG: 53.8564
[01/25],[008/938], errD: 18.8072, errG: 54.2745
[01/25],[009/938], errD: 16.9954, errG: 54.7258
[01/25],[010/938], errD: 18.0264, errG: 54.4659
[01/25],[011/938], errD: 15.8856, errG: 54.1997
[01/25],[012/938], errD: 16.2785, errG: 53.6193
[01/25],[013/938], errD: 15.3113, errG: 52.6825
[01/25],[014/938], errD: 16.9313, errG: 51.5540
[01/25],[015/938], errD: 15.3503, errG: 51.2775
[01/25],[016/938], errD: 14.3153, errG: 50.8099
[01/25],[017/938], errD: 16.9786, errG: 50.2609
[01/25],[018/938], errD: 16.3790, errG: 50.9790
[01/25],[019/938], errD: 15.1706, errG: 51.9637
[01/25],[020/938], errD: 15.4172, errG: 